diff --git a/docs/notebooks/basic-usage.ipynb b/docs/notebooks/basic-usage.ipynb index 4371b71..088b17d 100644 --- a/docs/notebooks/basic-usage.ipynb +++ b/docs/notebooks/basic-usage.ipynb @@ -53,6 +53,152 @@ "binsparse.write(z, \"csr\", csr)\n", "binsparse.read(z[\"csr\"])" ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "
M0
\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
gb.Matrix
nvals
nrows
ncols
dtype
format
100001000100FP64csr
\n", + "
\n", + "
(Install pandas to see a preview of the data)
" + ], + "text/plain": [ + "\"M_0\" nvals nrows ncols dtype format\n", + "gb.Matrix 10000 1000 100 FP64 csr" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "z = zarr.group()\n", + "\n", + "binsparse.write(z, \"csr\", csr)\n", + "binsparse.read(z[\"csr\"], struct=\"graphblas\")" + ] } ], "metadata": { @@ -71,7 +217,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.11.6" }, "vscode": { "interpreter": { diff --git a/pyproject.toml b/pyproject.toml index 5807306..7d9cad6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,12 +29,14 @@ dependencies = [ dev = [ # CLI for bumping the version number "pre-commit", - "twine>=4.0.2" + "twine>=4.0.2", + "python-graphblas[default]", ] doc = [ "docutils>=0.8,!=0.18.*,!=0.19.*", "sphinx>=4", "sphinx-book-theme>=1.0.0", + "matplotlib", "myst-nb", "sphinxcontrib-bibtex>=1.0.0", "sphinx-autodoc-typehints", diff --git a/src/binsparse/_io/api.py b/src/binsparse/_io/api.py index 0e33318..1cdb7cd 100644 --- a/src/binsparse/_io/api.py +++ b/src/binsparse/_io/api.py @@ -15,7 +15,7 @@ from binsparse._types import GroupTypes -def read(group: GroupTypes) -> sparse.spmatrix: +def read(group: GroupTypes, _class: str = "scipy") -> object: """Read a sparse matrix from a store. Parameters @@ -25,11 +25,11 @@ def read(group: GroupTypes) -> sparse.spmatrix: """ metadata = read_attr(group, "binsparse") if metadata["format"] == "CSR": - return read_csr(group) + return read_csr(group, _class=_class) elif metadata["format"] == "CSC": - return read_csc(group) + return read_csc(group, _class=_class) elif metadata["format"] == "COO": - return read_coo(group) + return read_coo(group, _class=_class) else: raise NotImplementedError(f"no implementation for format {metadata['format']}") diff --git a/src/binsparse/_io/methods.py b/src/binsparse/_io/methods.py index a537a9c..489e9e1 100644 --- a/src/binsparse/_io/methods.py +++ b/src/binsparse/_io/methods.py @@ -1,6 +1,7 @@ from collections.abc import Mapping from types import MappingProxyType +import graphblas as gb import numpy as np from scipy import sparse @@ -59,7 +60,7 @@ def write_csr(store: GroupTypes, key: str, x: sparse.csr_matrix, *, dataset_kwar store.create_dataset(f"{key}/values", data=x.data, **dataset_kwargs) -def read_csr(group: GroupTypes) -> sparse.csr_matrix: +def read_csr(group: GroupTypes, _class: str) -> gb.Matrix | sparse.csr_matrix: """Read a CSR matrix from a store. Parameters @@ -67,6 +68,9 @@ def read_csr(group: GroupTypes) -> sparse.csr_matrix: group A Zarr or h5py group. + _class + The class we will use to create the sparse matrix. Either "scipy" or "graphblas" + Returns ------- x @@ -76,14 +80,23 @@ def read_csr(group: GroupTypes) -> sparse.csr_matrix: assert metadata["format"] == "CSR" shape = tuple(metadata["shape"]) - return sparse.csr_matrix( - ( - group["values"][()], - group["indices_1"][()], + if _class.lower() == "scipy": + return sparse.csr_matrix( + ( + group["values"][()], + group["indices_1"][()], + group["pointers_to_1"][()], + ), + shape=shape, + ) + elif _class.lower() == "graphblas": + return gb.Matrix.from_csr( group["pointers_to_1"][()], - ), - shape=shape, - ) + group["indices_1"][()], + group["values"][()], + ) + else: + raise NotImplementedError(f"no implementation for returning data using {_class}") def write_csc(store: GroupTypes, key: str, x: sparse.csc_matrix, *, dataset_kwargs: Mapping = MappingProxyType({})): @@ -116,7 +129,7 @@ def write_csc(store: GroupTypes, key: str, x: sparse.csc_matrix, *, dataset_kwar store.create_dataset(f"{key}/values", data=x.data, **dataset_kwargs) -def read_csc(group: GroupTypes) -> sparse.csc_matrix: +def read_csc(group: GroupTypes, _class: str) -> gb.Matrix | sparse.csc_matrix: """Read a CSC matrix from a store. Parameters @@ -124,6 +137,9 @@ def read_csc(group: GroupTypes) -> sparse.csc_matrix: group A Zarr or h5py group. + _class + The class we will use to create the sparse matrix. Either "scipy" or "graphblas" + Returns ------- x @@ -133,14 +149,23 @@ def read_csc(group: GroupTypes) -> sparse.csc_matrix: assert metadata["format"] == "CSC" shape = tuple(metadata["shape"]) - return sparse.csc_matrix( - ( - group["values"][()], - group["indices_1"][()], + if _class.lower() == "scipy": + return sparse.csc_matrix( + ( + group["values"][()], + group["indices_1"][()], + group["pointers_to_1"][()], + ), + shape=shape, + ) + elif _class.lower() == "graphblas": + return gb.Matrix.from_csc( group["pointers_to_1"][()], - ), - shape=shape, - ) + group["indices_1"][()], + group["values"][()], + ) + else: + raise NotImplementedError(f"no implementation for returning data using {_class}") def write_coo(store: GroupTypes, key: str, x: sparse.csc_matrix, *, dataset_kwargs: Mapping = MappingProxyType({})): @@ -173,7 +198,7 @@ def write_coo(store: GroupTypes, key: str, x: sparse.csc_matrix, *, dataset_kwar store.create_dataset(f"{key}/values", data=x.data, **dataset_kwargs) -def read_coo(group: GroupTypes): +def read_coo(group: GroupTypes, _class: str) -> gb.Matrix | sparse.coo_matrix: """Read a COO matrix from a store. Parameters @@ -181,6 +206,9 @@ def read_coo(group: GroupTypes): group A Zarr or h5py group. + _class + The class we will use to create the sparse matrix. Either "scipy" or "graphblas" + Returns ------- x @@ -190,13 +218,22 @@ def read_coo(group: GroupTypes): assert metadata["format"] == "COO" shape = tuple(metadata["shape"]) - return sparse.coo_matrix( - ( - group["values"][()], + if _class.lower() == "scipy": + return sparse.coo_matrix( ( - group["indices_0"][()], - group["indices_1"][()], + group["values"][()], + ( + group["indices_0"][()], + group["indices_1"][()], + ), ), - ), - shape=shape, - ) + shape=shape, + ) + elif _class.lower() == "graphblas": + return gb.Matrix.from_coo( + group["indices_0"][()], + group["indices_1"][()], + group["values"][()], + ) + else: + raise NotImplementedError(f"no implementation for returning data using {_class}") diff --git a/src/binsparse/_testing.py b/src/binsparse/_testing.py index d01da6b..926aeae 100644 --- a/src/binsparse/_testing.py +++ b/src/binsparse/_testing.py @@ -1,10 +1,11 @@ from functools import singledispatch +import graphblas as gb import numpy as np from scipy import sparse -def assert_equal(a: sparse.spmatrix, b): +def assert_equal(a: gb.Matrix | sparse.spmatrix, b): assert type(a) == type(b), f"types differ: {type(a)} != {type(b)}" assert a.shape == b.shape, f"shapes differ: {a.shape} != {b.shape}" _assert_equal(a, b) @@ -26,3 +27,9 @@ def _(a, b): def _(a, b): for attr in ["row", "col", "data"]: np.testing.assert_equal(getattr(a, attr), getattr(b, attr)) + + +@_assert_equal.register(gb.Matrix) +def _(a, b): + equal = a.isequal(b) + assert equal, "GraphBLAS matrices are not equal" diff --git a/tests/test_basic.py b/tests/test_basic.py index 7059e9a..d2c26e5 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,3 +1,4 @@ +import graphblas as gb import h5py import pytest import zarr @@ -45,3 +46,22 @@ def test_metadata(store, fmt): for k, v in metadata["data_types"].items(): assert v in _DTYPE_STR_REGISTRY.values(), f"unrecognized dtype for '{k}': {v}" + + +@pytest.mark.parametrize("fmt", ["csr", "csc", "coo"]) +def test_graphblas_structure(store, fmt): + orig = sparse.random(100, 100, density=0.1, format=fmt) + binsparse.write(store, "X", orig) + orig = gb.io.from_scipy_sparse(orig) + from_disk = binsparse.read(store["X"], _class="graphblas") + assert_equal(orig, from_disk) + + +@pytest.mark.parametrize("fmt", ["csr", "csc", "coo"]) +def test_non_implemented_structure(store, fmt): + orig = sparse.random(100, 100, density=0.1, format=fmt) + binsparse.write(store, "X", orig) + orig = gb.io.from_scipy_sparse(orig) + with pytest.raises(NotImplementedError): + from_disk = binsparse.read(store["X"], _class="torch") + assert_equal(orig, from_disk)