Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "sedpack_rs"
version = "0.1.3"
version = "0.1.4"
edition = "2021"
description = "Rust bindings for sedpack a general ML dataset package"
authors = [
Expand Down
30 changes: 22 additions & 8 deletions src/sedpack/io/flatbuffer/iterate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from sedpack.io.compress import CompressedFile
from sedpack.io.metadata import Attribute
from sedpack.io.types import ExampleT
from sedpack.io.types import AttributeValueT, ExampleT
from sedpack.io.shard import IterateShardBase
from sedpack.io.shard.iterate_shard_base import T
from sedpack.io.utils import func_or_identity
Expand Down Expand Up @@ -82,18 +82,16 @@ def _iterate_content(self, content: bytes) -> Iterable[ExampleT]:
attribute=attribute,
)

# Copy otherwise the arrays are immutable and keep the whole
# file content from being garbage collected.
np_array = np.copy(np_array)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy issues should be fine since Rust is allocating the memory and
passing ownership to Python.


example_dictionary[attribute.name] = np_array

yield example_dictionary

@staticmethod
def decode_array(np_bytes: npt.NDArray[np.uint8],
attribute: Attribute,
batch_size: int = 0) -> npt.NDArray[np.generic]:
def decode_array(
np_bytes: npt.NDArray[np.uint8],
attribute: Attribute,
batch_size: int = 0,
) -> AttributeValueT:
"""Decode an array. See `sedpack.io.shard.shard_writer_flatbuffer
.ShardWriterFlatBuffer.save_numpy_vector_as_bytearray`
for format description. The code tries to avoid unnecessary copies.
Expand All @@ -112,6 +110,22 @@ def decode_array(np_bytes: npt.NDArray[np.uint8],

Returns: the parsed np.ndarray of the correct dtype and shape.
"""
match attribute.dtype:
case "str":
return np_bytes.tobytes().decode("utf-8")
case "bytes":
return np_bytes.tobytes()
case "int":
array = np.frombuffer(
buffer=np_bytes,
dtype=np.dtype("int64").newbyteorder("<"),
)
assert array.shape == (1,), f"{array.shape = }"
return int(array[0])
case _:
# The rest is interpreted as NumPy array.
pass

dt = np.dtype(attribute.dtype)
# FlatBuffers are little-endian. There is no byteswap by
# `np.frombuffer` but the array will be interpreted correctly.
Expand Down
138 changes: 127 additions & 11 deletions src/sedpack/io/npz/iterate_npz.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Google LLC
# Copyright 2024-2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,6 +22,7 @@
import aiofiles
import numpy as np

from sedpack.io.metadata import Attribute
from sedpack.io.shard import IterateShardBase
from sedpack.io.shard.iterate_shard_base import T
from sedpack.io.types import AttributeValueT, ExampleT
Expand All @@ -32,20 +33,129 @@ class IterateShardNP(IterateShardBase[T]):
"""Iterate a shard saved in the npz format.
"""

@staticmethod
def decode_attribute(
attribute: Attribute,
example_index: int,
prefixed_name: str,
shard_content: dict[str, list[AttributeValueT]],
) -> AttributeValueT:
"""Choose the correct way to decode the given attribute.

Args:

attribute (Attribute): Information about the attribute being decoded.

example_index (int): Which example from this shard is being decoded.

prefixed_name (str): For the case of `bytes` attributes we need to
store them in continuous array otherwise variable length would require
allowing pickling and result in a potential arbitrary code execution.
This name is the prefix-sum encoded lengths of the attribute values.

shard_content (dict[str, list[AttributeValueT]]): The shard values.
"""
if attribute.dtype == "bytes":
return IterateShardNP.decode_bytes_attribute(
value=shard_content[attribute.name][0],
indexes=shard_content[prefixed_name],
attribute=attribute,
index=example_index,
)

return IterateShardNP.decode_non_bytes_attribute(
np_value=shard_content[attribute.name][example_index],
attribute=attribute,
)

@staticmethod
def decode_non_bytes_attribute(
np_value: AttributeValueT,
attribute: Attribute,
) -> AttributeValueT:
match attribute.dtype:
case "str":
return str(np_value)
case "bytes":
raise ValueError("One needs to use decode_bytes_attribute")
case "int":
return int(np_value)
case _:
return np_value

@staticmethod
def decode_bytes_attribute(
value: AttributeValueT,
indexes: list[AttributeValueT],
attribute: Attribute,
index: int,
) -> AttributeValueT:
"""Decode a bytes attribute. We are saving the byte attributes as a
continuous array across multiple examples and on the side we also save
the indexes into this array.

Args:

value (AttributeValueT): The NumPy array of np.uint8 containing
concatenated bytes values.

indexes (list[AttributeValueT]): Indexes into this array.

attribute (Attribute): The attribute description.

index (int): Which example out of this shard to return.
"""
if attribute.dtype != "bytes":
raise ValueError("One needs to use decode_attribute")
# Help with type-checking:
my_value = np.array(value, np.uint8)
my_indexes = np.array(indexes, np.int64)
del value
del indexes

begin: int = my_indexes[index]
end: int = my_indexes[index + 1]
return bytes(my_value[begin:end])

def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]:
"""Iterate a shard saved in the NumPy format npz.
"""
shard_content: dict[str, list[AttributeValueT]] = np.load(file_path)
# A prefix such that prepended it creates a new name without collision
# with any attribute name.
counting_prefix: str = "len" + "_" * max(
len(attribute.name)
for attribute in self.dataset_structure.saved_data_description)
self._prefixed_names: dict[str, str] = {
attribute.name: counting_prefix + attribute.name
for attribute in self.dataset_structure.saved_data_description
}

shard_content: dict[str, list[AttributeValueT]] = np.load(
file_path,
allow_pickle=False,
)

# A given shard contains the same number of elements for each
# attribute.
elements: int = 0
for values in shard_content.values():
elements = len(values)
break
elements: int
first_attribute = self.dataset_structure.saved_data_description[0]
if self._prefixed_names[first_attribute.name] in shard_content:
elements = len(
shard_content[self._prefixed_names[first_attribute.name]]) - 1
else:
elements = len(shard_content[first_attribute.name])

for i in range(elements):
yield {name: value[i] for name, value in shard_content.items()}
for example_index in range(elements):
yield {
attribute.name:
IterateShardNP.decode_attribute(
attribute=attribute,
example_index=example_index,
prefixed_name=self._prefixed_names[attribute.name],
shard_content=shard_content,
)
for attribute in self.dataset_structure.saved_data_description
}

# TODO(issue #85) fix and test async iterator typing
async def iterate_shard_async( # pylint: disable=invalid-overridden-method
Expand All @@ -58,7 +168,10 @@ async def iterate_shard_async( # pylint: disable=invalid-overridden-method
content_bytes: bytes = await f.read()
content_io = io.BytesIO(content_bytes)

shard_content: dict[str, list[AttributeValueT]] = np.load(content_io)
shard_content: dict[str, list[AttributeValueT]] = np.load(
content_io,
allow_pickle=False,
)

# A given shard contains the same number of elements for each
# attribute.
Expand All @@ -67,8 +180,11 @@ async def iterate_shard_async( # pylint: disable=invalid-overridden-method
elements = len(values)
break

for i in range(elements):
yield {name: value[i] for name, value in shard_content.items()}
for example_index in range(elements):
yield {
name: value[example_index]
for name, value in shard_content.items()
}

def process_and_list(self, shard_file: Path) -> list[T]:
process_record = func_or_identity(self.process_record)
Expand Down
116 changes: 83 additions & 33 deletions src/sedpack/io/shard/shard_writer_flatbuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,45 +101,28 @@ def _write(self, values: ExampleT) -> None:
self._examples.append(fbapi_Example.ExampleEnd(self._builder))

@staticmethod
def save_numpy_vector_as_bytearray( # type: ignore[no-any-unimported]
builder: Builder, attribute: Attribute,
value: AttributeValueT) -> int:
"""Save a given array into a FlatBuffer as bytes. This is to ensure
compatibility with types which are not supported by FlatBuffers (e.g.,
np.float16). The FlatBuffers schema must mark this vector as type
bytes [byte] (see src/sedpack/io/flatbuffer/shard.fbs) since there is a
distinction of how the length is being saved. The inverse of this
function is
`sedpack.io.flatbuffer.iterate.IterateShardFlatBuffer.decode_array`.

If we have an array of np.int32 of 10 elements the FlatBuffers library
would save the length as 10. Which is then impossible to read in Rust
since the length and itemsize (sizeof of the type) are private. Thus we
could not get the full array back. Thus we are saving the array as 40
bytes. This function does not modify the `value`, and saves a flattened
version of it. This function also saves the exact dtype as given by
`attribute`. Bytes are being saved in little endian ("<") and
c_contiguous ("C") order, same as with FlatBuffers. Alignment is set to
`dtype.itemsize` as opposed to FlatBuffers choice of `dtype.alignment`.
def _np_to_bytes( # type: ignore[no-any-unimported]
builder: Builder,
attribute: Attribute,
value: AttributeValueT,
) -> bytes:
"""Turn a single AttributeValueT value into a sequence of bytes to be
saved in a FlatBuffer shard.

Args:

builder (flatbuffers.Builder): The byte buffer being constructed.
Must be initialized.
builder (Builder): The FlatBuffer builder.

attribute (Attribute): Description of this attribute (shape and
dtype).
attribute (Attribute): Description of the attribute defining dtype and
shape.

value (AttributeValueT): The array to be saved. The shape should be
as defined in `attribute` (will be flattened).
value (AttributeValueT): The actual value which is being represented.

Returns: The offset returned by `flatbuffers.Builder.EndVector`.
Raises: ValueError if the value cannot be cast to the correct dtype with
the correct byteorder.
"""
# Not sure about flatbuffers.Builder __bool__ semantics.
assert builder is not None

# See `flatbuffers.builder.Builder.CreateNumpyVector`.

# Copy the value in order not to modify the original and flatten for
# better saving.
value_flattened = np.copy(value).flatten()
Expand Down Expand Up @@ -178,7 +161,75 @@ def save_numpy_vector_as_bytearray( # type: ignore[no-any-unimported]
f" {attribute.name}")

# This is going to be saved, ensure c_contiguous ordering.
byte_representation = value_np.tobytes(order="C")
return value_np.tobytes(order="C")

@staticmethod
def save_numpy_vector_as_bytearray( # type: ignore[no-any-unimported]
builder: Builder,
attribute: Attribute,
value: AttributeValueT,
) -> int:
"""Save a given array into a FlatBuffer as bytes. This is to ensure
compatibility with types which are not supported by FlatBuffers (e.g.,
np.float16). The FlatBuffers schema must mark this vector as type
bytes [byte] (see src/sedpack/io/flatbuffer/shard.fbs) since there is a
distinction of how the length is being saved. The inverse of this
function is
`sedpack.io.flatbuffer.iterate.IterateShardFlatBuffer.decode_array`.

If we have an array of np.int32 of 10 elements the FlatBuffers library
would save the length as 10. Which is then impossible to read in Rust
since the length and itemsize (sizeof of the type) are private. Thus we
could not get the full array back. Thus we are saving the array as 40
bytes. This function does not modify the `value`, and saves a flattened
version of it. This function also saves the exact dtype as given by
`attribute`. Bytes are being saved in little endian ("<") and
c_contiguous ("C") order, same as with FlatBuffers. Alignment is set to
`dtype.itemsize` as opposed to FlatBuffers choice of `dtype.alignment`.

Args:

builder (flatbuffers.Builder): The byte buffer being constructed.
Must be initialized.

attribute (Attribute): Description of this attribute (shape and
dtype).

value (AttributeValueT): The array to be saved. The shape should be
as defined in `attribute` (will be flattened).

Returns: The offset returned by `flatbuffers.Builder.EndVector`.

Raises: ValueError if the dtype is unknown to NumPy.
"""
# Not sure about flatbuffers.Builder __bool__ semantics.
assert builder is not None

# See `flatbuffers.builder.Builder.CreateNumpyVector`.

byte_representation: bytes
alignment: int
match attribute.dtype:
case "str":
byte_representation = str(value).encode("utf-8")
alignment = 16
case "bytes":
byte_representation = np.array(value).tobytes()
alignment = 16
case _:
try:
_ = np.dtype(attribute.dtype)
except Exception as exc:
raise ValueError(
f"Unsupported dtype {attribute.dtype}") from exc

byte_representation = ShardWriterFlatBuffer._np_to_bytes(
builder=builder,
attribute=attribute,
value=value,
)
value_np = np.array(value, dtype=attribute.dtype)
alignment = value_np.dtype.itemsize

# Total length of the array (in bytes).
length: int = len(byte_representation)
Expand All @@ -187,8 +238,7 @@ def save_numpy_vector_as_bytearray( # type: ignore[no-any-unimported]
builder.StartVector(
elemSize=1, # Storing bytes.
numElems=length,
alignment=value_np.dtype.
itemsize, # Cautious alignment of the array.
alignment=alignment, # Cautious alignment of the array.
)
builder.head = int(builder.Head() - length)

Expand Down
Loading
Loading