diff --git a/.github/workflows/test_ci.yml b/.github/workflows/test_ci.yml index 0a00288..45f545d 100644 --- a/.github/workflows/test_ci.yml +++ b/.github/workflows/test_ci.yml @@ -81,3 +81,28 @@ jobs: for exe in build/sha3_224_example build/sha3_256_example build/sha3_384_example build/sha3_512_example build/shake128_example build/shake256_example build/turboshake128_example build/turboshake256_example; do ./$exe done + + test-avx2: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + compiler: [g++, clang++] + steps: + - uses: actions/checkout@v6 + + - name: Configure (AVX2) + run: > + cmake -B build + -DCMAKE_CXX_COMPILER=${{ matrix.compiler }} + -DCMAKE_BUILD_TYPE=Release + -DSHA3_BUILD_TESTS=ON + -DSHA3_BUILD_BENCHMARKS=ON + -DSHA3_FETCH_DEPS=ON + -DSHA3_NATIVE_OPT=ON + + - name: Build + run: cmake --build build -j + + - name: Test + run: ctest --test-dir build --output-on-failure -j diff --git a/benches/bench_hashing.cpp b/benches/bench_hashing.cpp index 715dfda..d603463 100644 --- a/benches/bench_hashing.cpp +++ b/benches/bench_hashing.cpp @@ -25,7 +25,7 @@ bench_sha3_224(benchmark::State& state) benchmark::ClobberMemory(); } - const size_t bytes_processed = state.iterations() * (msg.size() + sha3_224::DIGEST_LEN); + const size_t bytes_processed = static_cast(state.iterations()) * (msg.size() + sha3_224::DIGEST_LEN); state.SetBytesProcessed(static_cast(bytes_processed)); #ifdef CYCLES_PER_BYTE @@ -50,7 +50,7 @@ bench_sha3_256(benchmark::State& state) benchmark::ClobberMemory(); } - const size_t bytes_processed = state.iterations() * (msg.size() + sha3_256::DIGEST_LEN); + const size_t bytes_processed = static_cast(state.iterations()) * (msg.size() + sha3_256::DIGEST_LEN); state.SetBytesProcessed(static_cast(bytes_processed)); #ifdef CYCLES_PER_BYTE @@ -75,7 +75,7 @@ bench_sha3_384(benchmark::State& state) benchmark::ClobberMemory(); } - const size_t bytes_processed = state.iterations() * (msg.size() + sha3_384::DIGEST_LEN); + const size_t bytes_processed = static_cast(state.iterations()) * (msg.size() + sha3_384::DIGEST_LEN); state.SetBytesProcessed(static_cast(bytes_processed)); #ifdef CYCLES_PER_BYTE @@ -100,7 +100,7 @@ bench_sha3_512(benchmark::State& state) benchmark::ClobberMemory(); } - const size_t bytes_processed = state.iterations() * (msg.size() + sha3_512::DIGEST_LEN); + const size_t bytes_processed = static_cast(state.iterations()) * (msg.size() + sha3_512::DIGEST_LEN); state.SetBytesProcessed(static_cast(bytes_processed)); #ifdef CYCLES_PER_BYTE diff --git a/benches/bench_keccak.cpp b/benches/bench_keccak.cpp index ed7edda..7bd0e24 100644 --- a/benches/bench_keccak.cpp +++ b/benches/bench_keccak.cpp @@ -3,6 +3,10 @@ #include #include +#if defined(__AVX2__) +#include "sha3/internals/keccak_x4.hpp" +#endif + namespace { // Benchmarks Keccak-p[1600, 12] or Keccak-p[1600, 24] permutation. @@ -20,7 +24,38 @@ bench_keccak_permutation(benchmark::State& state) benchmark::ClobberMemory(); } - const size_t bytes_processed = state.iterations() * sizeof(st); + const size_t bytes_processed = static_cast(state.iterations()) * sizeof(st); + state.SetBytesProcessed(static_cast(bytes_processed)); + +#ifdef CYCLES_PER_BYTE + state.counters["CYCLES/ BYTE"] = state.counters["CYCLES"] / static_cast(bytes_processed); +#endif +} + +#if defined(__AVX2__) + +// Benchmarks 4-way parallel Keccak-p[1600, 12] or Keccak-p[1600, 24] permutation using AVX2. +template +void +bench_keccak_x4_permutation(benchmark::State& state) +{ + using vec = keccak_x4::vec; + + std::array st{}; + for (auto& lane : st) { + std::array tmp{}; + generate_random_data(tmp); + lane = vec::load(tmp); + } + + for (auto _ : state) { + keccak_x4::permute(st); + + benchmark::DoNotOptimize(st); + benchmark::ClobberMemory(); + } + + const size_t bytes_processed = static_cast(state.iterations()) * sizeof(st); state.SetBytesProcessed(static_cast(bytes_processed)); #ifdef CYCLES_PER_BYTE @@ -28,7 +63,14 @@ bench_keccak_permutation(benchmark::State& state) #endif } +#endif + } BENCHMARK(bench_keccak_permutation<12>)->Name("keccak-p[1600, 12]")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max); BENCHMARK(bench_keccak_permutation<24>)->Name("keccak-p[1600, 24]")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max); + +#if defined(__AVX2__) +BENCHMARK(bench_keccak_x4_permutation<12>)->Name("keccak-p[1600, 12] x4/avx2")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max); +BENCHMARK(bench_keccak_x4_permutation<24>)->Name("keccak-p[1600, 24] x4/avx2")->ComputeStatistics("min", compute_min)->ComputeStatistics("max", compute_max); +#endif diff --git a/benches/bench_xof.cpp b/benches/bench_xof.cpp index 84609d0..13c91de 100644 --- a/benches/bench_xof.cpp +++ b/benches/bench_xof.cpp @@ -6,6 +6,11 @@ #include #include +#if defined(__AVX2__) +#include "sha3/shake128_x4.hpp" +#include "sha3/shake256_x4.hpp" +#endif + namespace { /** @@ -37,7 +42,7 @@ bench_shake128(benchmark::State& state) benchmark::ClobberMemory(); } - const size_t bytes_processed = state.iterations() * (msg.size() + out.size()); + const size_t bytes_processed = static_cast(state.iterations()) * (msg.size() + out.size()); state.SetBytesProcessed(static_cast(bytes_processed)); #ifdef CYCLES_PER_BYTE @@ -74,7 +79,7 @@ bench_shake256(benchmark::State& state) benchmark::ClobberMemory(); } - const size_t bytes_processed = state.iterations() * (msg.size() + out.size()); + const size_t bytes_processed = static_cast(state.iterations()) * (msg.size() + out.size()); state.SetBytesProcessed(static_cast(bytes_processed)); #ifdef CYCLES_PER_BYTE @@ -111,7 +116,7 @@ bench_turboshake128(benchmark::State& state) benchmark::ClobberMemory(); } - const size_t bytes_processed = state.iterations() * (msg.size() + out.size()); + const size_t bytes_processed = static_cast(state.iterations()) * (msg.size() + out.size()); state.SetBytesProcessed(static_cast(bytes_processed)); #ifdef CYCLES_PER_BYTE @@ -148,7 +153,7 @@ bench_turboshake256(benchmark::State& state) benchmark::ClobberMemory(); } - const size_t bytes_processed = state.iterations() * (msg.size() + out.size()); + const size_t bytes_processed = static_cast(state.iterations()) * (msg.size() + out.size()); state.SetBytesProcessed(static_cast(bytes_processed)); #ifdef CYCLES_PER_BYTE @@ -156,6 +161,98 @@ bench_turboshake256(benchmark::State& state) #endif } +#if defined(__AVX2__) + +// Benchmarks 4-way parallel SHAKE128 XOF using AVX2. +void +bench_shake128_x4(benchmark::State& state) +{ + const auto mlen = static_cast(state.range(0)); + const auto olen = static_cast(state.range(1)); + + std::vector msg0(mlen); + std::vector msg1(mlen); + std::vector msg2(mlen); + std::vector msg3(mlen); + + std::vector out0(olen); + std::vector out1(olen); + std::vector out2(olen); + std::vector out3(olen); + + generate_random_data(msg0); + generate_random_data(msg1); + generate_random_data(msg2); + generate_random_data(msg3); + + for (auto _ : state) { + shake128_x4::shake128_x4_t hasher; + hasher.absorb(msg0, msg1, msg2, msg3); + hasher.finalize(); + hasher.squeeze(out0, out1, out2, out3); + + benchmark::DoNotOptimize(hasher); + benchmark::DoNotOptimize(out0); + benchmark::DoNotOptimize(out1); + benchmark::DoNotOptimize(out2); + benchmark::DoNotOptimize(out3); + benchmark::ClobberMemory(); + } + + const size_t bytes_processed = static_cast(state.iterations()) * 4 * (mlen + olen); + state.SetBytesProcessed(static_cast(bytes_processed)); + +#ifdef CYCLES_PER_BYTE + state.counters["CYCLES/ BYTE"] = state.counters["CYCLES"] / static_cast(bytes_processed); +#endif +} + +// Benchmarks 4-way parallel SHAKE256 XOF using AVX2. +void +bench_shake256_x4(benchmark::State& state) +{ + const auto mlen = static_cast(state.range(0)); + const auto olen = static_cast(state.range(1)); + + std::vector msg0(mlen); + std::vector msg1(mlen); + std::vector msg2(mlen); + std::vector msg3(mlen); + + std::vector out0(olen); + std::vector out1(olen); + std::vector out2(olen); + std::vector out3(olen); + + generate_random_data(msg0); + generate_random_data(msg1); + generate_random_data(msg2); + generate_random_data(msg3); + + for (auto _ : state) { + shake256_x4::shake256_x4_t hasher; + hasher.absorb(msg0, msg1, msg2, msg3); + hasher.finalize(); + hasher.squeeze(out0, out1, out2, out3); + + benchmark::DoNotOptimize(hasher); + benchmark::DoNotOptimize(out0); + benchmark::DoNotOptimize(out1); + benchmark::DoNotOptimize(out2); + benchmark::DoNotOptimize(out3); + benchmark::ClobberMemory(); + } + + const size_t bytes_processed = static_cast(state.iterations()) * 4 * (mlen + olen); + state.SetBytesProcessed(static_cast(bytes_processed)); + +#ifdef CYCLES_PER_BYTE + state.counters["CYCLES/ BYTE"] = state.counters["CYCLES"] / static_cast(bytes_processed); +#endif +} + +#endif + } BENCHMARK(bench_shake128) @@ -178,3 +275,16 @@ BENCHMARK(bench_turboshake256) ->Name("turboshake256") ->ComputeStatistics("min", compute_min) ->ComputeStatistics("max", compute_max); + +#if defined(__AVX2__) +BENCHMARK(bench_shake128_x4) + ->ArgsProduct({ benchmark::CreateRange(64, 16384, 4), { 64 } }) + ->Name("shake128_x4/avx2") + ->ComputeStatistics("min", compute_min) + ->ComputeStatistics("max", compute_max); +BENCHMARK(bench_shake256_x4) + ->ArgsProduct({ benchmark::CreateRange(64, 16384, 4), { 64 } }) + ->Name("shake256_x4/avx2") + ->ComputeStatistics("min", compute_min) + ->ComputeStatistics("max", compute_max); +#endif diff --git a/include/sha3/internals/keccak_x4.hpp b/include/sha3/internals/keccak_x4.hpp new file mode 100644 index 0000000..355eca9 --- /dev/null +++ b/include/sha3/internals/keccak_x4.hpp @@ -0,0 +1,455 @@ +#pragma once + +#if defined(__AVX2__) + +#include "sha3/internals/force_inline.hpp" +#include "sha3/internals/keccak.hpp" +#include "sha3/internals/simd/avx2.hpp" +#include +#include + +// 4-way parallel Keccak-p[1600, 12] and Keccak-p[1600, 24] permutation using AVX2. +// Each vec lane holds one of 4 independent Keccak-f[1600] states. +namespace keccak_x4 { + +using vec = sha3_simd::avx2::vec; + +/** + * 4-way parallel Keccak-f[1600] round function, applying all five step mapping functions, updating 4 states simultaneously. + * Applies 4 consecutive rounds (ridx, ridx+1, ridx+2, ridx+3) in a single call. + * + * This is a direct AVX2 translation of keccak::roundx4() from keccak.hpp. + */ +static forceinline void +roundx4(std::array& state, const size_t ridx) +{ + std::array bc{}; + std::array d{}; + vec t; + + // Round ridx + 0 + bc[0] = bc[1] = bc[2] = bc[3] = bc[4] = vec{}; + + for (size_t i = 0; i < keccak::LANE_CNT; i += 5) { + bc[0] ^= state[i + 0]; + bc[1] ^= state[i + 1]; + bc[2] ^= state[i + 2]; + bc[3] ^= state[i + 3]; + bc[4] ^= state[i + 4]; + } + + d[0] = bc[4] ^ rotl64(bc[1], 1); + d[1] = bc[0] ^ rotl64(bc[2], 1); + d[2] = bc[1] ^ rotl64(bc[3], 1); + d[3] = bc[2] ^ rotl64(bc[4], 1); + d[4] = bc[3] ^ rotl64(bc[0], 1); + + bc[0] = state[0] ^ d[0]; + t = state[6] ^ d[1]; + bc[1] = rotl64(t, keccak::ROT[6]); + t = state[12] ^ d[2]; + bc[2] = rotl64(t, keccak::ROT[12]); + t = state[18] ^ d[3]; + bc[3] = rotl64(t, keccak::ROT[18]); + t = state[24] ^ d[4]; + bc[4] = rotl64(t, keccak::ROT[24]); + + state[0] = (bc[0] ^ andnot(bc[1], bc[2])) ^ vec::broadcast(keccak::RC[ridx]); + state[6] = bc[1] ^ andnot(bc[2], bc[3]); + state[12] = bc[2] ^ andnot(bc[3], bc[4]); + state[18] = bc[3] ^ andnot(bc[4], bc[0]); + state[24] = bc[4] ^ andnot(bc[0], bc[1]); + + t = state[10] ^ d[0]; + bc[2] = rotl64(t, keccak::ROT[10]); + t = state[16] ^ d[1]; + bc[3] = rotl64(t, keccak::ROT[16]); + t = state[22] ^ d[2]; + bc[4] = rotl64(t, keccak::ROT[22]); + t = state[3] ^ d[3]; + bc[0] = rotl64(t, keccak::ROT[3]); + t = state[9] ^ d[4]; + bc[1] = rotl64(t, keccak::ROT[9]); + + state[10] = bc[0] ^ andnot(bc[1], bc[2]); + state[16] = bc[1] ^ andnot(bc[2], bc[3]); + state[22] = bc[2] ^ andnot(bc[3], bc[4]); + state[3] = bc[3] ^ andnot(bc[4], bc[0]); + state[9] = bc[4] ^ andnot(bc[0], bc[1]); + + t = state[20] ^ d[0]; + bc[4] = rotl64(t, keccak::ROT[20]); + t = state[1] ^ d[1]; + bc[0] = rotl64(t, keccak::ROT[1]); + t = state[7] ^ d[2]; + bc[1] = rotl64(t, keccak::ROT[7]); + t = state[13] ^ d[3]; + bc[2] = rotl64(t, keccak::ROT[13]); + t = state[19] ^ d[4]; + bc[3] = rotl64(t, keccak::ROT[19]); + + state[20] = bc[0] ^ andnot(bc[1], bc[2]); + state[1] = bc[1] ^ andnot(bc[2], bc[3]); + state[7] = bc[2] ^ andnot(bc[3], bc[4]); + state[13] = bc[3] ^ andnot(bc[4], bc[0]); + state[19] = bc[4] ^ andnot(bc[0], bc[1]); + + t = state[5] ^ d[0]; + bc[1] = rotl64(t, keccak::ROT[5]); + t = state[11] ^ d[1]; + bc[2] = rotl64(t, keccak::ROT[11]); + t = state[17] ^ d[2]; + bc[3] = rotl64(t, keccak::ROT[17]); + t = state[23] ^ d[3]; + bc[4] = rotl64(t, keccak::ROT[23]); + t = state[4] ^ d[4]; + bc[0] = rotl64(t, keccak::ROT[4]); + + state[5] = bc[0] ^ andnot(bc[1], bc[2]); + state[11] = bc[1] ^ andnot(bc[2], bc[3]); + state[17] = bc[2] ^ andnot(bc[3], bc[4]); + state[23] = bc[3] ^ andnot(bc[4], bc[0]); + state[4] = bc[4] ^ andnot(bc[0], bc[1]); + + t = state[15] ^ d[0]; + bc[3] = rotl64(t, keccak::ROT[15]); + t = state[21] ^ d[1]; + bc[4] = rotl64(t, keccak::ROT[21]); + t = state[2] ^ d[2]; + bc[0] = rotl64(t, keccak::ROT[2]); + t = state[8] ^ d[3]; + bc[1] = rotl64(t, keccak::ROT[8]); + t = state[14] ^ d[4]; + bc[2] = rotl64(t, keccak::ROT[14]); + + state[15] = bc[0] ^ andnot(bc[1], bc[2]); + state[21] = bc[1] ^ andnot(bc[2], bc[3]); + state[2] = bc[2] ^ andnot(bc[3], bc[4]); + state[8] = bc[3] ^ andnot(bc[4], bc[0]); + state[14] = bc[4] ^ andnot(bc[0], bc[1]); + + // Round ridx + 1 + bc[0] = bc[1] = bc[2] = bc[3] = bc[4] = vec{}; + + for (size_t i = 0; i < keccak::LANE_CNT; i += 5) { + bc[0] ^= state[i + 0]; + bc[1] ^= state[i + 1]; + bc[2] ^= state[i + 2]; + bc[3] ^= state[i + 3]; + bc[4] ^= state[i + 4]; + } + + d[0] = bc[4] ^ rotl64(bc[1], 1); + d[1] = bc[0] ^ rotl64(bc[2], 1); + d[2] = bc[1] ^ rotl64(bc[3], 1); + d[3] = bc[2] ^ rotl64(bc[4], 1); + d[4] = bc[3] ^ rotl64(bc[0], 1); + + bc[0] = state[0] ^ d[0]; + t = state[16] ^ d[1]; + bc[1] = rotl64(t, keccak::ROT[6]); + t = state[7] ^ d[2]; + bc[2] = rotl64(t, keccak::ROT[12]); + t = state[23] ^ d[3]; + bc[3] = rotl64(t, keccak::ROT[18]); + t = state[14] ^ d[4]; + bc[4] = rotl64(t, keccak::ROT[24]); + + state[0] = (bc[0] ^ andnot(bc[1], bc[2])) ^ vec::broadcast(keccak::RC[ridx + 1]); + state[16] = bc[1] ^ andnot(bc[2], bc[3]); + state[7] = bc[2] ^ andnot(bc[3], bc[4]); + state[23] = bc[3] ^ andnot(bc[4], bc[0]); + state[14] = bc[4] ^ andnot(bc[0], bc[1]); + + t = state[20] ^ d[0]; + bc[2] = rotl64(t, keccak::ROT[10]); + t = state[11] ^ d[1]; + bc[3] = rotl64(t, keccak::ROT[16]); + t = state[2] ^ d[2]; + bc[4] = rotl64(t, keccak::ROT[22]); + t = state[18] ^ d[3]; + bc[0] = rotl64(t, keccak::ROT[3]); + t = state[9] ^ d[4]; + bc[1] = rotl64(t, keccak::ROT[9]); + + state[20] = bc[0] ^ andnot(bc[1], bc[2]); + state[11] = bc[1] ^ andnot(bc[2], bc[3]); + state[2] = bc[2] ^ andnot(bc[3], bc[4]); + state[18] = bc[3] ^ andnot(bc[4], bc[0]); + state[9] = bc[4] ^ andnot(bc[0], bc[1]); + + t = state[15] ^ d[0]; + bc[4] = rotl64(t, keccak::ROT[20]); + t = state[6] ^ d[1]; + bc[0] = rotl64(t, keccak::ROT[1]); + t = state[22] ^ d[2]; + bc[1] = rotl64(t, keccak::ROT[7]); + t = state[13] ^ d[3]; + bc[2] = rotl64(t, keccak::ROT[13]); + t = state[4] ^ d[4]; + bc[3] = rotl64(t, keccak::ROT[19]); + + state[15] = bc[0] ^ andnot(bc[1], bc[2]); + state[6] = bc[1] ^ andnot(bc[2], bc[3]); + state[22] = bc[2] ^ andnot(bc[3], bc[4]); + state[13] = bc[3] ^ andnot(bc[4], bc[0]); + state[4] = bc[4] ^ andnot(bc[0], bc[1]); + + t = state[10] ^ d[0]; + bc[1] = rotl64(t, keccak::ROT[5]); + t = state[1] ^ d[1]; + bc[2] = rotl64(t, keccak::ROT[11]); + t = state[17] ^ d[2]; + bc[3] = rotl64(t, keccak::ROT[17]); + t = state[8] ^ d[3]; + bc[4] = rotl64(t, keccak::ROT[23]); + t = state[24] ^ d[4]; + bc[0] = rotl64(t, keccak::ROT[4]); + + state[10] = bc[0] ^ andnot(bc[1], bc[2]); + state[1] = bc[1] ^ andnot(bc[2], bc[3]); + state[17] = bc[2] ^ andnot(bc[3], bc[4]); + state[8] = bc[3] ^ andnot(bc[4], bc[0]); + state[24] = bc[4] ^ andnot(bc[0], bc[1]); + + t = state[5] ^ d[0]; + bc[3] = rotl64(t, keccak::ROT[15]); + t = state[21] ^ d[1]; + bc[4] = rotl64(t, keccak::ROT[21]); + t = state[12] ^ d[2]; + bc[0] = rotl64(t, keccak::ROT[2]); + t = state[3] ^ d[3]; + bc[1] = rotl64(t, keccak::ROT[8]); + t = state[19] ^ d[4]; + bc[2] = rotl64(t, keccak::ROT[14]); + + state[5] = bc[0] ^ andnot(bc[1], bc[2]); + state[21] = bc[1] ^ andnot(bc[2], bc[3]); + state[12] = bc[2] ^ andnot(bc[3], bc[4]); + state[3] = bc[3] ^ andnot(bc[4], bc[0]); + state[19] = bc[4] ^ andnot(bc[0], bc[1]); + + // Round ridx + 2 + bc[0] = bc[1] = bc[2] = bc[3] = bc[4] = vec{}; + + for (size_t i = 0; i < keccak::LANE_CNT; i += 5) { + bc[0] ^= state[i + 0]; + bc[1] ^= state[i + 1]; + bc[2] ^= state[i + 2]; + bc[3] ^= state[i + 3]; + bc[4] ^= state[i + 4]; + } + + d[0] = bc[4] ^ rotl64(bc[1], 1); + d[1] = bc[0] ^ rotl64(bc[2], 1); + d[2] = bc[1] ^ rotl64(bc[3], 1); + d[3] = bc[2] ^ rotl64(bc[4], 1); + d[4] = bc[3] ^ rotl64(bc[0], 1); + + bc[0] = state[0] ^ d[0]; + t = state[11] ^ d[1]; + bc[1] = rotl64(t, keccak::ROT[6]); + t = state[22] ^ d[2]; + bc[2] = rotl64(t, keccak::ROT[12]); + t = state[8] ^ d[3]; + bc[3] = rotl64(t, keccak::ROT[18]); + t = state[19] ^ d[4]; + bc[4] = rotl64(t, keccak::ROT[24]); + + state[0] = (bc[0] ^ andnot(bc[1], bc[2])) ^ vec::broadcast(keccak::RC[ridx + 2]); + state[11] = bc[1] ^ andnot(bc[2], bc[3]); + state[22] = bc[2] ^ andnot(bc[3], bc[4]); + state[8] = bc[3] ^ andnot(bc[4], bc[0]); + state[19] = bc[4] ^ andnot(bc[0], bc[1]); + + t = state[15] ^ d[0]; + bc[2] = rotl64(t, keccak::ROT[10]); + t = state[1] ^ d[1]; + bc[3] = rotl64(t, keccak::ROT[16]); + t = state[12] ^ d[2]; + bc[4] = rotl64(t, keccak::ROT[22]); + t = state[23] ^ d[3]; + bc[0] = rotl64(t, keccak::ROT[3]); + t = state[9] ^ d[4]; + bc[1] = rotl64(t, keccak::ROT[9]); + + state[15] = bc[0] ^ andnot(bc[1], bc[2]); + state[1] = bc[1] ^ andnot(bc[2], bc[3]); + state[12] = bc[2] ^ andnot(bc[3], bc[4]); + state[23] = bc[3] ^ andnot(bc[4], bc[0]); + state[9] = bc[4] ^ andnot(bc[0], bc[1]); + + t = state[5] ^ d[0]; + bc[4] = rotl64(t, keccak::ROT[20]); + t = state[16] ^ d[1]; + bc[0] = rotl64(t, keccak::ROT[1]); + t = state[2] ^ d[2]; + bc[1] = rotl64(t, keccak::ROT[7]); + t = state[13] ^ d[3]; + bc[2] = rotl64(t, keccak::ROT[13]); + t = state[24] ^ d[4]; + bc[3] = rotl64(t, keccak::ROT[19]); + + state[5] = bc[0] ^ andnot(bc[1], bc[2]); + state[16] = bc[1] ^ andnot(bc[2], bc[3]); + state[2] = bc[2] ^ andnot(bc[3], bc[4]); + state[13] = bc[3] ^ andnot(bc[4], bc[0]); + state[24] = bc[4] ^ andnot(bc[0], bc[1]); + + t = state[20] ^ d[0]; + bc[1] = rotl64(t, keccak::ROT[5]); + t = state[6] ^ d[1]; + bc[2] = rotl64(t, keccak::ROT[11]); + t = state[17] ^ d[2]; + bc[3] = rotl64(t, keccak::ROT[17]); + t = state[3] ^ d[3]; + bc[4] = rotl64(t, keccak::ROT[23]); + t = state[14] ^ d[4]; + bc[0] = rotl64(t, keccak::ROT[4]); + + state[20] = bc[0] ^ andnot(bc[1], bc[2]); + state[6] = bc[1] ^ andnot(bc[2], bc[3]); + state[17] = bc[2] ^ andnot(bc[3], bc[4]); + state[3] = bc[3] ^ andnot(bc[4], bc[0]); + state[14] = bc[4] ^ andnot(bc[0], bc[1]); + + t = state[10] ^ d[0]; + bc[3] = rotl64(t, keccak::ROT[15]); + t = state[21] ^ d[1]; + bc[4] = rotl64(t, keccak::ROT[21]); + t = state[7] ^ d[2]; + bc[0] = rotl64(t, keccak::ROT[2]); + t = state[18] ^ d[3]; + bc[1] = rotl64(t, keccak::ROT[8]); + t = state[4] ^ d[4]; + bc[2] = rotl64(t, keccak::ROT[14]); + + state[10] = bc[0] ^ andnot(bc[1], bc[2]); + state[21] = bc[1] ^ andnot(bc[2], bc[3]); + state[7] = bc[2] ^ andnot(bc[3], bc[4]); + state[18] = bc[3] ^ andnot(bc[4], bc[0]); + state[4] = bc[4] ^ andnot(bc[0], bc[1]); + + // Round ridx + 3 + bc[0] = bc[1] = bc[2] = bc[3] = bc[4] = vec{}; + + for (size_t i = 0; i < keccak::LANE_CNT; i += 5) { + bc[0] ^= state[i + 0]; + bc[1] ^= state[i + 1]; + bc[2] ^= state[i + 2]; + bc[3] ^= state[i + 3]; + bc[4] ^= state[i + 4]; + } + + d[0] = bc[4] ^ rotl64(bc[1], 1); + d[1] = bc[0] ^ rotl64(bc[2], 1); + d[2] = bc[1] ^ rotl64(bc[3], 1); + d[3] = bc[2] ^ rotl64(bc[4], 1); + d[4] = bc[3] ^ rotl64(bc[0], 1); + + bc[0] = state[0] ^ d[0]; + t = state[1] ^ d[1]; + bc[1] = rotl64(t, keccak::ROT[6]); + t = state[2] ^ d[2]; + bc[2] = rotl64(t, keccak::ROT[12]); + t = state[3] ^ d[3]; + bc[3] = rotl64(t, keccak::ROT[18]); + t = state[4] ^ d[4]; + bc[4] = rotl64(t, keccak::ROT[24]); + + state[0] = (bc[0] ^ andnot(bc[1], bc[2])) ^ vec::broadcast(keccak::RC[ridx + 3]); + state[1] = bc[1] ^ andnot(bc[2], bc[3]); + state[2] = bc[2] ^ andnot(bc[3], bc[4]); + state[3] = bc[3] ^ andnot(bc[4], bc[0]); + state[4] = bc[4] ^ andnot(bc[0], bc[1]); + + t = state[5] ^ d[0]; + bc[2] = rotl64(t, keccak::ROT[10]); + t = state[6] ^ d[1]; + bc[3] = rotl64(t, keccak::ROT[16]); + t = state[7] ^ d[2]; + bc[4] = rotl64(t, keccak::ROT[22]); + t = state[8] ^ d[3]; + bc[0] = rotl64(t, keccak::ROT[3]); + t = state[9] ^ d[4]; + bc[1] = rotl64(t, keccak::ROT[9]); + + state[5] = bc[0] ^ andnot(bc[1], bc[2]); + state[6] = bc[1] ^ andnot(bc[2], bc[3]); + state[7] = bc[2] ^ andnot(bc[3], bc[4]); + state[8] = bc[3] ^ andnot(bc[4], bc[0]); + state[9] = bc[4] ^ andnot(bc[0], bc[1]); + + t = state[10] ^ d[0]; + bc[4] = rotl64(t, keccak::ROT[20]); + t = state[11] ^ d[1]; + bc[0] = rotl64(t, keccak::ROT[1]); + t = state[12] ^ d[2]; + bc[1] = rotl64(t, keccak::ROT[7]); + t = state[13] ^ d[3]; + bc[2] = rotl64(t, keccak::ROT[13]); + t = state[14] ^ d[4]; + bc[3] = rotl64(t, keccak::ROT[19]); + + state[10] = bc[0] ^ andnot(bc[1], bc[2]); + state[11] = bc[1] ^ andnot(bc[2], bc[3]); + state[12] = bc[2] ^ andnot(bc[3], bc[4]); + state[13] = bc[3] ^ andnot(bc[4], bc[0]); + state[14] = bc[4] ^ andnot(bc[0], bc[1]); + + t = state[15] ^ d[0]; + bc[1] = rotl64(t, keccak::ROT[5]); + t = state[16] ^ d[1]; + bc[2] = rotl64(t, keccak::ROT[11]); + t = state[17] ^ d[2]; + bc[3] = rotl64(t, keccak::ROT[17]); + t = state[18] ^ d[3]; + bc[4] = rotl64(t, keccak::ROT[23]); + t = state[19] ^ d[4]; + bc[0] = rotl64(t, keccak::ROT[4]); + + state[15] = bc[0] ^ andnot(bc[1], bc[2]); + state[16] = bc[1] ^ andnot(bc[2], bc[3]); + state[17] = bc[2] ^ andnot(bc[3], bc[4]); + state[18] = bc[3] ^ andnot(bc[4], bc[0]); + state[19] = bc[4] ^ andnot(bc[0], bc[1]); + + t = state[20] ^ d[0]; + bc[3] = rotl64(t, keccak::ROT[15]); + t = state[21] ^ d[1]; + bc[4] = rotl64(t, keccak::ROT[21]); + t = state[22] ^ d[2]; + bc[0] = rotl64(t, keccak::ROT[2]); + t = state[23] ^ d[3]; + bc[1] = rotl64(t, keccak::ROT[8]); + t = state[24] ^ d[4]; + bc[2] = rotl64(t, keccak::ROT[14]); + + state[20] = bc[0] ^ andnot(bc[1], bc[2]); + state[21] = bc[1] ^ andnot(bc[2], bc[3]); + state[22] = bc[2] ^ andnot(bc[3], bc[4]); + state[23] = bc[3] ^ andnot(bc[4], bc[0]); + state[24] = bc[4] ^ andnot(bc[0], bc[1]); +} + +/** + * 4-way parallel Keccak-f[1600] permutation, applying either 12 or 24 rounds on 4 independent states simultaneously. + */ +template +forceinline void +permute(std::array& state) + requires((num_rounds == 12) || (num_rounds == keccak::MAX_NUM_ROUNDS)) +{ + constexpr size_t start_at_round = keccak::MAX_NUM_ROUNDS - num_rounds; + constexpr size_t STEP_BY = 4; + + static_assert(num_rounds % STEP_BY == 0, "Requested number of keccak-p[1600] rounds need to be a multiple of 4 for manual unrolling to work."); + + for (size_t i = start_at_round; i < keccak::MAX_NUM_ROUNDS; i += STEP_BY) { + roundx4(state, i); + } +} + +} + +#endif diff --git a/include/sha3/internals/simd/avx2.hpp b/include/sha3/internals/simd/avx2.hpp new file mode 100644 index 0000000..ed4db66 --- /dev/null +++ b/include/sha3/internals/simd/avx2.hpp @@ -0,0 +1,95 @@ +#pragma once + +#if defined(__AVX2__) + +#include "sha3/internals/force_inline.hpp" +#include +#include +#include +#include + +// AVX2 SIMD vector wrapper for 4x64-bit lanes. +// Zero-cost abstraction: every method is forceinline and wraps a single intrinsic. +namespace sha3_simd::avx2 { + +struct vec +{ +private: + __m256i v; + +public: + forceinline vec() + : v(_mm256_setzero_si256()) + { + } + + forceinline explicit vec(__m256i x) + : v(x) + { + } + + forceinline explicit operator __m256i() const { return v; } + + // --- Factories --- + + static forceinline vec broadcast(uint64_t x) { return vec(_mm256_set1_epi64x(static_cast(x))); } + + // e0 = bits[63:0], e1 = bits[127:64], e2 = bits[191:128], e3 = bits[255:192]. + static forceinline vec set(uint64_t e0, uint64_t e1, uint64_t e2, uint64_t e3) + { + return vec(_mm256_set_epi64x(static_cast(e3), static_cast(e2), static_cast(e1), static_cast(e0))); + } + + static forceinline vec load(std::span data) + { + __m256i result; + std::memcpy(&result, data.data(), sizeof(__m256i)); + + return vec(result); + } + + // --- Bitwise operators --- + + forceinline vec operator^(vec rhs) const { return vec(_mm256_xor_si256(v, rhs.v)); } + forceinline vec& operator^=(vec rhs) + { + *this = *this ^ rhs; + return *this; + } + + forceinline vec operator|(vec rhs) const { return vec(_mm256_or_si256(v, rhs.v)); } + forceinline vec& operator|=(vec rhs) + { + *this = *this | rhs; + return *this; + } + + forceinline vec operator&(vec rhs) const { return vec(_mm256_and_si256(v, rhs.v)); } + forceinline vec& operator&=(vec rhs) + { + *this = *this & rhs; + return *this; + } + + // Per-lane 64-bit shifts. + forceinline vec operator<<(int n) const { return vec(_mm256_slli_epi64(v, n)); } + forceinline vec operator>>(int n) const { return vec(_mm256_srli_epi64(v, n)); } + + // --- Lane extraction --- + + template + [[nodiscard]] forceinline uint64_t extract() const + { + return static_cast(_mm256_extract_epi64(v, lane)); + } + + // ~a & b (note: _mm256_andnot_si256 computes ~first & second). + friend forceinline vec andnot(vec a, vec b) { return vec(_mm256_andnot_si256(a.v, b.v)); } + + // Rotate left each 64-bit lane by n bits. + friend forceinline vec rotl64(vec x, int n) { return (x << n) | (x >> (64 - n)); } +}; + +} + +#endif diff --git a/include/sha3/internals/sponge_x4.hpp b/include/sha3/internals/sponge_x4.hpp new file mode 100644 index 0000000..8743a54 --- /dev/null +++ b/include/sha3/internals/sponge_x4.hpp @@ -0,0 +1,181 @@ +#pragma once + +#if defined(__AVX2__) + +#include "sha3/internals/force_inline.hpp" +#include "sha3/internals/keccak.hpp" +#include "sha3/internals/keccak_x4.hpp" +#include "sha3/internals/simd/avx2.hpp" +#include "sha3/internals/utils.hpp" +#include +#include +#include +#include +#include +#include + +// 4-way parallel Keccak sponge functions using AVX2. +// All 4 instances must absorb same-length messages and squeeze same-length outputs. +namespace sponge_x4 { + +using vec = sha3_simd::avx2::vec; + +static constexpr size_t KECCAK_WORD_BYTE_LEN = keccak::LANE_BW / std::numeric_limits::digits; + +/** + * Absorb 4 same-length messages into 4 parallel Keccak[c] sponge states. + * All 4 messages must have the same length. + */ +template +static forceinline void +absorb_x4(std::array& state, + size_t& offset, + std::span msg0, + std::span msg1, + std::span msg2, + std::span msg3) +{ + constexpr size_t num_bytes_in_rate = num_bits_in_rate / std::numeric_limits::digits; + + std::array blk0{}; + std::array blk1{}; + std::array blk2{}; + std::array blk3{}; + + auto blk0_span = std::span(blk0); + auto blk1_span = std::span(blk1); + auto blk2_span = std::span(blk2); + auto blk3_span = std::span(blk3); + + const size_t mlen = msg0.size(); + size_t msg_offset = 0; + + while (msg_offset < mlen) { + const size_t remaining_num_bytes = mlen - msg_offset; + const size_t absorbable_num_bytes = std::min(remaining_num_bytes, num_bytes_in_rate - offset); + const size_t effective_block_byte_len = offset + absorbable_num_bytes; + const size_t padded_effective_block_byte_len = (effective_block_byte_len + (KECCAK_WORD_BYTE_LEN - 1)) & (-KECCAK_WORD_BYTE_LEN); + const size_t padded_effective_block_begins_at = offset & (-KECCAK_WORD_BYTE_LEN); + const size_t fill_len = padded_effective_block_byte_len - padded_effective_block_begins_at; + + std::fill_n(blk0_span.subspan(padded_effective_block_begins_at).begin(), fill_len, 0); + std::fill_n(blk1_span.subspan(padded_effective_block_begins_at).begin(), fill_len, 0); + std::fill_n(blk2_span.subspan(padded_effective_block_begins_at).begin(), fill_len, 0); + std::fill_n(blk3_span.subspan(padded_effective_block_begins_at).begin(), fill_len, 0); + + std::copy_n(msg0.subspan(msg_offset).begin(), absorbable_num_bytes, blk0_span.subspan(offset).begin()); + std::copy_n(msg1.subspan(msg_offset).begin(), absorbable_num_bytes, blk1_span.subspan(offset).begin()); + std::copy_n(msg2.subspan(msg_offset).begin(), absorbable_num_bytes, blk2_span.subspan(offset).begin()); + std::copy_n(msg3.subspan(msg_offset).begin(), absorbable_num_bytes, blk3_span.subspan(offset).begin()); + + size_t state_word_index = padded_effective_block_begins_at / KECCAK_WORD_BYTE_LEN; + for (size_t i = padded_effective_block_begins_at; i < padded_effective_block_byte_len; i += KECCAK_WORD_BYTE_LEN) { + const auto w0 = sha3_utils::le_bytes_to_u64(std::span(blk0_span.subspan(i, KECCAK_WORD_BYTE_LEN))); + const auto w1 = sha3_utils::le_bytes_to_u64(std::span(blk1_span.subspan(i, KECCAK_WORD_BYTE_LEN))); + const auto w2 = sha3_utils::le_bytes_to_u64(std::span(blk2_span.subspan(i, KECCAK_WORD_BYTE_LEN))); + const auto w3 = sha3_utils::le_bytes_to_u64(std::span(blk3_span.subspan(i, KECCAK_WORD_BYTE_LEN))); + + state[state_word_index] ^= vec::set(w0, w1, w2, w3); + state_word_index++; + } + + offset += absorbable_num_bytes; + msg_offset += absorbable_num_bytes; + + if (offset == num_bytes_in_rate) [[unlikely]] { + keccak_x4::permute(state); + offset = 0; + } + } +} + +/** + * Finalize 4 parallel Keccak[c] sponge states with domain separator and 10*1 padding. + * All 4 instances share the same offset (since they absorbed same-length messages). + */ +template +static forceinline void +finalize_x4(std::array& state, size_t& offset) + requires(ds_bit_len <= 6U) +{ + constexpr size_t num_bytes_in_rate = num_bits_in_rate / std::numeric_limits::digits; + constexpr size_t num_words_in_rate = num_bytes_in_rate / KECCAK_WORD_BYTE_LEN; + + const auto state_word_index = offset / KECCAK_WORD_BYTE_LEN; + const auto byte_index_in_state_word = offset % KECCAK_WORD_BYTE_LEN; + const auto shl_bit_offset = byte_index_in_state_word * std::numeric_limits::digits; + + constexpr uint8_t mask = (1U << ds_bit_len) - 1U; + constexpr uint8_t pad_byte = (1U << ds_bit_len) | (domain_separator & mask); + + state[state_word_index] ^= vec::broadcast(static_cast(pad_byte) << shl_bit_offset); + state[num_words_in_rate - 1] ^= vec::broadcast(UINT64_C(0x80) << 56); + + keccak_x4::permute(state); + offset = 0; +} + +/** + * Squeeze from 4 parallel Keccak[c] sponge states into 4 same-length output buffers. + * All 4 output buffers must have the same length. + */ +template +static forceinline void +squeeze_x4(std::array& state, + size_t& squeezable, + std::span out0, + std::span out1, + std::span out2, + std::span out3) +{ + constexpr size_t num_bytes_in_rate = num_bits_in_rate / std::numeric_limits::digits; + + std::array blk0{}; + std::array blk1{}; + std::array blk2{}; + std::array blk3{}; + + auto blk0_span = std::span(blk0); + auto blk1_span = std::span(blk1); + auto blk2_span = std::span(blk2); + auto blk3_span = std::span(blk3); + + const size_t olen = out0.size(); + size_t out_offset = 0; + + while (out_offset < olen) { + const size_t state_byte_offset = num_bytes_in_rate - squeezable; + const size_t remaining_num_bytes = olen - out_offset; + const size_t squeezable_num_bytes = std::min(remaining_num_bytes, squeezable); + const size_t effective_block_byte_len = state_byte_offset + squeezable_num_bytes; + const size_t padded_effective_block_byte_len = (effective_block_byte_len + (KECCAK_WORD_BYTE_LEN - 1)) & (-KECCAK_WORD_BYTE_LEN); + const size_t padded_effective_block_begins_at = state_byte_offset & (-KECCAK_WORD_BYTE_LEN); + + size_t state_word_index = padded_effective_block_begins_at / KECCAK_WORD_BYTE_LEN; + for (size_t i = padded_effective_block_begins_at; i < padded_effective_block_byte_len; i += KECCAK_WORD_BYTE_LEN) { + sha3_utils::u64_to_le_bytes(state[state_word_index].extract<0>(), std::span(blk0_span.subspan(i, KECCAK_WORD_BYTE_LEN))); + sha3_utils::u64_to_le_bytes(state[state_word_index].extract<1>(), std::span(blk1_span.subspan(i, KECCAK_WORD_BYTE_LEN))); + sha3_utils::u64_to_le_bytes(state[state_word_index].extract<2>(), std::span(blk2_span.subspan(i, KECCAK_WORD_BYTE_LEN))); + sha3_utils::u64_to_le_bytes(state[state_word_index].extract<3>(), std::span(blk3_span.subspan(i, KECCAK_WORD_BYTE_LEN))); + + state_word_index++; + } + + std::copy_n(blk0_span.subspan(state_byte_offset).begin(), squeezable_num_bytes, out0.subspan(out_offset).begin()); + std::copy_n(blk1_span.subspan(state_byte_offset).begin(), squeezable_num_bytes, out1.subspan(out_offset).begin()); + std::copy_n(blk2_span.subspan(state_byte_offset).begin(), squeezable_num_bytes, out2.subspan(out_offset).begin()); + std::copy_n(blk3_span.subspan(state_byte_offset).begin(), squeezable_num_bytes, out3.subspan(out_offset).begin()); + + squeezable -= squeezable_num_bytes; + out_offset += squeezable_num_bytes; + + if (squeezable == 0) [[unlikely]] { + keccak_x4::permute(state); + squeezable = num_bytes_in_rate; + } + } +} + +} + +#endif diff --git a/include/sha3/shake128_x4.hpp b/include/sha3/shake128_x4.hpp new file mode 100644 index 0000000..2953baf --- /dev/null +++ b/include/sha3/shake128_x4.hpp @@ -0,0 +1,81 @@ +#pragma once + +#if defined(__AVX2__) + +#include "sha3/internals/keccak.hpp" +#include "sha3/internals/simd/avx2.hpp" +#include "sha3/internals/sponge_x4.hpp" +#include +#include +#include +#include +#include +#include +#include + +// 4-way parallel SHAKE128 Extendable Output Function using AVX2. +namespace shake128_x4 { + +using vec = sha3_simd::avx2::vec; + +static constexpr size_t NUM_KECCAK_ROUNDS = keccak::MAX_NUM_ROUNDS; +static constexpr size_t TARGET_BIT_SECURITY_LEVEL = 128; +static constexpr size_t CAPACITY = 2 * TARGET_BIT_SECURITY_LEVEL; +static constexpr size_t RATE = 1600 - CAPACITY; +static constexpr uint8_t DOM_SEP = 0b00001111; +static constexpr size_t DOM_SEP_BW = std::bit_width(DOM_SEP); + +/** + * 4-way parallel SHAKE128 XOF using AVX2. + * + * Operates on 4 independent SHAKE128 instances simultaneously, sharing the same + * absorb/finalize/squeeze schedule. All 4 instances must absorb same-length messages + * and squeeze the same number of bytes. + */ +struct shake128_x4_t +{ +private: + std::array state{}; + size_t offset = 0; + alignas(4) bool finalized = false; + size_t squeezable = 0; + +public: + forceinline shake128_x4_t() = default; + + forceinline void absorb(std::span msg0, std::span msg1, std::span msg2, std::span msg3) + { + if (!finalized) { + sponge_x4::absorb_x4(state, offset, msg0, msg1, msg2, msg3); + } + } + + forceinline void finalize() + { + if (!finalized) { + sponge_x4::finalize_x4(state, offset); + + finalized = true; + squeezable = RATE / std::numeric_limits::digits; + } + } + + forceinline void squeeze(std::span out0, std::span out1, std::span out2, std::span out3) + { + if (finalized) { + sponge_x4::squeeze_x4(state, squeezable, out0, out1, out2, out3); + } + } + + forceinline void reset() + { + std::fill(state.begin(), state.end(), vec{}); + offset = 0; + finalized = false; + squeezable = 0; + } +}; + +} + +#endif diff --git a/include/sha3/shake256_x4.hpp b/include/sha3/shake256_x4.hpp new file mode 100644 index 0000000..8bc3bf4 --- /dev/null +++ b/include/sha3/shake256_x4.hpp @@ -0,0 +1,80 @@ +#pragma once + +#if defined(__AVX2__) + +#include "sha3/internals/keccak.hpp" +#include "sha3/internals/simd/avx2.hpp" +#include "sha3/internals/sponge_x4.hpp" +#include +#include +#include +#include +#include +#include +#include + +// 4-way parallel SHAKE256 Extendable Output Function using AVX2. +namespace shake256_x4 { + +using vec = sha3_simd::avx2::vec; + +static constexpr size_t NUM_KECCAK_ROUNDS = keccak::MAX_NUM_ROUNDS; +static constexpr size_t TARGET_BIT_SECURITY_LEVEL = 256; +static constexpr size_t CAPACITY = 2 * TARGET_BIT_SECURITY_LEVEL; +static constexpr size_t RATE = 1600 - CAPACITY; +static constexpr uint8_t DOM_SEP = 0b00001111; +static constexpr size_t DOM_SEP_BW = std::bit_width(DOM_SEP); + +/** + * 4-way parallel SHAKE256 XOF using AVX2. + * + * Operates on 4 independent SHAKE256 instances simultaneously, sharing the same + * absorb/finalize/squeeze schedule. All 4 instances must absorb same-length messages + * and squeeze the same number of bytes. + */ +struct shake256_x4_t +{ +private: + std::array state{}; + size_t offset = 0; + alignas(4) bool finalized = false; + size_t squeezable = 0; + +public: + forceinline shake256_x4_t() = default; + + forceinline void absorb(std::span msg0, std::span msg1, std::span msg2, std::span msg3) + { + if (!finalized) { + sponge_x4::absorb_x4(state, offset, msg0, msg1, msg2, msg3); + } + } + + forceinline void finalize() + { + if (!finalized) { + sponge_x4::finalize_x4(state, offset); + finalized = true; + squeezable = RATE / std::numeric_limits::digits; + } + } + + forceinline void squeeze(std::span out0, std::span out1, std::span out2, std::span out3) + { + if (finalized) { + sponge_x4::squeeze_x4(state, squeezable, out0, out1, out2, out3); + } + } + + forceinline void reset() + { + std::fill(state.begin(), state.end(), vec{}); + offset = 0; + finalized = false; + squeezable = 0; + } +}; + +} + +#endif diff --git a/tests/test_shake128_x4.cpp b/tests/test_shake128_x4.cpp new file mode 100644 index 0000000..d28869f --- /dev/null +++ b/tests/test_shake128_x4.cpp @@ -0,0 +1,190 @@ +#if defined(__AVX2__) + +#include "sha3/shake128.hpp" +#include "sha3/shake128_x4.hpp" +#include "test_conf.hpp" +#include "test_utils.hpp" +#include +#include +#include + +// Verify each lane of SHAKE128 x4 matches the scalar SHAKE128 output. +TEST(Sha3XOF, SHAKE128x4ParityWithScalar) +{ + for (size_t mlen = MIN_MSG_LEN; mlen < MAX_MSG_LEN; mlen++) { + for (size_t olen = MIN_OUT_LEN; olen < MAX_OUT_LEN; olen++) { + std::vector msg0(mlen); + std::vector msg1(mlen); + std::vector msg2(mlen); + std::vector msg3(mlen); + + sha3_test_utils::random_data(msg0); + sha3_test_utils::random_data(msg1); + sha3_test_utils::random_data(msg2); + sha3_test_utils::random_data(msg3); + + // x4 path + std::vector out0_x4(olen); + std::vector out1_x4(olen); + std::vector out2_x4(olen); + std::vector out3_x4(olen); + + { + shake128_x4::shake128_x4_t h; + + h.absorb(msg0, msg1, msg2, msg3); + h.finalize(); + h.squeeze(out0_x4, out1_x4, out2_x4, out3_x4); + } + + // scalar path + auto shake128x1 = [&](std::span msg, std::span out) { + shake128::shake128_t h; + + h.absorb(msg); + h.finalize(); + h.squeeze(out); + }; + + std::vector out0_sc(olen); + std::vector out1_sc(olen); + std::vector out2_sc(olen); + std::vector out3_sc(olen); + + shake128x1(msg0, out0_sc); + shake128x1(msg1, out1_sc); + shake128x1(msg2, out2_sc); + shake128x1(msg3, out3_sc); + + EXPECT_EQ(out0_x4, out0_sc); + EXPECT_EQ(out1_x4, out1_sc); + EXPECT_EQ(out2_x4, out2_sc); + EXPECT_EQ(out3_x4, out3_sc); + } + } +} + +/** + * Test that absorbing same message bytes using both incremental and one-shot hashing, + * should yield same output bytes, for SHAKE128 x4 XOF. + * + * Uses msg0's content to determine chunk sizes (same across all 4 lanes). + */ +TEST(Sha3XOF, SHAKE128x4IncrementalAbsorptionAndSqueezing) +{ + for (size_t mlen = MIN_MSG_LEN; mlen < MAX_MSG_LEN; mlen++) { + for (size_t olen = MIN_OUT_LEN; olen < MAX_OUT_LEN; olen++) { + std::vector msg0(mlen); + std::vector msg1(mlen); + std::vector msg2(mlen); + std::vector msg3(mlen); + + sha3_test_utils::random_data(msg0); + sha3_test_utils::random_data(msg1); + sha3_test_utils::random_data(msg2); + sha3_test_utils::random_data(msg3); + + auto msg0_span = std::span(msg0); + auto msg1_span = std::span(msg1); + auto msg2_span = std::span(msg2); + auto msg3_span = std::span(msg3); + + std::vector oneshot_out0(olen); + std::vector oneshot_out1(olen); + std::vector oneshot_out2(olen); + std::vector oneshot_out3(olen); + + std::vector multishot_out0(olen); + std::vector multishot_out1(olen); + std::vector multishot_out2(olen); + std::vector multishot_out3(olen); + + shake128_x4::shake128_x4_t hasher; + + // Oneshot absorption and squeezing + hasher.absorb(msg0_span, msg1_span, msg2_span, msg3_span); + hasher.finalize(); + hasher.squeeze(oneshot_out0, oneshot_out1, oneshot_out2, oneshot_out3); + + hasher.reset(); + + // Incremental absorption (chunk sizes driven by msg0 content) + size_t off = 0; + while (off < mlen) { + auto tmp = std::max(msg0[off], 1); + auto elen = std::min(tmp, mlen - off); + + hasher.absorb(msg0_span.subspan(off, elen), msg1_span.subspan(off, elen), msg2_span.subspan(off, elen), msg3_span.subspan(off, elen)); + off += elen; + } + + hasher.finalize(); + + // Incremental squeezing (chunk sizes driven by multishot_out0 content) + off = 0; + while (off < olen) { + auto s0 = std::span(multishot_out0).subspan(off, 1); + auto s1 = std::span(multishot_out1).subspan(off, 1); + auto s2 = std::span(multishot_out2).subspan(off, 1); + auto s3 = std::span(multishot_out3).subspan(off, 1); + hasher.squeeze(s0, s1, s2, s3); + + auto elen = std::min(multishot_out0[off], olen - (off + 1)); + + off += 1; + hasher.squeeze(std::span(multishot_out0).subspan(off, elen), + std::span(multishot_out1).subspan(off, elen), + std::span(multishot_out2).subspan(off, elen), + std::span(multishot_out3).subspan(off, elen)); + off += elen; + } + + EXPECT_EQ(oneshot_out0, multishot_out0); + EXPECT_EQ(oneshot_out1, multishot_out1); + EXPECT_EQ(oneshot_out2, multishot_out2); + EXPECT_EQ(oneshot_out3, multishot_out3); + } + } +} + +// All 4 lanes get identical input; verify all 4 outputs are identical. +TEST(Sha3XOF, SHAKE128x4AllSameInput) +{ + constexpr size_t MLEN = 200; + constexpr size_t OLEN = 256; + + std::vector msg(MLEN); + sha3_test_utils::random_data(msg); + + std::vector out0(OLEN); + std::vector out1(OLEN); + std::vector out2(OLEN); + std::vector out3(OLEN); + + { + shake128_x4::shake128_x4_t h; + + h.absorb(msg, msg, msg, msg); + h.finalize(); + h.squeeze(out0, out1, out2, out3); + } + + EXPECT_EQ(out0, out1); + EXPECT_EQ(out1, out2); + EXPECT_EQ(out2, out3); + + // Also verify against scalar + std::vector out_sc(OLEN); + + { + shake128::shake128_t h; + + h.absorb(msg); + h.finalize(); + h.squeeze(out_sc); + } + + EXPECT_EQ(out0, out_sc); +} + +#endif diff --git a/tests/test_shake256_x4.cpp b/tests/test_shake256_x4.cpp new file mode 100644 index 0000000..1901b3c --- /dev/null +++ b/tests/test_shake256_x4.cpp @@ -0,0 +1,190 @@ +#if defined(__AVX2__) + +#include "sha3/shake256.hpp" +#include "sha3/shake256_x4.hpp" +#include "test_conf.hpp" +#include "test_utils.hpp" +#include +#include +#include + +// Verify each lane of SHAKE256 x4 matches the scalar SHAKE256 output. +TEST(Sha3XOF, SHAKE256x4ParityWithScalar) +{ + for (size_t mlen = MIN_MSG_LEN; mlen < MAX_MSG_LEN; mlen++) { + for (size_t olen = MIN_OUT_LEN; olen < MAX_OUT_LEN; olen++) { + std::vector msg0(mlen); + std::vector msg1(mlen); + std::vector msg2(mlen); + std::vector msg3(mlen); + + sha3_test_utils::random_data(msg0); + sha3_test_utils::random_data(msg1); + sha3_test_utils::random_data(msg2); + sha3_test_utils::random_data(msg3); + + // x4 path + std::vector out0_x4(olen); + std::vector out1_x4(olen); + std::vector out2_x4(olen); + std::vector out3_x4(olen); + + { + shake256_x4::shake256_x4_t h; + + h.absorb(msg0, msg1, msg2, msg3); + h.finalize(); + h.squeeze(out0_x4, out1_x4, out2_x4, out3_x4); + } + + // scalar path + auto shake256x1 = [&](std::span msg, std::span out) { + shake256::shake256_t h; + + h.absorb(msg); + h.finalize(); + h.squeeze(out); + }; + + std::vector out0_sc(olen); + std::vector out1_sc(olen); + std::vector out2_sc(olen); + std::vector out3_sc(olen); + + shake256x1(msg0, out0_sc); + shake256x1(msg1, out1_sc); + shake256x1(msg2, out2_sc); + shake256x1(msg3, out3_sc); + + EXPECT_EQ(out0_x4, out0_sc); + EXPECT_EQ(out1_x4, out1_sc); + EXPECT_EQ(out2_x4, out2_sc); + EXPECT_EQ(out3_x4, out3_sc); + } + } +} + +/** + * Test that absorbing same message bytes using both incremental and one-shot hashing, + * should yield same output bytes, for SHAKE256 x4 XOF. + * + * Uses msg0's content to determine chunk sizes (same across all 4 lanes). + */ +TEST(Sha3XOF, SHAKE256x4IncrementalAbsorptionAndSqueezing) +{ + for (size_t mlen = MIN_MSG_LEN; mlen < MAX_MSG_LEN; mlen++) { + for (size_t olen = MIN_OUT_LEN; olen < MAX_OUT_LEN; olen++) { + std::vector msg0(mlen); + std::vector msg1(mlen); + std::vector msg2(mlen); + std::vector msg3(mlen); + + sha3_test_utils::random_data(msg0); + sha3_test_utils::random_data(msg1); + sha3_test_utils::random_data(msg2); + sha3_test_utils::random_data(msg3); + + auto msg0_span = std::span(msg0); + auto msg1_span = std::span(msg1); + auto msg2_span = std::span(msg2); + auto msg3_span = std::span(msg3); + + std::vector oneshot_out0(olen); + std::vector oneshot_out1(olen); + std::vector oneshot_out2(olen); + std::vector oneshot_out3(olen); + + std::vector multishot_out0(olen); + std::vector multishot_out1(olen); + std::vector multishot_out2(olen); + std::vector multishot_out3(olen); + + shake256_x4::shake256_x4_t hasher; + + // Oneshot absorption and squeezing + hasher.absorb(msg0_span, msg1_span, msg2_span, msg3_span); + hasher.finalize(); + hasher.squeeze(oneshot_out0, oneshot_out1, oneshot_out2, oneshot_out3); + + hasher.reset(); + + // Incremental absorption (chunk sizes driven by msg0 content) + size_t off = 0; + while (off < mlen) { + auto tmp = std::max(msg0[off], 1); + auto elen = std::min(tmp, mlen - off); + + hasher.absorb(msg0_span.subspan(off, elen), msg1_span.subspan(off, elen), msg2_span.subspan(off, elen), msg3_span.subspan(off, elen)); + off += elen; + } + + hasher.finalize(); + + // Incremental squeezing (chunk sizes driven by multishot_out0 content) + off = 0; + while (off < olen) { + auto s0 = std::span(multishot_out0).subspan(off, 1); + auto s1 = std::span(multishot_out1).subspan(off, 1); + auto s2 = std::span(multishot_out2).subspan(off, 1); + auto s3 = std::span(multishot_out3).subspan(off, 1); + hasher.squeeze(s0, s1, s2, s3); + + auto elen = std::min(multishot_out0[off], olen - (off + 1)); + + off += 1; + hasher.squeeze(std::span(multishot_out0).subspan(off, elen), + std::span(multishot_out1).subspan(off, elen), + std::span(multishot_out2).subspan(off, elen), + std::span(multishot_out3).subspan(off, elen)); + off += elen; + } + + EXPECT_EQ(oneshot_out0, multishot_out0); + EXPECT_EQ(oneshot_out1, multishot_out1); + EXPECT_EQ(oneshot_out2, multishot_out2); + EXPECT_EQ(oneshot_out3, multishot_out3); + } + } +} + +// All 4 lanes get identical input; verify all 4 outputs are identical. +TEST(Sha3XOF, SHAKE256x4AllSameInput) +{ + constexpr size_t MLEN = 200; + constexpr size_t OLEN = 256; + + std::vector msg(MLEN); + sha3_test_utils::random_data(msg); + + std::vector out0(OLEN); + std::vector out1(OLEN); + std::vector out2(OLEN); + std::vector out3(OLEN); + + { + shake256_x4::shake256_x4_t h; + + h.absorb(msg, msg, msg, msg); + h.finalize(); + h.squeeze(out0, out1, out2, out3); + } + + EXPECT_EQ(out0, out1); + EXPECT_EQ(out1, out2); + EXPECT_EQ(out2, out3); + + // Also verify against scalar + std::vector out_sc(OLEN); + + { + shake256::shake256_t h; + + h.absorb(msg); + h.finalize(); + h.squeeze(out_sc); + } + + EXPECT_EQ(out0, out_sc); +} + +#endif