diff --git a/docs/tutorials/sca/tiny_aes.py b/docs/tutorials/sca/tiny_aes.py index 4c0b2a25..1c57f3ae 100644 --- a/docs/tutorials/sca/tiny_aes.py +++ b/docs/tutorials/sca/tiny_aes.py @@ -197,6 +197,23 @@ def convert_to_sedpack(dataset_path: Path, original_files: Path) -> None: dataset.write_config() +def process_batch(batch: dict[str, Any]) -> tuple[Any, dict[str, Any]]: + """Processing of a batch of records. The input is a dictionary of string + and tensor, the output of this function is a tuple the neural network's + input (trace) and a dictionary of one-hot encoded expected outputs. + """ + # The first neural network was using just the first half of the trace: + inputs = batch["trace1"] + outputs = { + "sub_bytes_in_0": + keras.ops.one_hot( + batch["sub_bytes_in"][:, 0], + num_classes=256, + ), + } + return (inputs, outputs) + + def process_record(record: dict[str, Any]) -> tuple[Any, dict[str, Any]]: """Processing of a single record. The input is a dictionary of string and tensor, the output of this function is a tuple the neural network's input @@ -256,20 +273,32 @@ def train(dataset_path: Path) -> None: ) model.summary() - train_ds = dataset.as_tfdataset( - split="train", - process_record=process_record, - batch_size=batch_size, - #file_parallelism=4, - #parallelism=4, - ) - validation_ds = dataset.as_tfdataset( - split="test", - process_record=process_record, - batch_size=batch_size, - #file_parallelism=4, - #parallelism=4, - ) + match keras.backend.backend(): + case "tensorflow": + train_ds = dataset.as_tfdataset( + split="train", + process_record=process_record, + batch_size=batch_size, + ) + validation_ds = dataset.as_tfdataset( + split="test", + process_record=process_record, + batch_size=batch_size, + ) + case "jax" | "torch": + train_ds = dataset.as_numpy_iterator_rust_batched( + split="train", + process_batch=process_batch, + batch_size=batch_size, + ) + validation_ds = dataset.as_numpy_iterator_rust_batched( + split="test", + process_batch=process_batch, + batch_size=batch_size, + ) + case _: + print(f"TODO support {keras.backend.backend() = }") + return # Train the model. _ = model.fit(