diff --git a/dali/python/nvidia/dali/plugin/jax/fn/_function_transform.py b/dali/python/nvidia/dali/plugin/jax/fn/_function_transform.py index 78f350edfb..94afda1514 100644 --- a/dali/python/nvidia/dali/plugin/jax/fn/_function_transform.py +++ b/dali/python/nvidia/dali/plugin/jax/fn/_function_transform.py @@ -29,7 +29,7 @@ def gpu_to_dlpack(tensor: jax.Array, stream): f"The function returned array residing on the device of " f"kind `{devices[0].platform}`, expected `gpu`." ) - return jax.dlpack.to_dlpack(tensor, stream=stream) + return tensor.__dlpack__(stream=stream) def cpu_to_dlpack(tensor: jax.Array): @@ -44,7 +44,7 @@ def cpu_to_dlpack(tensor: jax.Array): f"The function returned array residing on the device of " f"kind `{devices[0].platform}`, expected `cpu`." ) - return jax.dlpack.to_dlpack(tensor) + return tensor.__dlpack__() def with_gpu_dl_tensors_as_arrays(callback):