From bb3fb172c5ec90cdf4b789cf2a814b4122ccaccf Mon Sep 17 00:00:00 2001 From: Janusz Lisiecki Date: Mon, 9 Mar 2026 07:51:06 +0100 Subject: [PATCH] Fix JAX DLPack export to use __dlpack__ protocol directly - Replaces deprecated jax.dlpack.to_dlpack() calls with the standard tensor.__dlpack__() method, which is the correct DLPack protocol interface for JAX 0.6+. Signed-off-by: Janusz Lisiecki --- dali/python/nvidia/dali/plugin/jax/fn/_function_transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dali/python/nvidia/dali/plugin/jax/fn/_function_transform.py b/dali/python/nvidia/dali/plugin/jax/fn/_function_transform.py index 78f350edfb..94afda1514 100644 --- a/dali/python/nvidia/dali/plugin/jax/fn/_function_transform.py +++ b/dali/python/nvidia/dali/plugin/jax/fn/_function_transform.py @@ -29,7 +29,7 @@ def gpu_to_dlpack(tensor: jax.Array, stream): f"The function returned array residing on the device of " f"kind `{devices[0].platform}`, expected `gpu`." ) - return jax.dlpack.to_dlpack(tensor, stream=stream) + return tensor.__dlpack__(stream=stream) def cpu_to_dlpack(tensor: jax.Array): @@ -44,7 +44,7 @@ def cpu_to_dlpack(tensor: jax.Array): f"The function returned array residing on the device of " f"kind `{devices[0].platform}`, expected `cpu`." ) - return jax.dlpack.to_dlpack(tensor) + return tensor.__dlpack__() def with_gpu_dl_tensors_as_arrays(callback):