Enable torch.compile under a flag#210
Conversation
4518269 to
8f2cd78
Compare
There was a problem hiding this comment.
Pull request overview
Adds opt-in support for torch.compile by wrapping FlashInfer kernels in opaque torch.library.custom_ops (to avoid Dynamo tracing into extension code), plus a small ROCm benchmark/verification script and optional local Git hook scaffolding.
Changes:
- Introduce
FLASHINFER_USE_TORCH_CUSTOM_OPSenv flag (checked at import time) to enable/disabletorch.library.custom_opwrapping. - Export
use_torch_custom_ops_enabled()from the top-levelflashinferpackage for both CUDA and ROCm/HIP builds. - Add a ROCm micro-benchmark script and optional
.githookspre-push protection foramd-integration.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
flashinfer/utils.py |
Adds env-gated custom-op enablement + helper to query the setting; conditionally registers custom_op/register_fake. |
flashinfer/__init__.py |
Re-exports use_torch_custom_ops_enabled() for CUDA and HIP branches. |
scripts/verify_enable_torch_compile.py |
New ROCm-focused eager vs torch.compile micro-benchmark for append_paged_kv_cache. |
.githooks/README |
Documents how to enable optional local Git hooks and intended behavior. |
.githooks/pre-push |
New local pre-push hook intended to block pushing/updating amd-integration. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
8f2cd78 to
768e946
Compare
768e946 to
4e3bb90
Compare
9eadd10 to
c9ef604
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if not _USE_TORCH_CUSTOM_OPS: | ||
| return _guard_compile(f, name) | ||
| with contextlib.suppress(ValueError, TypeError): | ||
| return torch.library.custom_op( |
There was a problem hiding this comment.
Why do we need to suppress the ValueError, TypeError here? If the custom_op creation fails should not we raise? Copilot review makes sense to me.
There was a problem hiding this comment.
Why not convert it into a pytest? You can run a specific test that has the envvar set and then unset it
| """ | ||
| include_dir = os.path.join(_get_package_root_dir(), "include") | ||
| return str(include_dir) | ||
| package_dir = pathlib.Path(_get_package_root_dir()).resolve() |
There was a problem hiding this comment.
Why is this change required? Running from a source tree was never supported even before the CMake editable install fix. Why do we need to enable it? IMO it is a non-standard way and a better way to deal with it will be to test sdist packages and make sure they work. The necessary sdist tooling using scikit-build-core is already there. I have never verified it.
…ytest - Remove contextlib.suppress in register_custom_op so registration failures surface instead of silently falling back (Copilot + diptorupd) - Fold torch >= 2.4 check into _USE_TORCH_CUSTOM_OPS so use_torch_custom_ops_enabled() reports effective behavior (Copilot) - Replace scripts/verify_enable_torch_compile.py with tests/rocm_tests/test_torch_compile_hip.py covering eager and torch.compile paths in subprocess isolation (diptorupd) Made-with: Cursor
Drop the repo-root include fallback per reviewer feedback: running from a source tree without an install is not a supported workflow. Made-with: Cursor
- Use device="cuda" (not "hip") — PyTorch on ROCm uses "cuda" as the device string - Check HIP availability properly in pytestmark - Accept any exception from torch.compile when custom ops are disabled, since the actual error (TorchRuntimeError from FakeTensor) differs from the _guard_compile message Made-with: Cursor
torch.library.custom_op cannot infer schemas for parameters like Optional[torch.Generator]. Instead of crashing, fall back to _guard_compile so these ops still get torch.compile protection. Made-with: Cursor
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def test_eager_with_custom_ops(): | ||
| """append_paged_kv_cache works in eager mode with custom ops enabled.""" | ||
| snippet = _PREAMBLE + textwrap.dedent( | ||
| """\ | ||
| assert flashinfer.use_torch_custom_ops_enabled() | ||
| append(k, v) | ||
| print("OK") | ||
| """ | ||
| ) | ||
| result = _run_snippet(snippet, {"FLASHINFER_USE_TORCH_CUSTOM_OPS": "1"}) | ||
| assert result.returncode == 0, f"eager (custom ops on) failed:\n{result.stderr}" |
There was a problem hiding this comment.
test_eager_with_custom_ops asserts flashinfer.use_torch_custom_ops_enabled() is true when the env var is set, but _USE_TORCH_CUSTOM_OPS in flashinfer/utils.py is additionally gated on torch>=2.4. On torch 2.3.x this test will fail even though the env var is "1". Consider adding the same torch>=2.4 skip condition as the torch.compile tests, or relaxing the assertion to account for the version gate.
| plan_info = self._plan_info | ||
| assert plan_info is not None, "plan info is not initialized" | ||
|
|
||
| run_args = [ | ||
| self._float_workspace_buffer, | ||
| self._int_workspace_buffer, | ||
| self._plan_info, | ||
| plan_info, | ||
| q, |
There was a problem hiding this comment.
The non-tensor-core branch now asserts self._plan_info is initialized before building run_args, but the tensor-core branch still passes self._plan_info directly without a guard. This can lead to a less clear failure mode if run() is called before plan() (or if planning failed). Consider adding the same assert self._plan_info is not None (or raising a ValueError) in the tensor-core path for consistency.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| """Normalize JIT ``plan`` output to int64 on ``device`` for custom-op boundaries. | ||
|
|
||
| ``plan()`` returns a tensor on HIP/CUDA; some call sites may still use a Python | ||
| sequence — accept both. | ||
| """ | ||
| if isinstance(plan_info, torch.Tensor): | ||
| if plan_info.dtype == torch.int64 and plan_info.device == device: | ||
| return plan_info | ||
| return plan_info.to(device=device, dtype=torch.int64) | ||
| return torch.tensor(list(plan_info), dtype=torch.int64, device=device) |
There was a problem hiding this comment.
plan_info_vec_as_tensor() currently converts an existing plan_info tensor to device=device. On ROCm, the underlying C++ ops (e.g. flashinfer/csrc_rocm/batch_prefill.cu, flashinfer/csrc_rocm/batch_decode.cu) read plan_info_vec.data_ptr<int64_t>() on the host side to build a std::vector<int64_t>, which requires plan_info_vec to be a CPU tensor. Moving it to a CUDA/HIP device will make data_ptr point to device memory and can lead to crashes or incorrect reads. Please keep plan_info on CPU (and optionally add a TORCH_CHECK(plan_info.is_cpu()) in the C++ path / adjust this helper’s API+doc to reflect that HIP plan() returns a CPU tensor).
| """Normalize JIT ``plan`` output to int64 on ``device`` for custom-op boundaries. | |
| ``plan()`` returns a tensor on HIP/CUDA; some call sites may still use a Python | |
| sequence — accept both. | |
| """ | |
| if isinstance(plan_info, torch.Tensor): | |
| if plan_info.dtype == torch.int64 and plan_info.device == device: | |
| return plan_info | |
| return plan_info.to(device=device, dtype=torch.int64) | |
| return torch.tensor(list(plan_info), dtype=torch.int64, device=device) | |
| """Normalize JIT ``plan`` output to a CPU int64 tensor for custom-op boundaries. | |
| Some downstream custom-op paths read ``plan_info`` via host-side ``data_ptr()``, | |
| so the returned tensor must stay on CPU. ``device`` is accepted for API | |
| compatibility but is intentionally ignored here. | |
| """ | |
| if isinstance(plan_info, torch.Tensor): | |
| if plan_info.dtype == torch.int64 and plan_info.device.type == "cpu": | |
| return plan_info | |
| return plan_info.to(device="cpu", dtype=torch.int64) | |
| return torch.tensor(list(plan_info), dtype=torch.int64) |
| causal, | ||
| ) | ||
| self._plan_info = plan_info_vec_as_tensor( | ||
| self._plan_info, device=self._float_workspace_buffer.device |
There was a problem hiding this comment.
plan_info_vec_as_tensor(..., device=self._float_workspace_buffer.device) will move self._plan_info to the GPU. The ROCm C++ implementation expects plan_info_vec to be a CPU int64 tensor (it reads data_ptr<int64_t>() on the host to build a std::vector). Please keep plan_info on CPU here (e.g. pass device=torch.device("cpu"), or adjust the helper so it never migrates HIP plan info to the device).
| self._plan_info, device=self._float_workspace_buffer.device | |
| self._plan_info, device=torch.device("cpu") |
| self._plan_info = plan_info_vec_as_tensor( | ||
| self._plan_info, device=self._float_workspace_buffer.device | ||
| ) |
There was a problem hiding this comment.
Same issue as the other plan_info_vec_as_tensor call site: passing device=self._float_workspace_buffer.device can move the plan-info tensor to GPU memory, but ROCm kernels expect plan_info_vec to be a CPU tensor (host reads its data_ptr<int64_t>()). Keep this tensor on CPU.
| self._plan_info = plan_info_vec_as_tensor( | |
| self._plan_info, device=self._float_workspace_buffer.device | |
| ) | |
| self._plan_info = plan_info_vec_as_tensor(self._plan_info) |
| self._plan_info = plan_info_vec_as_tensor( | ||
| self._plan_info, device=self._float_workspace_buffer.device | ||
| ) |
There was a problem hiding this comment.
plan_info_vec_as_tensor(..., device=self._float_workspace_buffer.device) can move the plan-info tensor to the GPU. The ROCm C++ decode op builds a std::vector<int64_t> from plan_info_vec.data_ptr<int64_t>() on the host, so plan_info_vec must stay on CPU. Please keep plan info on CPU at this call site.
| self._plan_info = plan_info_vec_as_tensor( | ||
| self._plan_info, device=self._float_workspace_buffer.device | ||
| ) |
There was a problem hiding this comment.
Same issue as the earlier plan-info normalization: moving self._plan_info to the GPU here breaks the ROCm C++ decode op which expects plan_info_vec to be a CPU int64 tensor (host reads it via data_ptr<int64_t>()). Keep the plan info tensor on CPU at this call site as well.
| self._plan_info = plan_info_vec_as_tensor( | |
| self._plan_info, device=self._float_workspace_buffer.device | |
| ) | |
| self._plan_info = plan_info_vec_as_tensor(self._plan_info) |
Adds an environment flag
FLASHINFER_USE_TORCH_CUSTOM_OPS, setting which enables thetorch.compilein flashinfer. By default, it's disabled in the upstream (flashinfer/flashinfer/utils.py
Line 273 in 61a9b74
How to use it
`FLASHINFER_USE_TORCH_CUSTOM_OPS=1 <your command>Testing
Adds a pytest for it.
tests/rocm_tests/test_torch_compile_hip.pyto be run automatically when we runpytest.All pytests passed for both the cases:
FLASHINFER_USE_TORCH_CUSTOM_OPS=0 pytest, or simplypytest, andFLASHINFER_USE_TORCH_CUSTOM_OPS=1 pytest