diff --git a/docs/notebooks/basic-usage.ipynb b/docs/notebooks/basic-usage.ipynb
index 4371b71..088b17d 100644
--- a/docs/notebooks/basic-usage.ipynb
+++ b/docs/notebooks/basic-usage.ipynb
@@ -53,6 +53,152 @@
"binsparse.write(z, \"csr\", csr)\n",
"binsparse.read(z[\"csr\"])"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
M0\n",
+ "
\n",
+ " \n",
+ " gb.Matrix | \n",
+ " nvals | \n",
+ " nrows | \n",
+ " ncols | \n",
+ " dtype | \n",
+ " format | \n",
+ "
\n",
+ " \n",
+ " | 10000 | \n",
+ " 1000 | \n",
+ " 100 | \n",
+ " FP64 | \n",
+ " csr | \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
(Install pandas to see a preview of the data) "
+ ],
+ "text/plain": [
+ "\"M_0\" nvals nrows ncols dtype format\n",
+ "gb.Matrix 10000 1000 100 FP64 csr"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "z = zarr.group()\n",
+ "\n",
+ "binsparse.write(z, \"csr\", csr)\n",
+ "binsparse.read(z[\"csr\"], struct=\"graphblas\")"
+ ]
}
],
"metadata": {
@@ -71,7 +217,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.12"
+ "version": "3.11.6"
},
"vscode": {
"interpreter": {
diff --git a/pyproject.toml b/pyproject.toml
index 5807306..7d9cad6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -29,12 +29,14 @@ dependencies = [
dev = [
# CLI for bumping the version number
"pre-commit",
- "twine>=4.0.2"
+ "twine>=4.0.2",
+ "python-graphblas[default]",
]
doc = [
"docutils>=0.8,!=0.18.*,!=0.19.*",
"sphinx>=4",
"sphinx-book-theme>=1.0.0",
+ "matplotlib",
"myst-nb",
"sphinxcontrib-bibtex>=1.0.0",
"sphinx-autodoc-typehints",
diff --git a/src/binsparse/_io/api.py b/src/binsparse/_io/api.py
index 0e33318..1cdb7cd 100644
--- a/src/binsparse/_io/api.py
+++ b/src/binsparse/_io/api.py
@@ -15,7 +15,7 @@
from binsparse._types import GroupTypes
-def read(group: GroupTypes) -> sparse.spmatrix:
+def read(group: GroupTypes, _class: str = "scipy") -> object:
"""Read a sparse matrix from a store.
Parameters
@@ -25,11 +25,11 @@ def read(group: GroupTypes) -> sparse.spmatrix:
"""
metadata = read_attr(group, "binsparse")
if metadata["format"] == "CSR":
- return read_csr(group)
+ return read_csr(group, _class=_class)
elif metadata["format"] == "CSC":
- return read_csc(group)
+ return read_csc(group, _class=_class)
elif metadata["format"] == "COO":
- return read_coo(group)
+ return read_coo(group, _class=_class)
else:
raise NotImplementedError(f"no implementation for format {metadata['format']}")
diff --git a/src/binsparse/_io/methods.py b/src/binsparse/_io/methods.py
index a537a9c..489e9e1 100644
--- a/src/binsparse/_io/methods.py
+++ b/src/binsparse/_io/methods.py
@@ -1,6 +1,7 @@
from collections.abc import Mapping
from types import MappingProxyType
+import graphblas as gb
import numpy as np
from scipy import sparse
@@ -59,7 +60,7 @@ def write_csr(store: GroupTypes, key: str, x: sparse.csr_matrix, *, dataset_kwar
store.create_dataset(f"{key}/values", data=x.data, **dataset_kwargs)
-def read_csr(group: GroupTypes) -> sparse.csr_matrix:
+def read_csr(group: GroupTypes, _class: str) -> gb.Matrix | sparse.csr_matrix:
"""Read a CSR matrix from a store.
Parameters
@@ -67,6 +68,9 @@ def read_csr(group: GroupTypes) -> sparse.csr_matrix:
group
A Zarr or h5py group.
+ _class
+ The class we will use to create the sparse matrix. Either "scipy" or "graphblas"
+
Returns
-------
x
@@ -76,14 +80,23 @@ def read_csr(group: GroupTypes) -> sparse.csr_matrix:
assert metadata["format"] == "CSR"
shape = tuple(metadata["shape"])
- return sparse.csr_matrix(
- (
- group["values"][()],
- group["indices_1"][()],
+ if _class.lower() == "scipy":
+ return sparse.csr_matrix(
+ (
+ group["values"][()],
+ group["indices_1"][()],
+ group["pointers_to_1"][()],
+ ),
+ shape=shape,
+ )
+ elif _class.lower() == "graphblas":
+ return gb.Matrix.from_csr(
group["pointers_to_1"][()],
- ),
- shape=shape,
- )
+ group["indices_1"][()],
+ group["values"][()],
+ )
+ else:
+ raise NotImplementedError(f"no implementation for returning data using {_class}")
def write_csc(store: GroupTypes, key: str, x: sparse.csc_matrix, *, dataset_kwargs: Mapping = MappingProxyType({})):
@@ -116,7 +129,7 @@ def write_csc(store: GroupTypes, key: str, x: sparse.csc_matrix, *, dataset_kwar
store.create_dataset(f"{key}/values", data=x.data, **dataset_kwargs)
-def read_csc(group: GroupTypes) -> sparse.csc_matrix:
+def read_csc(group: GroupTypes, _class: str) -> gb.Matrix | sparse.csc_matrix:
"""Read a CSC matrix from a store.
Parameters
@@ -124,6 +137,9 @@ def read_csc(group: GroupTypes) -> sparse.csc_matrix:
group
A Zarr or h5py group.
+ _class
+ The class we will use to create the sparse matrix. Either "scipy" or "graphblas"
+
Returns
-------
x
@@ -133,14 +149,23 @@ def read_csc(group: GroupTypes) -> sparse.csc_matrix:
assert metadata["format"] == "CSC"
shape = tuple(metadata["shape"])
- return sparse.csc_matrix(
- (
- group["values"][()],
- group["indices_1"][()],
+ if _class.lower() == "scipy":
+ return sparse.csc_matrix(
+ (
+ group["values"][()],
+ group["indices_1"][()],
+ group["pointers_to_1"][()],
+ ),
+ shape=shape,
+ )
+ elif _class.lower() == "graphblas":
+ return gb.Matrix.from_csc(
group["pointers_to_1"][()],
- ),
- shape=shape,
- )
+ group["indices_1"][()],
+ group["values"][()],
+ )
+ else:
+ raise NotImplementedError(f"no implementation for returning data using {_class}")
def write_coo(store: GroupTypes, key: str, x: sparse.csc_matrix, *, dataset_kwargs: Mapping = MappingProxyType({})):
@@ -173,7 +198,7 @@ def write_coo(store: GroupTypes, key: str, x: sparse.csc_matrix, *, dataset_kwar
store.create_dataset(f"{key}/values", data=x.data, **dataset_kwargs)
-def read_coo(group: GroupTypes):
+def read_coo(group: GroupTypes, _class: str) -> gb.Matrix | sparse.coo_matrix:
"""Read a COO matrix from a store.
Parameters
@@ -181,6 +206,9 @@ def read_coo(group: GroupTypes):
group
A Zarr or h5py group.
+ _class
+ The class we will use to create the sparse matrix. Either "scipy" or "graphblas"
+
Returns
-------
x
@@ -190,13 +218,22 @@ def read_coo(group: GroupTypes):
assert metadata["format"] == "COO"
shape = tuple(metadata["shape"])
- return sparse.coo_matrix(
- (
- group["values"][()],
+ if _class.lower() == "scipy":
+ return sparse.coo_matrix(
(
- group["indices_0"][()],
- group["indices_1"][()],
+ group["values"][()],
+ (
+ group["indices_0"][()],
+ group["indices_1"][()],
+ ),
),
- ),
- shape=shape,
- )
+ shape=shape,
+ )
+ elif _class.lower() == "graphblas":
+ return gb.Matrix.from_coo(
+ group["indices_0"][()],
+ group["indices_1"][()],
+ group["values"][()],
+ )
+ else:
+ raise NotImplementedError(f"no implementation for returning data using {_class}")
diff --git a/src/binsparse/_testing.py b/src/binsparse/_testing.py
index d01da6b..926aeae 100644
--- a/src/binsparse/_testing.py
+++ b/src/binsparse/_testing.py
@@ -1,10 +1,11 @@
from functools import singledispatch
+import graphblas as gb
import numpy as np
from scipy import sparse
-def assert_equal(a: sparse.spmatrix, b):
+def assert_equal(a: gb.Matrix | sparse.spmatrix, b):
assert type(a) == type(b), f"types differ: {type(a)} != {type(b)}"
assert a.shape == b.shape, f"shapes differ: {a.shape} != {b.shape}"
_assert_equal(a, b)
@@ -26,3 +27,9 @@ def _(a, b):
def _(a, b):
for attr in ["row", "col", "data"]:
np.testing.assert_equal(getattr(a, attr), getattr(b, attr))
+
+
+@_assert_equal.register(gb.Matrix)
+def _(a, b):
+ equal = a.isequal(b)
+ assert equal, "GraphBLAS matrices are not equal"
diff --git a/tests/test_basic.py b/tests/test_basic.py
index 7059e9a..d2c26e5 100644
--- a/tests/test_basic.py
+++ b/tests/test_basic.py
@@ -1,3 +1,4 @@
+import graphblas as gb
import h5py
import pytest
import zarr
@@ -45,3 +46,22 @@ def test_metadata(store, fmt):
for k, v in metadata["data_types"].items():
assert v in _DTYPE_STR_REGISTRY.values(), f"unrecognized dtype for '{k}': {v}"
+
+
+@pytest.mark.parametrize("fmt", ["csr", "csc", "coo"])
+def test_graphblas_structure(store, fmt):
+ orig = sparse.random(100, 100, density=0.1, format=fmt)
+ binsparse.write(store, "X", orig)
+ orig = gb.io.from_scipy_sparse(orig)
+ from_disk = binsparse.read(store["X"], _class="graphblas")
+ assert_equal(orig, from_disk)
+
+
+@pytest.mark.parametrize("fmt", ["csr", "csc", "coo"])
+def test_non_implemented_structure(store, fmt):
+ orig = sparse.random(100, 100, density=0.1, format=fmt)
+ binsparse.write(store, "X", orig)
+ orig = gb.io.from_scipy_sparse(orig)
+ with pytest.raises(NotImplementedError):
+ from_disk = binsparse.read(store["X"], _class="torch")
+ assert_equal(orig, from_disk)