From ad1d22558d5847c9b741cf600641a0b2e425393b Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 25 Aug 2025 15:33:20 +0200 Subject: [PATCH 1/2] Finalize dask patch --- src/fast_array_utils/_plugins/dask.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/fast_array_utils/_plugins/dask.py b/src/fast_array_utils/_plugins/dask.py index f9f93ca..5eef179 100644 --- a/src/fast_array_utils/_plugins/dask.py +++ b/src/fast_array_utils/_plugins/dask.py @@ -2,9 +2,7 @@ from __future__ import annotations import numpy as np - -# Other lookup candidates: tensordot_lookup and take_lookup -from dask.array.dispatch import concatenate_lookup +from dask.array.dispatch import concatenate_lookup, take_lookup, tensordot_lookup from scipy.sparse import sparray, spmatrix @@ -13,11 +11,12 @@ def patch() -> None: # pragma: no cover """Patch dask to support sparse arrays. - See + See """ # Avoid patch if already patched or upstream support has been added if concatenate_lookup.dispatch(sparray) is not np.concatenate: return - concatenate = concatenate_lookup.dispatch(spmatrix) - concatenate_lookup.register(sparray, concatenate) + concatenate_lookup.register(sparray, concatenate_lookup.dispatch(spmatrix)) + tensordot_lookup.register(sparray, tensordot_lookup.dispatch(spmatrix)) + take_lookup.register(sparray, take_lookup.dispatch(spmatrix)) From c06ddbffca7fb9720b50761eb37904fa55b1a902 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 25 Aug 2025 15:38:07 +0200 Subject: [PATCH 2/2] cleaner --- src/fast_array_utils/_plugins/dask.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/fast_array_utils/_plugins/dask.py b/src/fast_array_utils/_plugins/dask.py index 5eef179..535a29c 100644 --- a/src/fast_array_utils/_plugins/dask.py +++ b/src/fast_array_utils/_plugins/dask.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: MPL-2.0 from __future__ import annotations -import numpy as np from dask.array.dispatch import concatenate_lookup, take_lookup, tensordot_lookup from scipy.sparse import sparray, spmatrix @@ -14,7 +13,7 @@ def patch() -> None: # pragma: no cover See """ # Avoid patch if already patched or upstream support has been added - if concatenate_lookup.dispatch(sparray) is not np.concatenate: + if concatenate_lookup.dispatch(sparray) is not concatenate_lookup.dispatch(spmatrix): return concatenate_lookup.register(sparray, concatenate_lookup.dispatch(spmatrix))