diff --git a/src/validator/stealthAddressValidator/EllipticCurve.sol b/src/validator/stealthAddressValidator/EllipticCurve.sol new file mode 100644 index 00000000..94a8eedf --- /dev/null +++ b/src/validator/stealthAddressValidator/EllipticCurve.sol @@ -0,0 +1,328 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.0; + +/** + * @title Elliptic Curve Library + * @dev Library providing arithmetic operations over elliptic curves. + * This library does not check whether the inserted points belong to the curve + * `isOnCurve` function should be used by the library user to check the aforementioned statement. + * @author Witnet Foundation + */ +library EllipticCurve { + // Pre-computed constant for 2 ** 255 + uint256 private constant U255_MAX_PLUS_1 = + 57896044618658097711785492504343953926634992332820282019728792003956564819968; + + /// @dev Modular euclidean inverse of a number (mod p). + /// @param _x The number + /// @param _pp The modulus + /// @return q such that x*q = 1 (mod _pp) + function invMod(uint256 _x, uint256 _pp) internal pure returns (uint256) { + require(_x != 0 && _x != _pp && _pp != 0, "Invalid number"); + uint256 q = 0; + uint256 newT = 1; + uint256 r = _pp; + uint256 t; + while (_x != 0) { + t = r / _x; + (q, newT) = (newT, addmod(q, (_pp - mulmod(t, newT, _pp)), _pp)); + (r, _x) = (_x, r - t * _x); + } + + return q; + } + + /// @dev Modular exponentiation, b^e % _pp. + /// Source: https://github.com/androlo/standard-contracts/blob/master/contracts/src/crypto/ECCMath.sol + /// @param _base base + /// @param _exp exponent + /// @param _pp modulus + /// @return r such that r = b**e (mod _pp) + function expMod(uint256 _base, uint256 _exp, uint256 _pp) internal pure returns (uint256) { + require(_pp != 0, "EllipticCurve: modulus is zero"); + + if (_base == 0) return 0; + if (_exp == 0) return 1; + + uint256 r = 1; + uint256 bit = U255_MAX_PLUS_1; + assembly { + for {} gt(bit, 0) {} { + r := mulmod(mulmod(r, r, _pp), exp(_base, iszero(iszero(and(_exp, bit)))), _pp) + r := mulmod(mulmod(r, r, _pp), exp(_base, iszero(iszero(and(_exp, div(bit, 2))))), _pp) + r := mulmod(mulmod(r, r, _pp), exp(_base, iszero(iszero(and(_exp, div(bit, 4))))), _pp) + r := mulmod(mulmod(r, r, _pp), exp(_base, iszero(iszero(and(_exp, div(bit, 8))))), _pp) + bit := div(bit, 16) + } + } + + return r; + } + + /// @dev Converts a point (x, y, z) expressed in Jacobian coordinates to affine coordinates (x', y', 1). + /// @param _x coordinate x + /// @param _y coordinate y + /// @param _z coordinate z + /// @param _pp the modulus + /// @return (x', y') affine coordinates + function toAffine(uint256 _x, uint256 _y, uint256 _z, uint256 _pp) internal pure returns (uint256, uint256) { + uint256 zInv = invMod(_z, _pp); + uint256 zInv2 = mulmod(zInv, zInv, _pp); + uint256 x2 = mulmod(_x, zInv2, _pp); + uint256 y2 = mulmod(_y, mulmod(zInv, zInv2, _pp), _pp); + + return (x2, y2); + } + + /// @dev Derives the y coordinate from a compressed-format point x [[SEC-1]](https://www.secg.org/SEC1-Ver-1.0.pdf). + /// @param _prefix parity byte (0x02 even, 0x03 odd) + /// @param _x coordinate x + /// @param _aa constant of curve + /// @param _bb constant of curve + /// @param _pp the modulus + /// @return y coordinate y + function deriveY(uint8 _prefix, uint256 _x, uint256 _aa, uint256 _bb, uint256 _pp) + internal + pure + returns (uint256) + { + require(_prefix == 0x02 || _prefix == 0x03, "EllipticCurve:innvalid compressed EC point prefix"); + + // x^3 + ax + b + uint256 y2 = addmod(mulmod(_x, mulmod(_x, _x, _pp), _pp), addmod(mulmod(_x, _aa, _pp), _bb, _pp), _pp); + y2 = expMod(y2, (_pp + 1) / 4, _pp); + // uint256 cmp = yBit ^ y_ & 1; + uint256 y = (y2 + _prefix) % 2 == 0 ? y2 : _pp - y2; + + return y; + } + + /// @dev Check whether point (x,y) is on curve defined by a, b, and _pp. + /// @param _x coordinate x of P1 + /// @param _y coordinate y of P1 + /// @param _aa constant of curve + /// @param _bb constant of curve + /// @param _pp the modulus + /// @return true if x,y in the curve, false else + function isOnCurve(uint256 _x, uint256 _y, uint256 _aa, uint256 _bb, uint256 _pp) internal pure returns (bool) { + if (0 == _x || _x >= _pp || 0 == _y || _y >= _pp) { + return false; + } + // y^2 + uint256 lhs = mulmod(_y, _y, _pp); + // x^3 + uint256 rhs = mulmod(mulmod(_x, _x, _pp), _x, _pp); + if (_aa != 0) { + // x^3 + a*x + rhs = addmod(rhs, mulmod(_x, _aa, _pp), _pp); + } + if (_bb != 0) { + // x^3 + a*x + b + rhs = addmod(rhs, _bb, _pp); + } + + return lhs == rhs; + } + + /// @dev Calculate inverse (x, -y) of point (x, y). + /// @param _x coordinate x of P1 + /// @param _y coordinate y of P1 + /// @param _pp the modulus + /// @return (x, -y) + function ecInv(uint256 _x, uint256 _y, uint256 _pp) internal pure returns (uint256, uint256) { + return (_x, (_pp - _y) % _pp); + } + + /// @dev Add two points (x1, y1) and (x2, y2) in affine coordinates. + /// @param _x1 coordinate x of P1 + /// @param _y1 coordinate y of P1 + /// @param _x2 coordinate x of P2 + /// @param _y2 coordinate y of P2 + /// @param _aa constant of the curve + /// @param _pp the modulus + /// @return (qx, qy) = P1+P2 in affine coordinates + function ecAdd(uint256 _x1, uint256 _y1, uint256 _x2, uint256 _y2, uint256 _aa, uint256 _pp) + internal + pure + returns (uint256, uint256) + { + uint256 x = 0; + uint256 y = 0; + uint256 z = 0; + + // Double if x1==x2 else add + if (_x1 == _x2) { + // y1 = -y2 mod p + if (addmod(_y1, _y2, _pp) == 0) { + return (0, 0); + } else { + // P1 = P2 + (x, y, z) = jacDouble(_x1, _y1, 1, _aa, _pp); + } + } else { + (x, y, z) = jacAdd(_x1, _y1, 1, _x2, _y2, 1, _pp); + } + // Get back to affine + return toAffine(x, y, z, _pp); + } + + /// @dev Substract two points (x1, y1) and (x2, y2) in affine coordinates. + /// @param _x1 coordinate x of P1 + /// @param _y1 coordinate y of P1 + /// @param _x2 coordinate x of P2 + /// @param _y2 coordinate y of P2 + /// @param _aa constant of the curve + /// @param _pp the modulus + /// @return (qx, qy) = P1-P2 in affine coordinates + function ecSub(uint256 _x1, uint256 _y1, uint256 _x2, uint256 _y2, uint256 _aa, uint256 _pp) + internal + pure + returns (uint256, uint256) + { + // invert square + (uint256 x, uint256 y) = ecInv(_x2, _y2, _pp); + // P1-square + return ecAdd(_x1, _y1, x, y, _aa, _pp); + } + + /// @dev Multiply point (x1, y1, z1) times d in affine coordinates. + /// @param _k scalar to multiply + /// @param _x coordinate x of P1 + /// @param _y coordinate y of P1 + /// @param _aa constant of the curve + /// @param _pp the modulus + /// @return (qx, qy) = d*P in affine coordinates + function ecMul(uint256 _k, uint256 _x, uint256 _y, uint256 _aa, uint256 _pp) + internal + pure + returns (uint256, uint256) + { + // Jacobian multiplication + (uint256 x1, uint256 y1, uint256 z1) = jacMul(_k, _x, _y, 1, _aa, _pp); + // Get back to affine + return toAffine(x1, y1, z1, _pp); + } + + /// @dev Adds two points (x1, y1, z1) and (x2 y2, z2). + /// @param _x1 coordinate x of P1 + /// @param _y1 coordinate y of P1 + /// @param _z1 coordinate z of P1 + /// @param _x2 coordinate x of square + /// @param _y2 coordinate y of square + /// @param _z2 coordinate z of square + /// @param _pp the modulus + /// @return (qx, qy, qz) P1+square in Jacobian + function jacAdd(uint256 _x1, uint256 _y1, uint256 _z1, uint256 _x2, uint256 _y2, uint256 _z2, uint256 _pp) + internal + pure + returns (uint256, uint256, uint256) + { + if (_x1 == 0 && _y1 == 0) return (_x2, _y2, _z2); + if (_x2 == 0 && _y2 == 0) return (_x1, _y1, _z1); + + // We follow the equations described in https://pdfs.semanticscholar.org/5c64/29952e08025a9649c2b0ba32518e9a7fb5c2.pdf Section 5 + uint256[4] memory zs; // z1^2, z1^3, z2^2, z2^3 + zs[0] = mulmod(_z1, _z1, _pp); + zs[1] = mulmod(_z1, zs[0], _pp); + zs[2] = mulmod(_z2, _z2, _pp); + zs[3] = mulmod(_z2, zs[2], _pp); + + // u1, s1, u2, s2 + zs = [mulmod(_x1, zs[2], _pp), mulmod(_y1, zs[3], _pp), mulmod(_x2, zs[0], _pp), mulmod(_y2, zs[1], _pp)]; + + // In case of zs[0] == zs[2] && zs[1] == zs[3], double function should be used + require(zs[0] != zs[2] || zs[1] != zs[3], "Use jacDouble function instead"); + + uint256[4] memory hr; + //h + hr[0] = addmod(zs[2], _pp - zs[0], _pp); + //r + hr[1] = addmod(zs[3], _pp - zs[1], _pp); + //h^2 + hr[2] = mulmod(hr[0], hr[0], _pp); + // h^3 + hr[3] = mulmod(hr[2], hr[0], _pp); + // qx = -h^3 -2u1h^2+r^2 + uint256 qx = addmod(mulmod(hr[1], hr[1], _pp), _pp - hr[3], _pp); + qx = addmod(qx, _pp - mulmod(2, mulmod(zs[0], hr[2], _pp), _pp), _pp); + // qy = -s1*z1*h^3+r(u1*h^2 -x^3) + uint256 qy = mulmod(hr[1], addmod(mulmod(zs[0], hr[2], _pp), _pp - qx, _pp), _pp); + qy = addmod(qy, _pp - mulmod(zs[1], hr[3], _pp), _pp); + // qz = h*z1*z2 + uint256 qz = mulmod(hr[0], mulmod(_z1, _z2, _pp), _pp); + return (qx, qy, qz); + } + + /// @dev Doubles a points (x, y, z). + /// @param _x coordinate x of P1 + /// @param _y coordinate y of P1 + /// @param _z coordinate z of P1 + /// @param _aa the a scalar in the curve equation + /// @param _pp the modulus + /// @return (qx, qy, qz) 2P in Jacobian + function jacDouble(uint256 _x, uint256 _y, uint256 _z, uint256 _aa, uint256 _pp) + internal + pure + returns (uint256, uint256, uint256) + { + if (_z == 0) return (_x, _y, _z); + + // We follow the equations described in https://pdfs.semanticscholar.org/5c64/29952e08025a9649c2b0ba32518e9a7fb5c2.pdf Section 5 + // Note: there is a bug in the paper regarding the m parameter, M=3*(x1^2)+a*(z1^4) + // x, y, z at this point represent the squares of _x, _y, _z + uint256 x = mulmod(_x, _x, _pp); //x1^2 + uint256 y = mulmod(_y, _y, _pp); //y1^2 + uint256 z = mulmod(_z, _z, _pp); //z1^2 + + // s + uint256 s = mulmod(4, mulmod(_x, y, _pp), _pp); + // m + uint256 m = addmod(mulmod(3, x, _pp), mulmod(_aa, mulmod(z, z, _pp), _pp), _pp); + + // x, y, z at this point will be reassigned and rather represent qx, qy, qz from the paper + // This allows to reduce the gas cost and stack footprint of the algorithm + // qx + x = addmod(mulmod(m, m, _pp), _pp - addmod(s, s, _pp), _pp); + // qy = -8*y1^4 + M(S-T) + y = addmod(mulmod(m, addmod(s, _pp - x, _pp), _pp), _pp - mulmod(8, mulmod(y, y, _pp), _pp), _pp); + // qz = 2*y1*z1 + z = mulmod(2, mulmod(_y, _z, _pp), _pp); + + return (x, y, z); + } + + /// @dev Multiply point (x, y, z) times d. + /// @param _d scalar to multiply + /// @param _x coordinate x of P1 + /// @param _y coordinate y of P1 + /// @param _z coordinate z of P1 + /// @param _aa constant of curve + /// @param _pp the modulus + /// @return (qx, qy, qz) d*P1 in Jacobian + function jacMul(uint256 _d, uint256 _x, uint256 _y, uint256 _z, uint256 _aa, uint256 _pp) + internal + pure + returns (uint256, uint256, uint256) + { + // Early return in case that `_d == 0` + if (_d == 0) { + return (_x, _y, _z); + } + + uint256 remaining = _d; + uint256 qx = 0; + uint256 qy = 0; + uint256 qz = 1; + + // Double and add algorithm + while (remaining != 0) { + if ((remaining & 1) != 0) { + (qx, qy, qz) = jacAdd(qx, qy, qz, _x, _y, _z, _pp); + } + remaining = remaining / 2; + (_x, _y, _z) = jacDouble(_x, _y, _z, _aa, _pp); + } + return (qx, qy, qz); + } +} diff --git a/src/validator/stealthAddressValidator/StealthAddressValidator.sol b/src/validator/stealthAddressValidator/StealthAddressValidator.sol new file mode 100644 index 00000000..efca0763 --- /dev/null +++ b/src/validator/stealthAddressValidator/StealthAddressValidator.sol @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.0; + +import {UserOperation} from "I4337/interfaces/UserOperation.sol"; +import {ECDSA} from "solady/utils/ECDSA.sol"; +import {EIP712} from "solady/utils/EIP712.sol"; +import {StealthAggreagteSignature} from "./StealthAggreagteSignature.sol"; +import {IKernelValidator} from "../../interfaces/IKernelValidator.sol"; +import {ValidationData} from "../../common/Types.sol"; +import {SIG_VALIDATION_FAILED} from "../../common/Constants.sol"; + +/** + * @dev Storage structure for Stealth Address Registry Module. + * StealthPubkey, dhkey are used in aggregated signature. + * EphemeralPubkey is used to recover private key of stealth address. + */ +struct StealthAddressValidatorStorage { + uint256 stealthPubkey; + uint256 dhkey; + uint256 ephemeralPubkey; + address stealthAddress; + uint8 stealthPubkeyPrefix; + uint8 dhkeyPrefix; + uint8 ephemeralPrefix; +} + +/** + * @author Justin Zen - + * @title Stealth Address Validator for ZeroDev Kernel. + * @notice This validator uses the Stealth address to validate signatures. + */ +contract StealthAddressValidator is IKernelValidator, EIP712 { + /// @notice The type hash used for kernel user op validation + bytes32 constant USER_OP_TYPEHASH = + keccak256("AllowUserOp(address owner,address kernelWallet,bytes32 userOpHash)"); + /// @notice The type hash used for kernel signature validation + bytes32 constant SIGNATURE_TYPEHASH = + keccak256("KernelSignature(address owner,address kernelWallet,bytes32 hash)"); + + /// @notice Emitted when the stealth address of a kernel is changed. + event StealthAddressChanged( + address indexed kernel, address indexed oldStealthAddress, address indexed newStealthAddress + ); + + /* -------------------------------------------------------------------------- */ + /* Storage */ + /* -------------------------------------------------------------------------- */ + mapping(address => StealthAddressValidatorStorage) public stealthAddressValidatorStorage; + + /* -------------------------------------------------------------------------- */ + /* EIP-712 Methods */ + /* -------------------------------------------------------------------------- */ + + /// @dev Get the current name & version of the validator, used for the EIP-712 domain separator from Solady + function _domainNameAndVersion() internal pure override returns (string memory, string memory) { + return ("Kernel:StealthAddressValidator", "1.0.0"); + } + + /// @dev Tell to solady that the current name & version of the validator won't change, so no need to recompute the eip-712 domain separator + function _domainNameAndVersionMayChange() internal pure override returns (bool) { + return false; + } + + /// @dev Export the current domain seperator + function getDomainSeperator() public view returns (bytes32) { + return _domainSeparator(); + } + + /* -------------------------------------------------------------------------- */ + /* Kernel validator Methods */ + /* -------------------------------------------------------------------------- */ + + /// @dev Enable this validator for a given `kernel` (msg.sender) + function enable(bytes calldata _data) external payable override { + address stealthAddress = address(bytes20(_data[0:20])); + uint256 stealthAddressPubkey = uint256(bytes32(_data[20:52])); + uint256 stealthAddressDhkey = uint256(bytes32(_data[52:84])); + uint8 stealthAddressPubkeyPrefix = uint8(_data[84]); + uint8 stealthAddressDhkeyPrefix = uint8(_data[85]); + uint256 ephemeralPubkey = uint256(bytes32(_data[86:118])); + uint8 ephemeralPrefix = uint8(_data[118]); + + address oldStealthAddress = stealthAddressValidatorStorage[msg.sender].stealthAddress; + stealthAddressValidatorStorage[msg.sender] = StealthAddressValidatorStorage({ + stealthPubkey: stealthAddressPubkey, + dhkey: stealthAddressDhkey, + ephemeralPubkey: ephemeralPubkey, + stealthAddress: stealthAddress, + stealthPubkeyPrefix: stealthAddressPubkeyPrefix, + dhkeyPrefix: stealthAddressDhkeyPrefix, + ephemeralPrefix: ephemeralPrefix + }); + emit StealthAddressChanged(msg.sender, oldStealthAddress, stealthAddress); + } + + /// @dev Disable this validator for a given `kernel` (msg.sender) + function disable(bytes calldata) external payable override { + address stealthAddress; + delete stealthAddressValidatorStorage[msg.sender]; + emit StealthAddressChanged(msg.sender, stealthAddress, address(0)); + } + + /// @dev Validate a `_userOp` using a EIP-712 signature, signed by the owner of the kernel account who is the `_userOp` sender + function validateUserOp(UserOperation calldata _userOp, bytes32 _userOpHash, uint256) + external + payable + override + returns (ValidationData validationData) + { + bytes1 mode = _userOp.signature[0]; + StealthAddressValidatorStorage storage stealthData = stealthAddressValidatorStorage[_userOp.sender]; + address stealthAddress = stealthData.stealthAddress; + bytes32 typedDataHash = + _hashTypedData(keccak256(abi.encode(USER_OP_TYPEHASH, stealthAddress, _userOp.sender, _userOpHash))); + + // 0x00: signature from spending key + // 0x01: aggregated signature from owner and shared secret + if (mode == 0x00) { + return stealthAddress == ECDSA.recover(typedDataHash, _userOp.signature[1:]) + ? ValidationData.wrap(0) + : SIG_VALIDATION_FAILED; + } else if (mode == 0x01) { + return StealthAggreagteSignature.validateAggregatedSignature( + stealthData.stealthPubkey, + stealthData.dhkey, + stealthData.stealthPubkeyPrefix, + stealthData.dhkeyPrefix, + typedDataHash, + _userOp.signature[1:] + ) ? ValidationData.wrap(0) : SIG_VALIDATION_FAILED; + } else { + return SIG_VALIDATION_FAILED; + } + } + + /// @dev Validate a `_signature` of the `_hash` ofor the given `kernel` (msg.sender) + function validateSignature(bytes32 _hash, bytes calldata _signature) + external + view + override + returns (ValidationData validationData) + { + bytes1 mode = _signature[0]; + StealthAddressValidatorStorage storage stealthData = stealthAddressValidatorStorage[msg.sender]; + address stealthAddress = stealthData.stealthAddress; + bytes32 typedDataHash = + _hashTypedData(keccak256(abi.encode(SIGNATURE_TYPEHASH, stealthAddress, msg.sender, _hash))); + + // 0x00: signature from spending key + // 0x01: aggregated signature from owner and shared secret + if (mode == 0x00) { + return stealthAddress == ECDSA.recover(typedDataHash, _signature[1:]) + ? ValidationData.wrap(0) + : SIG_VALIDATION_FAILED; + } else if (mode == 0x01) { + return StealthAggreagteSignature.validateAggregatedSignature( + stealthData.stealthPubkey, + stealthData.dhkey, + stealthData.stealthPubkeyPrefix, + stealthData.dhkeyPrefix, + typedDataHash, + _signature[1:] + ) ? ValidationData.wrap(0) : SIG_VALIDATION_FAILED; + } else { + return SIG_VALIDATION_FAILED; + } + } + + /// @dev Check if the caller is a valid signer for this kernel account + function validCaller(address _caller, bytes calldata) external view override returns (bool) { + return stealthAddressValidatorStorage[msg.sender].stealthAddress == _caller; + } + + /* -------------------------------------------------------------------------- */ + /* Public view methods */ + /* -------------------------------------------------------------------------- */ + + /// @dev Get the owner of a given `kernel` + function getOwner(address _kernel) public view returns (address) { + return stealthAddressValidatorStorage[_kernel].stealthAddress; + } +} diff --git a/src/validator/stealthAddressValidator/StealthAggreagteSignature.sol b/src/validator/stealthAddressValidator/StealthAggreagteSignature.sol new file mode 100644 index 00000000..d184a5f7 --- /dev/null +++ b/src/validator/stealthAddressValidator/StealthAggreagteSignature.sol @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.0; + +import {EllipticCurve} from "./EllipticCurve.sol"; + +library StealthAggreagteSignature { + uint256 public constant GX = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798; + uint256 public constant GY = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8; + uint256 public constant AA = 0; + uint256 public constant BB = 7; + uint256 public constant PP = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F; + uint256 public constant N = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141; + + function validateAggregatedSignature( + uint256 _pubkey, + uint256 _dhkey, + uint8 _pubkeyPrefix, + uint8 _dhkeyPrefix, + bytes32 _message, + bytes calldata _signature + ) external pure returns (bool) { + uint256 aggh2; + uint256 aggpb; + uint256 aggdh; + + uint256 sigr = uint256(bytes32(_signature[0:32])); + uint256 sigs = uint256(bytes32(_signature[32:64])); + uint256 sinv = EllipticCurve.invMod(sigs, N); + uint256 num_message = uint256(_message); + + assembly { + aggh2 := mulmod(mulmod(sinv, num_message, N), num_message, N) + aggpb := mulmod(mulmod(sinv, sigr, N), num_message, N) + aggdh := mulmod(mulmod(sinv, sigr, N), sigr, N) + } + (uint256 p1x, uint256 p1y) = EllipticCurve.ecMul(aggh2, GX, GY, AA, PP); + uint256 pubY = EllipticCurve.deriveY(_pubkeyPrefix, _pubkey, AA, BB, PP); + uint256 pubdhY = EllipticCurve.deriveY(_dhkeyPrefix, _dhkey, AA, BB, PP); + + (uint256 p2x, uint256 p2y) = EllipticCurve.ecMul(aggpb, _pubkey, pubY, AA, PP); + (uint256 p3x, uint256 p3y) = EllipticCurve.ecMul(aggdh, _dhkey, pubdhY, AA, PP); + (uint256 aggp1x, uint256 aggp1y) = EllipticCurve.ecAdd(p1x, p1y, p2x, p2y, AA, PP); + (uint256 aggpx,) = EllipticCurve.ecAdd(aggp1x, aggp1y, p3x, p3y, AA, PP); + + return aggpx % N == sigr; + } +} diff --git a/test/foundry/validator/StealthAddressValidator.t.sol b/test/foundry/validator/StealthAddressValidator.t.sol new file mode 100644 index 00000000..ed612038 --- /dev/null +++ b/test/foundry/validator/StealthAddressValidator.t.sol @@ -0,0 +1,311 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.0; + +import {FixedPointMathLib} from "solady/utils/FixedPointMathLib.sol"; +import "src/Kernel.sol"; +import {EllipticCurve} from "src/validator/stealthAddressValidator/EllipticCurve.sol"; +import {IKernel} from "src/interfaces/IKernel.sol"; +import {StealthAddressValidator} from "src/validator/stealthAddressValidator/StealthAddressValidator.sol"; +// test utils +import {KernelTestBase} from "../KernelTestBase.sol"; +import {TestExecutor} from "../mock/TestExecutor.sol"; +import {TestValidator} from "../mock/TestValidator.sol"; +import "forge-std/Vm.sol"; + +struct StealthAddressKey { + address stealthAddress; + uint256 stealthPub; + uint256 dhPub; + uint8 stealthPrefix; + uint8 dhPrefix; + uint256 ephemeralPub; + uint8 ephemeralPrefix; + uint256 hashSecret; + uint256 stealthPrivate; +} + +contract StealthAddressValidatorTest is KernelTestBase { + StealthAddressValidator private stealthAddressValidator; + VmSafe.Wallet private wallet; + VmSafe.Wallet private ephemeralWallet; + uint256 private stealthPrivateKey; + + function setUp() public { + _initialize(); + wallet = vm.createWallet(uint256(keccak256(bytes("owner")))); + ephemeralWallet = vm.createWallet(uint256(keccak256(bytes("ephemeral")))); + + StealthAddressKey memory stealthAddressKey = getStealthAddress(wallet, ephemeralWallet); + owner = stealthAddressKey.stealthAddress; + ownerKey = stealthAddressKey.stealthPrivate; + stealthAddressValidator = new StealthAddressValidator(); + defaultValidator = stealthAddressValidator; + _setAddress(); + _setExecutionDetail(); + } + + function _setExecutionDetail() internal virtual override { + executionDetail.executor = address(new TestExecutor()); + executionSig = TestExecutor.doNothing.selector; + executionDetail.validator = new TestValidator(); + } + + function getEnableData() internal view virtual override returns (bytes memory) { + return ""; + } + + function getValidatorSignature(UserOperation memory) internal view virtual override returns (bytes memory) { + return ""; + } + + function getOwners() internal view override returns (address[] memory) { + address[] memory owners = new address[](1); + owners[0] = owner; + return owners; + } + + function getInitializeData() internal view override returns (bytes memory) { + StealthAddressKey memory stealthAddressKey = getStealthAddress(wallet, ephemeralWallet); + return abi.encodeWithSelector( + KernelStorage.initialize.selector, + defaultValidator, + abi.encodePacked( + stealthAddressKey.stealthAddress, + stealthAddressKey.stealthPub, + stealthAddressKey.dhPub, + stealthAddressKey.stealthPrefix, + stealthAddressKey.dhPrefix, + stealthAddressKey.ephemeralPub, + stealthAddressKey.ephemeralPrefix + ) + ); + } + + function signUserOp(UserOperation memory op) internal view override returns (bytes memory) { + StealthAddressKey memory stealthAddressKey = getStealthAddress(wallet, ephemeralWallet); + return abi.encodePacked( + bytes4(0x00000000), bytes1(0x00), _generateUserOpSignature(entryPoint, op, stealthAddressKey.stealthPrivate) + ); + } + + function getWrongSignature(UserOperation memory op) internal view override returns (bytes memory) { + StealthAddressKey memory stealthAddressKey = getStealthAddress(wallet, ephemeralWallet); + return abi.encodePacked( + bytes4(0x00000000), + bytes1(0x00), + _generateUserOpSignature(entryPoint, op, stealthAddressKey.stealthPrivate + 1) + ); + } + + function signHash(bytes32 _hash) internal view override returns (bytes memory) { + StealthAddressKey memory stealthAddressKey = getStealthAddress(wallet, ephemeralWallet); + return _generateHashSignature(_hash, address(kernel), stealthAddressKey.stealthPrivate); + } + + function getWrongSignature(bytes32 _hash) internal view override returns (bytes memory) { + StealthAddressKey memory stealthAddressKey = getStealthAddress(wallet, ephemeralWallet); + return _generateHashSignature(_hash, address(kernel), stealthAddressKey.stealthPrivate + 1); + } + + function test_default_validator_enable() external override { + StealthAddressKey memory stealthAddressKey = getStealthAddress(wallet, ephemeralWallet); + + UserOperation memory op = buildUserOperation( + abi.encodeWithSelector( + IKernel.execute.selector, + address(defaultValidator), + 0, + abi.encodeWithSelector( + StealthAddressValidator.enable.selector, + abi.encodePacked( + stealthAddressKey.stealthAddress, + stealthAddressKey.stealthPub, + stealthAddressKey.dhPub, + stealthAddressKey.stealthPrefix, + stealthAddressKey.dhPrefix, + stealthAddressKey.ephemeralPub, + stealthAddressKey.ephemeralPrefix + ) + ), + Operation.Call + ) + ); + performUserOperationWithSig(op); + address owner = stealthAddressValidator.getOwner(address(kernel)); + assertEq(owner, stealthAddressKey.stealthAddress, "owner should be stealthAddress"); + } + + function test_default_validator_disable() external override { + UserOperation memory op = buildUserOperation( + abi.encodeWithSelector( + IKernel.execute.selector, + address(defaultValidator), + 0, + abi.encodeWithSelector(StealthAddressValidator.disable.selector, ""), + Operation.Call + ) + ); + performUserOperationWithSig(op); + address owner = stealthAddressValidator.getOwner(address(kernel)); + assertEq(owner, address(0), "owner should be 0"); + } + + function test_stealth_validate_userop_aggsig() external { + UserOperation memory userOp = UserOperation({ + sender: address(kernel), + nonce: 0, + initCode: bytes(""), + callData: bytes(""), + callGasLimit: 1, + verificationGasLimit: 1, + preVerificationGas: 1, + maxFeePerGas: 1, + maxPriorityFeePerGas: 1, + paymasterAndData: bytes(""), + signature: bytes("") + }); + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + + // Get the validator domain separator + bytes32 domainSeparator = stealthAddressValidator.getDomainSeperator(); + bytes32 typedMsgHash = keccak256( + abi.encodePacked( + "\x19\x01", domainSeparator, keccak256(abi.encode(USER_OP_TYPEHASH, owner, address(kernel), userOpHash)) + ) + ); + bytes memory aggregatedSignature = getAggregatedSignature(typedMsgHash, wallet); + userOp.signature = aggregatedSignature; + + (,, address result) = parseValidationData(defaultValidator.validateUserOp(userOp, userOpHash, 0)); + assertEq(result, address(0)); + } + + function test_stealth_validate_sig_aggsig() external { + bytes32 message = bytes32(uint256(0x102030405060708090a)); + + // Get the validator domain separator + bytes32 domainSeparator = stealthAddressValidator.getDomainSeperator(); + bytes32 typedMsgHash = keccak256( + abi.encodePacked( + "\x19\x01", domainSeparator, keccak256(abi.encode(SIGNATURE_TYPEHASH, owner, address(kernel), message)) + ) + ); + bytes memory aggregatedSignature = getAggregatedSignature(typedMsgHash, wallet); + + vm.prank(address(kernel)); + (,, address result) = parseValidationData(defaultValidator.validateSignature(message, aggregatedSignature)); + assertEq(result, address(0)); + } + + /* -------------------------------------------------------------------------- */ + /* Helper methods */ + /* -------------------------------------------------------------------------- */ + + /// @notice The type hash used for kernel user op validation + bytes32 constant USER_OP_TYPEHASH = keccak256("AllowUserOp(address owner,address kernelWallet,bytes32 userOpHash)"); + + /// @dev Generate the signature for a user op + function _generateUserOpSignature(IEntryPoint _entryPoint, UserOperation memory _op, uint256 _privateKey) + internal + view + returns (bytes memory) + { + // Get the kernel private key owner address + address owner = vm.addr(_privateKey); + + // Get the user op hash + bytes32 userOpHash = _entryPoint.getUserOpHash(_op); + // Get the validator domain separator + bytes32 domainSeparator = stealthAddressValidator.getDomainSeperator(); + bytes32 typedMsgHash = keccak256( + abi.encodePacked( + "\x19\x01", domainSeparator, keccak256(abi.encode(USER_OP_TYPEHASH, owner, _op.sender, userOpHash)) + ) + ); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(_privateKey, typedMsgHash); + return abi.encodePacked(r, s, v); + } + + /// @notice The type hash used for kernel signature validation + bytes32 constant SIGNATURE_TYPEHASH = keccak256("KernelSignature(address owner,address kernelWallet,bytes32 hash)"); + + /// @dev Generate the signature for a given hash for a kernel account + function _generateHashSignature(bytes32 _hash, address _kernel, uint256 _privateKey) + internal + view + returns (bytes memory) + { + // Get the kernel private key owner address + address owner = vm.addr(_privateKey); + + // Get the validator domain separator + bytes32 domainSeparator = stealthAddressValidator.getDomainSeperator(); + bytes32 typedMsgHash = keccak256( + abi.encodePacked( + "\x19\x01", domainSeparator, keccak256(abi.encode(SIGNATURE_TYPEHASH, owner, _kernel, _hash)) + ) + ); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(_privateKey, typedMsgHash); + return abi.encodePacked(bytes1(0), r, s, v); + } + + /// @notice The parameter used in the elliptic curve + uint256 GX = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798; + uint256 GY = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8; + uint256 AA = 0; + uint256 PP = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F; + uint256 N = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141; + + /// @dev Generate stealth address + function getStealthAddress(VmSafe.Wallet memory _ownerWallet, VmSafe.Wallet memory _ephemeralWallet) + public + view + returns (StealthAddressKey memory) + { + (uint256 ephemeralPub, uint256 ephemeralPrefix) = + (_ephemeralWallet.publicKeyX, _ephemeralWallet.publicKeyY % 2 + 2); + + (uint256 sharedSecretX, uint256 sharedSecretY) = + EllipticCurve.ecMul(_ephemeralWallet.privateKey, _ownerWallet.publicKeyX, _ownerWallet.publicKeyY, AA, PP); + uint256 hashSecret = uint256(keccak256(abi.encode(sharedSecretX, sharedSecretY))); + (uint256 pubX, uint256 pubY) = EllipticCurve.ecMul(hashSecret, GX, GY, AA, PP); + uint256 stealthPrivate = _ownerWallet.privateKey + hashSecret % N; + (uint256 stealthPubX, uint256 stealthPubY) = + EllipticCurve.ecAdd(_ownerWallet.publicKeyX, _ownerWallet.publicKeyY, pubX, pubY, AA, PP); + address stealthAddress = address(uint160(uint256(keccak256(abi.encode(stealthPubX, stealthPubY))))); + (uint256 dhkx, uint256 dhky) = + EllipticCurve.ecMul(hashSecret, _ownerWallet.publicKeyX, _ownerWallet.publicKeyY, AA, PP); + return StealthAddressKey( + stealthAddress, + stealthPubX, + dhkx, + uint8(stealthPubY % 2 + 2), + uint8(dhky % 2 + 2), + ephemeralPub, + uint8(ephemeralPrefix), + hashSecret, + stealthPrivate + ); + } + + function getAggregatedSignature(bytes32 _hash, Vm.Wallet memory _wallet) internal view returns (bytes memory) { + StealthAddressKey memory stelathAddressKey = getStealthAddress(_wallet, ephemeralWallet); + (, bytes32 r, bytes32 s) = vm.sign(_wallet.privateKey, _hash); + uint256 numR = uint256(r); + uint256 numS = uint256(s); + + // aggregatedSig = numS * (stelathAddressKey.hashSecret * numR + typedMsgHash) + bytes32 aggregatedSig = bytes32( + FixedPointMathLib.rawMulMod( + FixedPointMathLib.rawAddMod( + FixedPointMathLib.rawMulMod(stelathAddressKey.hashSecret, numR, N), uint256(_hash), N + ), + numS, + N + ) + ); + + return abi.encodePacked(bytes1(uint8(1)), r, aggregatedSig); + } +}