Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 43 additions & 14 deletions docs/tutorials/sca/tiny_aes.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,23 @@ def convert_to_sedpack(dataset_path: Path, original_files: Path) -> None:
dataset.write_config()


def process_batch(batch: dict[str, Any]) -> tuple[Any, dict[str, Any]]:
"""Processing of a batch of records. The input is a dictionary of string
and tensor, the output of this function is a tuple the neural network's
input (trace) and a dictionary of one-hot encoded expected outputs.
"""
# The first neural network was using just the first half of the trace:
inputs = batch["trace1"]
outputs = {
"sub_bytes_in_0":
keras.ops.one_hot(
batch["sub_bytes_in"][:, 0],
num_classes=256,
),
}
return (inputs, outputs)
Comment on lines +200 to +214
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The process_batch function is introduced to handle batched processing for JAX backend. It correctly extracts inputs and one-hot encodes the sub_bytes_in for the first byte. This is a good addition for multi-backend support.



def process_record(record: dict[str, Any]) -> tuple[Any, dict[str, Any]]:
"""Processing of a single record. The input is a dictionary of string and
tensor, the output of this function is a tuple the neural network's input
Expand Down Expand Up @@ -256,20 +273,32 @@ def train(dataset_path: Path) -> None:
)
model.summary()

train_ds = dataset.as_tfdataset(
split="train",
process_record=process_record,
batch_size=batch_size,
#file_parallelism=4,
#parallelism=4,
)
validation_ds = dataset.as_tfdataset(
split="test",
process_record=process_record,
batch_size=batch_size,
#file_parallelism=4,
#parallelism=4,
)
match keras.backend.backend():
case "tensorflow":
train_ds = dataset.as_tfdataset(
split="train",
process_record=process_record,
batch_size=batch_size,
)
validation_ds = dataset.as_tfdataset(
split="test",
process_record=process_record,
batch_size=batch_size,
)
case "jax" | "torch":
train_ds = dataset.as_numpy_iterator_rust_batched(
split="train",
process_batch=process_batch,
batch_size=batch_size,
)
validation_ds = dataset.as_numpy_iterator_rust_batched(
split="test",
process_batch=process_batch,
batch_size=batch_size,
)
case _:
print(f"TODO support {keras.backend.backend() = }")
return
Comment on lines +276 to +301
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The match statement for keras.backend.backend() is a clean way to handle different backends. The implementation for 'tensorflow' and 'jax' correctly uses as_tfdataset and as_numpy_iterator_rust_batched respectively. The TODO for unsupported backends is also appropriate.


# Train the model.
_ = model.fit(
Expand Down
Loading