Skip to content

Enable torch.compile under a flag#210

Merged
diptorupd merged 10 commits intoROCm:amd-integrationfrom
demandal25:enable-torch-compile
Apr 3, 2026
Merged

Enable torch.compile under a flag#210
diptorupd merged 10 commits intoROCm:amd-integrationfrom
demandal25:enable-torch-compile

Conversation

@demandal25
Copy link
Copy Markdown
Collaborator

@demandal25 demandal25 commented Mar 31, 2026

Adds an environment flag FLASHINFER_USE_TORCH_CUSTOM_OPS, setting which enables the torch.compile in flashinfer. By default, it's disabled in the upstream (

# NOTE(Zihao): torch.library.custom_op has significant overhead as mentioned in the following link
).

How to use it

`FLASHINFER_USE_TORCH_CUSTOM_OPS=1 <your command>

Testing

Adds a pytest for it. tests/rocm_tests/test_torch_compile_hip.py to be run automatically when we run pytest.

All pytests passed for both the cases:

  • FLASHINFER_USE_TORCH_CUSTOM_OPS=0 pytest, or simply pytest, and
  • FLASHINFER_USE_TORCH_CUSTOM_OPS=1 pytest
image

Copilot AI review requested due to automatic review settings March 31, 2026 14:40
@demandal25 demandal25 force-pushed the enable-torch-compile branch from 4518269 to 8f2cd78 Compare March 31, 2026 14:45
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds opt-in support for torch.compile by wrapping FlashInfer kernels in opaque torch.library.custom_ops (to avoid Dynamo tracing into extension code), plus a small ROCm benchmark/verification script and optional local Git hook scaffolding.

Changes:

  • Introduce FLASHINFER_USE_TORCH_CUSTOM_OPS env flag (checked at import time) to enable/disable torch.library.custom_op wrapping.
  • Export use_torch_custom_ops_enabled() from the top-level flashinfer package for both CUDA and ROCm/HIP builds.
  • Add a ROCm micro-benchmark script and optional .githooks pre-push protection for amd-integration.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
flashinfer/utils.py Adds env-gated custom-op enablement + helper to query the setting; conditionally registers custom_op/register_fake.
flashinfer/__init__.py Re-exports use_torch_custom_ops_enabled() for CUDA and HIP branches.
scripts/verify_enable_torch_compile.py New ROCm-focused eager vs torch.compile micro-benchmark for append_paged_kv_cache.
.githooks/README Documents how to enable optional local Git hooks and intended behavior.
.githooks/pre-push New local pre-push hook intended to block pushing/updating amd-integration.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread .githooks/pre-push Outdated
Comment thread flashinfer/utils.py Outdated
Comment thread scripts/verify_enable_torch_compile.py Outdated
@demandal25 demandal25 force-pushed the enable-torch-compile branch from 8f2cd78 to 768e946 Compare March 31, 2026 14:54
@demandal25 demandal25 force-pushed the enable-torch-compile branch from 768e946 to 4e3bb90 Compare March 31, 2026 15:00
@demandal25 demandal25 marked this pull request as draft March 31, 2026 15:01
@demandal25 demandal25 force-pushed the enable-torch-compile branch from 9eadd10 to c9ef604 Compare April 2, 2026 01:31
@demandal25 demandal25 marked this pull request as ready for review April 3, 2026 02:48
Copilot AI review requested due to automatic review settings April 3, 2026 02:48
@demandal25 demandal25 changed the title Enable torch compile Enable torch.compile under a flag Apr 3, 2026
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread flashinfer/utils.py
Comment thread scripts/verify_enable_torch_compile.py Outdated
@demandal25 demandal25 requested a review from diptorupd April 3, 2026 03:03
Comment thread flashinfer/utils.py
if not _USE_TORCH_CUSTOM_OPS:
return _guard_compile(f, name)
with contextlib.suppress(ValueError, TypeError):
return torch.library.custom_op(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to suppress the ValueError, TypeError here? If the custom_op creation fails should not we raise? Copilot review makes sense to me.

Comment thread scripts/verify_enable_torch_compile.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not convert it into a pytest? You can run a specific test that has the envvar set and then unset it

Comment thread flashinfer/get_include_paths.py Outdated
"""
include_dir = os.path.join(_get_package_root_dir(), "include")
return str(include_dir)
package_dir = pathlib.Path(_get_package_root_dir()).resolve()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this change required? Running from a source tree was never supported even before the CMake editable install fix. Why do we need to enable it? IMO it is a non-standard way and a better way to deal with it will be to test sdist packages and make sure they work. The necessary sdist tooling using scikit-build-core is already there. I have never verified it.

…ytest

- Remove contextlib.suppress in register_custom_op so registration
  failures surface instead of silently falling back (Copilot + diptorupd)
- Fold torch >= 2.4 check into _USE_TORCH_CUSTOM_OPS so
  use_torch_custom_ops_enabled() reports effective behavior (Copilot)
- Replace scripts/verify_enable_torch_compile.py with
  tests/rocm_tests/test_torch_compile_hip.py covering eager and
  torch.compile paths in subprocess isolation (diptorupd)

Made-with: Cursor
Drop the repo-root include fallback per reviewer feedback: running from
a source tree without an install is not a supported workflow.

Made-with: Cursor
- Use device="cuda" (not "hip") — PyTorch on ROCm uses "cuda" as the
  device string
- Check HIP availability properly in pytestmark
- Accept any exception from torch.compile when custom ops are disabled,
  since the actual error (TorchRuntimeError from FakeTensor) differs
  from the _guard_compile message

Made-with: Cursor
@demandal25 demandal25 requested a review from diptorupd April 3, 2026 18:48
torch.library.custom_op cannot infer schemas for parameters like
Optional[torch.Generator]. Instead of crashing, fall back to
_guard_compile so these ops still get torch.compile protection.

Made-with: Cursor
Copilot AI review requested due to automatic review settings April 3, 2026 19:09
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +90 to +100
def test_eager_with_custom_ops():
"""append_paged_kv_cache works in eager mode with custom ops enabled."""
snippet = _PREAMBLE + textwrap.dedent(
"""\
assert flashinfer.use_torch_custom_ops_enabled()
append(k, v)
print("OK")
"""
)
result = _run_snippet(snippet, {"FLASHINFER_USE_TORCH_CUSTOM_OPS": "1"})
assert result.returncode == 0, f"eager (custom ops on) failed:\n{result.stderr}"
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_eager_with_custom_ops asserts flashinfer.use_torch_custom_ops_enabled() is true when the env var is set, but _USE_TORCH_CUSTOM_OPS in flashinfer/utils.py is additionally gated on torch>=2.4. On torch 2.3.x this test will fail even though the env var is "1". Consider adding the same torch>=2.4 skip condition as the torch.compile tests, or relaxing the assertion to account for the version gate.

Copilot uses AI. Check for mistakes.
Comment thread flashinfer/decode_rocm.py
Comment on lines 1255 to 1262
plan_info = self._plan_info
assert plan_info is not None, "plan info is not initialized"

run_args = [
self._float_workspace_buffer,
self._int_workspace_buffer,
self._plan_info,
plan_info,
q,
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The non-tensor-core branch now asserts self._plan_info is initialized before building run_args, but the tensor-core branch still passes self._plan_info directly without a guard. This can lead to a less clear failure mode if run() is called before plan() (or if planning failed). Consider adding the same assert self._plan_info is not None (or raising a ValueError) in the tensor-core path for consistency.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

@diptorupd diptorupd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! LGTM.

@diptorupd diptorupd merged commit 2ecbe16 into ROCm:amd-integration Apr 3, 2026
7 of 8 checks passed
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread flashinfer/utils.py
Comment on lines +38 to +47
"""Normalize JIT ``plan`` output to int64 on ``device`` for custom-op boundaries.

``plan()`` returns a tensor on HIP/CUDA; some call sites may still use a Python
sequence — accept both.
"""
if isinstance(plan_info, torch.Tensor):
if plan_info.dtype == torch.int64 and plan_info.device == device:
return plan_info
return plan_info.to(device=device, dtype=torch.int64)
return torch.tensor(list(plan_info), dtype=torch.int64, device=device)
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plan_info_vec_as_tensor() currently converts an existing plan_info tensor to device=device. On ROCm, the underlying C++ ops (e.g. flashinfer/csrc_rocm/batch_prefill.cu, flashinfer/csrc_rocm/batch_decode.cu) read plan_info_vec.data_ptr<int64_t>() on the host side to build a std::vector<int64_t>, which requires plan_info_vec to be a CPU tensor. Moving it to a CUDA/HIP device will make data_ptr point to device memory and can lead to crashes or incorrect reads. Please keep plan_info on CPU (and optionally add a TORCH_CHECK(plan_info.is_cpu()) in the C++ path / adjust this helper’s API+doc to reflect that HIP plan() returns a CPU tensor).

Suggested change
"""Normalize JIT ``plan`` output to int64 on ``device`` for custom-op boundaries.
``plan()`` returns a tensor on HIP/CUDA; some call sites may still use a Python
sequenceaccept both.
"""
if isinstance(plan_info, torch.Tensor):
if plan_info.dtype == torch.int64 and plan_info.device == device:
return plan_info
return plan_info.to(device=device, dtype=torch.int64)
return torch.tensor(list(plan_info), dtype=torch.int64, device=device)
"""Normalize JIT ``plan`` output to a CPU int64 tensor for custom-op boundaries.
Some downstream custom-op paths read ``plan_info`` via host-side ``data_ptr()``,
so the returned tensor must stay on CPU. ``device`` is accepted for API
compatibility but is intentionally ignored here.
"""
if isinstance(plan_info, torch.Tensor):
if plan_info.dtype == torch.int64 and plan_info.device.type == "cpu":
return plan_info
return plan_info.to(device="cpu", dtype=torch.int64)
return torch.tensor(list(plan_info), dtype=torch.int64)

Copilot uses AI. Check for mistakes.
causal,
)
self._plan_info = plan_info_vec_as_tensor(
self._plan_info, device=self._float_workspace_buffer.device
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plan_info_vec_as_tensor(..., device=self._float_workspace_buffer.device) will move self._plan_info to the GPU. The ROCm C++ implementation expects plan_info_vec to be a CPU int64 tensor (it reads data_ptr<int64_t>() on the host to build a std::vector). Please keep plan_info on CPU here (e.g. pass device=torch.device("cpu"), or adjust the helper so it never migrates HIP plan info to the device).

Suggested change
self._plan_info, device=self._float_workspace_buffer.device
self._plan_info, device=torch.device("cpu")

Copilot uses AI. Check for mistakes.
Comment on lines +2790 to +2792
self._plan_info = plan_info_vec_as_tensor(
self._plan_info, device=self._float_workspace_buffer.device
)
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as the other plan_info_vec_as_tensor call site: passing device=self._float_workspace_buffer.device can move the plan-info tensor to GPU memory, but ROCm kernels expect plan_info_vec to be a CPU tensor (host reads its data_ptr<int64_t>()). Keep this tensor on CPU.

Suggested change
self._plan_info = plan_info_vec_as_tensor(
self._plan_info, device=self._float_workspace_buffer.device
)
self._plan_info = plan_info_vec_as_tensor(self._plan_info)

Copilot uses AI. Check for mistakes.
Comment thread flashinfer/decode_rocm.py
Comment on lines +976 to +978
self._plan_info = plan_info_vec_as_tensor(
self._plan_info, device=self._float_workspace_buffer.device
)
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plan_info_vec_as_tensor(..., device=self._float_workspace_buffer.device) can move the plan-info tensor to the GPU. The ROCm C++ decode op builds a std::vector<int64_t> from plan_info_vec.data_ptr<int64_t>() on the host, so plan_info_vec must stay on CPU. Please keep plan info on CPU at this call site.

Copilot uses AI. Check for mistakes.
Comment thread flashinfer/decode_rocm.py
Comment on lines +1012 to +1014
self._plan_info = plan_info_vec_as_tensor(
self._plan_info, device=self._float_workspace_buffer.device
)
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as the earlier plan-info normalization: moving self._plan_info to the GPU here breaks the ROCm C++ decode op which expects plan_info_vec to be a CPU int64 tensor (host reads it via data_ptr<int64_t>()). Keep the plan info tensor on CPU at this call site as well.

Suggested change
self._plan_info = plan_info_vec_as_tensor(
self._plan_info, device=self._float_workspace_buffer.device
)
self._plan_info = plan_info_vec_as_tensor(self._plan_info)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants