From fc1f1ee5c345e65c783d50f877057ae95b768580 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 3 Apr 2026 20:52:47 +0000 Subject: [PATCH 1/5] Adds a Jupyter notebook tutorial --- README.md | 2 + examples/amd_flashinfer_rocm_tutorial.ipynb | 446 ++++++++++++++++++++ examples/run_jupyter_server.sh | 40 ++ 3 files changed, 488 insertions(+) create mode 100644 examples/amd_flashinfer_rocm_tutorial.ipynb create mode 100755 examples/run_jupyter_server.sh diff --git a/README.md b/README.md index e24b372dbf..f1f0789623 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,8 @@ done * `single_prefill_example.py` - Single-sequence prefill attention * `batch_prefill_example.py` - Batched prefill attention * `batch_decode_example.py` - Batched decode attention +* `amd_flashinfer_rocm_tutorial.ipynb` - Jupyter tutorial: environment verification (`hip_utils`), AITER-backed prefill examples, and `logits_processor` on ROCm +* `run_jupyter_server.sh` - Start JupyterLab from the repo root (run inside your ROCm/FlashInfer environment or Docker container) ## Build from Source diff --git a/examples/amd_flashinfer_rocm_tutorial.ipynb b/examples/amd_flashinfer_rocm_tutorial.ipynb new file mode 100644 index 0000000000..6422cc9422 --- /dev/null +++ b/examples/amd_flashinfer_rocm_tutorial.ipynb @@ -0,0 +1,446 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "4fd518dd", + "metadata": {}, + "source": [ + "\n", + "\n", + "# Hands-On Tutorial: amd-flashinfer on AMD Instinct GPUs\n", + "\n", + "**amd-flashinfer** is the ROCm port of [FlashInfer](https://github.com/flashinfer-ai/flashinfer), a library of high-performance GPU operators for large language model (LLM) inference and serving. It provides fused attention, paged key–value cache layouts, sampling-related utilities, and related kernels adapted for AMD Instinct accelerators under the HIP programming model.\n", + "\n", + "**Purpose of this tutorial** \n", + "We demonstrate two practical pieces of an inference stack on ROCm: (1) **multi-head / grouped-query attention prefill**—the phase where query, key, and value tensors are combined to produce contextual representations for prompt tokens—and (2) the **`logits_processor`** pipeline, which applies temperature scaling, filtering (e.g. top-*k*), and sampling to raw logits from the final linear layer. Together, these illustrate how amd-flashinfer fits into a typical decode-time workflow alongside frameworks such as vLLM or SGLang.\n", + "\n", + "**Target audience** \n", + "Engineers and researchers who already use PyTorch on ROCm and want a concise, runnable introduction to amd-flashinfer APIs before integrating them into larger serving systems. Familiarity with attention mechanics (queries, keys, values, heads) is assumed.\n", + "\n", + "**Scope** \n", + "The exercises below use supported prefill and logits APIs in a single-node GPU environment. They do not cover every operator in the library. For broader feature coverage and installation matrices, see the [FlashInfer+ROCm README](https://github.com/ROCm/flashinfer/blob/main/README.md)." + ] + }, + { + "cell_type": "markdown", + "id": "44497ec3", + "metadata": {}, + "source": [ + "## Setting up amd-flashinfer\n", + "\n", + "The [project README](https://github.com/ROCm/flashinfer/blob/main/README.md) describes supported ROCm and PyTorch versions (e.g. ROCm 7.0.2–7.2; Torch+ROCm 2.8.0 / 2.9.1) and **Instinct** architectures (**gfx942**, **gfx950**). Align your environment with one of the following before running the cells below.\n", + "\n", + "**Option 1 — Pre-built Docker image (recommended for reproducibility)** \n", + "AMD publishes images on Docker Hub under [`rocm/flashinfer`](https://hub.docker.com/r/rocm/flashinfer/tags). Start a container with GPU devices exposed, for example:\n", + "\n", + "```bash\n", + "docker run -it --privileged --network=host --device=/dev/kfd --device=/dev/dri \\\n", + " --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \\\n", + " --ipc=host --shm-size 128G --name= \n", + "```\n", + "\n", + "Inside the image, activate the provided environment (often `micromamba activate base`) and confirm `python -c \"import flashinfer; print(flashinfer.__version__)\"`.\n", + "\n", + "**Option 2 — Wheel install on your own ROCm system** \n", + "Install the package from AMD’s index and pair it with a **ROCm-enabled** PyTorch wheel from [repo.radeon.com](https://repo.radeon.com) (match major/minor ROCm to your stack; pin the torch version so a CPU-only wheel is not pulled from PyPI):\n", + "\n", + "```bash\n", + "pip install amd-flashinfer --index-url https://pypi.amd.com/simple/\n", + "pip install torch== -f https://repo.radeon.com/rocm/manylinux/rocm-rel-\n", + "```\n", + "\n", + "**Option 3 — Build from source** \n", + "For development or custom builds, follow **Build from Source** in the README (development Dockerfile, editable install, or wheel build).\n", + "\n", + "Once one of these paths is in place, continue with the next section to **verify** the interpreter you selected for this notebook." + ] + }, + { + "cell_type": "markdown", + "id": "2c83b553", + "metadata": {}, + "source": [ + "## Verifying the runtime: PyTorch ROCm, FlashInfer, and GPU compatibility\n", + "\n", + "The following cell imports amd-flashinfer and uses helpers from [`flashinfer/hip_utils.py`](https://github.com/ROCm/flashinfer/blob/main/flashinfer/hip_utils.py) to confirm that PyTorch was built with ROCm support, to report the detected ROCm version when possible, and to list GPU agents that match **FlashInfer-supported** architecture names (via `rocminfo`). It also checks that at least one accelerator is visible through the HIP runtime (`torch.cuda` on ROCm).\n", + "\n", + "Run this cell **before** the worked examples; if it raises or warns, resolve the installation mismatch (README links above) before proceeding." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ebe68e6", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "import flashinfer\n", + "import flashinfer.hip_utils as hip_utils\n", + "\n", + "device = \"cuda:0\"\n", + "\n", + "hip_utils.check_torch_rocm_compatibility()\n", + "\n", + "assert torch.cuda.is_available(), (\n", + " \"No HIP device visible to PyTorch. Check drivers, container device flags, \"\n", + " \"and HIP_VISIBLE_DEVICES / CUDA_VISIBLE_DEVICES.\"\n", + ")\n", + "\n", + "n_gpus = hip_utils.get_available_gpu_count()\n", + "assert n_gpus >= 1, f\"Expected at least one visible GPU, got {n_gpus}\"\n", + "\n", + "print(\"flashinfer:\", getattr(flashinfer, \"__version__\", \"unknown\"))\n", + "print(\"torch:\", torch.__version__)\n", + "if hasattr(torch.version, \"hip\") and torch.version.hip:\n", + " print(\"PyTorch HIP / ROCm build:\", torch.version.hip)\n", + "\n", + "rocm_ver = hip_utils.get_system_rocm_version()\n", + "print(\"Detected system ROCm version:\", rocm_ver if rocm_ver else \"(could not detect)\")\n", + "\n", + "print(\n", + " \"Architectures with AMD FlashInfer ports:\",\n", + " \", \".join(hip_utils.FLASHINFER_SUPPORTED_ROCM_ARCHS),\n", + ")\n", + "\n", + "supported_idx = hip_utils.get_supported_device_indices()\n", + "print(\"GPU count (torch):\", n_gpus)\n", + "print(\"Device indices FlashInfer treats as supported Instinct (rocminfo):\", supported_idx)\n", + "if not supported_idx:\n", + " print(\n", + " \"Note: rocminfo did not report a matching agent, or rocminfo is unavailable. \"\n", + " \"You may still run on a supported GPU if the stack is correct.\"\n", + " )\n", + "\n", + "print(\"Using device:\", device)" + ] + }, + { + "cell_type": "markdown", + "id": "ac466beb", + "metadata": {}, + "source": [ + "## Part A — Attention prefill\n", + "\n", + "In transformer inference, **prefill** computes attention over prompt tokens so each position receives context from the rest of the sequence (subject to masking). amd-flashinfer exposes both **single-sequence** prefill and **batched prefill with a paged KV cache**, which is the layout many servers use to manage memory efficiently across layers and batches.\n", + "\n", + "**Backend integration.** amd-flashinfer offers [AITER](https://github.com/ROCm/aiter)—AMD’s SOTA kernel library for large language model operators on ROCm—as an optional **`backend=\"aiter\"`**. When the `aiter` package is present and the requested operator is integrated with AITER in this port, work **delegates** to AITER; otherwise the default HIP path applies. The following cells exercise the AITER-backed prefill path.\n", + "\n", + "**Requirements and layout.** Install the `aiter` package (e.g. `amd-aiter` from AMD’s PyPI); see [AITER Support](https://github.com/ROCm/flashinfer/blob/main/README.md#aiter-support) in the README. These calls expect **NHD** key/value layout (`[sequence, heads, dim]`). For batched paged prefill, allowed **page sizes** depend on your AITER build (typical CK values are listed in the README, e.g. 1, 16, 1024)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "680ef917", + "metadata": {}, + "outputs": [], + "source": [ + "import importlib.util\n", + "\n", + "if importlib.util.find_spec(\"aiter.ops\") is None:\n", + " raise RuntimeError(\n", + " \"The AITER Python package is required for the prefill examples in this tutorial.\\n\"\n", + " \"Install e.g.: pip install amd-aiter --index-url https://pypi.amd.com/simple/\\n\"\n", + " \"See README → AITER Support.\"\n", + " )\n", + "print(\"AITER package available (aiter.ops).\")" + ] + }, + { + "cell_type": "markdown", + "id": "b57de4ec", + "metadata": {}, + "source": [ + "### Warm-up run\n", + "\n", + "The first HIP kernel launch may trigger just-in-time compilation. Execute a small prefill once so later cells reflect steady-state timing if you benchmark." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b12fb8f6", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "\n", + "dtype = torch.float16\n", + "\n", + "_q = torch.randn(4, 8, 64, device=device, dtype=dtype)\n", + "_k = torch.randn(4, 4, 64, device=device, dtype=dtype)\n", + "_v = torch.randn(4, 4, 64, device=device, dtype=dtype)\n", + "_ = flashinfer.single_prefill_with_kv_cache(\n", + " _q,\n", + " _k,\n", + " _v,\n", + " causal=True,\n", + " kv_layout=\"NHD\",\n", + " pos_encoding_mode=\"NONE\",\n", + " backend=\"aiter\",\n", + ")\n", + "torch.cuda.synchronize()\n", + "print(\"Warm-up complete.\")" + ] + }, + { + "cell_type": "markdown", + "id": "e81a2c61", + "metadata": {}, + "source": [ + "### Single-sequence prefill\n", + "\n", + "We use **grouped-query attention** (eight query heads, four KV heads): `k` and `v` are repeated logically along the head axis inside the kernel. **Causal** masking keeps each query position from attending to future keys. The call below sets `backend=\"aiter\"` so prefill runs through AITER’s fused path; numerics are compared to a short PyTorch reference in float32." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d9545581", + "metadata": {}, + "outputs": [], + "source": [ + "def naive_attention_prefill(q, k, v, causal: bool):\n", + " qo_len, num_qo_heads, head_dim = q.shape\n", + " kv_len, num_kv_heads, _ = k.shape\n", + " sm_scale = 1.0 / math.sqrt(head_dim)\n", + " g = num_qo_heads // num_kv_heads\n", + " k = k.repeat_interleave(g, dim=1)\n", + " v = v.repeat_interleave(g, dim=1)\n", + " qt = q.transpose(0, 1)\n", + " kt = k.transpose(0, 1)\n", + " vt = v.transpose(0, 1)\n", + " scores = torch.matmul(qt, kt.transpose(1, 2)) * sm_scale\n", + " if causal:\n", + " mask = torch.tril(\n", + " torch.ones((qo_len, kv_len), device=q.device, dtype=torch.bool),\n", + " diagonal=(kv_len - qo_len),\n", + " )\n", + " scores = scores.masked_fill(~mask.unsqueeze(0), float(\"-inf\"))\n", + " attn = torch.softmax(scores, dim=-1)\n", + " return torch.matmul(attn, vt).transpose(0, 1)\n", + "\n", + "\n", + "qo_len, kv_len = 15, 127\n", + "num_qo_heads, num_kv_heads, head_dim = 8, 4, 64\n", + "\n", + "q = torch.randn(qo_len, num_qo_heads, head_dim, device=device, dtype=dtype)\n", + "k = torch.randn(kv_len, num_kv_heads, head_dim, device=device, dtype=dtype)\n", + "v = torch.randn(kv_len, num_kv_heads, head_dim, device=device, dtype=dtype)\n", + "\n", + "o = flashinfer.single_prefill_with_kv_cache(\n", + " q,\n", + " k,\n", + " v,\n", + " causal=True,\n", + " kv_layout=\"NHD\",\n", + " pos_encoding_mode=\"NONE\",\n", + " backend=\"aiter\",\n", + ")\n", + "\n", + "o_ref = naive_attention_prefill(q.float(), k.float(), v.float(), causal=True).to(dtype)\n", + "torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2)\n", + "print(\"Single-sequence prefill: OK — output shape\", tuple(o.shape))" + ] + }, + { + "cell_type": "markdown", + "id": "0e81b7e9", + "metadata": {}, + "source": [ + "### Batched prefill with a paged KV cache\n", + "\n", + "Serving systems often store keys and values in **fixed-size pages** and index them with indirection tables. The same logical batch can then be planned once and reused across layers. Here we build a minimal paged layout—`[total_pages, 2, page_size, num_kv_heads, head_dim]` for NHD, slot `0` for K and `1` for V—and run **`BatchPrefillWithPagedKVCacheWrapper`** with `backend=\"aiter\"` in the constructor. We use `page_size=16` and a 512 MiB workspace buffer, matching the [batch prefill example](https://github.com/ROCm/flashinfer/blob/main/examples/batch_prefill_example.py). To build confidence, the first sequence in the batch is checked against the single-sequence prefill call above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "022a1442", + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 4\n", + "kv_len = 128\n", + "qo_len = 128\n", + "page_size = 16\n", + "num_kv_heads = 4\n", + "num_qo_heads = 32\n", + "head_dim = 64\n", + "kv_layout = \"NHD\"\n", + "\n", + "q = torch.randn(\n", + " batch_size * qo_len, num_qo_heads, head_dim, device=device, dtype=dtype\n", + ")\n", + "q_indptr = torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * qo_len\n", + "\n", + "num_pages_per_seq = (kv_len + page_size - 1) // page_size\n", + "total_num_pages = num_pages_per_seq * batch_size\n", + "kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim]\n", + "kv_data = torch.randn(*kv_shape, device=device, dtype=dtype)\n", + "kv_indptr = torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * num_pages_per_seq\n", + "kv_indices = torch.arange(0, total_num_pages, device=device, dtype=torch.int32)\n", + "kv_last_page_len = torch.full(\n", + " (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device=device\n", + ")\n", + "\n", + "workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=device)\n", + "wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(\n", + " workspace_buffer, kv_layout, backend=\"aiter\"\n", + ")\n", + "wrapper.plan(\n", + " q_indptr,\n", + " kv_indptr,\n", + " kv_indices,\n", + " kv_last_page_len,\n", + " num_qo_heads,\n", + " num_kv_heads,\n", + " head_dim,\n", + " page_size,\n", + " causal=False,\n", + " pos_encoding_mode=\"NONE\",\n", + ")\n", + "\n", + "o_batch = wrapper.run(q, kv_data)\n", + "print(\"Batched paged prefill: output shape\", tuple(o_batch.shape))\n", + "\n", + "\n", + "def reconstruct_seq_from_paged_nhd(kv_tensor, kv_ip, kv_lpl, seq_idx, kv_slot):\n", + " chunks = []\n", + " start = int(kv_ip[seq_idx].item())\n", + " end = int(kv_ip[seq_idx + 1].item())\n", + " last_tokens = int(kv_lpl[seq_idx].item())\n", + " for p in range(start, end - 1):\n", + " chunks.append(kv_tensor[p, kv_slot, :, :, :].reshape(-1, num_kv_heads, head_dim))\n", + " p_last = end - 1\n", + " chunks.append(\n", + " kv_tensor[p_last, kv_slot, :last_tokens, :, :].reshape(-1, num_kv_heads, head_dim)\n", + " )\n", + " return torch.cat(chunks, dim=0)\n", + "\n", + "\n", + "k0 = reconstruct_seq_from_paged_nhd(kv_data, kv_indptr, kv_last_page_len, 0, 0)\n", + "v0 = reconstruct_seq_from_paged_nhd(kv_data, kv_indptr, kv_last_page_len, 0, 1)\n", + "assert k0.shape[0] == kv_len\n", + "\n", + "q0 = q[:qo_len]\n", + "o0 = flashinfer.single_prefill_with_kv_cache(\n", + " q0, k0, v0, causal=False, kv_layout=\"NHD\", pos_encoding_mode=\"NONE\", backend=\"aiter\"\n", + ")\n", + "torch.testing.assert_close(o_batch[:qo_len], o0, rtol=1e-2, atol=1e-2)\n", + "print(\"Consistency check (batch vs single for sequence 0): OK\")" + ] + }, + { + "cell_type": "markdown", + "id": "6a241cb4", + "metadata": {}, + "source": [ + "## Part B — Logits processing after the last layer\n", + "\n", + "After the model produces logits of shape `[batch, vocabulary]`, serving code often applies temperature, optional softmax, constrained sampling (top-*k*, top-*p*), and discrete **sampling**. amd-flashinfer exposes these steps as composable **`LogitsProcessor`** objects inside a **`LogitsPipe`**, which can be **compiled** for lower Python overhead. The snippet below uses random logits only to show the API; in production, logits would come from your transformer output projection.\n", + "\n", + "This portion is independent of Part A and does not require AITER." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64de5c6f", + "metadata": {}, + "outputs": [], + "source": [ + "from flashinfer.logits_processor import LogitsPipe, Softmax, Temperature, TopK, Sample\n", + "\n", + "torch.manual_seed(0)\n", + "batch_size_lp, vocab_size = 32, 4096\n", + "logits = torch.randn(batch_size_lp, vocab_size, device=device, dtype=torch.float32)\n", + "\n", + "pipe_eager = LogitsPipe([Temperature(), Softmax(), TopK(50), Sample()], compile=False)\n", + "pipe_compiled = LogitsPipe([Temperature(), Softmax(), TopK(50), Sample()], compile=True)\n", + "\n", + "samples_eager = pipe_eager(logits, temperature=0.8)\n", + "samples_compiled = pipe_compiled(logits, temperature=0.8)\n", + "\n", + "assert samples_eager.shape == (batch_size_lp,)\n", + "assert samples_compiled.shape == (batch_size_lp,)\n", + "assert (samples_eager >= 0).all() and (samples_eager < vocab_size).all()\n", + "assert (samples_compiled >= 0).all() and (samples_compiled < vocab_size).all()\n", + "print(\"LogitsPipe (eager vs compiled): sample indices in range — OK\")" + ] + }, + { + "cell_type": "markdown", + "id": "7428ddb4", + "metadata": {}, + "source": [ + "## Optional timing and further reading\n", + "\n", + "If you measure latency, keep the warm-up from Part A in mind. **Decode-phase** attention (one new token per step against a growing cache) is available in amd-flashinfer’s HIP port; tuning it for production is typically done inside **serving frameworks** and complementary libraries. For MLA decode kernels on Instinct, see the [AITER MLA notebook](https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/gpu_dev_optimize/aiter_mla_decode_kernel.html) linked earlier.\n", + "\n", + "The cell below reports a rough wall-clock time for repeated single-sequence **AITER** prefill using the tensor sizes from Part A (not a rigorous benchmark)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0088eec7", + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "\n", + "def bench_single_prefill(n_warmup=3, n_iter=20):\n", + " q = torch.randn(qo_len, num_qo_heads, head_dim, device=device, dtype=dtype)\n", + " k = torch.randn(kv_len, num_kv_heads, head_dim, device=device, dtype=dtype)\n", + " v = torch.randn(kv_len, num_kv_heads, head_dim, device=device, dtype=dtype)\n", + " for _ in range(n_warmup):\n", + " flashinfer.single_prefill_with_kv_cache(\n", + " q, k, v, causal=True, kv_layout=\"NHD\", pos_encoding_mode=\"NONE\", backend=\"aiter\"\n", + " )\n", + " torch.cuda.synchronize()\n", + " t0 = time.perf_counter()\n", + " for _ in range(n_iter):\n", + " flashinfer.single_prefill_with_kv_cache(\n", + " q, k, v, causal=True, kv_layout=\"NHD\", pos_encoding_mode=\"NONE\", backend=\"aiter\"\n", + " )\n", + " torch.cuda.synchronize()\n", + " t1 = time.perf_counter()\n", + " return (t1 - t0) / n_iter * 1000\n", + "\n", + "\n", + "ms = bench_single_prefill()\n", + "print(f\"~{ms:.3f} ms per single-sequence prefill (AITER path), mean over iterations\")" + ] + }, + { + "cell_type": "markdown", + "id": "ee64cba8", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "This notebook introduced procedures to confirm a ROCm-enabled PyTorch stack with amd-flashinfer, including GPU and ROCm compatibility checks via `flashinfer.hip_utils`. It demonstrated **attention prefill** using the **AITER** backend for a single sequence and for **batched paged** key–value cache layouts, and it illustrated the **`LogitsPipe`** API for temperature scaling, filtering, and sampling on model logits. For supported platforms, container images, and the complete operator matrix, consult the [FlashInfer+ROCm README](https://github.com/ROCm/flashinfer/blob/main/README.md); for API details shared with upstream FlashInfer, see the [FlashInfer documentation](https://docs.flashinfer.ai)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/run_jupyter_server.sh b/examples/run_jupyter_server.sh new file mode 100755 index 0000000000..c594395d39 --- /dev/null +++ b/examples/run_jupyter_server.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash +# SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Start JupyterLab for amd_flashinfer_rocm_tutorial.ipynb (or any notebook in the repo). +# +# Usage (activate your ROCm/flashinfer env first, e.g. inside rocm/flashinfer Docker): +# cd /path/to/flashinfer +# ./examples/run_jupyter_server.sh +# +# Remote / SSH: listen on all interfaces (default) and forward the port: +# ssh -L 8888:localhost:8888 user@node +# Then open the printed http://127.0.0.1:8888/lab?token=... in a browser. +# +# Override port: JUPYTER_PORT=8890 ./examples/run_jupyter_server.sh + +set -euo pipefail + +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +cd "$ROOT" + +if ! python -c "import jupyterlab" 2>/dev/null; then + echo "Installing jupyterlab into the current Python environment..." + pip install jupyterlab +fi + +PORT="${JUPYTER_PORT:-8888}" +IP="${JUPYTER_IP:-0.0.0.0}" + +echo "Starting JupyterLab from: $ROOT" +echo " URL: http://127.0.0.1:${PORT}/lab (use SSH -L if remote)" +echo " Stop: Ctrl+C" +echo "" + +exec python -m jupyter lab \ + --no-browser \ + --ip="$IP" \ + --port="$PORT" \ + --notebook-dir="$ROOT" \ + "$@" From 3a0d41493ae5f0c38544619ecd7899961a06a6c9 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Mon, 13 Apr 2026 14:02:46 -0500 Subject: [PATCH 2/5] Update README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Diptorup Deb --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f1f0789623..a30d42d942 100644 --- a/README.md +++ b/README.md @@ -119,8 +119,8 @@ done * `single_prefill_example.py` - Single-sequence prefill attention * `batch_prefill_example.py` - Batched prefill attention * `batch_decode_example.py` - Batched decode attention -* `amd_flashinfer_rocm_tutorial.ipynb` - Jupyter tutorial: environment verification (`hip_utils`), AITER-backed prefill examples, and `logits_processor` on ROCm -* `run_jupyter_server.sh` - Start JupyterLab from the repo root (run inside your ROCm/FlashInfer environment or Docker container) +* `examples/amd_flashinfer_rocm_tutorial.ipynb` - Jupyter tutorial: environment verification (`hip_utils`), AITER-backed prefill examples, and `logits_processor` on ROCm +* `examples/run_jupyter_server.sh` - Start JupyterLab from the repo root (run inside your ROCm/FlashInfer environment or Docker container) ## Build from Source From 4d00bf982ea5c8383de82d282d8fcdf000c62836 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Mon, 13 Apr 2026 14:03:01 -0500 Subject: [PATCH 3/5] Update examples/run_jupyter_server.sh Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Diptorup Deb --- examples/run_jupyter_server.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/run_jupyter_server.sh b/examples/run_jupyter_server.sh index c594395d39..681f721781 100755 --- a/examples/run_jupyter_server.sh +++ b/examples/run_jupyter_server.sh @@ -21,7 +21,7 @@ cd "$ROOT" if ! python -c "import jupyterlab" 2>/dev/null; then echo "Installing jupyterlab into the current Python environment..." - pip install jupyterlab + python -m pip install jupyterlab fi PORT="${JUPYTER_PORT:-8888}" From 1fd58ad1b75d7c92dc9d3b80f7583d72460464db Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Mon, 13 Apr 2026 14:03:42 -0500 Subject: [PATCH 4/5] Update examples/amd_flashinfer_rocm_tutorial.ipynb Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Diptorup Deb --- examples/amd_flashinfer_rocm_tutorial.ipynb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/amd_flashinfer_rocm_tutorial.ipynb b/examples/amd_flashinfer_rocm_tutorial.ipynb index 6422cc9422..bf8990ed9b 100644 --- a/examples/amd_flashinfer_rocm_tutorial.ipynb +++ b/examples/amd_flashinfer_rocm_tutorial.ipynb @@ -361,11 +361,11 @@ "batch_size_lp, vocab_size = 32, 4096\n", "logits = torch.randn(batch_size_lp, vocab_size, device=device, dtype=torch.float32)\n", "\n", - "pipe_eager = LogitsPipe([Temperature(), Softmax(), TopK(50), Sample()], compile=False)\n", - "pipe_compiled = LogitsPipe([Temperature(), Softmax(), TopK(50), Sample()], compile=True)\n", + "pipe_eager = LogitsPipe([Temperature(), Softmax(), TopK(), Sample()], compile=False)\n", + "pipe_compiled = LogitsPipe([Temperature(), Softmax(), TopK(), Sample()], compile=True)\n", "\n", - "samples_eager = pipe_eager(logits, temperature=0.8)\n", - "samples_compiled = pipe_compiled(logits, temperature=0.8)\n", + "samples_eager = pipe_eager(logits, temperature=0.8, top_k=50)\n", + "samples_compiled = pipe_compiled(logits, temperature=0.8, top_k=50)\n", "\n", "assert samples_eager.shape == (batch_size_lp,)\n", "assert samples_compiled.shape == (batch_size_lp,)\n", From d6ee47fe6064a5519f1938013057976e24eb7c5f Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Mon, 13 Apr 2026 20:25:33 +0000 Subject: [PATCH 5/5] Updated notebook --- examples/amd_flashinfer_rocm_tutorial.ipynb | 225 +++++++++++++++++--- examples/run_jupyter_server.sh | 17 +- 2 files changed, 211 insertions(+), 31 deletions(-) diff --git a/examples/amd_flashinfer_rocm_tutorial.ipynb b/examples/amd_flashinfer_rocm_tutorial.ipynb index bf8990ed9b..5f159fc3af 100644 --- a/examples/amd_flashinfer_rocm_tutorial.ipynb +++ b/examples/amd_flashinfer_rocm_tutorial.ipynb @@ -72,10 +72,39 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "0ebe68e6", - "metadata": {}, - "outputs": [], + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-13T19:55:15.549091Z", + "iopub.status.busy": "2026-04-13T19:55:15.548883Z", + "iopub.status.idle": "2026-04-13T19:55:19.421222Z", + "shell.execute_reply": "2026-04-13T19:55:19.420842Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[aiter] import [module_aiter_enum] under /home/AMD/diptodeb/micromamba/envs/flashinfer-rocm-devel/lib/python3.12/site-packages/aiter/jit/module_aiter_enum.so\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "flashinfer: 0.5.3+amd.1.dev9\n", + "torch: 2.9.1+rocm7.2.0.git7e1940d4\n", + "PyTorch HIP / ROCm build: 7.2.26015-fc0010cf6a\n", + "Detected system ROCm version: 7.2.0\n", + "Architectures with AMD FlashInfer ports: gfx942, gfx950\n", + "GPU count (torch): 1\n", + "Device indices FlashInfer treats as supported Instinct (rocminfo): (0,)\n", + "Using device: cuda:0\n" + ] + } + ], "source": [ "import torch\n", "\n", @@ -135,10 +164,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "680ef917", - "metadata": {}, - "outputs": [], + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-13T19:55:19.423607Z", + "iopub.status.busy": "2026-04-13T19:55:19.423451Z", + "iopub.status.idle": "2026-04-13T19:55:19.425186Z", + "shell.execute_reply": "2026-04-13T19:55:19.425020Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "AITER package available (aiter.ops).\n" + ] + } + ], "source": [ "import importlib.util\n", "\n", @@ -163,18 +207,51 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "b12fb8f6", - "metadata": {}, - "outputs": [], + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-13T19:55:19.427227Z", + "iopub.status.busy": "2026-04-13T19:55:19.427151Z", + "iopub.status.idle": "2026-04-13T19:55:19.475351Z", + "shell.execute_reply": "2026-04-13T19:55:19.474796Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[aiter] import [module_fmha_v3_varlen_fwd] under /home/AMD/diptodeb/micromamba/envs/flashinfer-rocm-devel/lib/python3.12/site-packages/aiter/jit/module_fmha_v3_varlen_fwd.so\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[aiter] type hints mismatch, override to --> fmha_v3_varlen_fwd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, min_seqlen_q: int, dropout_p: float, softmax_scale: float, logits_soft_cap: float, zero_tensors: bool, is_causal: bool, window_size_left: int, window_size_right: int, return_softmax_lse: bool, return_dropout_randval: bool, how_v3_bf16_cvt: int, out: Optional[torch.Tensor] = None, block_table: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, gen: Optional[torch.Generator] = None, cu_seqlens_q_padded: Optional[torch.Tensor] = None, cu_seqlens_k_padded: Optional[torch.Tensor] = None) -> List[torch.Tensor]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[aiter] hipModuleLoad: /home/AMD/diptodeb/micromamba/envs/flashinfer-rocm-devel/lib/python3.12/site-packages/aiter_meta/hsa//gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtna_group.co GetFunction: _ZN5aiter37fmha_fwd_hd128_bf16_causal_rtna_groupE Success\n", + "Warm-up complete.\n" + ] + } + ], "source": [ "import math\n", "\n", - "dtype = torch.float16\n", + "# bfloat16 + head_dim=128 routes through the pre-built fmha_v3_varlen_fwd kernel\n", + "# (module_fmha_v3_varlen_fwd.so), which ships with the amd_aiter wheel and requires\n", + "# no JIT compilation at startup.\n", + "dtype = torch.bfloat16\n", "\n", - "_q = torch.randn(4, 8, 64, device=device, dtype=dtype)\n", - "_k = torch.randn(4, 4, 64, device=device, dtype=dtype)\n", - "_v = torch.randn(4, 4, 64, device=device, dtype=dtype)\n", + "_q = torch.randn(4, 8, 128, device=device, dtype=dtype)\n", + "_k = torch.randn(4, 4, 128, device=device, dtype=dtype)\n", + "_v = torch.randn(4, 4, 128, device=device, dtype=dtype)\n", "_ = flashinfer.single_prefill_with_kv_cache(\n", " _q,\n", " _k,\n", @@ -200,10 +277,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "d9545581", - "metadata": {}, - "outputs": [], + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-13T19:55:19.476989Z", + "iopub.status.busy": "2026-04-13T19:55:19.476906Z", + "iopub.status.idle": "2026-04-13T19:55:23.370214Z", + "shell.execute_reply": "2026-04-13T19:55:23.369779Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Single-sequence prefill: OK — output shape (15, 8, 128)\n" + ] + } + ], "source": [ "def naive_attention_prefill(q, k, v, causal: bool):\n", " qo_len, num_qo_heads, head_dim = q.shape\n", @@ -227,7 +319,7 @@ "\n", "\n", "qo_len, kv_len = 15, 127\n", - "num_qo_heads, num_kv_heads, head_dim = 8, 4, 64\n", + "num_qo_heads, num_kv_heads, head_dim = 8, 4, 128\n", "\n", "q = torch.randn(qo_len, num_qo_heads, head_dim, device=device, dtype=dtype)\n", "k = torch.randn(kv_len, num_kv_heads, head_dim, device=device, dtype=dtype)\n", @@ -260,10 +352,41 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "022a1442", - "metadata": {}, - "outputs": [], + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-13T19:55:23.372087Z", + "iopub.status.busy": "2026-04-13T19:55:23.371972Z", + "iopub.status.idle": "2026-04-13T19:55:23.426578Z", + "shell.execute_reply": "2026-04-13T19:55:23.426314Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2026-04-13 14:55:23,409 - WARNING - prefill_rocm.py:2112 - flashinfer.jit: enable_pdl is not supported in the HIP/ROCm backend and will be ignored. This parameter is only effective on CUDA devices with sm_90+.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2026-04-13 14:55:23,410 - INFO - prefill_rocm.py:389 - flashinfer.jit: ###### AITER backend is used for batch prefill ######\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[aiter] hipModuleLoad: /home/AMD/diptodeb/micromamba/envs/flashinfer-rocm-devel/lib/python3.12/site-packages/aiter_meta/hsa//gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtna_group.co GetFunction: _ZN5aiter30fmha_fwd_hd128_bf16_rtna_groupE Success\n", + "Batched paged prefill: output shape (512, 32, 128)\n", + "Consistency check (batch vs single for sequence 0): OK\n" + ] + } + ], "source": [ "batch_size = 4\n", "kv_len = 128\n", @@ -271,7 +394,7 @@ "page_size = 16\n", "num_kv_heads = 4\n", "num_qo_heads = 32\n", - "head_dim = 64\n", + "head_dim = 128\n", "kv_layout = \"NHD\"\n", "\n", "q = torch.randn(\n", @@ -304,6 +427,7 @@ " page_size,\n", " causal=False,\n", " pos_encoding_mode=\"NONE\",\n", + " q_data_type=torch.bfloat16,\n", ")\n", "\n", "o_batch = wrapper.run(q, kv_data)\n", @@ -350,10 +474,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "64de5c6f", - "metadata": {}, - "outputs": [], + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-13T19:55:23.428223Z", + "iopub.status.busy": "2026-04-13T19:55:23.428123Z", + "iopub.status.idle": "2026-04-13T19:55:41.469169Z", + "shell.execute_reply": "2026-04-13T19:55:41.468832Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Pipeline is not compiled, running discrete ops.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LogitsPipe (eager vs compiled): sample indices in range — OK\n" + ] + } + ], "source": [ "from flashinfer.logits_processor import LogitsPipe, Softmax, Temperature, TopK, Sample\n", "\n", @@ -388,10 +534,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "0088eec7", - "metadata": {}, - "outputs": [], + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-13T19:55:41.470606Z", + "iopub.status.busy": "2026-04-13T19:55:41.470521Z", + "iopub.status.idle": "2026-04-13T19:55:41.474077Z", + "shell.execute_reply": "2026-04-13T19:55:41.473879Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "~0.032 ms per single-sequence prefill (AITER path), mean over iterations\n" + ] + } + ], "source": [ "import time\n", "\n", @@ -437,8 +598,16 @@ "name": "python3" }, "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", "name": "python", - "pygments_lexer": "ipython3" + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" } }, "nbformat": 4, diff --git a/examples/run_jupyter_server.sh b/examples/run_jupyter_server.sh index 681f721781..289aa35e5d 100755 --- a/examples/run_jupyter_server.sh +++ b/examples/run_jupyter_server.sh @@ -8,11 +8,14 @@ # cd /path/to/flashinfer # ./examples/run_jupyter_server.sh # -# Remote / SSH: listen on all interfaces (default) and forward the port: +# By default the server listens on 127.0.0.1 (localhost only). +# Remote / SSH port-forwarding: forward the port from your local machine: # ssh -L 8888:localhost:8888 user@node # Then open the printed http://127.0.0.1:8888/lab?token=... in a browser. # -# Override port: JUPYTER_PORT=8890 ./examples/run_jupyter_server.sh +# Override port: JUPYTER_PORT=8890 ./examples/run_jupyter_server.sh +# Override IP: JUPYTER_IP=0.0.0.0 ./examples/run_jupyter_server.sh +# (setting JUPYTER_IP=0.0.0.0 binds on all interfaces; only do this intentionally) set -euo pipefail @@ -25,7 +28,15 @@ if ! python -c "import jupyterlab" 2>/dev/null; then fi PORT="${JUPYTER_PORT:-8888}" -IP="${JUPYTER_IP:-0.0.0.0}" +IP="${JUPYTER_IP:-127.0.0.1}" + +if [[ "$IP" == "0.0.0.0" ]]; then + echo "WARNING: JUPYTER_IP=0.0.0.0 binds JupyterLab on ALL network interfaces." + echo " This exposes the server beyond localhost, which is risky on shared" + echo " machines or when --network=host is in use." + echo " Prefer the default 127.0.0.1 and use SSH -L for remote access." + echo "" +fi echo "Starting JupyterLab from: $ROOT" echo " URL: http://127.0.0.1:${PORT}/lab (use SSH -L if remote)"