feat(pymllm): support Qwen3 Jetson BF16, W4A16, and W8A8 serving#670
Conversation
…nchmarks Phase 0.1: baseline benchmark scripts for GEMM and activation quant Phase 1.1: port sglang Triton per_token_quant_int8 to pymllm/quantization/kernels/ Triton kernel correctness: +-1 LSB rounding diff vs torch (0.01% elements) Triton kernel performance: 25-67% faster than torch path on Jetson SM87 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Phase 1.2: - Replace _per_token_quant_int8 torch impl with Triton kernel import - Remove old mllm JIT kernel fallback from _int8_scaled_mm - GEMM now uses torch._int_mm directly (intermediate state before CUTLASS) - Update test to verify Triton quant + torch._int_mm path All 26 tests pass (config + runtime + kernel). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Phase 2: Port sglang CUTLASS int8 GEMM to mllm-kernel. - SM89 tile shapes (100K shared memory safe for Jetson Orin SM87) - Per-row/col scale epilogue fused into GEMM - JIT compiled via torch.utils.cpp_extension (~100s first run, cached after) - Integrated into compressed_tensors.py W8A8 forward path Performance on SM87 (93,2048,6144): CUTLASS: 0.295 ms (4.2x vs torch._int_mm, 67.8x vs old JIT) cutlass_extensions ported from sglang sgl-kernel (Apache 2.0) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…eprecated 28 test cases: 7 shapes x 2 dtypes x 2 bias configs, all pass. Old naive JIT kernel marked DEPRECATED (kept for reference). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Spike files verify CUTLASS compilation and tile configs on SM87. Baseline doc updated with Phase 1 (Triton) and Phase 2 (CUTLASS) results. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
SM87 (Jetson Orin) has 164K shared memory per SM (same as SM80), not 100K as initially assumed. However, benchmarking shows SM89's 3-stage tiles outperform SM80's 5-stage tiles at large M when the per-row/col scale epilogue visitor is used, due to lower smem pressure and better occupancy. SM80 dispatch kept in source for future use on devices that benefit from larger tiles. Verified: 28 CUTLASS tests pass, performance unchanged. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Spike files kept on disk for reference but excluded from version control. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
process_weights_after_loading now stores weight as (K, N) column-major (stride(0)==1) instead of row-major. This eliminates a full weight matrix copy (~12MB per linear layer) on every forward call. Root cause: CUTLASS requires column-major B (stride(0)==1), but weights were stored row-major, triggering .t().contiguous().t() on every call — ~2.3GB of copies per decode step for Qwen3-VL-2B. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Cache Triton quant and CUTLASS mm function refs to avoid repeated import lookups (140 calls/decode step) - Remove redundant .contiguous(), .reshape(-1), .to(float32) in CUTLASS wrapper — scales are already in correct format from Triton quant and process_weights_after_loading - Only do scales_a.squeeze(-1) to convert (M,1) -> (M,) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Compares GEMM kernel performance at representative Qwen3-VL-2B shapes. W8A8 columns show activation quant, GEMM, and total separately. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
SM87 (Jetson Orin) has 164KB smem, same as SM80 — not 100KB like SM86/SM89. Both sglang and vllm route SM87 to SM80 dispatch. E2E benchmark confirms SM80 ≈ SM89 tiles on SM87 (<2% diff). Reverts the SM89 override from 5e6c634 and uses SM80 dispatch with deeper pipeline stages (5-6 stage). Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
When disable_radix_cache=True, ChunkCache is a no-op cache and should not be treated as cache-enabled. Previously cache_enabled only checked cache is not None, which made the insert path report did_insert=True and skip Phase 4 free logic. This change excludes ChunkCache from cache_enabled so KV slots are released correctly.
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds CUTLASS-backed INT8 scaled GEMM with per-row/column scaling, Triton per-token activation quant, Marlin GPTQ repack, a compressed-tensors quant pipeline (W4A16/W8A8), Qwen3/Qwen3-VL model implementations and fixes, many CUDA/JIT kernels and tests, benchmarks, runtime timing propagation, and extensive docs and tooling. Changes
Sequence Diagram(s)sequenceDiagram
participant App as Application
participant Tokenizer as Tokenizer/Processor
participant Triton as Triton Quant Kernel
participant CUTLASS as CUTLASS int8_scaled_mm
participant Marlin as Marlin gptq_marlin_gemm
participant Detok as Detokenizer/Output
App->>Tokenizer: prepare activations (fp16)
Tokenizer->>Triton: per-token quantization -> (int8 activations, scales)
Triton->>CUTLASS: send int8 activations + scales
CUTLASS->>CUTLASS: INT8 GEMM + epilogue (per-row/col scaling, bias) -> logits
alt W4A16 path (Marlin)
App->>Marlin: prepacked int4 weights + activations
Marlin->>Marlin: run gptq_marlin_gemm -> logits
end
CUTLASS->>Detok: logits
Marlin->>Detok: logits
Detok->>App: return final responses (include timing metadata)
Estimated Code Review Effort🎯 5 (Critical) | ⏱️ ~120 minutes Suggested Reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Review rate limit: 7/8 reviews remaining, refill in 7 minutes and 30 seconds.Comment |
There was a problem hiding this comment.
Actionable comments posted: 12
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (5)
pymllm/executor/model_runner.py (1)
486-505:⚠️ Potential issue | 🟠 MajorFalls back to returning the full HF config when
config.jsonlacksquantization_config.When
config.jsonis one of theuniquecandidate filenames (it is —compressed-tensorslists it viaget_config_filenames()), the new branch only short-circuits thequantization_configunwrap when that key exists. Ifconfig.jsonis present but does not contain aquantization_configfield (e.g. an unquantized checkpoint, or a checkpoint where the quant metadata only lives inquantize_config.jsonlater in the list), Line 495 returns the full top-level HF config (architectures,hidden_size, …) and_resolve_quant_configwill then read a bogusquant_method(likelyNone, but in general any string the upstream HF config contains), instead of either continuing the search forquantize_config.jsonor returning{}.This also makes the post-loop fallback at lines 498–505 effectively dead whenever
config.jsonis inunique.🐛 Proposed fix
for fname in unique: fpath = model_path / fname if fpath.exists(): with open(fpath) as fp: cfg = json.load(fp) - # config.json stores model metadata at the top level and - # nests quantization details under quantization_config. - if fname == "config.json" and "quantization_config" in cfg: - return cfg["quantization_config"] - return cfg + # config.json stores model metadata at the top level and + # nests quantization details under quantization_config. + # When the nested key is missing, fall through to the next + # candidate filename (e.g. quantize_config.json) instead of + # returning the full HF model config. + if fname == "config.json": + if "quantization_config" in cfg: + return cfg["quantization_config"] + continue + return cfgAfter this fix you can also drop the duplicated post-loop fallback (lines 497–505), since the loop already handles
config.jsoncorrectly.As per coding guidelines: "Validate inputs for public APIs and critical internal functions."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pymllm/executor/model_runner.py` around lines 486 - 505, In _resolve_quant_config, don't return the full HF config when encountering "config.json" that lacks "quantization_config"; instead, if fname == "config.json" and "quantization_config" in cfg then return cfg["quantization_config"], otherwise continue the loop to allow later candidates (e.g., "quantize_config.json") to be checked; remove the duplicate post-loop fallback that re-reads config.json since the loop now handles it, and keep the final return {} if nothing is found.pymllm/server/launch.py (1)
707-768:⚠️ Potential issue | 🟠 Major
rmay be unbound whenresultsis empty in/v1/completions.
ris the loop variable fromfor i, r in enumerate(results):at line 717. The newtimingblock at lines 743-767 referencesrafter the loop. If_iter_with_disconnect_checkyields nothing (e.g., client disconnect, engine error),resultsis empty, the loop never executes, andris undefined —r.get(...)raisesNameError, which then falls into the genericexcept Exceptionhandler at line 777 and is reported to the client as a 500 with no actionable detail. Compare with/v1/chat/completionswhich initializesr = {}on line 965.🐛 Suggested fix
try: results = [] async for item in _iter_with_disconnect_check(gen, request): results.append(item) + if not results: + raise HTTPException(status_code=500, detail="No output from engine") choices = [] prompt_tokens = 0 completion_tokens = 0 - for i, r in enumerate(results): + r: Dict[str, Any] = {} + for i, r in enumerate(results): choices.append(Alternatively, mirror the chat-completions path and bind a single
r = results[-1] if results else {}before constructing the response.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pymllm/server/launch.py` around lines 707 - 768, The timing block uses the loop variable r after the for-loop in the completion handler (the results/choices assembly around engine.generate_async), which can be undefined when results is empty; fix by binding r = results[-1] if results else {} (or an empty dict) immediately after the for i, r in enumerate(results): loop and before constructing the ORJSONResponse so all r.get(...) calls are safe (refer to the loop producing choices and the timing dictionary construction).pymllm/models/qwen3_vl.py (2)
1172-1226:⚠️ Potential issue | 🟡 MinorWall-clock timings here don't include async GPU work.
time.perf_counter()around CUDA ops only captures kernel launch time unless preceded/followed by atorch.cuda.synchronize(). The_vit_t0/_llm_t0deltas reported here can wildly under‑report real model latency, especially in extend mode where the work being measured is large.For server-side coarse latency this may be acceptable (and
pymllm/README-ZH.mddoes call this out), but the same values get bubbled into OpenAI-compatible responses via the orchestrator timing field, so users may interpret them as actual model time.If accuracy matters, gate a sync via a config flag, e.g.:
_vit_t0 = time.perf_counter() vision_features = self.visual(pixel_values, grid_thw=image_grid_thw) if torch.cuda.is_available(): torch.cuda.synchronize() vit_prefill_ms = (time.perf_counter() - _vit_t0) * 1000.0Same applies to the LLM forward timing block at Lines 1206–1216.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pymllm/models/qwen3_vl.py` around lines 1172 - 1226, The timing measurements around vision and LLM forwards (_vit_t0/_llm_t0) only measure kernel launch time and under-report real GPU latency; modify the blocks that call self.visual(...) and self.model(...) to optionally synchronize the device before stopping the timer when CUDA is available (e.g., check torch.cuda.is_available() and call torch.cuda.synchronize() if a new config flag like self.sync_cuda_timings is true), and propagate this gating to both the vision timing (vit_prefill_ms, vit_prefill_tokens) and LLM timing (llm_prefill_ms/llm_decode_ms) so timings reported from functions like self.visual and self.model reflect real GPU work when configured.
1186-1192:⚠️ Potential issue | 🟡 MinorInclude video tokens in the image token mask to match
qwen3_5.pybehavior and ensure accurate token counting.The
image_maskat line 1188 uses onlyself.image_token_id, butself.video_token_idis configured at line 1065 and should be included. The sibling modelqwen3_5.pyat line 517 correctly usesmask = (input_ids == self.image_token_id) | (input_ids == self.video_token_id). Without video tokens in the mask, thevit_prefill_tokenscount becomes incomplete, and deepstack embeddings at lines 1194–1201 also skip video tokens.♻️ Suggested fix
- image_mask = input_ids == self.image_token_id + image_mask = (input_ids == self.image_token_id) | ( + input_ids == self.video_token_id + ) if image_mask.any(): vit_prefill_tokens = int(image_mask.sum().item()) input_embeds[image_mask] = vision_embeds.to(input_embeds.dtype)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pymllm/models/qwen3_vl.py` around lines 1186 - 1192, The image mask currently only checks for self.image_token_id in the block that computes input_embeds via self.model.embed_tokens; update the mask logic used in image_mask (and any subsequent uses of vit_prefill_tokens) to include self.video_token_id as well (i.e., mask = (input_ids == self.image_token_id) | (input_ids == self.video_token_id)) so that input_embeds[image_mask] = vision_embeds.to(...) and vit_prefill_tokens = int(image_mask.sum().item()) correctly account for video tokens; apply the same mask change wherever vit_prefill_tokens or deepstack embedding replacements are computed so video tokens are handled identically to image tokens.pymllm/orchestrator/scheduler_process.py (1)
792-837:⚠️ Potential issue | 🟡 MinorConflicting
llm_decode_msupdates: per-batch accumulation is overwritten by cumulative recompute.At Line 799 you accumulate the model-runner-reported decode time per batch:
req.llm_decode_ms = (req.llm_decode_ms or 0.0) + out["llm_decode_ms"]but immediately afterwards (lines 833–837) for any decode batch you overwrite it with
(now - decode_start_tic) * 1000.0, which is a wall-clock cumulative timer including scheduler overhead, queueing, etc.The result is:
- The accumulation at Line 799 is dead work for decode mode.
- The reported
llm_decode_msis wall-clock cumulative since the first decode step, not the sum of per-step model forward time, which conflicts with howllm_prefill_ms/vit_prefill_msare reported (single forward duration).Pick one source of truth. If you want cumulative wall-clock, drop the per-batch accumulation. If you want sum of forward times, drop the cumulative recompute and keep the accumulator.
♻️ Suggested fix (keep model-reported per-step accumulation only)
- if batch.forward_mode.is_decode(): - _decode_now = time.perf_counter() - for req in batch.reqs: - if req.decode_start_tic is not None: - req.llm_decode_ms = (_decode_now - req.decode_start_tic) * 1000.0 -🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pymllm/orchestrator/scheduler_process.py` around lines 792 - 837, The code currently both accumulates model-reported per-step decode time into req.llm_decode_ms (using out["llm_decode_ms"]) and later overwrites it for decode batches with a wall-clock recompute in the batch.forward_mode.is_decode() loop; remove the latter overwrite so the per-step model-reported accumulation is the single source of truth: delete or skip the assignment req.llm_decode_ms = (_decode_now - req.decode_start_tic) * 1000.0 inside the batch.forward_mode.is_decode() block (references: req.llm_decode_ms, out["llm_decode_ms"], batch.forward_mode.is_decode(), req.decode_start_tic) so llm_decode_ms remains the sum of model-runner-reported decode times.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py`:
- Around line 6-9: The docstring in bench_w4a16_vs_w8a8.py contains a hardcoded
developer worktree path; replace that line with a generic, portable instruction
such as "cd to the repository root" or "run from the repository root" and/or
show a relative invocation like "python3
mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py" so users know to run the script
from the repo root instead of the specific /workspace/.worktrees/... path.
In
`@mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h`:
- Around line 113-183: The constructor EpilogueVisitorPerRowPerCol has its
member-initializer list out of declaration order causing -Wreorder; reorder the
initializer list to match the class member declaration order (Params const&
params_, SharedStorage& shared_storage_, MatrixCoord extent_, MatrixCoord
extent_real_, ElementwiseFunctor elementwise_, bool with_bias_, bool
per_token_quant_, bool per_channel_quant_, AlphaScaleElementType*
ptr_alpha_row_, AlphaScaleElementType* ptr_alpha_col_, ScaleTileIterator
iterator_alpha_col_, OutputTileIterator iterator_C_, OutputTileIterator
iterator_D_, ...), i.e., move extent_real_ to be initialized immediately after
extent_ and ensure elementwise_ follows extent_real_ (and the rest follow their
declared order) while keeping the existing initialization expressions and
runtime logic unchanged.
In
`@mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h`:
- Around line 184-202: The Params() constructor in gemm_with_epilogue_visitor.h
initializes params_alpha_col but omits params_alpha_row, leaving
params_alpha_row potentially uninitialized; update the Params() member
initializer list for the Params() constructor to include params_alpha_row(0)
(mirroring params_alpha_col(0)) so ScaleTileIterator::Params for the row scale
is set to the zero/empty state consistent with ptr_alpha_row being nullptr.
- Around line 204-224: The initializer list for Params mistakenly constructs
params_alpha_row from the column layout
(params_alpha_row(args.ref_alpha_col.layout())); change it to use the actual row
reference layout by initializing params_alpha_row with
args.ref_alpha_row.layout() in the Params constructor so the alpha_row scale
iterator gets correct ColumnMajor strides (update the initializer near
params_alpha_col/params_alpha_row and verify ptr_alpha_row uses
args.ref_alpha_row.data()).
In `@mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin_repack.py`:
- Around line 56-75: Add explicit input validation at the start of
gptq_marlin_repack: check num_bits is a positive divisor of 32 (raise ValueError
if num_bits <= 0 or 32 % num_bits != 0), ensure tile_size alignment by
validating size_k % tile_size == 0 (use tile_size = 16 as in the function) and
ensure (size_n * tile_size) % (32 // num_bits) == 0 to avoid truncation, and
verify b_q_weight.dtype is torch.int32 (raise TypeError if not). Use the
existing symbols gptq_marlin_repack, pack_factor (computed as 32 // num_bits),
tile_size, and b_q_weight to locate where to add these checks and raise clear,
descriptive exceptions before calling _normalize_perm or the kernel.
In `@mllm-kernel/mllm_kernel/cuda/jit/int8_scaled_mm_cutlass.py`:
- Around line 74-85: The hardcoded "-arch=sm_87" in the extra_cuda_cflags
restricts the JIT cubin to SM87+, so change the build step that constructs
extra_cuda_cflags (the list passed into the JIT compile call and the
build_directory=cache_dir) to detect the local GPU compute capability at runtime
(e.g., via torch.cuda.get_device_capability() or the CUDA runtime API) and emit
a matching "-arch=sm_{major}{minor}" flag instead of the fixed "-arch=sm_87";
also include that SM string in the cache directory name (e.g., append
"/sm{major}{minor}/" to cache_dir) so the cache is per-arch and avoids
collisions. Ensure the code path that currently builds extra_cuda_cflags and
calls the JIT compile uses the detected major/minor for both the flag and the
cache directory.
In `@pymllm/layers/rms_norm.py`:
- Around line 46-54: The fallback branch in the rms norm function returns the
original unmodified residual after flashinfer.norm.fused_add_rmsnorm fails;
change the fallback to explicitly compute and return the updated residual (e.g.,
compute new_residual = x + residual and return _torch_rmsnorm(new_residual,
self.weight, self.eps), new_residual) instead of returning the old residual, and
narrow the except clause to catch only expected runtime errors (e.g.,
RuntimeError/ValueError) and add a logger call (use a module logger like logger
= logging.getLogger(__name__)) to record unexpected failures before falling
back; locate these changes around flashinfer.norm.fused_add_rmsnorm and the
fallback that currently calls _torch_rmsnorm.
In `@pymllm/models/qwen3.py`:
- Around line 296-304: The timing around self.model(...) currently measures CUDA
launch overhead not actual GPU execution; update the timing in the qwen3.py
block that sets forward_batch.llm_prefill_ms and forward_batch.llm_decode_ms to
either (a) call torch.cuda.synchronize() immediately after self.model(...)
before computing _llm_ms so the measured delta reflects real GPU work, or (b)
use torch.cuda.Event start/stop + elapsed_time for GPU timings; then set
forward_batch.llm_prefill_ms/llm_decode_ms accordingly and add a one-line
comment beside this logic (referencing forward_batch.forward_mode.is_extend(),
forward_batch.llm_prefill_ms, forward_batch.llm_decode_ms, and self.model)
documenting whether the field represents engine-perceived latency (asynchronous
launch time) or actual GPU execution time (synchronized or event-based).
In `@pymllm/orchestrator/tokenizer_process.py`:
- Around line 374-386: When mm_inputs provides multimodal input_ids
(image_inputs["input_ids"]) we currently override input_ids without enforcing
the model context cap; after extracting proc_input_ids in the block handling
mm_inputs/image_inputs, truncate proc_input_ids to self._context_length before
converting to a Python list or assigning to input_ids (e.g., slice ndarray or
list to [:self._context_length] or use equivalent method for objects exposing
tolist), so input_ids never exceed the model's context length; reference
symbols: mm_inputs, image_inputs, proc_input_ids, input_ids,
self._context_length.
In `@pymllm/quantization/kernels/int8_activation_triton.py`:
- Around line 59-82: Replace the brittle assert and the risky negative 2-index
stride access: in the function that calls _per_token_quant_int8 (the block
around the _per_token_quant_int8[(M,)](...) call), replace "assert
x.is_contiguous()" with an explicit check that raises ValueError if not
contiguous, and compute stride_x and stride_xq defensively (use x.stride(-2) /
x_q.stride(-2) when x.ndim >= 2 else fall back to x.stride(0) / x_q.stride(0));
pass those computed stride_x and stride_xq into the _per_token_quant_int8 call.
This keeps the documented "any shape" behavior and avoids IndexError on 1-D
inputs and also fixes the assert being stripped under python -O.
In `@pymllm/README-ZH.md`:
- Around line 199-201: The example in the README currently hard-codes a
developer-specific absolute path ("/workspace/xcd_mllm/test.png") inside the
JSON snippet for the "image_url" entry; replace that literal with a neutral
placeholder (e.g., "<path/to/image.png>" or a sample relative/remote URL) so
users won’t copy a non-existent path—update the "image_url" -> "url" value in
the JSON array shown in the README-ZH.md accordingly.
In `@pymllm/README.md`:
- Around line 50-60: The README's environment table lists a non-existent PyPI
package/version "flashinfer: `0.6.7`" which makes the setup unreproducible;
update the README.md entry referencing "flashinfer 0.6.7" by either removing the
flashinfer line, replacing it with the correct source (e.g., a Git URL, private
repo, or build instructions), or adding explicit install instructions and
provenance (e.g., how to build/install from source or a link to the correct
release). Edit the list block in README.md where "flashinfer: `0.6.7`" appears
so the table accurately reflects how to obtain or install flashinfer, and
include the concrete install command or URL if you choose to keep the
dependency.
---
Outside diff comments:
In `@pymllm/executor/model_runner.py`:
- Around line 486-505: In _resolve_quant_config, don't return the full HF config
when encountering "config.json" that lacks "quantization_config"; instead, if
fname == "config.json" and "quantization_config" in cfg then return
cfg["quantization_config"], otherwise continue the loop to allow later
candidates (e.g., "quantize_config.json") to be checked; remove the duplicate
post-loop fallback that re-reads config.json since the loop now handles it, and
keep the final return {} if nothing is found.
In `@pymllm/models/qwen3_vl.py`:
- Around line 1172-1226: The timing measurements around vision and LLM forwards
(_vit_t0/_llm_t0) only measure kernel launch time and under-report real GPU
latency; modify the blocks that call self.visual(...) and self.model(...) to
optionally synchronize the device before stopping the timer when CUDA is
available (e.g., check torch.cuda.is_available() and call
torch.cuda.synchronize() if a new config flag like self.sync_cuda_timings is
true), and propagate this gating to both the vision timing (vit_prefill_ms,
vit_prefill_tokens) and LLM timing (llm_prefill_ms/llm_decode_ms) so timings
reported from functions like self.visual and self.model reflect real GPU work
when configured.
- Around line 1186-1192: The image mask currently only checks for
self.image_token_id in the block that computes input_embeds via
self.model.embed_tokens; update the mask logic used in image_mask (and any
subsequent uses of vit_prefill_tokens) to include self.video_token_id as well
(i.e., mask = (input_ids == self.image_token_id) | (input_ids ==
self.video_token_id)) so that input_embeds[image_mask] = vision_embeds.to(...)
and vit_prefill_tokens = int(image_mask.sum().item()) correctly account for
video tokens; apply the same mask change wherever vit_prefill_tokens or
deepstack embedding replacements are computed so video tokens are handled
identically to image tokens.
In `@pymllm/orchestrator/scheduler_process.py`:
- Around line 792-837: The code currently both accumulates model-reported
per-step decode time into req.llm_decode_ms (using out["llm_decode_ms"]) and
later overwrites it for decode batches with a wall-clock recompute in the
batch.forward_mode.is_decode() loop; remove the latter overwrite so the per-step
model-reported accumulation is the single source of truth: delete or skip the
assignment req.llm_decode_ms = (_decode_now - req.decode_start_tic) * 1000.0
inside the batch.forward_mode.is_decode() block (references: req.llm_decode_ms,
out["llm_decode_ms"], batch.forward_mode.is_decode(), req.decode_start_tic) so
llm_decode_ms remains the sum of model-runner-reported decode times.
In `@pymllm/server/launch.py`:
- Around line 707-768: The timing block uses the loop variable r after the
for-loop in the completion handler (the results/choices assembly around
engine.generate_async), which can be undefined when results is empty; fix by
binding r = results[-1] if results else {} (or an empty dict) immediately after
the for i, r in enumerate(results): loop and before constructing the
ORJSONResponse so all r.get(...) calls are safe (refer to the loop producing
choices and the timing dictionary construction).
---
Nitpick comments:
In `@mllm-kernel/benchmarks/bench_int8_scaled_mm.py`:
- Around line 145-151: The script crashes on CUDA-less machines because the
__main__ block calls torch.cuda.get_device_name(0) and
torch.cuda.get_device_capability(0) (and internal functions like
_torch_int_mm_scaled and _try_load_cutlass_kernel assume CUDA) unconditionally;
fix by guarding CUDA-specific behavior with a torch.cuda.is_available() check in
the __main__ path: if CUDA is available, print device name/capability and call
run_benchmarks() as now; if not, print a clear message and either skip CUDA-only
benchmarks or exit early so _torch_int_mm_scaled and _try_load_cutlass_kernel
are never invoked. Ensure all calls to torch.cuda.get_device_name,
torch.cuda.get_device_capability, _torch_int_mm_scaled, and
_try_load_cutlass_kernel are only reached when torch.cuda.is_available() is
True.
- Around line 81-142: run_benchmarks currently times each backend but never
verifies correctness; add a one-shot correctness check per shape before calling
bench_fn: compute a reference output via _torch_int_mm_scaled(mat_a, mat_b,
scales_a, scales_b, out_dtype=out_dtype), then for each backend in backends
(skip "torch._int_mm") call its fn with b_arg = mat_b_colmaj if name ==
"cutlass" else mat_b and out_dtype=out_dtype, compute a relative
mean-absolute-error (e.g.,
(out.float()-ref.float()).abs().mean()/ref.float().abs().mean().clamp(min=1e-6)),
and if rel > 1e-2 print a clear mismatch message including the backend name and
shape and mark that backend as invalid (skip timing or record error) so bench_fn
is only run for backends that pass the sanity check.
In `@mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py`:
- Around line 68-71: The unpacking from load_marlin() in prepare_marlin_weights
currently binds gptq_marlin_gemm but never uses it; update the tuple unpack in
prepare_marlin_weights to either drop that value or rename it to
_gptq_marlin_gemm (or simply _ ) so the unused symbol is intentionally ignored
and the linter warning (RUF059) is silenced—keep the other symbols
(gptq_marlin_repack, marlin_make_workspace, marlin_make_empty_g_idx,
marlin_permute_scales, SCALAR_TYPE_UINT4B8) unchanged.
- Around line 158-162: The variable tag computed in the benchmark loop (tag =
"decode" if M <= 8 else "prefill") is never used; either remove that line or
include tag in the formatted output; update the print call (the one formatting
f" ({M:>3},{K:>4},{N:>4}) {desc:<8s} ...") to incorporate tag (e.g., add
{tag:<8s} after {desc:<8s}) if you want rows labeled, otherwise delete the
unused tag assignment to avoid the dead variable.
In `@mllm-kernel/cmake/CPM.cmake`:
- Around line 12-23: The download block using file(DOWNLOAD) may leave a partial
file at ${CPM_DOWNLOAD_LOCATION} on failure; update the failure path in the
CPM.cmake snippet that handles download_status (the list(GET download_status 0
download_status_code) branch) to: if download_status_code is non-zero,
remove/unlink ${CPM_DOWNLOAD_LOCATION} if it exists to avoid poisoning the
cache, include the download_status/details in the message(FATAL_ERROR) for
better diagnostics, and then fail; reference the CPM_DOWNLOAD_LOCATION,
CPM_VERSION, download_status and download_status_code symbols to locate and
modify the file(DOWNLOAD) failure handling.
- Around line 14-18: Add integrity verification to the fallback file(DOWNLOAD
...) call by providing an EXPECTED_HASH for CPM v0.42.0 and failing the build if
the downloaded file's hash doesn't match; update the file(DOWNLOAD ...)
invocation that uses CPM_VERSION, CPM_DOWNLOAD_LOCATION and download_status to
include EXPECTED_HASH "<sha256-or-appropriate-hash-for-v0.42.0>" and check
download_status (and/or use file(DOWNLOAD ... EXPECTED_HASH) behavior) so the
script stops with an error when the hash verification fails rather than silently
using a tampered CPM.cmake.
In `@mllm-kernel/include/mllm_kernel/scalar_type.hpp`:
- Around line 260-262: Remove the global namespace alias `namespace host =
::mllm_kernel::host;` from the public header `scalar_type.hpp` (it pollutes
every TU); instead, delete that line and add a localized alias in each
translation unit that needs the short name (e.g., `marlin.cuh`,
`gptq_marlin.cuh`, `kernel.h`, `marlin_template.h`) by declaring a local
`namespace host = ::mllm_kernel::host;` inside the appropriate enclosing
namespace and, where useful, add selective using declarations like `using
host::div_ceil;` so call sites keep the short form without exporting `host` from
the header.
In `@mllm-kernel/mllm_kernel/cuda/csrc/gemm/int8/int8_scaled_mm_cutlass.cu`:
- Around line 156-217: sm89_dispatch_shape is dead/unreachable because SM80–SM89
flows unconditionally through sm80_dispatch_shape; either remove the unused
template or guard it so it can't silently rot. Fix by either deleting the
sm89_dispatch_shape template entirely, or wrap its entire definition with a
clear opt-in preprocessor guard (e.g. `#ifdef` MLLM_KEEP_SM89_REFERENCE ...
`#endif`) and add a short comment explaining it's intentionally preserved for
reference; reference the sm89_dispatch_shape and sm80_dispatch_shape symbols so
reviewers can find the related code paths.
- Around line 27-34: getSMVersion currently calls cudaGetDevice and
cudaDeviceGetAttribute without checking their cudaError_t results, which
swallows CUDA errors and returns 0; update getSMVersion to check the return
values of cudaGetDevice and both cudaDeviceGetAttribute calls, and on failure
either propagate or surface the cudaError_t (e.g., log/throw a descriptive error
containing the cudaGetErrorString result) instead of silently returning 0 so
callers (and the user) see the real CUDA error; reference the getSMVersion
function and the cudaGetDevice / cudaDeviceGetAttribute calls when making the
change.
- Around line 121-135: Add a one-line comment next to the ldc assignment that
explains why ldc is set to 0 for bias broadcasting: note that setting ldc = 0
with bias_ptr passed into Gemm::Arguments causes CUTLASS to read each row from
the same base pointer (broadcasting the [N] bias across rows), so keep ldc = 0
and the bias_ptr usage as-is in the Gemm::Arguments construction. Reference
variables: ldc, bias_ptr, and Gemm::Arguments (and the
EpilogueOutputOp/EpilogueVisitor parameters) so reviewers can find the exact
spot to add the comment.
In `@mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin_repack.py`:
- Around line 12-25: The function _normalize_perm does two separate torch.any
checks which cause two GPU→CPU synchronizations; replace the dual checks
"torch.any(perm < 0) or torch.any(perm >= size_k)" with a single combined mask
and a single any call (e.g., torch.any((perm < 0) | (perm >= size_k))) so
there's only one GPU→CPU sync when validating perm; update the check in
_normalize_perm accordingly while keeping the surrounding device/dtype/length
validations unchanged.
In `@mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin.py`:
- Around line 32-40: The code redundantly computes cpp_args via
make_cpp_args(dtype) and passes an explicit cuda_wrappers tuple to the `@jit`
decorator even though args=[dtype] already causes the framework to auto-generate
the identical cuda wrapper; remove the local cpp_args assignment and the
cuda_wrappers parameter from the `@jit` call, leaving args=[dtype] and
cpp_wrappers=[] intact so the auto-generated wrapper for gptq_marlin_gemm is
used; ensure you do not change func_name="gptq_marlin_gemm" or the cuda_files
entry ("gemm/marlin/gptq_marlin.cuh").
In `@mllm-kernel/mllm_kernel/cuda/jit/int8_scaled_mm_cutlass.py`:
- Around line 53-86: The _load_module() function has a TOCTOU race on the
globals _module and _CUTLASS_INC when called concurrently; protect the
lazy-initialization by adding a module-level threading.Lock (e.g.,
_load_module_lock) and wrap the critical section in _load_module() with that
lock, using a double-checked pattern (check _module before acquiring and
re-check after acquiring) so only one thread runs torch.utils.cpp_extension.load
and sets _module/_CUTLASS_INC while others return the initialized module.
- Around line 116-118: Validate the out_dtype argument instead of coercing: in
the wrapper that builds dtype_str and calls mod.int8_scaled_mm, check that
out_dtype is either torch.float16 or torch.bfloat16 and raise a clear ValueError
if not; only then map torch.float16 -> "float16" and torch.bfloat16 ->
"bfloat16" and pass dtype_str to mod.int8_scaled_mm (referencing the existing
variables out_dtype, dtype_str and the call to mod.int8_scaled_mm).
In `@mllm-kernel/tests/test_gptq_marlin_repack.py`:
- Around line 47-50: Add a short inline comment above the constants (tc_offsets
and pack_idx) in test_gptq_marlin_repack.py explaining that tc_offsets = [0, 1,
8, 9] and pack_idx = [0, 2, 4, 6, 1, 3, 5, 7] reproduce the ldmatrix / m16n8k16
tensor-core fragment row/element ordering used by Marlin's repack CUDA kernel;
include a note that these are the m16n8k16 fragment row/element offsets and add
a reference to the Marlin kernel source or file name and a brief reminder to
update these when changing bit-width or tile size (variables: tc_offsets,
pack_idx, tile_size, n_tiles).
In `@mllm-kernel/tests/test_gptq_marlin.py`:
- Around line 52-72: The test reimplements permutation logic in _get_scale_perms
and _marlin_permute_scales which duplicates get_scale_perms and
marlin_permute_scales in pymllm.quantization.methods.compressed_tensors and can
drift; replace the local copies by importing get_scale_perms and
marlin_permute_scales from that module (use them directly in the test) or, if
you intentionally want an independent reference implementation, add a concise
comment above _get_scale_perms/_marlin_permute_scales stating that they are a
deliberate independent copy to detect regressions in the production helpers and
must remain separate.
In `@mllm-kernel/tests/test_int8_scaled_mm_cutlass.py`:
- Around line 27-29: The top-level pytest.importorskip("torch") is unreachable
because torch is already imported at module scope; either remove the
importorskip call or make the module import conditional by moving the torch
import into the test setup/fixture so the module can be collected without torch.
Specifically, either delete the pytest.importorskip("torch") line near the start
of test_int8_scaled_mm_cutlass.py, or refactor the code so the torch import is
performed inside the fixture/function that checks torch.cuda.is_available()
(e.g., move the import into the fixture that contains the existing CUDA check),
ensuring collection works on environments without torch.
In `@pymllm/layers/rms_norm.py`:
- Around line 13-21: The current _torch_rmsnorm does a no-op cast and forces an
unnecessary fp32→input-dtype copy by returning x_norm.to(dtype=x.dtype) *
weight; change the sequence to keep computations in fp32, multiply x_norm by
weight converted to fp32 (e.g. weight_fp32 = weight.to(x_fp32.dtype)), then cast
the final product back to x.dtype once before return. Update _torch_rmsnorm to
compute x_fp32, var, x_norm in fp32, do x_norm * weight_fp32, and only at the
end .to(x.dtype) on that product.
In `@pymllm/models/qwen3.py`:
- Around line 393-399: The substring check that skips lm_head weights is
intentionally done before _remap_weight_name to handle both VL and non‑VL
checkpoints (e.g. "lm_head.weight", "language_model.lm_head.weight", or
"model.language_model.lm_head.weight"); add a short clarifying comment next to
the existing skip (the code block that tests `"lm_head.weight" in name`)
explaining that the substring check is purposeful and order-sensitive relative
to the _remap_weight_name(name: str) function so future readers understand
parity handling between VL/non‑VL serialized names.
In `@pymllm/orchestrator/model_runner_process.py`:
- Around line 592-595: Replace the fragile hasattr(cache, "page_size") probe
with an explicit type check for ChunkCache and short‑circuit before doing work:
update the caller (or the start of _insert_into_radix_cache) to return early
when isinstance(cache, ChunkCache) so you skip building RadixKey and calling
cache.insert / cache.match_prefix for ChunkCache instances; then remove the
now-dead hasattr-based block that sets self._rid_to_cache_protected_len[rid] =
0. Ensure you reference the ChunkCache type, _insert_into_radix_cache function,
and the code paths that call cache.insert and cache.match_prefix when making the
change.
In `@pymllm/orchestrator/scheduler_process.py`:
- Around line 849-852: The code unconditionally resets req.llm_decode_ms to 0.0
when appending to self._running_batch which can clobber a value set earlier in
process_batch_result (see the extend branch behavior around process_batch_result
and the extend output handling). Change the initialization logic in the block
where you set req.decode_start_tic and append to self._running_batch so that you
only set req.llm_decode_ms = 0.0 if req.llm_decode_ms is currently None (leave
existing non-None values intact); reference req.decode_start_tic,
req.llm_decode_ms, process_batch_result, and self._running_batch to find and
update the exact spot.
In `@pymllm/quantization/methods/compressed_tensors.py`:
- Around line 591-597: get_quant_method currently calls
_validate_supported_signature(self) on every invocation which repeats expensive
checks (including verify_marlin_supported and
torch.cuda.get_device_capability()) per layer; instead compute and cache the
signature once on the config object (e.g. in __init__ as self._cached_signature
or via a `@functools.cached_property`) and have get_quant_method use that cached
value, keeping the early ignore check (any(ignored and
prefix.startswith(ignored) for ignored in self.ignore)) and returning
CompressedTensorsLinearMethod(self, cached_signature).
- Around line 78-91: The error messages in verify_marlin_supports_shape use
hardcoded "64"/"128" which can go out of sync with the constants; update the
ValueError messages to reference GPTQ_MARLIN_MIN_THREAD_N and
GPTQ_MARLIN_MIN_THREAD_K (e.g., format or f-string) so the message reports the
actual constant values, and likewise ensure the third ValueError mentions
group_size and input_size_per_partition consistently; modify the raise calls in
verify_marlin_supports_shape to include the constant names/values instead of
literal numbers.
- Around line 1-13: Remove the eager top-level imports of gptq_marlin_gemm and
gptq_marlin_repack and instead import them lazily inside the runtime methods
that need Marlin (e.g., inside process_weights_after_loading and apply of the
W4A16 quantization path); specifically, delete the module-level import
references and add local imports of mllm_kernel.cuda.jit.gptq_marlin_gemm and
gptq_marlin_repack at the start of process_weights_after_loading (and any
W4A16.apply callsites that perform repacking/GEMM), so the Marlin extension is
only loaded when those functions execute and importing this module no longer
fails for W8A8-only environments.
- Around line 282-310: weight_shape is intentionally registered as a placeholder
to receive the per-tensor shape buffer from compressed-tensors checkpoints but
is never read; add a one-line comment next to the weight_shape
creation/registration explaining this intended purpose (so future readers don't
remove it), and in process_weights_after_loading add a defensive check that
reads the registered weight_shape and asserts or logs if its values differ from
input_size_per_partition/output_size_per_partition (reference weight_shape,
layer.register_parameter("weight_shape", ...), process_weights_after_loading,
input_size_per_partition, output_size_per_partition).
In `@pymllm/server/launch.py`:
- Around line 743-767: Extract the repeated timing logic into a small helper
pair: implement _safe_tps(tokens, ms) to return None when tokens or ms are None
or ms <= 0 and otherwise compute tokens/(ms/1000.0), and implement
_build_timing(r, prompt_tokens, completion_tokens) which reads vit_prefill_ms,
llm_prefill_ms, llm_decode_ms from r and returns the timing dict
(vit_prefill_ms, llm_prefill_ms, llm_decode_ms, prefill_tokens, vit_prefill_tps
via _safe_tps(r.get("vit_prefill_tokens"), vit_ms), llm_prefill_tps via
_safe_tps(prompt_tokens, llm_pre_ms), and llm_decode_tps via
_safe_tps(completion_tokens, llm_dec_ms)); then replace the inline "timing": {
... } dict in the /v1/completions and /v1/chat/completions response construction
with "timing": _build_timing(r, prompt_tokens, completion_tokens) so both
endpoints share the same guarded logic.
In `@pymllm/tests/bench_w8a8_activation_quant.py`:
- Around line 91-98: The loop over backends currently swallows all exceptions
and prints a vague "ERR" (for name, fn in backends.items() with bench_fn), make
failures visible: remove the bare except so exceptions propagate during
CI/development, or if you must handle them, catch Exception as e and print its
message and stack trace (e.g., via traceback.print_exc or processLogger.error)
instead of discarding e; also if you don't use the loop variable name, iterate
over backends.values() (or rename it to _ ) to avoid the unused-variable
warning.
- Around line 101-107: Guard the CUDA-only entry point by checking
torch.cuda.is_available() at the top of the if __name__ == "__main__" block
before calling torch.cuda.get_device_name(0) or
torch.cuda.get_device_capability(0); if CUDA is not available, print a clear
message and exit (or return) instead of calling those functions. Modify the main
block that calls run_benchmarks() so the device name/capability prints only
after the availability check, referencing the existing __main__ block,
torch.cuda.is_available(), torch.cuda.get_device_name,
torch.cuda.get_device_capability, and run_benchmarks to locate where to add the
guard.
In `@pymllm/tests/test_compressed_tensors_config.py`:
- Around line 121-123: Add a missing blank line between the two top-level test
functions: insert one additional empty line after the end of
test_get_quant_method_respects_ignore so there are two blank lines before the
definition of test_get_quant_method_rejects_unsupported_signature to satisfy PEP
8 E302; locate the functions by their names
test_get_quant_method_respects_ignore and
test_get_quant_method_rejects_unsupported_signature and ensure exactly two blank
lines separate these top-level function definitions.
In `@pymllm/tests/test_compressed_tensors_runtime.py`:
- Around line 306-315: The test uses very loose tolerances (atol=2e-1,
rtol=2e-1) which can hide real regressions; update the assertion comparing out
(from qm.apply) and ref (computed using ct._per_token_quant_int8 and
torch._int_mm) to use tighter tolerances such as atol=5e-2 and rtol=5e-2 (or
stronger when running with controlled K/scales), so that differences from
CUTLASS INT8 GEMM or scale-broadcast bugs are caught early; locate the
comparison around the variables x, bias, out, qm.apply,
ct._per_token_quant_int8, torch._int_mm, and layer.weight_scale and replace the
tolerance values accordingly.
In `@pymllm/tests/test_qwen3_weight_loading.py`:
- Around line 67-115: The two tests
test_load_weights_stacks_qkv_and_gate_up_from_model_prefix and
test_load_weights_accepts_model_language_model_prefix should be merged and
parametrized over layer_prefix to also include the "language_model" (no leading
"model.") case that _remap_weight_name handles; change them into a single
pytest.mark.parametrize test that calls _build_language_weights(cfg,
layer_prefix=layer_prefix), constructs Qwen3ForCausalLM(cfg), calls
model.load_weights(...), and then runs the same q/k/v and gate/up and
embedding/lm_head assertions using the parameterized layer_prefix so all three
remap branches ( "model", "model.language_model", "language_model") are covered
and the duplicated test bodies (referencing Qwen3ForCausalLM,
_build_language_weights, model.load_weights, layerX.self_attn.qkv_proj, and
layerX.mlp.gate_up_proj) are consolidated.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ff521d21-fe30-425b-8334-d94821c10a06
📒 Files selected for processing (38)
mllm-kernel/benchmarks/bench_int8_scaled_mm.pymllm-kernel/benchmarks/bench_w4a16_vs_w8a8.pymllm-kernel/cmake/CPM.cmakemllm-kernel/include/mllm_kernel/scalar_type.hppmllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.hmllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.hmllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.hmllm-kernel/mllm_kernel/cuda/csrc/gemm/int8/int8_scaled_mm_cutlass.cumllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/marlin.cuhmllm-kernel/mllm_kernel/cuda/jit/__init__.pymllm-kernel/mllm_kernel/cuda/jit/gptq_marlin.pymllm-kernel/mllm_kernel/cuda/jit/gptq_marlin_repack.pymllm-kernel/mllm_kernel/cuda/jit/int8_scaled_mm_cutlass.pymllm-kernel/tests/test_gptq_marlin.pymllm-kernel/tests/test_gptq_marlin_repack.pymllm-kernel/tests/test_int8_scaled_mm_cutlass.pypymllm/README-ZH.mdpymllm/README.mdpymllm/executor/model_runner.pypymllm/layers/rms_norm.pypymllm/models/__init__.pypymllm/models/qwen3.pypymllm/models/qwen3_vl.pypymllm/orchestrator/detokenizer_process.pypymllm/orchestrator/model_runner_process.pypymllm/orchestrator/scheduler_process.pypymllm/orchestrator/tokenizer_process.pypymllm/quantization/kernels/__init__.pypymllm/quantization/kernels/int8_activation_triton.pypymllm/quantization/methods/__init__.pypymllm/quantization/methods/compressed_tensors.pypymllm/server/launch.pypymllm/tests/bench_w8a8_activation_quant.pypymllm/tests/test_compressed_tensors_config.pypymllm/tests/test_compressed_tensors_runtime.pypymllm/tests/test_qwen3_forward_timing.pypymllm/tests/test_qwen3_model_registry.pypymllm/tests/test_qwen3_weight_loading.py
| Usage: | ||
| cd /workspace/.worktrees/pymllm-qwen3-vl-w8a8 | ||
| python3 mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py | ||
| """ |
There was a problem hiding this comment.
Replace the hardcoded developer worktree path in the docstring.
cd /workspace/.worktrees/pymllm-qwen3-vl-w8a8 is specific to a single developer's environment and won't help users running the benchmark from a fresh clone.
📝 Suggested doc fix
Usage:
- cd /workspace/.worktrees/pymllm-qwen3-vl-w8a8
- python3 mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py
+ # From the repository root:
+ python3 mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Usage: | |
| cd /workspace/.worktrees/pymllm-qwen3-vl-w8a8 | |
| python3 mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py | |
| """ | |
| Usage: | |
| # From the repository root: | |
| python3 mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py | |
| """ |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py` around lines 6 - 9, The
docstring in bench_w4a16_vs_w8a8.py contains a hardcoded developer worktree
path; replace that line with a generic, portable instruction such as "cd to the
repository root" or "run from the repository root" and/or show a relative
invocation like "python3 mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py" so users
know to run the script from the repo root instead of the specific
/workspace/.worktrees/... path.
| Params const& params_; | ||
| SharedStorage& shared_storage_; | ||
| MatrixCoord extent_; | ||
| MatrixCoord extent_real_; | ||
| ElementwiseFunctor elementwise_; | ||
|
|
||
| bool const with_bias_; | ||
| bool const per_token_quant_; | ||
| bool const per_channel_quant_; | ||
|
|
||
| AlphaScaleElementType* ptr_alpha_row_; | ||
| AlphaScaleElementType* ptr_alpha_col_; | ||
| ScaleTileIterator iterator_alpha_col_; | ||
| OutputTileIterator iterator_C_; | ||
| OutputTileIterator iterator_D_; | ||
|
|
||
| AlphaScaleElementType element_alpha_row_ = 1.0f; | ||
| AlphaScaleElementType element_alpha_col_ = 1.0f; | ||
| typename ScaleTileIterator::Fragment fragment_alpha_col_; | ||
| typename OutputTileIterator::Fragment fragment_C_; | ||
| typename OutputTileIterator::Fragment fragment_D_; | ||
|
|
||
| ElementAccumulator beta_; | ||
|
|
||
| int column_offset_; | ||
|
|
||
| MatrixCoord thread_offset_; | ||
|
|
||
| public: | ||
| CUTLASS_DEVICE | ||
| EpilogueVisitorPerRowPerCol( | ||
| Params const& params, | ||
| SharedStorage& shared_storage, | ||
| cutlass::MatrixCoord const& problem_size, | ||
| int thread_idx, | ||
| int warp_idx, | ||
| int lane_idx, | ||
| typename ScaleTileIterator::Params params_alpha_col, | ||
| typename OutputTileIterator::Params params_C, | ||
| typename OutputTileIterator::Params params_D, | ||
| bool with_bias, | ||
| bool per_token_quant, | ||
| bool per_channel_quant, | ||
| AlphaScaleElementType* ptr_alpha_row, | ||
| AlphaScaleElementType* ptr_alpha_col, | ||
| typename OutputTileIterator::Element* ptr_C, | ||
| typename OutputTileIterator::Element* ptr_D, | ||
| cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), | ||
| int column_offset = 0, | ||
| cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) | ||
| : params_(params), | ||
| shared_storage_(shared_storage), | ||
| extent_(problem_size), | ||
| elementwise_(params.elementwise), | ||
| with_bias_(with_bias), | ||
| per_token_quant_(per_token_quant), | ||
| per_channel_quant_(per_channel_quant), | ||
| ptr_alpha_row_(ptr_alpha_row), | ||
| ptr_alpha_col_(ptr_alpha_col), | ||
| iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset), | ||
| iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset), | ||
| iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset), | ||
| extent_real_(problem_size_real) { | ||
| if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) { | ||
| element_alpha_col_ = *ptr_alpha_col_; | ||
| } | ||
|
|
||
| if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) { | ||
| element_alpha_row_ = *ptr_alpha_row_; | ||
| } | ||
| } |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h | sed -n '110,185p'Repository: UbiquitousLearning/mllm
Length of output: 3202
Member-initializer list order doesn't match declaration order — emits -Wreorder.
extent_real_ is declared at line 116 (4th member) but initialized last at line 175, and elementwise_ is declared at line 117 but initialized at line 166 before with_bias_/per_token_quant_/per_channel_quant_. Since C++ always initializes members in declaration order, the apparent order in the initializer list is misleading and any compiler with -Wreorder (gcc/clang default in many builds) will warn here. There's no functional bug because none of the earlier initializers read extent_real_, but reordering keeps the intent and the actual behavior in sync.
♻️ Proposed reorder of the initializer list
: params_(params),
shared_storage_(shared_storage),
extent_(problem_size),
+ extent_real_(problem_size_real),
elementwise_(params.elementwise),
with_bias_(with_bias),
per_token_quant_(per_token_quant),
per_channel_quant_(per_channel_quant),
ptr_alpha_row_(ptr_alpha_row),
ptr_alpha_col_(ptr_alpha_col),
iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset),
iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
- iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
- extent_real_(problem_size_real) {
+ iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset) {🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h`
around lines 113 - 183, The constructor EpilogueVisitorPerRowPerCol has its
member-initializer list out of declaration order causing -Wreorder; reorder the
initializer list to match the class member declaration order (Params const&
params_, SharedStorage& shared_storage_, MatrixCoord extent_, MatrixCoord
extent_real_, ElementwiseFunctor elementwise_, bool with_bias_, bool
per_token_quant_, bool per_channel_quant_, AlphaScaleElementType*
ptr_alpha_row_, AlphaScaleElementType* ptr_alpha_col_, ScaleTileIterator
iterator_alpha_col_, OutputTileIterator iterator_C_, OutputTileIterator
iterator_D_, ...), i.e., move extent_real_ to be initialized immediately after
extent_ and ensure elementwise_ follows extent_real_ (and the rest follow their
declared order) while keeping the existing initialization expressions and
runtime logic unchanged.
| CUTLASS_HOST_DEVICE | ||
| Params() | ||
| : swizzle_log_tile(0), | ||
| params_A(0), | ||
| params_B(0), | ||
| params_alpha_col(0), | ||
| params_C(0), | ||
| params_D(0), | ||
| batch_count(0), | ||
| gemm_k_size(0), | ||
| mode(cutlass::gemm::GemmUniversalMode::kGemm), | ||
| ptr_A(nullptr), | ||
| ptr_B(nullptr), | ||
| ptr_alpha_col(nullptr), | ||
| ptr_alpha_row(nullptr), | ||
| ptr_C(nullptr), | ||
| ptr_D(nullptr), | ||
| batch_stride_A(0), | ||
| batch_stride_B(0) {} |
There was a problem hiding this comment.
Default Params() constructor omits params_alpha_row.
ptr_alpha_row is explicitly nulled but the corresponding params_alpha_row member is never listed in the initializer list (compare with params_alpha_col(0)). If ScaleTileIterator::Params is not trivially default-constructible with the right zero-stride state, this leaves it in an inconsistent state vs. its _col counterpart. Add the matching initializer for symmetry and to avoid surprises if ScaleTileIterator::Params changes.
🛡️ Proposed fix
params_alpha_col(0),
+ params_alpha_row(0),
params_C(0),📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| CUTLASS_HOST_DEVICE | |
| Params() | |
| : swizzle_log_tile(0), | |
| params_A(0), | |
| params_B(0), | |
| params_alpha_col(0), | |
| params_C(0), | |
| params_D(0), | |
| batch_count(0), | |
| gemm_k_size(0), | |
| mode(cutlass::gemm::GemmUniversalMode::kGemm), | |
| ptr_A(nullptr), | |
| ptr_B(nullptr), | |
| ptr_alpha_col(nullptr), | |
| ptr_alpha_row(nullptr), | |
| ptr_C(nullptr), | |
| ptr_D(nullptr), | |
| batch_stride_A(0), | |
| batch_stride_B(0) {} | |
| CUTLASS_HOST_DEVICE | |
| Params() | |
| : swizzle_log_tile(0), | |
| params_A(0), | |
| params_B(0), | |
| params_alpha_col(0), | |
| params_alpha_row(0), | |
| params_C(0), | |
| params_D(0), | |
| batch_count(0), | |
| gemm_k_size(0), | |
| mode(cutlass::gemm::GemmUniversalMode::kGemm), | |
| ptr_A(nullptr), | |
| ptr_B(nullptr), | |
| ptr_alpha_col(nullptr), | |
| ptr_alpha_row(nullptr), | |
| ptr_C(nullptr), | |
| ptr_D(nullptr), | |
| batch_stride_A(0), | |
| batch_stride_B(0) {} |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h`
around lines 184 - 202, The Params() constructor in gemm_with_epilogue_visitor.h
initializes params_alpha_col but omits params_alpha_row, leaving
params_alpha_row potentially uninitialized; update the Params() member
initializer list for the Params() constructor to include params_alpha_row(0)
(mirroring params_alpha_col(0)) so ScaleTileIterator::Params for the row scale
is set to the zero/empty state consistent with ptr_alpha_row being nullptr.
| Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_) | ||
| : problem_size(args.problem_size), | ||
| swizzle_log_tile(0), | ||
| params_A(args.ref_A.layout()), | ||
| params_B(args.ref_B.layout()), | ||
| params_alpha_col(args.ref_alpha_col.layout()), | ||
| params_alpha_row(args.ref_alpha_col.layout()), | ||
| params_C(args.ref_C.layout()), | ||
| params_D(args.ref_D.layout()), | ||
| mode(args.mode), | ||
| batch_count(args.batch_count), | ||
| gemm_k_size(args.problem_size.k()), | ||
| ptr_A(args.ref_A.data()), | ||
| ptr_B(args.ref_B.data()), | ||
| ptr_alpha_col(args.ref_alpha_col.data()), | ||
| ptr_alpha_row(args.ref_alpha_row.data()), | ||
| ptr_C(args.ref_C.data()), | ||
| ptr_D(args.ref_D.data()), | ||
| batch_stride_A(args.batch_stride_A), | ||
| batch_stride_B(args.batch_stride_B), | ||
| epilogue_visitor(args.epilogue_visitor) { |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Confirm both Params have separate row/col layout sources in upstream TRT-LLM.
curl -fsSL \
"https://raw.githubusercontent.com/NVIDIA/TensorRT-LLM/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h" \
| grep -nE 'params_alpha_(row|col)\s*\('Repository: UbiquitousLearning/mllm
Length of output: 233
Fix copy-paste bug: params_alpha_row initialized from wrong layout reference.
The alpha_row scale iterator's Params are being constructed from the column reference's layout. Since LayoutAlphaCol is RowMajor and LayoutAlphaRow is ColumnMajor (lines 56-57), this produces incorrect strides for the row scale iterator. This bug also exists in the upstream TRT-LLM source (commit be178810...), so it should be fixed here regardless.
Suggested fix
params_alpha_col(args.ref_alpha_col.layout()),
- params_alpha_row(args.ref_alpha_col.layout()),
+ params_alpha_row(args.ref_alpha_row.layout()),
params_C(args.ref_C.layout()),🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h`
around lines 204 - 224, The initializer list for Params mistakenly constructs
params_alpha_row from the column layout
(params_alpha_row(args.ref_alpha_col.layout())); change it to use the actual row
reference layout by initializing params_alpha_row with
args.ref_alpha_row.layout() in the Params constructor so the alpha_row scale
iterator gets correct ColumnMajor strides (update the initializer near
params_alpha_col/params_alpha_row and verify ptr_alpha_row uses
args.ref_alpha_row.data()).
| def gptq_marlin_repack( | ||
| b_q_weight: torch.Tensor, | ||
| perm: Optional[torch.Tensor], | ||
| size_k: int, | ||
| size_n: int, | ||
| num_bits: int, | ||
| ) -> torch.Tensor: | ||
| """Repack GPTQ/Compressed-Tensors weights into Marlin layout.""" | ||
|
|
||
| pack_factor = 32 // num_bits | ||
| tile_size = 16 | ||
| out = torch.empty( | ||
| (size_k // tile_size, size_n * tile_size // pack_factor), | ||
| dtype=b_q_weight.dtype, | ||
| device=b_q_weight.device, | ||
| ) | ||
| kernel = _make_gptq_marlin_repack_kernel() | ||
| perm_t = _normalize_perm(perm, size_k, b_q_weight.device) | ||
| kernel(b_q_weight, perm_t, out, size_k, size_n, num_bits) | ||
| return out |
There was a problem hiding this comment.
Add input validation for num_bits, size_k, size_n, and b_q_weight.dtype.
pack_factor = 32 // num_bits and the output shape (size_k // tile_size, size_n * tile_size // pack_factor) silently truncate / divide-by-zero on bad inputs:
num_bits == 0raises a confusingZeroDivisionError(instead of a clear validation message).num_bitsvalues that don't evenly divide 32 (e.g. 5, 6, 7) silently produce a wrong layout.size_k % tile_size != 0or(size_n * tile_size) % pack_factor != 0silently truncate the output dims.- The Marlin repack kernel expects
b_q_weightto beint32(matching howlayer.weight_packedis stored inpymllm/quantization/methods/compressed_tensors.pylines 329–344); a wrong dtype would otherwise be caught only deep inside the CUDA kernel.
The function is invoked once per quantized linear at load time, so the cost of these checks is negligible.
🛡️ Proposed validation
def gptq_marlin_repack(
b_q_weight: torch.Tensor,
perm: Optional[torch.Tensor],
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
"""Repack GPTQ/Compressed-Tensors weights into Marlin layout."""
- pack_factor = 32 // num_bits
- tile_size = 16
+ if num_bits not in (4, 8):
+ raise ValueError(f"num_bits must be 4 or 8, got {num_bits}")
+ if b_q_weight.dtype != torch.int32:
+ raise ValueError(
+ f"b_q_weight must be int32, got {b_q_weight.dtype}"
+ )
+ tile_size = 16
+ pack_factor = 32 // num_bits
+ if size_k % tile_size != 0:
+ raise ValueError(
+ f"size_k ({size_k}) must be a multiple of tile_size ({tile_size})"
+ )
+ if (size_n * tile_size) % pack_factor != 0:
+ raise ValueError(
+ f"size_n*tile_size ({size_n * tile_size}) must be a multiple of "
+ f"pack_factor ({pack_factor})"
+ )
out = torch.empty(
(size_k // tile_size, size_n * tile_size // pack_factor),
dtype=b_q_weight.dtype,
device=b_q_weight.device,
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin_repack.py` around lines 56 - 75,
Add explicit input validation at the start of gptq_marlin_repack: check num_bits
is a positive divisor of 32 (raise ValueError if num_bits <= 0 or 32 % num_bits
!= 0), ensure tile_size alignment by validating size_k % tile_size == 0 (use
tile_size = 16 as in the function) and ensure (size_n * tile_size) % (32 //
num_bits) == 0 to avoid truncation, and verify b_q_weight.dtype is torch.int32
(raise TypeError if not). Use the existing symbols gptq_marlin_repack,
pack_factor (computed as 32 // num_bits), tile_size, and b_q_weight to locate
where to add these checks and raise clear, descriptive exceptions before calling
_normalize_perm or the kernel.
| _llm_t0 = time.perf_counter() | ||
| hidden_states = self.model(input_ids, positions, forward_batch) | ||
| _llm_ms = (time.perf_counter() - _llm_t0) * 1000.0 | ||
|
|
||
| if forward_batch.forward_mode.is_extend(): | ||
| forward_batch.llm_prefill_ms = _llm_ms | ||
| forward_batch.llm_decode_ms = None | ||
| else: | ||
| forward_batch.llm_decode_ms = _llm_ms |
There was a problem hiding this comment.
Timing without torch.cuda.synchronize() measures launch overhead, not GPU execution.
time.perf_counter() deltas around self.model(...) capture only the time spent issuing CUDA work to the stream — the kernels are still running asynchronously when _llm_ms is computed. For decode (single token), this will systematically under-report by orders of magnitude; for prefill it's noisier. Since these values feed forward_batch.llm_prefill_ms / llm_decode_ms (consumed downstream via the orchestrator/server timing path), the "prefill_ms" reported to operators won't reflect actual prefill latency.
If the goal is engine-level perceived latency (good enough for SLO accounting since something downstream eventually syncs), please document that. If the goal is GPU-only model latency, sync explicitly:
🐛 Proposed fix (sync-based timing)
- _llm_t0 = time.perf_counter()
- hidden_states = self.model(input_ids, positions, forward_batch)
- _llm_ms = (time.perf_counter() - _llm_t0) * 1000.0
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ _llm_t0 = time.perf_counter()
+ hidden_states = self.model(input_ids, positions, forward_batch)
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ _llm_ms = (time.perf_counter() - _llm_t0) * 1000.0(or use torch.cuda.Event + elapsed_time to avoid host syncs in steady state)
As per coding guidelines: "Add comments for complex algorithms or non-obvious logic." Either way, please add a one-line comment clarifying which timing semantics this field carries.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@pymllm/models/qwen3.py` around lines 296 - 304, The timing around
self.model(...) currently measures CUDA launch overhead not actual GPU
execution; update the timing in the qwen3.py block that sets
forward_batch.llm_prefill_ms and forward_batch.llm_decode_ms to either (a) call
torch.cuda.synchronize() immediately after self.model(...) before computing
_llm_ms so the measured delta reflects real GPU work, or (b) use
torch.cuda.Event start/stop + elapsed_time for GPU timings; then set
forward_batch.llm_prefill_ms/llm_decode_ms accordingly and add a one-line
comment beside this logic (referencing forward_batch.forward_mode.is_extend(),
forward_batch.llm_prefill_ms, forward_batch.llm_decode_ms, and self.model)
documenting whether the field represents engine-perceived latency (asynchronous
launch time) or actual GPU execution time (synchronized or event-based).
| # If AutoProcessor produced multimodal input_ids, they must override | ||
| # the plain tokenizer result. Otherwise the prompt contains only a | ||
| # single image placeholder token and won't match the visual features. | ||
| if mm_inputs is not None: | ||
| image_inputs = mm_inputs.get("image_inputs") | ||
| if image_inputs is not None and "input_ids" in image_inputs: | ||
| proc_input_ids = image_inputs["input_ids"] | ||
| if hasattr(proc_input_ids, "ndim") and proc_input_ids.ndim > 1: | ||
| proc_input_ids = proc_input_ids[0] | ||
| if hasattr(proc_input_ids, "tolist"): | ||
| input_ids = proc_input_ids.tolist() | ||
| else: | ||
| input_ids = list(proc_input_ids) |
There was a problem hiding this comment.
Multimodal input_ids bypass context_length truncation.
The text-only path applies truncation=True, max_length=self._context_length on line 361-364, but when mm_inputs["image_inputs"]["input_ids"] is present it unconditionally overrides input_ids without re-checking the context cap. For models that expand image placeholders into many vision tokens (Qwen3-VL), this can silently produce sequences exceeding the model's maximum position embeddings and trigger downstream errors or OOM.
🛡️ Suggested guard
if hasattr(proc_input_ids, "tolist"):
input_ids = proc_input_ids.tolist()
else:
input_ids = list(proc_input_ids)
+ if (
+ self._context_length is not None
+ and len(input_ids) > self._context_length
+ ):
+ logger.warning(
+ "Multimodal input_ids (%d) exceed context_length=%d; truncating.",
+ len(input_ids),
+ self._context_length,
+ )
+ input_ids = input_ids[: self._context_length]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@pymllm/orchestrator/tokenizer_process.py` around lines 374 - 386, When
mm_inputs provides multimodal input_ids (image_inputs["input_ids"]) we currently
override input_ids without enforcing the model context cap; after extracting
proc_input_ids in the block handling mm_inputs/image_inputs, truncate
proc_input_ids to self._context_length before converting to a Python list or
assigning to input_ids (e.g., slice ndarray or list to [:self._context_length]
or use equivalent method for objects exposing tolist), so input_ids never exceed
the model's context length; reference symbols: mm_inputs, image_inputs,
proc_input_ids, input_ids, self._context_length.
| assert x.is_contiguous(), "Input must be contiguous" | ||
|
|
||
| M = x.numel() // x.shape[-1] | ||
| N = x.shape[-1] | ||
| x_q = torch.empty_like(x, device=x.device, dtype=torch.int8) | ||
| scales = torch.empty( | ||
| x.shape[:-1] + (1,), device=x.device, dtype=scale_dtype | ||
| ) | ||
|
|
||
| BLOCK = triton.next_power_of_2(N) | ||
| num_warps = min(max(BLOCK // 256, 1), 8) | ||
|
|
||
| _per_token_quant_int8[(M,)]( | ||
| x, | ||
| x_q, | ||
| scales, | ||
| stride_x=x.stride(-2), | ||
| stride_xq=x_q.stride(-2), | ||
| N=N, | ||
| BLOCK=BLOCK, | ||
| num_warps=num_warps, | ||
| num_stages=1, | ||
| ) | ||
| return x_q, scales |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
Does Triton 3.6.0 support tl.extra.cuda.libdevice.round on Jetson Orin (SM87) with CUDA 12.6?
💡 Result:
Yes, Triton 3.6.0 supports tl.extra.cuda.libdevice.round on Jetson Orin (SM_87) with CUDA 12.6.
Citations:
- 1: https://triton-lang.org/main/getting-started/tutorials/07-extern-functions.html
- 2: https://github.com/triton-lang/triton/releases/tag/v3.6.0
- 3: PyTorch Nightly's Torch Compile Feature Does Not Work on THOR due to Triton 3.6.x regression. pytorch/pytorch#170117
- 4: [Installation]: PyTorch Version Conflict on NVIDIA Jetson AGX Orin (aarch64) with CUDA 12.6 vllm-project/vllm#15169
- 5: https://docs.nvidia.com/jetson/archives/jetpack-archived/jetpack-62/release-notes/index.html
- 6: https://developer.nvidia.com/embedded/jetson-linux-r363
🏁 Script executed:
# First, locate and examine the file
find . -name "int8_activation_triton.py" -type fRepository: UbiquitousLearning/mllm
Length of output: 123
🏁 Script executed:
# Get the file structure to understand context
wc -l pymllm/quantization/kernels/int8_activation_triton.pyRepository: UbiquitousLearning/mllm
Length of output: 124
🏁 Script executed:
# Read the specific lines mentioned in the review (59-82) with context
sed -n '50,90p' pymllm/quantization/kernels/int8_activation_triton.pyRepository: UbiquitousLearning/mllm
Length of output: 971
🏁 Script executed:
# Also check the docstring mentioned (Line 52)
sed -n '45,58p' pymllm/quantization/kernels/int8_activation_triton.pyRepository: UbiquitousLearning/mllm
Length of output: 539
🏁 Script executed:
# Search for the production caller mentioned: CompressedTensorsW8A8Int8Scheme.apply
rg "CompressedTensorsW8A8Int8Scheme|apply_int8" --type py -A 10 -B 2Repository: UbiquitousLearning/mllm
Length of output: 2123
🏁 Script executed:
# Check for any direct imports or calls to _per_token_quant_int8
rg "_per_token_quant_int8" --type py -B 5 -A 5Repository: UbiquitousLearning/mllm
Length of output: 6167
🏁 Script executed:
# Search for all calls to per_token_quant_int8 function
rg "per_token_quant_int8" --type pyRepository: UbiquitousLearning/mllm
Length of output: 1685
🏁 Script executed:
# Check bench_w4a16_vs_w8a8.py for how per_token_quant_int8 is called
grep -A 20 -B 5 "per_token_quant_int8" mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.pyRepository: UbiquitousLearning/mllm
Length of output: 1859
🏁 Script executed:
# Check bench_w8a8_activation_quant.py for how it's called
grep -A 30 "triton_fn = _try_load_triton_kernel" pymllm/tests/bench_w8a8_activation_quant.pyRepository: UbiquitousLearning/mllm
Length of output: 986
🏁 Script executed:
# Check what test shapes are used in bench_w8a8_activation_quant.py
sed -n '40,120p' pymllm/tests/bench_w8a8_activation_quant.pyRepository: UbiquitousLearning/mllm
Length of output: 1779
x.stride(-2) will raise IndexError on 1-D inputs despite the docstring claiming "any shape".
The docstring (Line 52) states "Input tensor, any shape with last dim = hidden_dim", but for x.ndim == 1, x.stride(-2) raises IndexError: Dimension out of range. Currently dormant—the sole production caller (compressed_tensors.py line ~156) reshapes to 2-D before calling—but this is an undocumented constraint that becomes a footgun for direct callers or benchmarks.
Also, assert x.is_contiguous() is stripped under python -O; use an explicit ValueError check instead.
Suggested fix
- assert x.is_contiguous(), "Input must be contiguous"
-
- M = x.numel() // x.shape[-1]
+ if not x.is_contiguous():
+ raise ValueError("Input must be contiguous")
+ if x.ndim < 2:
+ x = x.view(1, -1)
+
N = x.shape[-1]
+ M = x.numel() // N
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(
- x.shape[:-1] + (1,), device=x.device, dtype=scale_dtype
+ (*x.shape[:-1], 1), device=x.device, dtype=scale_dtype
)The shape rewrite also picks up Ruff RUF005.
🧰 Tools
🪛 Ruff (0.15.12)
[warning] 65-65: Consider (*x.shape[:-1], 1) instead of concatenation
Replace with (*x.shape[:-1], 1)
(RUF005)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@pymllm/quantization/kernels/int8_activation_triton.py` around lines 59 - 82,
Replace the brittle assert and the risky negative 2-index stride access: in the
function that calls _per_token_quant_int8 (the block around the
_per_token_quant_int8[(M,)](...) call), replace "assert x.is_contiguous()" with
an explicit check that raises ValueError if not contiguous, and compute stride_x
and stride_xq defensively (use x.stride(-2) / x_q.stride(-2) when x.ndim >= 2
else fall back to x.stride(0) / x_q.stride(0)); pass those computed stride_x and
stride_xq into the _per_token_quant_int8 call. This keeps the documented "any
shape" behavior and avoids IndexError on 1-D inputs and also fixes the assert
being stripped under python -O.
| {"type": "text", "text": "请详细描述这张图片。"}, | ||
| {"type": "image_url", "image_url": {"url": "/workspace/xcd_mllm/test.png"}}, | ||
| ], |
There was a problem hiding this comment.
Hard-coded developer path leaked into example.
/workspace/xcd_mllm/test.png is a developer-specific absolute path that won't exist for anyone else. Replace with a placeholder so users don't blindly copy and hit a confusing FileNotFound:
📝 Suggested fix
{"type": "text", "text": "请详细描述这张图片。"},
- {"type": "image_url", "image_url": {"url": "/workspace/xcd_mllm/test.png"}},
+ {"type": "image_url", "image_url": {"url": "/path/to/your/image.png"}},📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| {"type": "text", "text": "请详细描述这张图片。"}, | |
| {"type": "image_url", "image_url": {"url": "/workspace/xcd_mllm/test.png"}}, | |
| ], | |
| {"type": "text", "text": "请详细描述这张图片。"}, | |
| {"type": "image_url", "image_url": {"url": "/path/to/your/image.png"}}, | |
| ], |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@pymllm/README-ZH.md` around lines 199 - 201, The example in the README
currently hard-codes a developer-specific absolute path
("/workspace/xcd_mllm/test.png") inside the JSON snippet for the "image_url"
entry; replace that literal with a neutral placeholder (e.g.,
"<path/to/image.png>" or a sample relative/remote URL) so users won’t copy a
non-existent path—update the "image_url" -> "url" value in the JSON array shown
in the README-ZH.md accordingly.
| - JetPack / L4T: `R36.4.4` (`/etc/nv_tegra_release`) | ||
| - Python: `3.10.12` | ||
| - PyTorch: `2.4.0` | ||
| - torchvision: `0.19.0a0+48b1edf` | ||
| - transformers: `5.3.0` | ||
| - safetensors: `0.7.0` | ||
| - flashinfer: `0.6.7` | ||
| - Triton Language: official PyPI `triton==3.6.0` manylinux aarch64 wheel | ||
| - CUDA: `12.6` | ||
| - GPU: Jetson Orin NX, SM87 | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify that the pinned versions are real.
for pkg in transformers safetensors flashinfer; do
echo "=== $pkg ==="
curl -s "https://pypi.org/pypi/${pkg}/json" | jq -r '.releases | keys | .[-5:]'
doneRepository: UbiquitousLearning/mllm
Length of output: 330
🏁 Script executed:
#!/bin/bash
# More targeted verification of specific versions
echo "=== Checking transformers 5.3.0 ==="
curl -s "https://pypi.org/pypi/transformers/5.3.0/json" | jq -r '.info.version // "NOT FOUND"'
echo "=== Checking safetensors 0.7.0 ==="
curl -s "https://pypi.org/pypi/safetensors/0.7.0/json" | jq -r '.info.version // "NOT FOUND"'
echo "=== Checking flashinfer 0.6.7 ==="
curl -s "https://pypi.org/pypi/flashinfer/0.6.7/json" | jq -r '.info.version // "NOT FOUND"'
echo ""
echo "=== Alternative: Check flashinfer existence (any version) ==="
curl -s "https://pypi.org/pypi/flashinfer/json" | jq -r '.info.name // "PACKAGE NOT FOUND"'Repository: UbiquitousLearning/mllm
Length of output: 275
Fix flashinfer dependency reference — version and package do not exist on PyPI.
Verification found that transformers 5.3.0 and safetensors 0.7.0 exist on PyPI and are correct. However, flashinfer 0.6.7 cannot be obtained: neither this version nor any flashinfer package is available on PyPI. The validated environment table lists this non-existent dependency, making the documented setup irreproducible. Either remove flashinfer from the table, clarify if it should come from a different source (e.g., custom build, private repository), or provide correct installation instructions.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@pymllm/README.md` around lines 50 - 60, The README's environment table lists
a non-existent PyPI package/version "flashinfer: `0.6.7`" which makes the setup
unreproducible; update the README.md entry referencing "flashinfer 0.6.7" by
either removing the flashinfer line, replacing it with the correct source (e.g.,
a Git URL, private repo, or build instructions), or adding explicit install
instructions and provenance (e.g., how to build/install from source or a link to
the correct release). Edit the list block in README.md where "flashinfer:
`0.6.7`" appears so the table accurately reflects how to obtain or install
flashinfer, and include the concrete install command or URL if you choose to
keep the dependency.
There was a problem hiding this comment.
♻️ Duplicate comments (1)
pymllm/layers/rms_norm.py (1)
52-54:⚠️ Potential issue | 🟠 MajorAvoid bare
Exceptionin FlashInfer fallbacks; narrow and log instead.Line 52 and Line 65 currently swallow all exceptions and silently fallback. In this core serving path, that can hide real defects and makes failures hard to diagnose. Catch expected runtime failures and emit a warning with context.
Proposed patch
from typing import Optional, Tuple, Union +import logging import torch import flashinfer from torch.nn import Parameter @@ from pymllm.layers.base import MllmBaseLayer from pymllm.layers.utils import set_weight_attrs +logger = logging.getLogger(__name__) + @@ - except Exception: + except (RuntimeError, ValueError) as e: + logger.warning( + "flashinfer fused_add_rmsnorm failed; using torch fallback: %s", e + ) residual = x + residual return _torch_rmsnorm(residual, self.weight, self.eps), residual @@ - except Exception: + except (RuntimeError, ValueError) as e: + logger.warning("flashinfer rmsnorm failed; using torch fallback: %s", e) return _torch_rmsnorm(x, self.weight, self.eps)As per coding guidelines, "Ensure functions that can fail return appropriate error codes or raise exceptions" and "Add appropriate logging (e.g., debug, info, warning, error) for significant events and errors, avoiding sensitive data exposure".
For the FlashInfer version used by this repository, what exception classes are documented or commonly raised by `flashinfer.norm.fused_add_rmsnorm` and `flashinfer.norm.rmsnorm` on unsupported shapes/dtypes/runtime failures?Also applies to: 65-66
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pymllm/layers/rms_norm.py` around lines 52 - 54, Replace the bare "except Exception:" handlers around the FlashInfer fallbacks (the blocks calling flashinfer.norm.fused_add_rmsnorm and flashinfer.norm.rmsnorm and returning _torch_rmsnorm(...)) with a narrow catch for expected runtime errors (e.g., except (RuntimeError, ValueError, TypeError) as e:), log a warning that includes function/context, tensor shapes/dtypes and the exception message, then perform the fallback; re-raise any other unexpected exceptions (use "except Exception: raise" or avoid catching them). Update both locations that currently swallow all exceptions so only anticipated FlashInfer failures are logged and handled while other errors bubble up.
🧹 Nitpick comments (2)
pymllm/bench_one_batch.py (2)
398-544: Document the runner lifecycle and failure contract.
PymllmBenchRunnerallocates request/KV slots directly and depends on a specificcreate()→clear()/extend()/decode()→shutdown()lifecycle. A class docstring plus method docs for preconditions, ownership, and errors would make this much safer to reuse.As per coding guidelines, ensure public APIs, classes, and functions have clear docstrings or comments explaining purpose, parameters, returns, and errors.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pymllm/bench_one_batch.py` around lines 398 - 544, Add a clear class-level docstring to PymllmBenchRunner describing the required lifecycle (call create() to get an instance, then use clear()/extend()/decode() in that order as needed, and finally call shutdown()), the ownership semantics of resources (ModelRunner and its req_to_token_pool and token_to_kv_pool_allocator), and the failure contract (which methods raise RuntimeError or ValueError and in what situations). Also add brief docstrings to create(), clear(), extend(), decode(), and shutdown() noting their parameters, return types (e.g., extend()/decode() return (next_token_ids, DecodeState)), preconditions (e.g., _require_initialized or expected tensor shapes), and side effects (alloc/free of request/KV slots and that extend() will free allocated req slots on allocation failure), referencing the methods by name so callers can find the behavior easily.
43-43: Avoid hardcoded/tmpdefaults for benchmark artifacts.Both defaults pin the script to a POSIX temp path and keep outputs in a predictable shared location. Prefer a platform-derived temp/cache directory and per-run names by default so this stays portable and avoids accidental collisions.
As per coding guidelines, ensure code is portable across supported platforms (e.g., Linux, Windows) unless explicitly platform-specific, and identify potential security issues (e.g., insecure temporary files) and recommend using secure alternatives.
Also applies to: 378-378
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pymllm/bench_one_batch.py` at line 43, The hardcoded POSIX temp path assigned to result_filename should be replaced with a platform-safe, unique temp file path: import tempfile (and uuid.uuid4 or use tempfile.NamedTemporaryFile(delete=False, suffix=".jsonl")) and set result_filename to Path(tempfile.gettempdir()) / f"pymllm_bench_one_batch_{uuid4()}.jsonl" or create a NamedTemporaryFile and use its name, ensuring you add the necessary imports and use uuid4() or tempfile to avoid collisions and insecure predictable names; apply the same change to the other occurrence referenced (result filename at the later location).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@pymllm/layers/rms_norm.py`:
- Around line 52-54: Replace the bare "except Exception:" handlers around the
FlashInfer fallbacks (the blocks calling flashinfer.norm.fused_add_rmsnorm and
flashinfer.norm.rmsnorm and returning _torch_rmsnorm(...)) with a narrow catch
for expected runtime errors (e.g., except (RuntimeError, ValueError, TypeError)
as e:), log a warning that includes function/context, tensor shapes/dtypes and
the exception message, then perform the fallback; re-raise any other unexpected
exceptions (use "except Exception: raise" or avoid catching them). Update both
locations that currently swallow all exceptions so only anticipated FlashInfer
failures are logged and handled while other errors bubble up.
---
Nitpick comments:
In `@pymllm/bench_one_batch.py`:
- Around line 398-544: Add a clear class-level docstring to PymllmBenchRunner
describing the required lifecycle (call create() to get an instance, then use
clear()/extend()/decode() in that order as needed, and finally call shutdown()),
the ownership semantics of resources (ModelRunner and its req_to_token_pool and
token_to_kv_pool_allocator), and the failure contract (which methods raise
RuntimeError or ValueError and in what situations). Also add brief docstrings to
create(), clear(), extend(), decode(), and shutdown() noting their parameters,
return types (e.g., extend()/decode() return (next_token_ids, DecodeState)),
preconditions (e.g., _require_initialized or expected tensor shapes), and side
effects (alloc/free of request/KV slots and that extend() will free allocated
req slots on allocation failure), referencing the methods by name so callers can
find the behavior easily.
- Line 43: The hardcoded POSIX temp path assigned to result_filename should be
replaced with a platform-safe, unique temp file path: import tempfile (and
uuid.uuid4 or use tempfile.NamedTemporaryFile(delete=False, suffix=".jsonl"))
and set result_filename to Path(tempfile.gettempdir()) /
f"pymllm_bench_one_batch_{uuid4()}.jsonl" or create a NamedTemporaryFile and use
its name, ensuring you add the necessary imports and use uuid4() or tempfile to
avoid collisions and insecure predictable names; apply the same change to the
other occurrence referenced (result filename at the later location).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 245f8fba-3829-44a3-8397-37411ca4a31c
📒 Files selected for processing (4)
pymllm/bench_one_batch.pypymllm/layers/rms_norm.pypymllm/tests/test_bench_one_batch.pypymllm/tests/test_rms_norm.py
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
pymllm/models/qwen3_vl.py (1)
1176-1190:⚠️ Potential issue | 🟠 MajorThese timing fields are batch totals, but the scheduler stores them per request.
The values written onto
forward_batchhere are single scalars for the whole batch. Downstream,pymllm/orchestrator/scheduler_process.pylines 792-806 copy them onto eachRequest, so multi-request batches will stamp the same aggregatevit_prefill_ms,vit_prefill_tokens,llm_prefill_ms, andllm_decode_msonto every request.Also applies to: 1224-1232
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pymllm/models/qwen3_vl.py` around lines 1176 - 1190, The batch-level timing/count fields (vit_prefill_ms, vit_prefill_tokens, llm_prefill_ms, llm_decode_ms) written into forward_batch must be per-request, not single scalars; change the logic around forward_batch assignment so it stores per-request arrays. Specifically, compute per-sample image token counts from image_mask (use image_mask.sum(dim=1) or equivalent) and set vit_prefill_tokens as a list/array per request; split vit_prefill_ms (and llm_prefill_ms / llm_decode_ms at the other region noted around 1224-1232) across requests proportionally to each request's token count (or compute timings per-request earlier before batching) and write per-request values into forward_batch so scheduler_process.py will copy correct per-request metrics. Ensure references: forward_batch, vit_prefill_ms, vit_prefill_tokens, llm_prefill_ms, llm_decode_ms, image_mask, vision_embeds, input_embeds.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@mllm-kernel/mllm_kernel/cuda/csrc/gemm/int8/int8_scaled_mm_cutlass.cu`:
- Around line 328-343: Add strict same-device CUDA checks for all inputs before
extracting raw pointers and launching the kernel: verify scales_a, scales_b, and
bias (if provided) are CUDA tensors, have correct dtypes/numel as already
checked, and are on the same device as mat_a (use mat_a.device() /
mat_b.device() for comparison); also assert mat_b.device() == mat_a.device().
Update the validation block that currently checks mat_a and mat_b to include
these additional TORCH_CHECKs so the kernel never receives CPU or cross-GPU
pointers when launching on mat_a's stream.
In `@mllm-kernel/tests/test_int8_scaled_mm_cutlass.py`:
- Around line 39-45: The fixture cutlass_module should skip tests on GPUs below
SM80; before importing mllm_kernel.cuda.jit.int8_scaled_mm_cutlass, query
torch.cuda.get_device_properties(0).major and .minor, compute compute_capability
= major*10 + minor and if compute_capability < 80 call pytest.skip("SM80+ GPU
required") so the runtime check in _current_cuda_arch() doesn't raise
RuntimeError during import; add this check right after the existing CUDA
availability check in cutlass_module.
---
Outside diff comments:
In `@pymllm/models/qwen3_vl.py`:
- Around line 1176-1190: The batch-level timing/count fields (vit_prefill_ms,
vit_prefill_tokens, llm_prefill_ms, llm_decode_ms) written into forward_batch
must be per-request, not single scalars; change the logic around forward_batch
assignment so it stores per-request arrays. Specifically, compute per-sample
image token counts from image_mask (use image_mask.sum(dim=1) or equivalent) and
set vit_prefill_tokens as a list/array per request; split vit_prefill_ms (and
llm_prefill_ms / llm_decode_ms at the other region noted around 1224-1232)
across requests proportionally to each request's token count (or compute timings
per-request earlier before batching) and write per-request values into
forward_batch so scheduler_process.py will copy correct per-request metrics.
Ensure references: forward_batch, vit_prefill_ms, vit_prefill_tokens,
llm_prefill_ms, llm_decode_ms, image_mask, vision_embeds, input_embeds.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 056f15e3-9555-41ba-8a82-9d1ef20af346
📒 Files selected for processing (5)
mllm-kernel/mllm_kernel/cuda/csrc/gemm/int8/int8_scaled_mm_cutlass.cumllm-kernel/mllm_kernel/cuda/jit/int8_scaled_mm_cutlass.pymllm-kernel/tests/test_int8_scaled_mm_cutlass.pypymllm/models/qwen3_vl.pypymllm/tests/test_qwen3_vl_deepstack.py
🚧 Files skipped from review as they are similar to previous changes (1)
- mllm-kernel/mllm_kernel/cuda/jit/int8_scaled_mm_cutlass.py
| TORCH_CHECK(mat_a.is_cuda(), "mat_a must be CUDA tensor"); | ||
| TORCH_CHECK(mat_b.is_cuda(), "mat_b must be CUDA tensor"); | ||
| TORCH_CHECK(mat_a.dim() == 2, "mat_a must be 2D"); | ||
| TORCH_CHECK(mat_b.dim() == 2, "mat_b must be 2D"); | ||
| TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be row-major"); | ||
| TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be column-major"); | ||
| TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "shape mismatch"); | ||
| TORCH_CHECK(mat_a.size(1) % 16 == 0, "K must be multiple of 16"); | ||
| TORCH_CHECK(mat_b.size(1) % 8 == 0, "N must be multiple of 8"); | ||
| TORCH_CHECK(mat_a.scalar_type() == torch::kInt8, "mat_a must be Int8"); | ||
| TORCH_CHECK(mat_b.scalar_type() == torch::kInt8, "mat_b must be Int8"); | ||
| TORCH_CHECK(scales_a.numel() == mat_a.size(0), "scales_a size mismatch"); | ||
| TORCH_CHECK(scales_b.numel() == mat_b.size(1), "scales_b size mismatch"); | ||
| TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be fp32"); | ||
| TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be fp32"); | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
fd -t f "int8_scaled_mm_cutlass.cu" --exec cat -n {} \; | head -400Repository: UbiquitousLearning/mllm
Length of output: 19121
Add strict same-device CUDA checks for all kernel inputs.
The current validation at lines 328-342 checks only mat_a/mat_b CUDA residency. scales_a, scales_b, bias, and cross-GPU device equality are not enforced. Since raw pointers are extracted at lines 111-115 and the kernel launches on mat_a's stream (line 141), mixed-device or CPU tensors will cause the kernel to receive invalid pointers, resulting in segfaults or silent corruption.
🔧 Proposed fix
torch::Tensor int8_scaled_mm(
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
@@
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be CUDA tensor");
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be CUDA tensor");
+ TORCH_CHECK(scales_a.is_cuda(), "scales_a must be CUDA tensor");
+ TORCH_CHECK(scales_b.is_cuda(), "scales_b must be CUDA tensor");
+ TORCH_CHECK(mat_a.get_device() == mat_b.get_device(),
+ "mat_a and mat_b must be on the same CUDA device");
+ TORCH_CHECK(mat_a.get_device() == scales_a.get_device(),
+ "scales_a must be on the same CUDA device as mat_a");
+ TORCH_CHECK(mat_a.get_device() == scales_b.get_device(),
+ "scales_b must be on the same CUDA device as mat_a");
@@
if (bias) {
+ TORCH_CHECK(bias->is_cuda(), "bias must be CUDA tensor");
+ TORCH_CHECK(bias->get_device() == mat_a.get_device(),
+ "bias must be on the same CUDA device as mat_a");
TORCH_CHECK(bias->numel() == mat_b.size(1), "bias size mismatch");
TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match out_dtype");
}Also applies to: 353-356
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@mllm-kernel/mllm_kernel/cuda/csrc/gemm/int8/int8_scaled_mm_cutlass.cu` around
lines 328 - 343, Add strict same-device CUDA checks for all inputs before
extracting raw pointers and launching the kernel: verify scales_a, scales_b, and
bias (if provided) are CUDA tensors, have correct dtypes/numel as already
checked, and are on the same device as mat_a (use mat_a.device() /
mat_b.device() for comparison); also assert mat_b.device() == mat_a.device().
Update the validation block that currently checks mat_a and mat_b to include
these additional TORCH_CHECKs so the kernel never receives CPU or cross-GPU
pointers when launching on mat_a's stream.
| def cutlass_module(): | ||
| """Load CUTLASS module once for all tests.""" | ||
| pytest.importorskip("torch") | ||
| if not torch.cuda.is_available(): | ||
| pytest.skip("CUDA required") | ||
| from mllm_kernel.cuda.jit.int8_scaled_mm_cutlass import int8_scaled_mm | ||
| return int8_scaled_mm |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, find the CUTLASS kernel implementation
find . -name "*int8_scaled_mm_cutlass*" -type fRepository: UbiquitousLearning/mllm
Length of output: 251
🏁 Script executed:
# Search for SM version constraints in the codebase
rg "SM80|SM89|compute_capability|major.*8" --type py --type cpp --type cu -A 2 -B 2Repository: UbiquitousLearning/mllm
Length of output: 94
🏁 Script executed:
# Examine the test file to see the current fixture implementation
cat -n mllm-kernel/tests/test_int8_scaled_mm_cutlass.pyRepository: UbiquitousLearning/mllm
Length of output: 5833
🏁 Script executed:
# Check the kernel CUDA source for SM requirements
head -100 ./mllm-kernel/mllm_kernel/cuda/csrc/gemm/int8/int8_scaled_mm_cutlass.cuRepository: UbiquitousLearning/mllm
Length of output: 4167
🏁 Script executed:
# Check JIT wrapper for SM version enforcement
cat -n ./mllm-kernel/mllm_kernel/cuda/jit/int8_scaled_mm_cutlass.pyRepository: UbiquitousLearning/mllm
Length of output: 5092
Add SM80+ GPU check to fixture to skip on unsupported hardware.
The kernel enforces SM80-SM89 at runtime via _current_cuda_arch() and raises RuntimeError on incompatible GPUs. The fixture only checks CUDA availability, causing tests to fail instead of skip on pre-SM80 GPUs. Add a device capability check in the fixture before importing the kernel module.
Proposed fix
`@pytest.fixture`(scope="module")
def cutlass_module():
"""Load CUTLASS module once for all tests."""
pytest.importorskip("torch")
if not torch.cuda.is_available():
pytest.skip("CUDA required")
+ major, _ = torch.cuda.get_device_capability()
+ if major < 8:
+ pytest.skip("CUTLASS int8_scaled_mm requires SM80+ GPUs")
from mllm_kernel.cuda.jit.int8_scaled_mm_cutlass import int8_scaled_mm
return int8_scaled_mm🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@mllm-kernel/tests/test_int8_scaled_mm_cutlass.py` around lines 39 - 45, The
fixture cutlass_module should skip tests on GPUs below SM80; before importing
mllm_kernel.cuda.jit.int8_scaled_mm_cutlass, query
torch.cuda.get_device_properties(0).major and .minor, compute compute_capability
= major*10 + minor and if compute_capability < 80 call pytest.skip("SM80+ GPU
required") so the runtime check in _current_cuda_arch() doesn't raise
RuntimeError during import; add this check right after the existing CUDA
availability check in cutlass_module.
| _vit_t0 = time.perf_counter() | ||
| vision_features = ( | ||
| self.visual(pixel_values, grid_thw=image_grid_thw) | ||
| ) | ||
| vit_prefill_ms = (time.perf_counter() - _vit_t0) * 1000.0 |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, check if the file exists and read the target lines
wc -l pymllm/models/qwen3_vl.pyRepository: UbiquitousLearning/mllm
Length of output: 98
🏁 Script executed:
# Read lines around 1172-1176
sed -n '1165,1180p' pymllm/models/qwen3_vl.pyRepository: UbiquitousLearning/mllm
Length of output: 715
🏁 Script executed:
# Read lines around 1212-1222
sed -n '1205,1230p' pymllm/models/qwen3_vl.pyRepository: UbiquitousLearning/mllm
Length of output: 982
🏁 Script executed:
# Search for how timing fields are used downstream
rg "vit_prefill_ms|llm_.*_ms" --type py -B 2 -A 2Repository: UbiquitousLearning/mllm
Length of output: 17766
🏁 Script executed:
# Search for cuda.synchronize usage
rg "cuda\.synchronize|torch\.cuda\.synchronize" --type py -B 2 -A 2Repository: UbiquitousLearning/mllm
Length of output: 6006
🏁 Script executed:
# Check if device detection exists
rg "device\.type|\.cuda\(\)|to\(.*cuda" pymllm/models/qwen3_vl.py | head -20Repository: UbiquitousLearning/mllm
Length of output: 49
🏁 Script executed:
# Check ForwardBatch definition
rg "class ForwardBatch" --type py -A 20Repository: UbiquitousLearning/mllm
Length of output: 1524
🏁 Script executed:
# Look at model runner forward to understand context
sed -n '1,50p' pymllm/orchestrator/model_runner_process.pyRepository: UbiquitousLearning/mllm
Length of output: 1820
🏁 Script executed:
# Check the comment context in model_runner_process around CUDA timing
sed -n '90,140p' pymllm/orchestrator/model_runner_process.pyRepository: UbiquitousLearning/mllm
Length of output: 2210
🏁 Script executed:
# Look at how ForwardBatch timing fields are extracted and used
sed -n '95,180p' pymllm/orchestrator/model_runner_process.pyRepository: UbiquitousLearning/mllm
Length of output: 3619
🏁 Script executed:
# Verify if timing should be per-batch or per-request
rg "forward_batch\.llm_prefill_ms|forward_batch\.vit_prefill_ms" --type py -B 3 -A 3Repository: UbiquitousLearning/mllm
Length of output: 1106
🏁 Script executed:
# Check the exact code where batch timing is copied to per-request
sed -n '280,295p' pymllm/orchestrator/scheduler_process.pyRepository: UbiquitousLearning/mllm
Length of output: 670
🏁 Script executed:
# Verify if requests are batched together
rg "batch\.reqs|len.*batch.*reqs" pymllm/orchestrator/scheduler_process.py -B 2 -A 2 | head -40Repository: UbiquitousLearning/mllm
Length of output: 1394
🏁 Script executed:
# Check if timing should really be per-request or batch-level by examining how it's used
sed -n '320,340p' pymllm/orchestrator/scheduler_process.pyRepository: UbiquitousLearning/mllm
Length of output: 919
🏁 Script executed:
# Check if device type is available in the forward function context
sed -n '1140,1180p' pymllm/models/qwen3_vl.pyRepository: UbiquitousLearning/mllm
Length of output: 1772
🏁 Script executed:
# Check if self.visual has device attribute
rg "self\.visual\(" pymllm/models/qwen3_vl.py -B 2 -A 5 | head -30Repository: UbiquitousLearning/mllm
Length of output: 407
🏁 Script executed:
# Check the full forward function signature
sed -n '1100,1140p' pymllm/models/qwen3_vl.pyRepository: UbiquitousLearning/mllm
Length of output: 1925
🏁 Script executed:
# Check the entire flow more carefully - see if per-request vs batch timing is documented anywhere
rg "forward_batch\.|ForwardBatch" pymllm/engine/forward_batch.py | head -20Repository: UbiquitousLearning/mllm
Length of output: 252
🏁 Script executed:
# Examine timing field definitions in ForwardBatch
sed -n '1,150p' pymllm/engine/forward_batch.py | tail -60Repository: UbiquitousLearning/mllm
Length of output: 2508
🏁 Script executed:
# Find timing field definitions in ForwardBatch
rg "vit_prefill_ms|llm_prefill_ms|llm_decode_ms" pymllm/engine/forward_batch.py -B 2 -A 2Repository: UbiquitousLearning/mllm
Length of output: 49
🏁 Script executed:
# Check if ForwardBatch is truly batch-level (one timing value per batch)
sed -n '1,300p' pymllm/engine/forward_batch.py | grep -A 5 "vit_prefill\|llm_prefill\|llm_decode"Repository: UbiquitousLearning/mllm
Length of output: 49
🏁 Script executed:
# Look at complete ForwardBatch class definition
wc -l pymllm/engine/forward_batch.py
sed -n '80,200p' pymllm/engine/forward_batch.pyRepository: UbiquitousLearning/mllm
Length of output: 4514
🏁 Script executed:
# Check the rest of ForwardBatch for timing fields
sed -n '160,191p' pymllm/engine/forward_batch.pyRepository: UbiquitousLearning/mllm
Length of output: 1442
🏁 Script executed:
# Search for where timing fields are actually added to ForwardBatch
rg "ForwardBatch\(\)" pymllm/models/qwen3_vl.py -B 3 -A 3Repository: UbiquitousLearning/mllm
Length of output: 49
🏁 Script executed:
# Check if timing fields are added dynamically as attributes
rg "setattr.*forward_batch|forward_batch\.(vit|llm).*=" pymllm/models/qwen3_vl.py -B 1 -A 1Repository: UbiquitousLearning/mllm
Length of output: 440
🏁 Script executed:
# Check the complete implementation in qwen3_vl.py forward method
sed -n '1220,1235p' pymllm/models/qwen3_vl.pyRepository: UbiquitousLearning/mllm
Length of output: 695
🏁 Script executed:
# Check how many requests typically end up in a batch
rg "batch_size|len.*batch.*reqs" pymllm/orchestrator/scheduler_process.py -B 2 -A 2 | head -50Repository: UbiquitousLearning/mllm
Length of output: 1056
🏁 Script executed:
# Check if there's any per-request timing accumulation elsewhere
sed -n '440,480p' pymllm/orchestrator/scheduler_process.pyRepository: UbiquitousLearning/mllm
Length of output: 1646
🏁 Script executed:
# Verify: does the model runner set timing per-batch or per-request?
sed -n '150,200p' pymllm/orchestrator/model_runner_process.pyRepository: UbiquitousLearning/mllm
Length of output: 1991
🏁 Script executed:
# Final verification: Check if there are tests that validate timing correctness
rg "vit_prefill_ms|llm.*_ms" pymllm/tests/ --type py -lRepository: UbiquitousLearning/mllm
Length of output: 109
🏁 Script executed:
# Check test_qwen3_forward_timing.py to see what's actually being tested
cat pymllm/tests/test_qwen3_forward_timing.pyRepository: UbiquitousLearning/mllm
Length of output: 3068
Synchronize CUDA before measuring wall-clock timing around model calls.
time.perf_counter() around self.visual(...) and self.model(...) only measures kernel submission time on CUDA. Without device synchronization, these timers capture only the async launch overhead, not actual compute time. On GPU paths, vit_prefill_ms and llm_*_ms will be systematically underestimated.
Additionally, when a batch contains multiple requests, the batch-level timing value is copied to each request object downstream (in scheduler_process.py), misattributing a single batch duration to all individual requests within that batch.
⏱️ Suggested fix
+def _sync_if_cuda(device: torch.device) -> None:
+ if device.type == "cuda":
+ torch.cuda.synchronize(device)
+
...
- _vit_t0 = time.perf_counter()
+ _sync_if_cuda(self.visual.device)
+ _vit_t0 = time.perf_counter()
vision_features = (
self.visual(pixel_values, grid_thw=image_grid_thw)
)
- vit_prefill_ms = (time.perf_counter() - _vit_t0) * 1000.0
+ _sync_if_cuda(vision_features.device)
+ vit_prefill_ms = (time.perf_counter() - _vit_t0) * 1000.0
...
- _llm_t0 = time.perf_counter()
+ _sync_if_cuda(self.model.embed_tokens.weight.device)
+ _llm_t0 = time.perf_counter()
hidden_states = (
self.model(
input_ids,
@@
- _llm_ms = (time.perf_counter() - _llm_t0) * 1000.0
+ _sync_if_cuda(hidden_states.device)
+ _llm_ms = (time.perf_counter() - _llm_t0) * 1000.0Also applies to: 1212-1222
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@pymllm/README-ZH.md`:
- Around line 133-143: The BF16 example in the "BF16 原生模型" section is
inconsistent: change the launch command flag value from "--server.dtype float16"
to "--server.dtype bf16" (or "bfloat16" if that is the CLI's accepted token) so
the --server.dtype setting matches the BF16 section title; update the example
invocation in the README snippet (the line containing "--server.dtype float16")
to use the bf16 token used by the server CLI.
In `@pymllm/server/launch.py`:
- Around line 793-798: The code calls _maybe_add_debug_timing(payload,
result=results[-1], prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens) which mixes timing from only the last
result with token counts aggregated across all results; instead compute an
aggregated timing dict by iterating results and summing numeric timing fields
from each result.get("timing", {}) (or merge/accumulate per-field totals) and
pass that aggregated timing dict as result, or alternatively call
_maybe_add_debug_timing once per individual result with that result's own
prompt/completion token counts; update the call site in launch.py to replace
result=results[-1] with the aggregated timing dict (or loop per-result) and keep
prompt_tokens/completion_tokens consistent with the chosen aggregation approach.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ec5a8fe4-ef55-440c-9d17-88355b4ab1fe
⛔ Files ignored due to path filters (1)
docs/_static/img/pymllm-arch.pngis excluded by!**/*.png
📒 Files selected for processing (12)
docs/index.rstdocs/pymllm_runtime/developer_guide.rstdocs/pymllm_runtime/index.rstdocs/pymllm_runtime/kernels_and_acceleration.rstdocs/pymllm_runtime/models_and_quantization.rstdocs/pymllm_runtime/runtime_design.rstdocs/pymllm_runtime/setup_and_usage.rstpymllm/README-ZH.mdpymllm/README.mdpymllm/configs/server_config.pypymllm/server/launch.pypymllm/tests/test_server_debug_timing.py
✅ Files skipped from review due to trivial changes (9)
- docs/index.rst
- docs/pymllm_runtime/index.rst
- docs/pymllm_runtime/developer_guide.rst
- docs/pymllm_runtime/models_and_quantization.rst
- docs/pymllm_runtime/setup_and_usage.rst
- docs/pymllm_runtime/kernels_and_acceleration.rst
- pymllm/configs/server_config.py
- pymllm/README.md
- docs/pymllm_runtime/runtime_design.rst
| ### BF16 原生模型 | ||
|
|
||
| ```bash | ||
| cd <repo-root> | ||
|
|
||
| python3 -m pymllm.server.launch \ | ||
| --server.model_path <model-path> \ | ||
| --server.tokenizer_path <model-path> \ | ||
| --server.load_format safetensors \ | ||
| --server.dtype float16 \ | ||
| --server.host 0.0.0.0 \ |
There was a problem hiding this comment.
BF16 section uses float16 in the launch example.
The section title says BF16 native serving, but the command at Line 142 sets --server.dtype float16. Please align the flag with the section intent to avoid user confusion.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@pymllm/README-ZH.md` around lines 133 - 143, The BF16 example in the "BF16
原生模型" section is inconsistent: change the launch command flag value from
"--server.dtype float16" to "--server.dtype bf16" (or "bfloat16" if that is the
CLI's accepted token) so the --server.dtype setting matches the BF16 section
title; update the example invocation in the README snippet (the line containing
"--server.dtype float16") to use the bf16 token used by the server CLI.
| _maybe_add_debug_timing( | ||
| payload, | ||
| result=results[-1] if results else {}, | ||
| prompt_tokens=prompt_tokens, | ||
| completion_tokens=completion_tokens, | ||
| ) |
There was a problem hiding this comment.
Avoid mixing aggregated token counts with single-result timing.
At Line 795, timing is taken from results[-1], but prompt_tokens/completion_tokens are aggregated across all results. For list prompts, this produces misleading TPS.
Suggested fix
- _maybe_add_debug_timing(
- payload,
- result=results[-1] if results else {},
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- )
+ # debug_timing currently represents a single request/result.
+ # Avoid publishing mixed-scope TPS for batched prompt lists.
+ if len(results) == 1:
+ r0 = results[0]
+ _maybe_add_debug_timing(
+ payload,
+ result=r0,
+ prompt_tokens=r0.get("prompt_tokens", 0),
+ completion_tokens=r0.get("completion_tokens", 0),
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| _maybe_add_debug_timing( | |
| payload, | |
| result=results[-1] if results else {}, | |
| prompt_tokens=prompt_tokens, | |
| completion_tokens=completion_tokens, | |
| ) | |
| # debug_timing currently represents a single request/result. | |
| # Avoid publishing mixed-scope TPS for batched prompt lists. | |
| if len(results) == 1: | |
| r0 = results[0] | |
| _maybe_add_debug_timing( | |
| payload, | |
| result=r0, | |
| prompt_tokens=r0.get("prompt_tokens", 0), | |
| completion_tokens=r0.get("completion_tokens", 0), | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@pymllm/server/launch.py` around lines 793 - 798, The code calls
_maybe_add_debug_timing(payload, result=results[-1],
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) which mixes
timing from only the last result with token counts aggregated across all
results; instead compute an aggregated timing dict by iterating results and
summing numeric timing fields from each result.get("timing", {}) (or
merge/accumulate per-field totals) and pass that aggregated timing dict as
result, or alternatively call _maybe_add_debug_timing once per individual result
with that result's own prompt/completion token counts; update the call site in
launch.py to replace result=results[-1] with the aggregated timing dict (or loop
per-result) and keep prompt_tokens/completion_tokens consistent with the chosen
aggregation approach.
There was a problem hiding this comment.
🧹 Nitpick comments (2)
pymllm/models/qwen3.py (2)
65-96: 💤 Low valueDead code: non-fused path is unreachable.
use_fused_qkv = Trueis hardcoded, making lines 77-96 unreachable. Either parameterizeuse_fused_qkvin the constructor or remove the dead code path.♻️ Option A: Remove dead code
- self.use_fused_qkv = True - - if self.use_fused_qkv: - self.qkv_proj = MergedLinear( - hidden_size, - [self.q_size, self.kv_size, self.kv_size], - bias=attention_bias, - quant_method=_get_qm("qkv_proj"), - ) - self.q_proj = None - self.k_proj = None - self.v_proj = None - else: - self.qkv_proj = None - self.q_proj = Linear( - hidden_size, - self.q_size, - bias=attention_bias, - quant_method=_get_qm("q_proj"), - ) - self.k_proj = Linear( - hidden_size, - self.kv_size, - bias=attention_bias, - quant_method=_get_qm("k_proj"), - ) - self.v_proj = Linear( - hidden_size, - self.kv_size, - bias=attention_bias, - quant_method=_get_qm("v_proj"), - ) + self.qkv_proj = MergedLinear( + hidden_size, + [self.q_size, self.kv_size, self.kv_size], + bias=attention_bias, + quant_method=_get_qm("qkv_proj"), + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pymllm/models/qwen3.py` around lines 65 - 96, The non-fused attention branch is dead because use_fused_qkv is hardcoded True; update the class constructor to accept a parameter (e.g., use_fused_qkv) and set self.use_fused_qkv from that param (or remove the else block if fused-only is intended), then adjust initialization of qkv_proj, q_proj, k_proj, and v_proj accordingly (referencing self.use_fused_qkv, MergedLinear, and Linear to locate the code paths); ensure default behavior preserves current fused behavior by defaulting the new parameter to True and update any callers to pass False if they need the non-fused path.
349-351: 💤 Low valueMove import to module level.
Placing the import inside
forward()incurs lookup overhead on every call. If there's no circular import issue, move it to the top with other imports.♻️ Proposed fix
At the top of the file (around line 24):
from pymllm.executor.model_runner import LogitsProcessorOutputThen simplify lines 349-351:
- from pymllm.executor.model_runner import LogitsProcessorOutput - return LogitsProcessorOutput(next_token_logits=logits)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pymllm/models/qwen3.py` around lines 349 - 351, The local import of LogitsProcessorOutput inside the forward() method causes repeated lookup overhead; move "from pymllm.executor.model_runner import LogitsProcessorOutput" to the module-level imports at the top of pymllm/models/qwen3.py and remove the in-method import in forward(), keeping the return as return LogitsProcessorOutput(next_token_logits=logits); if a circular import prevents this, add a short comment explaining the circular dependency and leave the local import as-is.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@pymllm/models/qwen3.py`:
- Around line 65-96: The non-fused attention branch is dead because
use_fused_qkv is hardcoded True; update the class constructor to accept a
parameter (e.g., use_fused_qkv) and set self.use_fused_qkv from that param (or
remove the else block if fused-only is intended), then adjust initialization of
qkv_proj, q_proj, k_proj, and v_proj accordingly (referencing
self.use_fused_qkv, MergedLinear, and Linear to locate the code paths); ensure
default behavior preserves current fused behavior by defaulting the new
parameter to True and update any callers to pass False if they need the
non-fused path.
- Around line 349-351: The local import of LogitsProcessorOutput inside the
forward() method causes repeated lookup overhead; move "from
pymllm.executor.model_runner import LogitsProcessorOutput" to the module-level
imports at the top of pymllm/models/qwen3.py and remove the in-method import in
forward(), keeping the return as return
LogitsProcessorOutput(next_token_logits=logits); if a circular import prevents
this, add a short comment explaining the circular dependency and leave the local
import as-is.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2458cdf9-12ee-43fb-905a-aa8ce2e7126c
📒 Files selected for processing (10)
pymllm/layers/__init__.pypymllm/layers/linear.pypymllm/layers/mlp.pypymllm/models/qwen3.pypymllm/models/qwen3_vl.pypymllm/tests/test_linear_merged.pypymllm/tests/test_qwen3_residual_carry.pypymllm/tests/test_qwen3_vl_deepstack.pypymllm/tests/test_qwen3_vl_weight_loading.pypymllm/tests/test_qwen3_weight_loading.py
🚧 Files skipped from review as they are similar to previous changes (1)
- pymllm/tests/test_qwen3_vl_deepstack.py
Summary
This PR adds Jetson-oriented
pymllmsupport for the Qwen3 family, covering BF16 serving, W4A16/AWQ compressed-tensors serving, and W8A8 INT8 compressed-tensors serving. It also adds abench_one_batchentrypoint for model-level prefill/decode benchmarking inpymllm.The W8A8 path uses:
fp16/bf16 activation
-> Triton per-token INT8 activation quantization
-> CUTLASS int8_scaled_mm with fused per-row/per-column scales
-> fp16/bf16 output
What Changed
validated on Jetson Orin SM87
bench_one_batchbenchmark entrypoint in pymllm:ModelRunnerexecution without HTTP server, scheduler, tokenizer, or detokenizerKey Implementation Notes
torch.cuda.get_device_capability()and cached per architecture under~/.cache/mllm_kernel/cutlass_int8_scaled_mm/sm_XX/, for examplesm_87on Jetson Orin.
triton==3.6.0aarch64 wheel.pymllm.bench_one_batchmeasures model-level latency directly around prefill and per-token decode with CUDA synchronization.Notes for Reviewers
(SM80-SM89). The validated target is Jetson Orin SM87. Hopper / SM90 is not
included in this PR.
CUTLASS JIT cache may spend extra time compiling the extension. Later runs on
the same GPU architecture reuse the cache.
bench_one_batchresults are model-level synthetic text-only measurements and should not be mixed with HTTP serving metrics such as TTFT, ITL, TPOT, or E2E latency.Summary by CodeRabbit
New Features
Bug Fixes
Tests
Documentation
Chores