From abd0543318bf684a12170e8134833c6a32f0f844 Mon Sep 17 00:00:00 2001 From: Karel Date: Fri, 20 Feb 2026 15:43:23 +0000 Subject: [PATCH 1/3] Support Keras multi-backend in tinyAES tutorial --- docs/tutorials/sca/tiny_aes.py | 58 ++++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 14 deletions(-) diff --git a/docs/tutorials/sca/tiny_aes.py b/docs/tutorials/sca/tiny_aes.py index 4c0b2a25..ce69248e 100644 --- a/docs/tutorials/sca/tiny_aes.py +++ b/docs/tutorials/sca/tiny_aes.py @@ -197,6 +197,23 @@ def convert_to_sedpack(dataset_path: Path, original_files: Path) -> None: dataset.write_config() +def process_batch(batch: dict[str, Any]) -> tuple[Any, dict[str, Any]]: + """Processing of a batch of records. The input is a dictionary of string + and tensor, the output of this function is a tuple the neural network's + input (trace) and a dictionary of one-hot encoded expected outputs. + """ + # The first neural network was using just the first half of the trace: + inputs = batch["trace1"] + outputs = { + "sub_bytes_in_0": + keras.ops.one_hot( + batch["sub_bytes_in"][:, 0], + num_classes=256, + ), + } + return (inputs, outputs) + + def process_record(record: dict[str, Any]) -> tuple[Any, dict[str, Any]]: """Processing of a single record. The input is a dictionary of string and tensor, the output of this function is a tuple the neural network's input @@ -256,20 +273,33 @@ def train(dataset_path: Path) -> None: ) model.summary() - train_ds = dataset.as_tfdataset( - split="train", - process_record=process_record, - batch_size=batch_size, - #file_parallelism=4, - #parallelism=4, - ) - validation_ds = dataset.as_tfdataset( - split="test", - process_record=process_record, - batch_size=batch_size, - #file_parallelism=4, - #parallelism=4, - ) + + match keras.backend.backend(): + case "tensorflow": + train_ds = dataset.as_tfdataset( + split="train", + process_record=process_record, + batch_size=batch_size, + ) + validation_ds = dataset.as_tfdataset( + split="test", + process_record=process_record, + batch_size=batch_size, + ) + case "jax": + train_ds = dataset.as_numpy_iterator_rust_batched( + split="train", + process_batch=process_batch, + batch_size=batch_size, + ) + validation_ds = dataset.as_numpy_iterator_rust_batched( + split="test", + process_batch=process_batch, + batch_size=batch_size, + ) + case _: + print(f"TODO support {keras.backend.backend() = }") + return # Train the model. _ = model.fit( From ca467b0e7e92d2c2bfaebd21c0b37ed16f654a4e Mon Sep 17 00:00:00 2001 From: Karel Date: Fri, 20 Feb 2026 15:54:52 +0000 Subject: [PATCH 2/3] [squash] support Keras with PyTorch --- docs/tutorials/sca/tiny_aes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/sca/tiny_aes.py b/docs/tutorials/sca/tiny_aes.py index ce69248e..e8909a00 100644 --- a/docs/tutorials/sca/tiny_aes.py +++ b/docs/tutorials/sca/tiny_aes.py @@ -286,7 +286,7 @@ def train(dataset_path: Path) -> None: process_record=process_record, batch_size=batch_size, ) - case "jax": + case "jax" | "torch": train_ds = dataset.as_numpy_iterator_rust_batched( split="train", process_batch=process_batch, From f90f4b1d745d3a6286a4c55f2897dfe566411bdf Mon Sep 17 00:00:00 2001 From: Karel Date: Fri, 20 Feb 2026 16:00:00 +0000 Subject: [PATCH 3/3] [squash] fix lint --- docs/tutorials/sca/tiny_aes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/tutorials/sca/tiny_aes.py b/docs/tutorials/sca/tiny_aes.py index e8909a00..1c57f3ae 100644 --- a/docs/tutorials/sca/tiny_aes.py +++ b/docs/tutorials/sca/tiny_aes.py @@ -273,7 +273,6 @@ def train(dataset_path: Path) -> None: ) model.summary() - match keras.backend.backend(): case "tensorflow": train_ds = dataset.as_tfdataset(