Skip to content

Fix JAX DLPack export to use __dlpack__ protocol directly#6247

Open
JanuszL wants to merge 1 commit intoNVIDIA:mainfrom
JanuszL:fix_dlpack_jax
Open

Fix JAX DLPack export to use __dlpack__ protocol directly#6247
JanuszL wants to merge 1 commit intoNVIDIA:mainfrom
JanuszL:fix_dlpack_jax

Conversation

@JanuszL
Copy link
Contributor

@JanuszL JanuszL commented Mar 9, 2026

  • Replaces deprecated jax.dlpack.to_dlpack() calls with the standard
    tensor.dlpack() method, which is the correct DLPack protocol
    interface for JAX 0.6+.

Category:

Bug fix (non-breaking change which fixes an issue)

Description:

  • Replaces deprecated jax.dlpack.to_dlpack() calls with the standard
    tensor.dlpack() method, which is the correct DLPack protocol
    interface for JAX 0.6+.

Additional information:

Affected modules and functionalities:

  • _function_transform.py

Key points relevant for the review:

  • NA

Tests:

  • Existing tests apply
    • TL0_jupyter
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A

Checklist

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: N/A

- Replaces deprecated jax.dlpack.to_dlpack() calls with the standard
  tensor.__dlpack__() method, which is the correct DLPack protocol
  interface for JAX 0.6+.

Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [45688155]: BUILD STARTED

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 9, 2026

Greptile Summary

This PR fixes a deprecation issue in the JAX DLPack integration by replacing jax.dlpack.to_dlpack() calls with the standard tensor.__dlpack__() protocol method, which is the correct producer-side interface in JAX 0.6+.

  • gpu_to_dlpack: jax.dlpack.to_dlpack(tensor, stream=stream)tensor.__dlpack__(stream=stream) — correct; stream synchronization hint is preserved.
  • cpu_to_dlpack: jax.dlpack.to_dlpack(tensor)tensor.__dlpack__() — correct; no stream needed for CPU arrays.
  • The import jax.dlpack on line 16 is intentionally retained because jax.dlpack.from_dlpack is still used on lines 53 and 67 for the consumer (import) direction, which remains the idiomatic approach.
  • The change is minimal, targeted, and consistent with the DLPack protocol specification.

Confidence Score: 5/5

  • This PR is safe to merge — it is a minimal, correct migration from a deprecated API to the standard DLPack protocol.
  • The change is a two-line substitution replacing a deprecated JAX helper with the equivalent lower-level protocol call. The semantics are identical, the stream parameter is correctly forwarded for GPU arrays, and the existing jax.dlpack.from_dlpack usages are left untouched. No new logic is introduced and the fix is consistent with the DLPack protocol specification.
  • No files require special attention.

Sequence Diagram

sequenceDiagram
    participant DALI
    participant JAX

    Note over DALI,JAX: Import direction (unchanged)
    DALI->>JAX: jax.dlpack.from_dlpack(dl_capsule)
    JAX-->>DALI: jax.Array

    Note over DALI,JAX: Export direction (this PR)
    DALI->>JAX: tensor.__dlpack__(stream=stream)
    JAX-->>DALI: DLPack PyCapsule (GPU, with stream sync)

    DALI->>JAX: tensor.__dlpack__()
    JAX-->>DALI: DLPack PyCapsule (CPU, no stream)
Loading

Last reviewed commit: bb3fb17

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants