From 68a5847281d591fe8e921d3952a455346f311eac Mon Sep 17 00:00:00 2001 From: adam Date: Wed, 5 Jun 2024 15:50:32 -0400 Subject: [PATCH 1/5] per validation hook data --- .solhint-test.json | 1 + src/account/AccountLoupe.sol | 3 +- src/account/AccountStorage.sol | 2 +- src/account/PluginManager2.sol | 32 +- src/account/UpgradeableModularAccount.sol | 114 ++++-- src/helpers/SparseCalldataSegmentLib.sol | 51 +++ src/interfaces/IValidation.sol | 1 + src/interfaces/IValidationHook.sol | 9 +- test/account/AccountReturnData.t.sol | 18 +- test/account/DefaultValidationTest.t.sol | 4 +- test/account/MultiValidation.t.sol | 36 +- test/account/PerHookData.t.sol | 361 ++++++++++++++++++ test/account/UpgradeableModularAccount.t.sol | 18 +- test/account/ValidationIntersection.t.sol | 20 +- test/libraries/SparseCalldataSegmentLib.t.sol | 111 ++++++ test/mocks/plugins/ComprehensivePlugin.sol | 6 +- .../plugins/MockAccessControlHookPlugin.sol | 78 ++++ test/mocks/plugins/ReturnDataPluginMocks.sol | 3 +- test/mocks/plugins/ValidationPluginMocks.sol | 6 +- test/utils/AccountTestBase.sol | 62 ++- 20 files changed, 836 insertions(+), 100 deletions(-) create mode 100644 src/helpers/SparseCalldataSegmentLib.sol create mode 100644 test/account/PerHookData.t.sol create mode 100644 test/libraries/SparseCalldataSegmentLib.t.sol create mode 100644 test/mocks/plugins/MockAccessControlHookPlugin.sol diff --git a/.solhint-test.json b/.solhint-test.json index cbd7bf02..fd2b1007 100644 --- a/.solhint-test.json +++ b/.solhint-test.json @@ -5,6 +5,7 @@ "immutable-vars-naming": ["error"], "no-unused-import": ["error"], "compiler-version": ["error", ">=0.8.19"], + "custom-errors": "off", "func-visibility": ["error", { "ignoreConstructors": true }], "max-line-length": ["error", 120], "max-states-count": ["warn", 30], diff --git a/src/account/AccountLoupe.sol b/src/account/AccountLoupe.sol index 9cde18b4..89ffb04b 100644 --- a/src/account/AccountLoupe.sol +++ b/src/account/AccountLoupe.sol @@ -58,8 +58,7 @@ abstract contract AccountLoupe is IAccountLoupe { override returns (FunctionReference[] memory preValidationHooks) { - preValidationHooks = - toFunctionReferenceArray(getAccountStorage().validationData[validationFunction].preValidationHooks); + preValidationHooks = getAccountStorage().validationData[validationFunction].preValidationHooks; } /// @inheritdoc IAccountLoupe diff --git a/src/account/AccountStorage.sol b/src/account/AccountStorage.sol index 0a992a33..ffdaff26 100644 --- a/src/account/AccountStorage.sol +++ b/src/account/AccountStorage.sol @@ -39,7 +39,7 @@ struct ValidationData { // Whether or not this validation is a signature validator. bool isSignatureValidation; // The pre validation hooks for this function selector. - EnumerableSet.Bytes32Set preValidationHooks; + FunctionReference[] preValidationHooks; } struct AccountStorage { diff --git a/src/account/PluginManager2.sol b/src/account/PluginManager2.sol index 9e73e306..0e860848 100644 --- a/src/account/PluginManager2.sol +++ b/src/account/PluginManager2.sol @@ -6,16 +6,20 @@ import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet import {IPlugin} from "../interfaces/IPlugin.sol"; import {FunctionReference} from "../interfaces/IPluginManager.sol"; import {FunctionReferenceLib} from "../helpers/FunctionReferenceLib.sol"; -import {AccountStorage, getAccountStorage, toSetValue, toFunctionReference} from "./AccountStorage.sol"; +import {AccountStorage, getAccountStorage, toSetValue} from "./AccountStorage.sol"; // Temporary additional functions for a user-controlled install flow for validation functions. abstract contract PluginManager2 { using EnumerableSet for EnumerableSet.Bytes32Set; + // Index marking the start of the data for the validation function. + uint8 internal constant _RESERVED_VALIDATION_DATA_INDEX = 255; + error DefaultValidationAlreadySet(FunctionReference validationFunction); error PreValidationAlreadySet(FunctionReference validationFunction, FunctionReference preValidationFunction); error ValidationAlreadySet(bytes4 selector, FunctionReference validationFunction); error ValidationNotSet(bytes4 selector, FunctionReference validationFunction); + error PreValidationHookLimitExceeded(); function _installValidation( FunctionReference validationFunction, @@ -36,19 +40,21 @@ abstract contract PluginManager2 { for (uint256 i = 0; i < preValidationFunctions.length; ++i) { FunctionReference preValidationFunction = preValidationFunctions[i]; - if ( - !_storage.validationData[validationFunction].preValidationHooks.add( - toSetValue(preValidationFunction) - ) - ) { - revert PreValidationAlreadySet(validationFunction, preValidationFunction); - } + _storage.validationData[validationFunction].preValidationHooks.push(preValidationFunction); if (initDatas[i].length > 0) { (address preValidationPlugin,) = FunctionReferenceLib.unpack(preValidationFunction); IPlugin(preValidationPlugin).onInstall(initDatas[i]); } } + + // Avoid collision between reserved index and actual indices + if ( + _storage.validationData[validationFunction].preValidationHooks.length + > _RESERVED_VALIDATION_DATA_INDEX + ) { + revert PreValidationHookLimitExceeded(); + } } if (isDefault) { @@ -85,16 +91,16 @@ abstract contract PluginManager2 { bytes[] memory preValidationHookUninstallDatas = abi.decode(preValidationHookUninstallData, (bytes[])); // Clear pre validation hooks - EnumerableSet.Bytes32Set storage preValidationHooks = + FunctionReference[] storage preValidationHooks = _storage.validationData[validationFunction].preValidationHooks; - while (preValidationHooks.length() > 0) { - FunctionReference preValidationFunction = toFunctionReference(preValidationHooks.at(0)); - preValidationHooks.remove(toSetValue(preValidationFunction)); - (address preValidationPlugin,) = FunctionReferenceLib.unpack(preValidationFunction); + for (uint256 i = 0; i < preValidationHooks.length; ++i) { + FunctionReference preValidationFunction = preValidationHooks[i]; if (preValidationHookUninstallDatas[0].length > 0) { + (address preValidationPlugin,) = FunctionReferenceLib.unpack(preValidationFunction); IPlugin(preValidationPlugin).onUninstall(preValidationHookUninstallDatas[0]); } } + delete _storage.validationData[validationFunction].preValidationHooks; // Because this function also calls `onUninstall`, and removes the default flag from validation, we must // assume these selectors passed in to be exhaustive. diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index 169b17a3..335cb762 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -10,6 +10,7 @@ import {IERC1271} from "@openzeppelin/contracts/interfaces/IERC1271.sol"; import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import {FunctionReferenceLib} from "../helpers/FunctionReferenceLib.sol"; +import {SparseCalldataSegmentLib} from "../helpers/SparseCalldataSegmentLib.sol"; import {_coalescePreValidation, _coalesceValidation} from "../helpers/ValidationDataHelpers.sol"; import {IPlugin, PluginManifest} from "../interfaces/IPlugin.sol"; import {IValidation} from "../interfaces/IValidation.sol"; @@ -19,14 +20,7 @@ import {FunctionReference, IPluginManager} from "../interfaces/IPluginManager.so import {IStandardExecutor, Call} from "../interfaces/IStandardExecutor.sol"; import {AccountExecutor} from "./AccountExecutor.sol"; import {AccountLoupe} from "./AccountLoupe.sol"; -import { - AccountStorage, - getAccountStorage, - SelectorData, - toSetValue, - toFunctionReference, - toExecutionHook -} from "./AccountStorage.sol"; +import {AccountStorage, getAccountStorage, SelectorData, toSetValue, toExecutionHook} from "./AccountStorage.sol"; import {AccountStorageInitializable} from "./AccountStorageInitializable.sol"; import {PluginManagerInternals} from "./PluginManagerInternals.sol"; import {PluginManager2} from "./PluginManager2.sol"; @@ -45,6 +39,7 @@ contract UpgradeableModularAccount is { using EnumerableSet for EnumerableSet.Bytes32Set; using FunctionReferenceLib for FunctionReference; + using SparseCalldataSegmentLib for bytes; struct PostExecToRun { bytes preExecHookReturnData; @@ -67,6 +62,7 @@ contract UpgradeableModularAccount is error ExecFromPluginNotPermitted(address plugin, bytes4 selector); error ExecFromPluginExternalNotPermitted(address plugin, address target, uint256 value, bytes data); error NativeTokenSpendingNotPermitted(address plugin); + error NonCanonicalEncoding(); error PostExecHookReverted(address plugin, uint8 functionId, bytes revertReason); error PreExecHookReverted(address plugin, uint8 functionId, bytes revertReason); error PreRuntimeValidationHookFailed(address plugin, uint8 functionId, bytes revertReason); @@ -77,6 +73,8 @@ contract UpgradeableModularAccount is error UnrecognizedFunction(bytes4 selector); error UserOpValidationFunctionMissing(bytes4 selector); error ValidationDoesNotApply(bytes4 selector, address plugin, uint8 functionId, bool isDefault); + error ValidationSignatureSegmentMissing(); + error SignatureSegmentOutOfOrder(); // Wraps execution of a native function with runtime validation and hooks // Used for upgradeTo, upgradeToAndCall, execute, executeBatch, installPlugin, uninstallPlugin @@ -350,38 +348,50 @@ contract UpgradeableModularAccount is _checkIfValidationApplies(selector, userOpValidationFunction, isDefaultValidation); - validationData = - _doUserOpValidation(selector, userOpValidationFunction, userOp, userOp.signature[22:], userOpHash); + validationData = _doUserOpValidation(userOpValidationFunction, userOp, userOp.signature[22:], userOpHash); } // To support gas estimation, we don't fail early when the failure is caused by a signature failure function _doUserOpValidation( - bytes4 selector, FunctionReference userOpValidationFunction, PackedUserOperation memory userOp, bytes calldata signature, bytes32 userOpHash - ) internal returns (uint256 validationData) { - userOp.signature = signature; + ) internal returns (uint256) { + // Set up the per-hook data tracking fields + bytes calldata signatureSegment; + (signatureSegment, signature) = signature.getNextSegment(); - if (userOpValidationFunction.isEmpty()) { - // If the validation function is empty, then the call cannot proceed. - revert UserOpValidationFunctionMissing(selector); - } - - uint256 currentValidationData; + uint256 validationData; // Do preUserOpValidation hooks - EnumerableSet.Bytes32Set storage preUserOpValidationHooks = + FunctionReference[] memory preUserOpValidationHooks = getAccountStorage().validationData[userOpValidationFunction].preValidationHooks; - uint256 preUserOpValidationHooksLength = preUserOpValidationHooks.length(); - for (uint256 i = 0; i < preUserOpValidationHooksLength; ++i) { - bytes32 key = preUserOpValidationHooks.at(i); - FunctionReference preUserOpValidationHook = toFunctionReference(key); + for (uint256 i = 0; i < preUserOpValidationHooks.length; ++i) { + // Load per-hook data, if any is present + // The segment index is the first byte of the signature + if (signatureSegment.getIndex() == i) { + // Use the current segment + userOp.signature = signatureSegment.getBody(); + + if (userOp.signature.length == 0) { + revert NonCanonicalEncoding(); + } + + // Load the next per-hook data segment + (signatureSegment, signature) = signature.getNextSegment(); + + if (signatureSegment.getIndex() <= i) { + revert SignatureSegmentOutOfOrder(); + } + } else { + userOp.signature = ""; + } - (address plugin, uint8 functionId) = preUserOpValidationHook.unpack(); - currentValidationData = IValidationHook(plugin).preUserOpValidationHook(functionId, userOp, userOpHash); + (address plugin, uint8 functionId) = preUserOpValidationHooks[i].unpack(); + uint256 currentValidationData = + IValidationHook(plugin).preUserOpValidationHook(functionId, userOp, userOpHash); if (uint160(currentValidationData) > 1) { // If the aggregator is not 0 or 1, it is an unexpected value @@ -392,16 +402,24 @@ contract UpgradeableModularAccount is // Run the user op validationFunction { + if (signatureSegment.getIndex() != _RESERVED_VALIDATION_DATA_INDEX) { + revert ValidationSignatureSegmentMissing(); + } + + userOp.signature = signatureSegment.getBody(); + (address plugin, uint8 functionId) = userOpValidationFunction.unpack(); - currentValidationData = IValidation(plugin).validateUserOp(functionId, userOp, userOpHash); + uint256 currentValidationData = IValidation(plugin).validateUserOp(functionId, userOp, userOpHash); - if (preUserOpValidationHooksLength != 0) { + if (preUserOpValidationHooks.length != 0) { // If we have other validation data we need to coalesce with validationData = _coalesceValidation(validationData, currentValidationData); } else { validationData = currentValidationData; } } + + return validationData; } function _doRuntimeValidation( @@ -409,18 +427,38 @@ contract UpgradeableModularAccount is bytes calldata callData, bytes calldata authorizationData ) internal { + // Set up the per-hook data tracking fields + bytes calldata authSegment; + (authSegment, authorizationData) = authorizationData.getNextSegment(); + // run all preRuntimeValidation hooks - EnumerableSet.Bytes32Set storage preRuntimeValidationHooks = + FunctionReference[] memory preRuntimeValidationHooks = getAccountStorage().validationData[runtimeValidationFunction].preValidationHooks; - uint256 preRuntimeValidationHooksLength = preRuntimeValidationHooks.length(); - for (uint256 i = 0; i < preRuntimeValidationHooksLength; ++i) { - bytes32 key = preRuntimeValidationHooks.at(i); - FunctionReference preRuntimeValidationHook = toFunctionReference(key); + for (uint256 i = 0; i < preRuntimeValidationHooks.length; ++i) { + bytes memory currentAuthData; + + if (authSegment.getIndex() == i) { + // Use the current segment + currentAuthData = authSegment.getBody(); + + if (currentAuthData.length == 0) { + revert NonCanonicalEncoding(); + } + + // Load the next per-hook data segment + (authSegment, authorizationData) = authorizationData.getNextSegment(); - (address hookPlugin, uint8 hookFunctionId) = preRuntimeValidationHook.unpack(); + if (authSegment.getIndex() <= i) { + revert SignatureSegmentOutOfOrder(); + } + } else { + currentAuthData = ""; + } + + (address hookPlugin, uint8 hookFunctionId) = preRuntimeValidationHooks[i].unpack(); try IValidationHook(hookPlugin).preRuntimeValidationHook( - hookFunctionId, msg.sender, msg.value, callData + hookFunctionId, msg.sender, msg.value, callData, currentAuthData ) // forgefmt: disable-start // solhint-disable-next-line no-empty-blocks @@ -430,9 +468,13 @@ contract UpgradeableModularAccount is } } + if (authSegment.getIndex() != _RESERVED_VALIDATION_DATA_INDEX) { + revert ValidationSignatureSegmentMissing(); + } + (address plugin, uint8 functionId) = runtimeValidationFunction.unpack(); - try IValidation(plugin).validateRuntime(functionId, msg.sender, msg.value, callData, authorizationData) + try IValidation(plugin).validateRuntime(functionId, msg.sender, msg.value, callData, authSegment.getBody()) // forgefmt: disable-start // solhint-disable-next-line no-empty-blocks {} catch (bytes memory revertReason) { diff --git a/src/helpers/SparseCalldataSegmentLib.sol b/src/helpers/SparseCalldataSegmentLib.sol new file mode 100644 index 00000000..0a6cc541 --- /dev/null +++ b/src/helpers/SparseCalldataSegmentLib.sol @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: GPL-3.0 +pragma solidity ^0.8.25; + +/// @title Sparse Calldata Segment Library +/// @notice Library for working with sparsely-packed calldata segments, identified with an index. +/// @dev The first byte of each segment is the index of the segment. +/// To prevent accidental stack-to-deep errors, the body and index of the segment are extracted separately, rather +/// than inline as part of the tuple returned by `getNextSegment`. +library SparseCalldataSegmentLib { + /// @notice Splits out a segment of calldata, sparsely-packed. + /// The expected format is: + /// [uint32(len(segment0)), segment0, uint32(len(segment1)), segment1, ... uint32(len(segmentN)), segmentN] + /// @param source The calldata to extract the segment from. + /// @return segment The extracted segment. Using the above example, this would be segment0. + /// @return remainder The remaining calldata. Using the above example, + /// this would start at uint32(len(segment1)) and continue to the end at segmentN. + function getNextSegment(bytes calldata source) + internal + pure + returns (bytes calldata segment, bytes calldata remainder) + { + // The first 4 bytes hold the length of the segment, excluding the index. + uint32 length = uint32(bytes4(source[:4])); + + // The offset of the remainder of the calldata. + uint256 remainderOffset = 4 + length; + + // The segment is the next `length` + 1 bytes, to account for the index. + // By convention, the first byte of each segment is the index of the segment. + segment = source[4:remainderOffset]; + + // The remainder is the rest of the calldata. + remainder = source[remainderOffset:]; + } + + /// @notice Extracts the index from a segment. + /// @dev The first byte of the segment is the index. + /// @param segment The segment to extract the index from + /// @return The index of the segment + function getIndex(bytes calldata segment) internal pure returns (uint8) { + return uint8(segment[0]); + } + + /// @notice Extracts the body from a segment. + /// @dev The body is the segment without the index. + /// @param segment The segment to extract the body from + /// @return The body of the segment. + function getBody(bytes calldata segment) internal pure returns (bytes calldata) { + return segment[1:]; + } +} diff --git a/src/interfaces/IValidation.sol b/src/interfaces/IValidation.sol index b3adcd3d..38c8a139 100644 --- a/src/interfaces/IValidation.sol +++ b/src/interfaces/IValidation.sol @@ -23,6 +23,7 @@ interface IValidation is IPlugin { /// @param sender The caller address. /// @param value The call value. /// @param data The calldata sent. + /// @param authorization Additional data for the validation function to use. function validateRuntime( uint8 functionId, address sender, diff --git a/src/interfaces/IValidationHook.sol b/src/interfaces/IValidationHook.sol index 8eb7a61d..8300bbb8 100644 --- a/src/interfaces/IValidationHook.sol +++ b/src/interfaces/IValidationHook.sol @@ -24,8 +24,13 @@ interface IValidationHook is IPlugin { /// @param sender The caller address. /// @param value The call value. /// @param data The calldata sent. - function preRuntimeValidationHook(uint8 functionId, address sender, uint256 value, bytes calldata data) - external; + function preRuntimeValidationHook( + uint8 functionId, + address sender, + uint256 value, + bytes calldata data, + bytes calldata authorization + ) external; // TODO: support this hook type within the account & in the manifest diff --git a/test/account/AccountReturnData.t.sol b/test/account/AccountReturnData.t.sol index 8e8f3215..fc9fd615 100644 --- a/test/account/AccountReturnData.t.sol +++ b/test/account/AccountReturnData.t.sol @@ -1,7 +1,7 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.19; -import {FunctionReference} from "../../src/helpers/FunctionReferenceLib.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; import {Call} from "../../src/interfaces/IStandardExecutor.sol"; import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol"; @@ -59,8 +59,12 @@ contract AccountReturnDataTest is AccountTestBase { account1.execute, (address(regularResultContract), 0, abi.encodeCall(RegularResultContract.foo, ())) ), - abi.encodePacked( - singleOwnerPlugin, ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER, SELECTOR_ASSOCIATED_VALIDATION + _encodeSignature( + FunctionReferenceLib.pack( + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ), + SELECTOR_ASSOCIATED_VALIDATION, + "" ) ); @@ -85,8 +89,12 @@ contract AccountReturnDataTest is AccountTestBase { bytes memory retData = account1.executeWithAuthorization( abi.encodeCall(account1.executeBatch, (calls)), - abi.encodePacked( - singleOwnerPlugin, ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER, SELECTOR_ASSOCIATED_VALIDATION + _encodeSignature( + FunctionReferenceLib.pack( + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ), + SELECTOR_ASSOCIATED_VALIDATION, + "" ) ); diff --git a/test/account/DefaultValidationTest.t.sol b/test/account/DefaultValidationTest.t.sol index fc93060d..c2f118de 100644 --- a/test/account/DefaultValidationTest.t.sol +++ b/test/account/DefaultValidationTest.t.sol @@ -57,7 +57,7 @@ contract DefaultValidationTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(ownerValidation, DEFAULT_VALIDATION, r, s, v); + userOp.signature = _encodeSignature(ownerValidation, DEFAULT_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -74,7 +74,7 @@ contract DefaultValidationTest is AccountTestBase { vm.prank(owner1); account1.executeWithAuthorization( abi.encodeCall(UpgradeableModularAccount.execute, (ethRecipient, 1 wei, "")), - abi.encodePacked(ownerValidation, DEFAULT_VALIDATION) + _encodeSignature(ownerValidation, DEFAULT_VALIDATION, "") ); assertEq(ethRecipient.balance, 2 wei); diff --git a/test/account/MultiValidation.t.sol b/test/account/MultiValidation.t.sol index 9b22f5a0..e80d022c 100644 --- a/test/account/MultiValidation.t.sol +++ b/test/account/MultiValidation.t.sol @@ -67,20 +67,24 @@ contract MultiValidationTest is AccountTestBase { ); account1.executeWithAuthorization( abi.encodeCall(IStandardExecutor.execute, (address(0), 0, "")), - abi.encodePacked( - address(validator2), - uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER), - SELECTOR_ASSOCIATED_VALIDATION + _encodeSignature( + FunctionReferenceLib.pack( + address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ), + SELECTOR_ASSOCIATED_VALIDATION, + "" ) ); vm.prank(owner2); account1.executeWithAuthorization( abi.encodeCall(IStandardExecutor.execute, (address(0), 0, "")), - abi.encodePacked( - address(validator2), - uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER), - SELECTOR_ASSOCIATED_VALIDATION + _encodeSignature( + FunctionReferenceLib.pack( + address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ), + SELECTOR_ASSOCIATED_VALIDATION, + "" ) ); } @@ -105,13 +109,10 @@ contract MultiValidationTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked( - address(validator2), + userOp.signature = _encodeSignature( + FunctionReferenceLib.pack(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER)), SELECTOR_ASSOCIATED_VALIDATION, - uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER), - r, - s, - v + abi.encodePacked(r, s, v) ); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); @@ -123,8 +124,11 @@ contract MultiValidationTest is AccountTestBase { userOp.nonce = 1; (v, r, s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = - abi.encodePacked(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER), r, s, v); + userOp.signature = _encodeSignature( + FunctionReferenceLib.pack(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER)), + SELECTOR_ASSOCIATED_VALIDATION, + abi.encodePacked(r, s, v) + ); userOps[0] = userOp; vm.expectRevert(abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA24 signature error")); diff --git a/test/account/PerHookData.t.sol b/test/account/PerHookData.t.sol new file mode 100644 index 00000000..77432854 --- /dev/null +++ b/test/account/PerHookData.t.sol @@ -0,0 +1,361 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; +import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol"; +import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; +import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; + +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; + +import {MockAccessControlHookPlugin} from "../mocks/plugins/MockAccessControlHookPlugin.sol"; +import {Counter} from "../mocks/Counter.sol"; +import {AccountTestBase} from "../utils/AccountTestBase.sol"; + +contract PerHookDataTest is AccountTestBase { + using MessageHashUtils for bytes32; + + MockAccessControlHookPlugin internal _accessControlHookPlugin; + + Counter internal _counter; + + FunctionReference internal _ownerValidation; + + uint256 public constant CALL_GAS_LIMIT = 50000; + uint256 public constant VERIFICATION_GAS_LIMIT = 1200000; + + function setUp() public { + _counter = new Counter(); + + _accessControlHookPlugin = new MockAccessControlHookPlugin(); + + // Write over `account1` with a new account proxy, with different initialization. + + address accountImplementation = address(factory.accountImplementation()); + + account1 = UpgradeableModularAccount(payable(new ERC1967Proxy(accountImplementation, ""))); + + _ownerValidation = FunctionReferenceLib.pack( + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ); + + FunctionReference accessControlHook = FunctionReferenceLib.pack( + address(_accessControlHookPlugin), uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK) + ); + + FunctionReference[] memory preValidationHooks = new FunctionReference[](1); + preValidationHooks[0] = accessControlHook; + + bytes[] memory preValidationHookData = new bytes[](1); + // Access control is restricted to only the _counter + preValidationHookData[0] = abi.encode(_counter); + + bytes memory packedPreValidationHooks = abi.encode(preValidationHooks, preValidationHookData); + + vm.prank(address(entryPoint)); + account1.installValidation( + _ownerValidation, true, new bytes4[](0), abi.encode(owner1), packedPreValidationHooks + ); + + vm.deal(address(account1), 100 ether); + } + + function test_passAccessControl_userOp() public { + assertEq(_counter.number(), 0); + + (PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP(); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)}); + + userOp.signature = _encodeSignature( + _ownerValidation, DEFAULT_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v) + ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + + assertEq(_counter.number(), 1); + } + + function test_failAccessControl_badSigData_userOp() public { + (PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP(); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({ + index: 0, + validationData: abi.encodePacked(address(0x1234123412341234123412341234123412341234)) + }); + + userOp.signature = _encodeSignature( + _ownerValidation, DEFAULT_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v) + ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectRevert( + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSignature("Error(string)", "Proof doesn't match target") + ) + ); + entryPoint.handleOps(userOps, beneficiary); + } + + function test_failAccessControl_noSigData_userOp() public { + (PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP(); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + userOp.signature = _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, abi.encodePacked(r, s, v)); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectRevert( + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSignature("Error(string)", "Proof doesn't match target") + ) + ); + entryPoint.handleOps(userOps, beneficiary); + } + + function test_failAccessControl_badIndexProvided_userOp() public { + (PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP(); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](2); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)}); + preValidationHookData[1] = PreValidationHookData({index: 1, validationData: abi.encodePacked(_counter)}); + + userOp.signature = _encodeSignature( + _ownerValidation, DEFAULT_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v) + ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectRevert( + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector(UpgradeableModularAccount.ValidationSignatureSegmentMissing.selector) + ) + ); + entryPoint.handleOps(userOps, beneficiary); + } + + // todo: index out of order failure case with 2 pre hooks + + function test_failAccessControl_badTarget_userOp() public { + PackedUserOperation memory userOp = PackedUserOperation({ + sender: address(account1), + nonce: 0, + initCode: "", + callData: abi.encodeCall(UpgradeableModularAccount.execute, (beneficiary, 1 wei, "")), + accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT), + preVerificationGas: 0, + gasFees: _encodeGas(1, 1), + paymasterAndData: "", + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(beneficiary)}); + + userOp.signature = _encodeSignature( + _ownerValidation, DEFAULT_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v) + ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectRevert( + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSignature("Error(string)", "Target not allowed") + ) + ); + entryPoint.handleOps(userOps, beneficiary); + } + + function test_failPerHookData_nonCanonicalEncoding_userOp() public { + (PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP(); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: ""}); + + userOp.signature = _encodeSignature( + _ownerValidation, DEFAULT_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v) + ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectRevert( + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector(UpgradeableModularAccount.NonCanonicalEncoding.selector) + ) + ); + entryPoint.handleOps(userOps, beneficiary); + } + + function test_passAccessControl_runtime() public { + assertEq(_counter.number(), 0); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)}); + + vm.prank(owner1); + account1.executeWithAuthorization( + abi.encodeCall( + UpgradeableModularAccount.execute, + (address(_counter), 0 wei, abi.encodeCall(Counter.increment, ())) + ), + _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, preValidationHookData, "") + ); + + assertEq(_counter.number(), 1); + } + + function test_failAccessControl_badSigData_runtime() public { + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({ + index: 0, + validationData: abi.encodePacked(address(0x1234123412341234123412341234123412341234)) + }); + + vm.prank(owner1); + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector, + _accessControlHookPlugin, + uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK), + abi.encodeWithSignature("Error(string)", "Proof doesn't match target") + ) + ); + account1.executeWithAuthorization( + abi.encodeCall( + UpgradeableModularAccount.execute, + (address(_counter), 0 wei, abi.encodeCall(Counter.increment, ())) + ), + _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, preValidationHookData, "") + ); + } + + function test_failAccessControl_noSigData_runtime() public { + vm.prank(owner1); + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector, + _accessControlHookPlugin, + uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK), + abi.encodeWithSignature("Error(string)", "Proof doesn't match target") + ) + ); + account1.executeWithAuthorization( + abi.encodeCall( + UpgradeableModularAccount.execute, + (address(_counter), 0 wei, abi.encodeCall(Counter.increment, ())) + ), + _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, "") + ); + } + + function test_failAccessControl_badIndexProvided_runtime() public { + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](2); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)}); + preValidationHookData[1] = PreValidationHookData({index: 1, validationData: abi.encodePacked(_counter)}); + + vm.prank(owner1); + vm.expectRevert( + abi.encodeWithSelector(UpgradeableModularAccount.ValidationSignatureSegmentMissing.selector) + ); + account1.executeWithAuthorization( + abi.encodeCall( + UpgradeableModularAccount.execute, + (address(_counter), 0 wei, abi.encodeCall(Counter.increment, ())) + ), + _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, preValidationHookData, "") + ); + } + + //todo: index out of order failure case with 2 pre hooks + + function test_failAccessControl_badTarget_runtime() public { + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(beneficiary)}); + + vm.prank(owner1); + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector, + _accessControlHookPlugin, + uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK), + abi.encodeWithSignature("Error(string)", "Target not allowed") + ) + ); + account1.executeWithAuthorization( + abi.encodeCall(UpgradeableModularAccount.execute, (beneficiary, 1 wei, "")), + _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, preValidationHookData, "") + ); + } + + function test_failPerHookData_nonCanonicalEncoding_runtime() public { + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: ""}); + + vm.prank(owner1); + vm.expectRevert(abi.encodeWithSelector(UpgradeableModularAccount.NonCanonicalEncoding.selector)); + account1.executeWithAuthorization( + abi.encodeCall( + UpgradeableModularAccount.execute, + (address(_counter), 0 wei, abi.encodeCall(Counter.increment, ())) + ), + _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, preValidationHookData, "") + ); + } + + function _getCounterUserOP() internal view returns (PackedUserOperation memory, bytes32) { + PackedUserOperation memory userOp = PackedUserOperation({ + sender: address(account1), + nonce: 0, + initCode: "", + callData: abi.encodeCall( + UpgradeableModularAccount.execute, (address(_counter), 0 wei, abi.encodeCall(Counter.increment, ())) + ), + accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT), + preVerificationGas: 0, + gasFees: _encodeGas(1, 1), + paymasterAndData: "", + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + + return (userOp, userOpHash); + } +} diff --git a/test/account/UpgradeableModularAccount.t.sol b/test/account/UpgradeableModularAccount.t.sol index e3d09e7f..2484b933 100644 --- a/test/account/UpgradeableModularAccount.t.sol +++ b/test/account/UpgradeableModularAccount.t.sol @@ -87,7 +87,8 @@ contract UpgradeableModularAccountTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, r, s, v); + userOp.signature = + _encodeSignature(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -116,7 +117,8 @@ contract UpgradeableModularAccountTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, r, s, v); + userOp.signature = + _encodeSignature(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -142,7 +144,8 @@ contract UpgradeableModularAccountTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, r, s, v); + userOp.signature = + _encodeSignature(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -168,7 +171,8 @@ contract UpgradeableModularAccountTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, r, s, v); + userOp.signature = + _encodeSignature(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -196,7 +200,8 @@ contract UpgradeableModularAccountTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, r, s, v); + userOp.signature = + _encodeSignature(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -227,7 +232,8 @@ contract UpgradeableModularAccountTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, r, s, v); + userOp.signature = + _encodeSignature(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; diff --git a/test/account/ValidationIntersection.t.sol b/test/account/ValidationIntersection.t.sol index 877f8ff9..07c91948 100644 --- a/test/account/ValidationIntersection.t.sol +++ b/test/account/ValidationIntersection.t.sol @@ -97,7 +97,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(noHookPlugin.foo.selector); - userOp.signature = abi.encodePacked(noHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(noHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -114,7 +114,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); - userOp.signature = abi.encodePacked(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -132,7 +132,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); - userOp.signature = abi.encodePacked(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -155,7 +155,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); - userOp.signature = abi.encodePacked(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -177,7 +177,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); - userOp.signature = abi.encodePacked(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -197,7 +197,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); - userOp.signature = abi.encodePacked(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -222,7 +222,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); - userOp.signature = abi.encodePacked(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -246,7 +246,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); - userOp.signature = abi.encodePacked(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -270,7 +270,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(twoHookPlugin.baz.selector); - userOp.signature = abi.encodePacked(twoHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(twoHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -289,7 +289,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(twoHookPlugin.baz.selector); - userOp.signature = abi.encodePacked(twoHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(twoHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); diff --git a/test/libraries/SparseCalldataSegmentLib.t.sol b/test/libraries/SparseCalldataSegmentLib.t.sol new file mode 100644 index 00000000..7edea4e4 --- /dev/null +++ b/test/libraries/SparseCalldataSegmentLib.t.sol @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {Test} from "forge-std/Test.sol"; + +import {SparseCalldataSegmentLib} from "../../src/helpers/SparseCalldataSegmentLib.sol"; + +contract SparseCalldataSegmentLibTest is Test { + using SparseCalldataSegmentLib for bytes; + + function testFuzz_sparseCalldataSegmentLib_encodeDecode_simple(bytes[] memory segments) public { + bytes memory encoded = _encodeSimple(segments); + bytes[] memory decoded = this.decodeSimple(encoded, segments.length); + + assertEq(decoded.length, segments.length, "decoded.length != segments.length"); + + for (uint256 i = 0; i < segments.length; i++) { + assertEq(decoded[i], segments[i]); + } + } + + function testFuzz_sparseCalldataSegmentLib_encodeDecode_withIndex(bytes[] memory segments, uint256 indexSeed) + public + { + // Generate random indices + uint8[] memory indices = new uint8[](segments.length); + for (uint256 i = 0; i < segments.length; i++) { + uint8 nextIndex = uint8(uint256(keccak256(abi.encodePacked(indexSeed, i)))); + indices[i] = nextIndex; + } + + // Encode + bytes memory encoded = _encodeWithIndex(segments, indices); + + // Decode + (bytes[] memory decodedBodies, uint8[] memory decodedIndices) = + this.decodeWithIndex(encoded, segments.length); + + assertEq(decodedBodies.length, segments.length, "decodedBodies.length != segments.length"); + assertEq(decodedIndices.length, segments.length, "decodedIndices.length != segments.length"); + + for (uint256 i = 0; i < segments.length; i++) { + assertEq(decodedBodies[i], segments[i]); + assertEq(decodedIndices[i], indices[i]); + } + } + + function _encodeSimple(bytes[] memory segments) internal pure returns (bytes memory) { + bytes memory result = ""; + + for (uint256 i = 0; i < segments.length; i++) { + result = abi.encodePacked(result, uint32(segments[i].length), segments[i]); + } + + return result; + } + + function _encodeWithIndex(bytes[] memory segments, uint8[] memory indices) + internal + pure + returns (bytes memory) + { + require(segments.length == indices.length, "segments len != indices len"); + + bytes memory result = ""; + + for (uint256 i = 0; i < segments.length; i++) { + result = abi.encodePacked(result, uint32(segments[i].length + 1), indices[i], segments[i]); + } + + return result; + } + + function decodeSimple(bytes calldata encoded, uint256 capacityHint) external pure returns (bytes[] memory) { + bytes[] memory result = new bytes[](capacityHint); + + bytes calldata remainder = encoded; + + uint256 index = 0; + while (remainder.length > 0) { + bytes calldata segment; + (segment, remainder) = remainder.getNextSegment(); + result[index] = segment; + index++; + } + + return result; + } + + function decodeWithIndex(bytes calldata encoded, uint256 capacityHint) + external + pure + returns (bytes[] memory, uint8[] memory) + { + bytes[] memory bodies = new bytes[](capacityHint); + uint8[] memory indices = new uint8[](capacityHint); + + bytes calldata remainder = encoded; + + uint256 index = 0; + while (remainder.length > 0) { + bytes calldata segment; + (segment, remainder) = remainder.getNextSegment(); + bodies[index] = segment.getBody(); + indices[index] = segment.getIndex(); + index++; + } + + return (bodies, indices); + } +} diff --git a/test/mocks/plugins/ComprehensivePlugin.sol b/test/mocks/plugins/ComprehensivePlugin.sol index 6ef654c7..4062218b 100644 --- a/test/mocks/plugins/ComprehensivePlugin.sol +++ b/test/mocks/plugins/ComprehensivePlugin.sol @@ -74,7 +74,11 @@ contract ComprehensivePlugin is IValidation, IValidationHook, IExecutionHook, Ba revert NotImplemented(); } - function preRuntimeValidationHook(uint8 functionId, address, uint256, bytes calldata) external pure override { + function preRuntimeValidationHook(uint8 functionId, address, uint256, bytes calldata, bytes calldata) + external + pure + override + { if (functionId == uint8(FunctionId.PRE_VALIDATION_HOOK_1)) { return; } else if (functionId == uint8(FunctionId.PRE_VALIDATION_HOOK_2)) { diff --git a/test/mocks/plugins/MockAccessControlHookPlugin.sol b/test/mocks/plugins/MockAccessControlHookPlugin.sol new file mode 100644 index 00000000..c17868a8 --- /dev/null +++ b/test/mocks/plugins/MockAccessControlHookPlugin.sol @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; + +import {PluginMetadata, PluginManifest} from "../../../src/interfaces/IPlugin.sol"; +import {IValidationHook} from "../../../src/interfaces/IValidationHook.sol"; +import {IStandardExecutor} from "../../../src/interfaces/IStandardExecutor.sol"; +import {BasePlugin} from "../../../src/plugins/BasePlugin.sol"; + +// A pre validaiton hook plugin that uses per-hook data. +// This example enforces that the target of an `execute` call must only be the previously specified address. +// This is just a mock - it does not enforce this over `executeBatch` and other methods of making calls, and should +// not be used in production.. +contract MockAccessControlHookPlugin is IValidationHook, BasePlugin { + enum FunctionId { + PRE_VALIDATION_HOOK + } + + mapping(address account => address allowedTarget) public allowedTargets; + + function onInstall(bytes calldata data) external override { + address allowedTarget = abi.decode(data, (address)); + allowedTargets[msg.sender] = allowedTarget; + } + + function onUninstall(bytes calldata) external override { + delete allowedTargets[msg.sender]; + } + + function preUserOpValidationHook(uint8 functionId, PackedUserOperation calldata userOp, bytes32) + external + view + override + returns (uint256) + { + if (functionId == uint8(FunctionId.PRE_VALIDATION_HOOK)) { + if (bytes4(userOp.callData[:4]) == IStandardExecutor.execute.selector) { + address target = abi.decode(userOp.callData[4:36], (address)); + + // Simulate a merkle proof - require that the target address is also provided in the signature + address proof = address(bytes20(userOp.signature)); + require(proof == target, "Proof doesn't match target"); + require(target == allowedTargets[msg.sender], "Target not allowed"); + return 0; + } + } + revert NotImplemented(); + } + + function preRuntimeValidationHook( + uint8 functionId, + address, + uint256, + bytes calldata data, + bytes calldata authorization + ) external view override { + if (functionId == uint8(FunctionId.PRE_VALIDATION_HOOK)) { + if (bytes4(data[:4]) == IStandardExecutor.execute.selector) { + address target = abi.decode(data[4:36], (address)); + + // Simulate a merkle proof - require that the target address is also provided in the authorization + // data + address proof = address(bytes20(authorization)); + require(proof == target, "Proof doesn't match target"); + require(target == allowedTargets[msg.sender], "Target not allowed"); + + return; + } + } + + revert NotImplemented(); + } + + function pluginMetadata() external pure override returns (PluginMetadata memory) {} + + function pluginManifest() external pure override returns (PluginManifest memory) {} +} diff --git a/test/mocks/plugins/ReturnDataPluginMocks.sol b/test/mocks/plugins/ReturnDataPluginMocks.sol index dae2c8e4..031ca68d 100644 --- a/test/mocks/plugins/ReturnDataPluginMocks.sol +++ b/test/mocks/plugins/ReturnDataPluginMocks.sol @@ -101,7 +101,8 @@ contract ResultConsumerPlugin is BasePlugin, IValidation { // This result should be allowed based on the manifest permission request bytes memory returnData = IStandardExecutor(msg.sender).executeWithAuthorization( abi.encodeCall(IStandardExecutor.execute, (target, 0, abi.encodeCall(RegularResultContract.foo, ()))), - abi.encodePacked(this, uint8(0), uint8(0)) // Validation function of self, selector-associated + abi.encodePacked(this, uint8(0), uint8(0), uint32(1), uint8(255)) // Validation function of self, + // selector-associated, with no auth data ); bytes32 actual = abi.decode(abi.decode(returnData, (bytes)), (bytes32)); diff --git a/test/mocks/plugins/ValidationPluginMocks.sol b/test/mocks/plugins/ValidationPluginMocks.sol index d5f75e99..f6ed4a5f 100644 --- a/test/mocks/plugins/ValidationPluginMocks.sol +++ b/test/mocks/plugins/ValidationPluginMocks.sol @@ -67,7 +67,11 @@ abstract contract MockBaseUserOpValidationPlugin is IValidation, IValidationHook // Empty stubs function pluginMetadata() external pure override returns (PluginMetadata memory) {} - function preRuntimeValidationHook(uint8, address, uint256, bytes calldata) external pure override { + function preRuntimeValidationHook(uint8, address, uint256, bytes calldata, bytes calldata) + external + pure + override + { revert NotImplemented(); } diff --git a/test/utils/AccountTestBase.sol b/test/utils/AccountTestBase.sol index 059e9cac..736b6041 100644 --- a/test/utils/AccountTestBase.sol +++ b/test/utils/AccountTestBase.sol @@ -3,6 +3,7 @@ pragma solidity ^0.8.19; import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol"; import {SingleOwnerPlugin} from "../../src/plugins/owner/SingleOwnerPlugin.sol"; @@ -14,6 +15,8 @@ import {MSCAFactoryFixture} from "../mocks/MSCAFactoryFixture.sol"; /// @dev This contract handles common boilerplate setup for tests using UpgradeableModularAccount with /// SingleOwnerPlugin. abstract contract AccountTestBase is OptimizedTest { + using FunctionReferenceLib for FunctionReference; + EntryPoint public entryPoint; address payable public beneficiary; SingleOwnerPlugin public singleOwnerPlugin; @@ -26,6 +29,11 @@ abstract contract AccountTestBase is OptimizedTest { uint8 public constant SELECTOR_ASSOCIATED_VALIDATION = 0; uint8 public constant DEFAULT_VALIDATION = 1; + struct PreValidationHookData { + uint8 index; + bytes validationData; + } + constructor() { entryPoint = new EntryPoint(); (owner1, owner1Key) = makeAddrAndKey("owner1"); @@ -50,10 +58,12 @@ abstract contract AccountTestBase is OptimizedTest { abi.encodeCall(SingleOwnerPlugin.transferOwnership, (address(this))) ) ), - abi.encodePacked( - address(singleOwnerPlugin), - ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER, - SELECTOR_ASSOCIATED_VALIDATION + _encodeSignature( + FunctionReferenceLib.pack( + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ), + SELECTOR_ASSOCIATED_VALIDATION, + "" ) ); } @@ -62,4 +72,48 @@ abstract contract AccountTestBase is OptimizedTest { function _encodeGas(uint256 g1, uint256 g2) internal pure returns (bytes32) { return bytes32(uint256((g1 << 128) + uint128(g2))); } + + // helper function to encode a signature, according to the per-hook and per-validation data format. + function _encodeSignature( + FunctionReference validationFunction, + uint8 defaultOrNot, + PreValidationHookData[] memory preValidationHookData, + bytes memory validationData + ) internal pure returns (bytes memory) { + bytes memory sig = abi.encodePacked(validationFunction, defaultOrNot); + + for (uint256 i = 0; i < preValidationHookData.length; ++i) { + sig = abi.encodePacked( + sig, + _packValidationDataWithIndex( + preValidationHookData[i].index, preValidationHookData[i].validationData + ) + ); + } + + // Index of the actual validation data is the length of the preValidationHooksRetrieved - aka + // one-past-the-end + sig = abi.encodePacked(sig, _packValidationDataWithIndex(255, validationData)); + + return sig; + } + + // overload for the case where there are no pre-validation hooks + function _encodeSignature( + FunctionReference validationFunction, + uint8 defaultOrNot, + bytes memory validationData + ) internal pure returns (bytes memory) { + PreValidationHookData[] memory emptyPreValidationHookData = new PreValidationHookData[](0); + return _encodeSignature(validationFunction, defaultOrNot, emptyPreValidationHookData, validationData); + } + + // helper function to pack validation data with an index, according to the sparse calldata segment spec. + function _packValidationDataWithIndex(uint8 index, bytes memory validationData) + internal + pure + returns (bytes memory) + { + return abi.encodePacked(uint32(validationData.length + 1), index, validationData); + } } From 78ae1eb28f7b506035faebe5fa2adf18641e02b4 Mon Sep 17 00:00:00 2001 From: adam Date: Tue, 11 Jun 2024 14:32:35 -0400 Subject: [PATCH 2/5] Add Allowlist sample plugin, refactor test base --- src/account/UpgradeableModularAccount.sol | 13 +- .../permissionhooks/AllowlistPlugin.sol | 142 ++++++++ test/account/AccountLoupe.t.sol | 13 +- test/account/DefaultValidationTest.t.sol | 15 +- test/account/MultiValidation.t.sol | 3 - test/account/PerHookData.t.sol | 65 ++-- test/account/UpgradeableModularAccount.t.sol | 23 +- .../mocks/DefaultValidationFactoryFixture.sol | 7 +- test/samples/AllowlistPlugin.t.sol | 318 ++++++++++++++++++ test/utils/AccountTestBase.sol | 134 ++++++++ test/utils/CustomValidationTestBase.sol | 44 +++ 11 files changed, 690 insertions(+), 87 deletions(-) create mode 100644 src/samples/permissionhooks/AllowlistPlugin.sol create mode 100644 test/samples/AllowlistPlugin.t.sol create mode 100644 test/utils/CustomValidationTestBase.sol diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index 335cb762..a642aec7 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -245,11 +245,14 @@ contract UpgradeableModularAccount is /// with user install configs. /// @dev This function is only callable once, and only by the EntryPoint. - function initializeDefaultValidation(FunctionReference validationFunction, bytes calldata installData) - external - initializer - { - _installValidation(validationFunction, true, new bytes4[](0), installData, bytes("")); + function initializeWithValidation( + FunctionReference validationFunction, + bool shared, + bytes4[] memory selectors, + bytes calldata installData, + bytes calldata preValidationHooks + ) external initializer { + _installValidation(validationFunction, shared, selectors, installData, preValidationHooks); emit ModularAccountInitialized(_ENTRY_POINT); } diff --git a/src/samples/permissionhooks/AllowlistPlugin.sol b/src/samples/permissionhooks/AllowlistPlugin.sol new file mode 100644 index 00000000..209d8370 --- /dev/null +++ b/src/samples/permissionhooks/AllowlistPlugin.sol @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; + +import {PluginMetadata, PluginManifest} from "../../interfaces/IPlugin.sol"; +import {IValidationHook} from "../../interfaces/IValidationHook.sol"; +import {IStandardExecutor, Call} from "../../interfaces/IStandardExecutor.sol"; +import {BasePlugin} from "../../plugins/BasePlugin.sol"; + +contract AllowlistPlugin is IValidationHook, BasePlugin { + enum FunctionId { + PRE_VALIDATION_HOOK + } + + struct AllowlistInit { + address target; + bool hasSelectorAllowlist; + bytes4[] selectors; + } + + struct AllowlistEntry { + bool allowed; + bool hasSelectorAllowlist; + } + + mapping(address target => mapping(address account => AllowlistEntry)) public targetAllowlist; + mapping(address target => mapping(bytes4 selector => mapping(address account => bool))) public + selectorAllowlist; + + error TargetNotAllowed(); + error SelectorNotAllowed(); + error NoSelectorSpecified(); + + function onInstall(bytes calldata data) external override { + AllowlistInit[] memory init = abi.decode(data, (AllowlistInit[])); + + for (uint256 i = 0; i < init.length; i++) { + targetAllowlist[init[i].target][msg.sender] = AllowlistEntry(true, init[i].hasSelectorAllowlist); + + if (init[i].hasSelectorAllowlist) { + for (uint256 j = 0; j < init[i].selectors.length; j++) { + selectorAllowlist[init[i].target][init[i].selectors[j]][msg.sender] = true; + } + } + } + } + + function onUninstall(bytes calldata data) external override { + AllowlistInit[] memory init = abi.decode(data, (AllowlistInit[])); + + for (uint256 i = 0; i < init.length; i++) { + delete targetAllowlist[init[i].target][msg.sender]; + + if (init[i].hasSelectorAllowlist) { + for (uint256 j = 0; j < init[i].selectors.length; j++) { + delete selectorAllowlist[init[i].target][init[i].selectors[j]][msg.sender]; + } + } + } + } + + function setAllowlistTarget(address target, bool allowed, bool hasSelectorAllowlist) external { + targetAllowlist[target][msg.sender] = AllowlistEntry(allowed, hasSelectorAllowlist); + } + + function setAllowlistSelector(address target, bytes4 selector, bool allowed) external { + selectorAllowlist[target][selector][msg.sender] = allowed; + } + + function preUserOpValidationHook(uint8 functionId, PackedUserOperation calldata userOp, bytes32) + external + view + override + returns (uint256) + { + if (functionId == uint8(FunctionId.PRE_VALIDATION_HOOK)) { + _checkAllowlistCalldata(userOp.callData); + return 0; + } + revert NotImplemented(); + } + + function preRuntimeValidationHook(uint8 functionId, address, uint256, bytes calldata data, bytes calldata) + external + view + override + { + if (functionId == uint8(FunctionId.PRE_VALIDATION_HOOK)) { + _checkAllowlistCalldata(data); + return; + } + + revert NotImplemented(); + } + + function pluginMetadata() external pure override returns (PluginMetadata memory) { + PluginMetadata memory metadata; + metadata.name = "Allowlist Plugin"; + metadata.version = "v0.0.1"; + metadata.author = "ERC-6900 Working Group"; + + return metadata; + } + + // solhint-disable-next-line no-empty-blocks + function pluginManifest() external pure override returns (PluginManifest memory) {} + + function _checkAllowlistCalldata(bytes calldata callData) internal view { + if (bytes4(callData[:4]) == IStandardExecutor.execute.selector) { + (address target,, bytes memory data) = abi.decode(callData[4:], (address, uint256, bytes)); + _checkCallPermission(msg.sender, target, data); + } else if (bytes4(callData[:4]) == IStandardExecutor.executeBatch.selector) { + Call[] memory calls = abi.decode(callData[4:], (Call[])); + + for (uint256 i = 0; i < calls.length; i++) { + _checkCallPermission(msg.sender, calls[i].target, calls[i].data); + } + } + } + + function _checkCallPermission(address account, address target, bytes memory data) internal view { + AllowlistEntry storage entry = targetAllowlist[target][account]; + (bool allowed, bool hasSelectorAllowlist) = (entry.allowed, entry.hasSelectorAllowlist); + + if (!allowed) { + revert TargetNotAllowed(); + } + + if (hasSelectorAllowlist) { + if (data.length < 4) { + revert NoSelectorSpecified(); + } + + bytes4 selector = bytes4(data); + + if (!selectorAllowlist[target][selector][account]) { + revert SelectorNotAllowed(); + } + } + } +} diff --git a/test/account/AccountLoupe.t.sol b/test/account/AccountLoupe.t.sol index fa92ab00..c16ed1c6 100644 --- a/test/account/AccountLoupe.t.sol +++ b/test/account/AccountLoupe.t.sol @@ -7,7 +7,6 @@ import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/Functio import {ExecutionHook} from "../../src/interfaces/IAccountLoupe.sol"; import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; import {IStandardExecutor} from "../../src/interfaces/IStandardExecutor.sol"; -import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol"; import {ComprehensivePlugin} from "../mocks/plugins/ComprehensivePlugin.sol"; import {AccountTestBase} from "../utils/AccountTestBase.sol"; @@ -15,8 +14,6 @@ import {AccountTestBase} from "../utils/AccountTestBase.sol"; contract AccountLoupeTest is AccountTestBase { ComprehensivePlugin public comprehensivePlugin; - FunctionReference public ownerValidation; - event ReceivedCall(bytes msgData, uint256 msgValue); function setUp() public { @@ -28,10 +25,6 @@ contract AccountLoupeTest is AccountTestBase { vm.prank(address(entryPoint)); account1.installPlugin(address(comprehensivePlugin), manifestHash, "", new FunctionReference[](0)); - ownerValidation = FunctionReferenceLib.pack( - address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) - ); - FunctionReference[] memory preValidationHooks = new FunctionReference[](2); preValidationHooks[0] = FunctionReferenceLib.pack( address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.PRE_VALIDATION_HOOK_1) @@ -43,7 +36,7 @@ contract AccountLoupeTest is AccountTestBase { bytes[] memory installDatas = new bytes[](2); vm.prank(address(entryPoint)); account1.installValidation( - ownerValidation, true, new bytes4[](0), bytes(""), abi.encode(preValidationHooks, installDatas) + _ownerValidation, true, new bytes4[](0), bytes(""), abi.encode(preValidationHooks, installDatas) ); } @@ -106,7 +99,7 @@ contract AccountLoupeTest is AccountTestBase { validations = account1.getValidations(account1.execute.selector); assertEq(validations.length, 1); - assertEq(FunctionReference.unwrap(validations[0]), FunctionReference.unwrap(ownerValidation)); + assertEq(FunctionReference.unwrap(validations[0]), FunctionReference.unwrap(_ownerValidation)); } function test_pluginLoupe_getExecutionHooks() public { @@ -147,7 +140,7 @@ contract AccountLoupeTest is AccountTestBase { } function test_pluginLoupe_getValidationHooks() public { - FunctionReference[] memory hooks = account1.getPreValidationHooks(ownerValidation); + FunctionReference[] memory hooks = account1.getPreValidationHooks(_ownerValidation); assertEq(hooks.length, 2); assertEq( diff --git a/test/account/DefaultValidationTest.t.sol b/test/account/DefaultValidationTest.t.sol index c2f118de..7324e176 100644 --- a/test/account/DefaultValidationTest.t.sol +++ b/test/account/DefaultValidationTest.t.sol @@ -5,8 +5,6 @@ import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interface import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; -import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; -import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol"; import {AccountTestBase} from "../utils/AccountTestBase.sol"; import {DefaultValidationFactoryFixture} from "../mocks/DefaultValidationFactoryFixture.sol"; @@ -16,11 +14,6 @@ contract DefaultValidationTest is AccountTestBase { DefaultValidationFactoryFixture public defaultValidationFactoryFixture; - uint256 public constant CALL_GAS_LIMIT = 50000; - uint256 public constant VERIFICATION_GAS_LIMIT = 1200000; - - FunctionReference public ownerValidation; - address public ethRecipient; function setUp() public { @@ -32,10 +25,6 @@ contract DefaultValidationTest is AccountTestBase { ethRecipient = makeAddr("ethRecipient"); vm.deal(ethRecipient, 1 wei); - - ownerValidation = FunctionReferenceLib.pack( - address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) - ); } function test_defaultValidation_userOp_simple() public { @@ -57,7 +46,7 @@ contract DefaultValidationTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = _encodeSignature(ownerValidation, DEFAULT_VALIDATION, abi.encodePacked(r, s, v)); + userOp.signature = _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -74,7 +63,7 @@ contract DefaultValidationTest is AccountTestBase { vm.prank(owner1); account1.executeWithAuthorization( abi.encodeCall(UpgradeableModularAccount.execute, (ethRecipient, 1 wei, "")), - _encodeSignature(ownerValidation, DEFAULT_VALIDATION, "") + _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, "") ); assertEq(ethRecipient.balance, 2 wei); diff --git a/test/account/MultiValidation.t.sol b/test/account/MultiValidation.t.sol index e80d022c..78867f55 100644 --- a/test/account/MultiValidation.t.sol +++ b/test/account/MultiValidation.t.sol @@ -25,9 +25,6 @@ contract MultiValidationTest is AccountTestBase { address public owner2; uint256 public owner2Key; - uint256 public constant CALL_GAS_LIMIT = 50000; - uint256 public constant VERIFICATION_GAS_LIMIT = 1200000; - function setUp() public { validator2 = new SingleOwnerPlugin(); diff --git a/test/account/PerHookData.t.sol b/test/account/PerHookData.t.sol index 77432854..17635ebf 100644 --- a/test/account/PerHookData.t.sol +++ b/test/account/PerHookData.t.sol @@ -3,63 +3,28 @@ pragma solidity ^0.8.25; import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol"; -import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; -import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol"; import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; import {MockAccessControlHookPlugin} from "../mocks/plugins/MockAccessControlHookPlugin.sol"; import {Counter} from "../mocks/Counter.sol"; -import {AccountTestBase} from "../utils/AccountTestBase.sol"; +import {CustomValidationTestBase} from "../utils/CustomValidationTestBase.sol"; -contract PerHookDataTest is AccountTestBase { +contract PerHookDataTest is CustomValidationTestBase { using MessageHashUtils for bytes32; MockAccessControlHookPlugin internal _accessControlHookPlugin; Counter internal _counter; - FunctionReference internal _ownerValidation; - - uint256 public constant CALL_GAS_LIMIT = 50000; - uint256 public constant VERIFICATION_GAS_LIMIT = 1200000; - function setUp() public { _counter = new Counter(); _accessControlHookPlugin = new MockAccessControlHookPlugin(); - // Write over `account1` with a new account proxy, with different initialization. - - address accountImplementation = address(factory.accountImplementation()); - - account1 = UpgradeableModularAccount(payable(new ERC1967Proxy(accountImplementation, ""))); - - _ownerValidation = FunctionReferenceLib.pack( - address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) - ); - - FunctionReference accessControlHook = FunctionReferenceLib.pack( - address(_accessControlHookPlugin), uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK) - ); - - FunctionReference[] memory preValidationHooks = new FunctionReference[](1); - preValidationHooks[0] = accessControlHook; - - bytes[] memory preValidationHookData = new bytes[](1); - // Access control is restricted to only the _counter - preValidationHookData[0] = abi.encode(_counter); - - bytes memory packedPreValidationHooks = abi.encode(preValidationHooks, preValidationHookData); - - vm.prank(address(entryPoint)); - account1.installValidation( - _ownerValidation, true, new bytes4[](0), abi.encode(owner1), packedPreValidationHooks - ); - - vm.deal(address(account1), 100 ether); + _customValidationSetup(); } function test_passAccessControl_userOp() public { @@ -358,4 +323,28 @@ contract PerHookDataTest is AccountTestBase { return (userOp, userOpHash); } + + // Test config + + function _initialValidationConfig() + internal + virtual + override + returns (FunctionReference, bool, bytes4[] memory, bytes memory, bytes memory) + { + FunctionReference accessControlHook = FunctionReferenceLib.pack( + address(_accessControlHookPlugin), uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK) + ); + + FunctionReference[] memory preValidationHooks = new FunctionReference[](1); + preValidationHooks[0] = accessControlHook; + + bytes[] memory preValidationHookData = new bytes[](1); + // Access control is restricted to only the counter + preValidationHookData[0] = abi.encode(_counter); + + bytes memory packedPreValidationHooks = abi.encode(preValidationHooks, preValidationHookData); + + return (_ownerValidation, true, new bytes4[](0), abi.encode(owner1), packedPreValidationHooks); + } } diff --git a/test/account/UpgradeableModularAccount.t.sol b/test/account/UpgradeableModularAccount.t.sol index 2484b933..3cc9810d 100644 --- a/test/account/UpgradeableModularAccount.t.sol +++ b/test/account/UpgradeableModularAccount.t.sol @@ -10,7 +10,7 @@ import {IERC1271} from "@openzeppelin/contracts/interfaces/IERC1271.sol"; import {PluginManagerInternals} from "../../src/account/PluginManagerInternals.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; -import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; +import {FunctionReference} from "../../src/helpers/FunctionReferenceLib.sol"; import {IPlugin, PluginManifest} from "../../src/interfaces/IPlugin.sol"; import {IAccountLoupe} from "../../src/interfaces/IAccountLoupe.sol"; import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; @@ -39,11 +39,6 @@ contract UpgradeableModularAccountTest is AccountTestBase { Counter public counter; PluginManifest internal _manifest; - FunctionReference public ownerValidation; - - uint256 public constant CALL_GAS_LIMIT = 50000; - uint256 public constant VERIFICATION_GAS_LIMIT = 1200000; - event PluginInstalled(address indexed plugin, bytes32 manifestHash, FunctionReference[] dependencies); event PluginUninstalled(address indexed plugin, bool indexed callbacksSucceeded); event ReceivedCall(bytes msgData, uint256 msgValue); @@ -61,10 +56,6 @@ contract UpgradeableModularAccountTest is AccountTestBase { vm.deal(ethRecipient, 1 wei); counter = new Counter(); counter.increment(); // amoritze away gas cost of zero->nonzero transition - - ownerValidation = FunctionReferenceLib.pack( - address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) - ); } function test_deployAccount() public { @@ -88,7 +79,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); userOp.signature = - _encodeSignature(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); + _encodeSignature(_ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -118,7 +109,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); userOp.signature = - _encodeSignature(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); + _encodeSignature(_ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -145,7 +136,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); userOp.signature = - _encodeSignature(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); + _encodeSignature(_ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -172,7 +163,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); userOp.signature = - _encodeSignature(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); + _encodeSignature(_ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -201,7 +192,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); userOp.signature = - _encodeSignature(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); + _encodeSignature(_ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -233,7 +224,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); userOp.signature = - _encodeSignature(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); + _encodeSignature(_ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; diff --git a/test/mocks/DefaultValidationFactoryFixture.sol b/test/mocks/DefaultValidationFactoryFixture.sol index a4836ad8..54663a7c 100644 --- a/test/mocks/DefaultValidationFactoryFixture.sol +++ b/test/mocks/DefaultValidationFactoryFixture.sol @@ -55,11 +55,14 @@ contract DefaultValidationFactoryFixture is OptimizedTest { new ERC1967Proxy{salt: getSalt(owner, salt)}(address(accountImplementation), ""); // point proxy to actual implementation and init plugins - UpgradeableModularAccount(payable(addr)).initializeDefaultValidation( + UpgradeableModularAccount(payable(addr)).initializeWithValidation( FunctionReferenceLib.pack( address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) ), - pluginInstallData + true, + new bytes4[](0), + pluginInstallData, + "" ); } diff --git a/test/samples/AllowlistPlugin.t.sol b/test/samples/AllowlistPlugin.t.sol new file mode 100644 index 00000000..6501256a --- /dev/null +++ b/test/samples/AllowlistPlugin.t.sol @@ -0,0 +1,318 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol"; + +import {Call} from "../../src/interfaces/IStandardExecutor.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {AllowlistPlugin} from "../../src/samples/permissionhooks/AllowlistPlugin.sol"; + +import {CustomValidationTestBase} from "../utils/CustomValidationTestBase.sol"; +import {Counter} from "../mocks/Counter.sol"; + +contract AllowlistPluginTest is CustomValidationTestBase { + AllowlistPlugin public allowlistPlugin; + + AllowlistPlugin.AllowlistInit[] public allowlistInit; + + Counter[] public counters; + + function setUp() public { + allowlistPlugin = new AllowlistPlugin(); + + counters = new Counter[](10); + + for (uint256 i = 0; i < counters.length; i++) { + counters[i] = new Counter(); + } + + // Don't call `_customValidationSetup` here, as we want to test various configurations of install data. + } + + function testFuzz_allowlistHook_userOp_single(uint256 seed) public { + AllowlistPlugin.AllowlistInit[] memory inits; + (inits, seed) = _generateRandomizedAllowlistInit(seed); + + _copyInitToStorage(inits); + _customValidationSetup(); + + Call[] memory calls = new Call[](1); + (calls[0], seed) = _generateRandomCall(seed); + bytes memory expectedError = _getExpectedUserOpError(calls); + + _runExecUserOp(calls[0].target, calls[0].data, expectedError); + } + + function testFuzz_allowlistHook_userOp_batch(uint256 seed) public { + AllowlistPlugin.AllowlistInit[] memory inits; + (inits, seed) = _generateRandomizedAllowlistInit(seed); + + _copyInitToStorage(inits); + _customValidationSetup(); + + Call[] memory calls; + (calls, seed) = _generateRandomCalls(seed); + bytes memory expectedError = _getExpectedUserOpError(calls); + + _runExecBatchUserOp(calls, expectedError); + } + + function testFuzz_allowlistHook_runtime_single(uint256 seed) public { + AllowlistPlugin.AllowlistInit[] memory inits; + (inits, seed) = _generateRandomizedAllowlistInit(seed); + + _copyInitToStorage(inits); + _customValidationSetup(); + + Call[] memory calls = new Call[](1); + (calls[0], seed) = _generateRandomCall(seed); + bytes memory expectedError = _getExpectedRuntimeError(calls); + + if (keccak256(expectedError) == keccak256("emptyrevert")) { + _runtimeExecExpFail(calls[0].target, calls[0].data, ""); + } else { + _runtimeExec(calls[0].target, calls[0].data, expectedError); + } + } + + function testFuzz_allowlistHook_runtime_batch(uint256 seed) public { + AllowlistPlugin.AllowlistInit[] memory inits; + (inits, seed) = _generateRandomizedAllowlistInit(seed); + + _copyInitToStorage(inits); + _customValidationSetup(); + + Call[] memory calls; + (calls, seed) = _generateRandomCalls(seed); + bytes memory expectedError = _getExpectedRuntimeError(calls); + + if (keccak256(expectedError) == keccak256("emptyrevert")) { + _runtimeExecBatchExpFail(calls, ""); + } else { + _runtimeExecBatch(calls, expectedError); + } + } + + function _generateRandomCalls(uint256 seed) internal view returns (Call[] memory, uint256) { + uint256 length = seed % 10; + seed = _next(seed); + + Call[] memory calls = new Call[](length); + + for (uint256 i = 0; i < length; i++) { + (calls[i], seed) = _generateRandomCall(seed); + } + + return (calls, seed); + } + + function _generateRandomCall(uint256 seed) internal view returns (Call memory call, uint256 newSeed) { + // Half of the time, the target is a random counter, the other half, it's a random address. + bool isCounter = seed % 2 == 0; + seed = _next(seed); + + call.target = isCounter ? address(counters[seed % counters.length]) : address(uint160(uint256(seed))); + seed = _next(seed); + + bool validSelector = seed % 2 == 0; + seed = _next(seed); + + if (validSelector) { + uint256 selectorIndex = seed % 3; + seed = _next(seed); + + if (selectorIndex == 0) { + call.data = abi.encodeCall(Counter.setNumber, (seed % 100)); + } else if (selectorIndex == 1) { + call.data = abi.encodeCall(Counter.increment, ()); + } else { + call.data = abi.encodeWithSignature("number()"); + } + + seed = _next(seed); + } else { + call.data = abi.encodePacked(bytes4(uint32(uint256(seed)))); + seed = _next(seed); + } + + return (call, seed); + } + + function _getExpectedUserOpError(Call[] memory calls) internal view returns (bytes memory) { + for (uint256 i = 0; i < calls.length; i++) { + Call memory call = calls[i]; + + (bool allowed, bool hasSelectorAllowlist) = + allowlistPlugin.targetAllowlist(call.target, address(account1)); + if (allowed) { + if ( + hasSelectorAllowlist + && !allowlistPlugin.selectorAllowlist(call.target, bytes4(call.data), address(account1)) + ) { + return abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector(AllowlistPlugin.SelectorNotAllowed.selector) + ); + } + } else { + return abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector(AllowlistPlugin.TargetNotAllowed.selector) + ); + } + } + + return ""; + } + + function _getExpectedRuntimeError(Call[] memory calls) internal view returns (bytes memory) { + for (uint256 i = 0; i < calls.length; i++) { + Call memory call = calls[i]; + + (bool allowed, bool hasSelectorAllowlist) = + allowlistPlugin.targetAllowlist(call.target, address(account1)); + if (allowed) { + if ( + hasSelectorAllowlist + && !allowlistPlugin.selectorAllowlist(call.target, bytes4(call.data), address(account1)) + ) { + return abi.encodeWithSelector( + UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector, + address(allowlistPlugin), + uint8(AllowlistPlugin.FunctionId.PRE_VALIDATION_HOOK), + abi.encodeWithSelector(AllowlistPlugin.SelectorNotAllowed.selector) + ); + } + } else { + return abi.encodeWithSelector( + UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector, + address(allowlistPlugin), + uint8(AllowlistPlugin.FunctionId.PRE_VALIDATION_HOOK), + abi.encodeWithSelector(AllowlistPlugin.TargetNotAllowed.selector) + ); + } + } + + // At this point, we have returned any error that would come from the AllowlistPlugin. + // But, because this is in the runtime path, the Counter itself may throw if it is not a valid selector. + + for (uint256 i = 0; i < calls.length; i++) { + Call memory call = calls[i]; + bytes4 selector = bytes4(call.data); + + if ( + selector != Counter.setNumber.selector && selector != Counter.increment.selector + && selector != bytes4(abi.encodeWithSignature("number()")) + ) { + //todo: better define a way to handle empty reverts. + return "emptyrevert"; + } + } + + return ""; + } + + function _generateRandomizedAllowlistInit(uint256 seed) + internal + view + returns (AllowlistPlugin.AllowlistInit[] memory, uint256) + { + uint256 length = seed % 10; + seed = _next(seed); + + AllowlistPlugin.AllowlistInit[] memory init = new AllowlistPlugin.AllowlistInit[](length); + + for (uint256 i = 0; i < length; i++) { + // Half the time, the target is a random counter, the other half, it's a random address. + bool isCounter = seed % 2 == 0; + seed = _next(seed); + + address target = + isCounter ? address(counters[seed % counters.length]) : address(uint160(uint256(seed))); + + bool hasSelectorAllowlist = seed % 2 == 0; + seed = _next(seed); + + uint256 selectorLength = seed % 10; + seed = _next(seed); + + bytes4[] memory selectors = new bytes4[](selectorLength); + + for (uint256 j = 0; j < selectorLength; j++) { + // half of the time, the selector is a valid selector on counter, the other half it's a random + // selector + + bool isCounterSelector = seed % 2 == 0; + seed = _next(seed); + + if (isCounterSelector) { + uint256 selectorIndex = seed % 3; + seed = _next(seed); + + if (selectorIndex == 0) { + selectors[j] = Counter.setNumber.selector; + } else if (selectorIndex == 1) { + selectors[j] = Counter.increment.selector; + } else { + selectors[j] = bytes4(abi.encodeWithSignature("number()")); + } + } else { + selectors[j] = bytes4(uint32(uint256(seed))); + seed = _next(seed); + } + + selectors[j] = bytes4(uint32(uint256(keccak256(abi.encodePacked(seed, j))))); + seed = _next(seed); + } + + init[i] = AllowlistPlugin.AllowlistInit(target, hasSelectorAllowlist, selectors); + } + + return (init, seed); + } + + // todo: runtime paths + + // fuzz targets, fuzz target selectors. + + // Maybe pull out the helper function for running user ops and possibly expect a failure? + + function _next(uint256 seed) internal pure returns (uint256) { + return uint256(keccak256(abi.encodePacked(seed))); + } + + function _initialValidationConfig() + internal + virtual + override + returns (FunctionReference, bool, bytes4[] memory, bytes memory, bytes memory) + { + FunctionReference accessControlHook = FunctionReferenceLib.pack( + address(allowlistPlugin), uint8(AllowlistPlugin.FunctionId.PRE_VALIDATION_HOOK) + ); + + FunctionReference[] memory preValidationHooks = new FunctionReference[](1); + preValidationHooks[0] = accessControlHook; + + bytes[] memory preValidationHookData = new bytes[](1); + // Access control is restricted to only the counter + preValidationHookData[0] = abi.encode(allowlistInit); + + bytes memory packedPreValidationHooks = abi.encode(preValidationHooks, preValidationHookData); + + return (_ownerValidation, true, new bytes4[](0), abi.encode(owner1), packedPreValidationHooks); + } + + // Unfortunately, this is a feature that solidity has only implemented in via-ir, so we need to do it manually + // to be able to run the tests in lite mode. + function _copyInitToStorage(AllowlistPlugin.AllowlistInit[] memory init) internal { + for (uint256 i = 0; i < init.length; i++) { + allowlistInit.push(init[i]); + } + } +} diff --git a/test/utils/AccountTestBase.sol b/test/utils/AccountTestBase.sol index 736b6041..f5fe033b 100644 --- a/test/utils/AccountTestBase.sol +++ b/test/utils/AccountTestBase.sol @@ -2,8 +2,11 @@ pragma solidity ^0.8.19; import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; +import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; +import {IStandardExecutor, Call} from "../../src/interfaces/IStandardExecutor.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol"; import {SingleOwnerPlugin} from "../../src/plugins/owner/SingleOwnerPlugin.sol"; @@ -16,6 +19,7 @@ import {MSCAFactoryFixture} from "../mocks/MSCAFactoryFixture.sol"; /// SingleOwnerPlugin. abstract contract AccountTestBase is OptimizedTest { using FunctionReferenceLib for FunctionReference; + using MessageHashUtils for bytes32; EntryPoint public entryPoint; address payable public beneficiary; @@ -26,9 +30,14 @@ abstract contract AccountTestBase is OptimizedTest { uint256 public owner1Key; UpgradeableModularAccount public account1; + FunctionReference internal _ownerValidation; + uint8 public constant SELECTOR_ASSOCIATED_VALIDATION = 0; uint8 public constant DEFAULT_VALIDATION = 1; + uint256 public constant CALL_GAS_LIMIT = 50000; + uint256 public constant VERIFICATION_GAS_LIMIT = 1200000; + struct PreValidationHookData { uint8 index; bytes validationData; @@ -44,6 +53,131 @@ abstract contract AccountTestBase is OptimizedTest { account1 = factory.createAccount(owner1, 0); vm.deal(address(account1), 100 ether); + + _ownerValidation = FunctionReferenceLib.pack( + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ); + } + + function _runExecUserOp(address target, bytes memory callData) internal { + _runUserOp(abi.encodeCall(IStandardExecutor.execute, (target, 0, callData))); + } + + function _runExecUserOp(address target, bytes memory callData, bytes memory revertReason) internal { + _runUserOp(abi.encodeCall(IStandardExecutor.execute, (target, 0, callData)), revertReason); + } + + function _runExecBatchUserOp(Call[] memory calls) internal { + _runUserOp(abi.encodeCall(IStandardExecutor.executeBatch, (calls))); + } + + function _runExecBatchUserOp(Call[] memory calls, bytes memory revertReason) internal { + _runUserOp(abi.encodeCall(IStandardExecutor.executeBatch, (calls)), revertReason); + } + + function _runUserOp(bytes memory callData) internal { + // Run user op without expecting a revert + _runUserOp(callData, hex""); + } + + function _runUserOp(bytes memory callData, bytes memory expectedRevertData) internal { + uint256 nonce = entryPoint.getNonce(address(account1), 0); + + PackedUserOperation memory userOp = PackedUserOperation({ + sender: address(account1), + nonce: nonce, + initCode: hex"", + callData: callData, + accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT), + preVerificationGas: 0, + gasFees: _encodeGas(1, 1), + paymasterAndData: hex"", + signature: hex"" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + userOp.signature = _encodeSignature( + FunctionReferenceLib.pack( + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ), + DEFAULT_VALIDATION, + abi.encodePacked(r, s, v) + ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + if (expectedRevertData.length > 0) { + vm.expectRevert(expectedRevertData); + } + entryPoint.handleOps(userOps, beneficiary); + } + + function _runtimeExec(address target, bytes memory callData) internal { + _runtimeCall(abi.encodeCall(IStandardExecutor.execute, (target, 0, callData))); + } + + function _runtimeExec(address target, bytes memory callData, bytes memory expectedRevertData) internal { + _runtimeCall(abi.encodeCall(IStandardExecutor.execute, (target, 0, callData)), expectedRevertData); + } + + function _runtimeExecExpFail(address target, bytes memory callData, bytes memory expectedRevertData) + internal + { + _runtimeCallExpFail(abi.encodeCall(IStandardExecutor.execute, (target, 0, callData)), expectedRevertData); + } + + function _runtimeExecBatch(Call[] memory calls) internal { + _runtimeCall(abi.encodeCall(IStandardExecutor.executeBatch, (calls))); + } + + function _runtimeExecBatch(Call[] memory calls, bytes memory expectedRevertData) internal { + _runtimeCall(abi.encodeCall(IStandardExecutor.executeBatch, (calls)), expectedRevertData); + } + + function _runtimeExecBatchExpFail(Call[] memory calls, bytes memory expectedRevertData) internal { + _runtimeCallExpFail(abi.encodeCall(IStandardExecutor.executeBatch, (calls)), expectedRevertData); + } + + function _runtimeCall(bytes memory callData) internal { + _runtimeCall(callData, ""); + } + + function _runtimeCall(bytes memory callData, bytes memory expectedRevertData) internal { + if (expectedRevertData.length > 0) { + vm.expectRevert(expectedRevertData); + } + + vm.prank(owner1); + account1.executeWithAuthorization( + callData, + _encodeSignature( + FunctionReferenceLib.pack( + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ), + DEFAULT_VALIDATION, + "" + ) + ); + } + + // Always expects a revert, even if the revert data is zero-length. + function _runtimeCallExpFail(bytes memory callData, bytes memory expectedRevertData) internal { + vm.expectRevert(expectedRevertData); + + vm.prank(owner1); + account1.executeWithAuthorization( + callData, + _encodeSignature( + FunctionReferenceLib.pack( + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ), + DEFAULT_VALIDATION, + "" + ) + ); } function _transferOwnershipToTest() internal { diff --git a/test/utils/CustomValidationTestBase.sol b/test/utils/CustomValidationTestBase.sol new file mode 100644 index 00000000..8bcdd406 --- /dev/null +++ b/test/utils/CustomValidationTestBase.sol @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; + +import {FunctionReference} from "../../src/helpers/FunctionReferenceLib.sol"; +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; + +import {AccountTestBase} from "./AccountTestBase.sol"; + +/// @dev This test contract base is used to test custom validation logic. +/// To use this, override the _initialValidationConfig function to return the desired validation configuration. +/// Then, call _customValidationSetup in the test setup. +/// Make sure to do so after any state variables that `_initialValidationConfig` relies on are set. +abstract contract CustomValidationTestBase is AccountTestBase { + function _customValidationSetup() internal { + ( + FunctionReference validationFunction, + bool shared, + bytes4[] memory selectors, + bytes memory installData, + bytes memory preValidationHooks + ) = _initialValidationConfig(); + + address accountImplementation = address(factory.accountImplementation()); + + account1 = UpgradeableModularAccount(payable(new ERC1967Proxy{salt: 0}(accountImplementation, ""))); + + account1.initializeWithValidation(validationFunction, shared, selectors, installData, preValidationHooks); + + vm.deal(address(account1), 100 ether); + } + + function _initialValidationConfig() + internal + virtual + returns ( + FunctionReference validationFunction, + bool shared, + bytes4[] memory selectors, + bytes memory installData, + bytes memory preValidationHooks + ); +} From 5fda9a22e018c4450e0ff4fd18e7494b306a8a80 Mon Sep 17 00:00:00 2001 From: adam Date: Wed, 19 Jun 2024 11:00:12 -0400 Subject: [PATCH 3/5] self-call protection --- src/account/UpgradeableModularAccount.sol | 64 +++++- test/account/SelfCallAuthorization.t.sol | 239 ++++++++++++++++++++++ 2 files changed, 294 insertions(+), 9 deletions(-) create mode 100644 test/account/SelfCallAuthorization.t.sol diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index a642aec7..ae53ed66 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -68,6 +68,7 @@ contract UpgradeableModularAccount is error PreRuntimeValidationHookFailed(address plugin, uint8 functionId, bytes revertReason); error RuntimeValidationFunctionMissing(bytes4 selector); error RuntimeValidationFunctionReverted(address plugin, uint8 functionId, bytes revertReason); + error SelfCallRecursionDepthExceeded(); error SignatureValidationInvalid(address plugin, uint8 functionId); error UnexpectedAggregator(address plugin, uint8 functionId, address aggregator); error UnrecognizedFunction(bytes4 selector); @@ -187,14 +188,12 @@ contract UpgradeableModularAccount is payable returns (bytes memory) { - bytes4 execSelector = bytes4(data[:4]); - // Revert if the provided `authorization` less than 21 bytes long, rather than right-padding. FunctionReference runtimeValidationFunction = FunctionReference.wrap(bytes21(authorization[:21])); // Check if the runtime validation function is allowed to be called bool isDefaultValidation = uint8(authorization[21]) == 1; - _checkIfValidationApplies(execSelector, runtimeValidationFunction, isDefaultValidation); + _checkIfValidationAppliesCallData(data, runtimeValidationFunction, isDefaultValidation); _doRuntimeValidation(runtimeValidationFunction, data, authorization[22:]); @@ -343,13 +342,12 @@ contract UpgradeableModularAccount is if (userOp.callData.length < 4) { revert UnrecognizedFunction(bytes4(userOp.callData)); } - bytes4 selector = bytes4(userOp.callData); // Revert if the provided `authorization` less than 21 bytes long, rather than right-padding. FunctionReference userOpValidationFunction = FunctionReference.wrap(bytes21(userOp.signature[:21])); bool isDefaultValidation = uint8(userOp.signature[21]) == 1; - _checkIfValidationApplies(selector, userOpValidationFunction, isDefaultValidation); + _checkIfValidationAppliesCallData(userOp.callData, userOpValidationFunction, isDefaultValidation); validationData = _doUserOpValidation(userOpValidationFunction, userOp, userOp.signature[22:], userOpHash); } @@ -564,10 +562,58 @@ contract UpgradeableModularAccount is // solhint-disable-next-line no-empty-blocks function _authorizeUpgrade(address newImplementation) internal override {} - function _checkIfValidationApplies(bytes4 selector, FunctionReference validationFunction, bool isDefault) - internal - view - { + function _checkIfValidationAppliesCallData( + bytes calldata callData, + FunctionReference validationFunction, + bool isDefault + ) internal view { + bytes4 outerSelector = bytes4(callData[:4]); + + _checkIfValidationAppliesSelector(outerSelector, validationFunction, isDefault); + + if (outerSelector == IStandardExecutor.execute.selector) { + (address target,,) = abi.decode(callData[4:], (address, uint256, bytes)); + + if (target == address(this)) { + // There is no point to call `execute` to recurse exactly once - this is equivalent to just having + // the calldata as a top-level call. + revert SelfCallRecursionDepthExceeded(); + } + } else if (outerSelector == IStandardExecutor.executeBatch.selector) { + // executeBatch may be used to batch account actions together, by targetting the account itself. + // If this is done, we must ensure all of the inner calls are allowed by the provided validation + // function. + + (Call[] memory calls) = abi.decode(callData[4:], (Call[])); + + for (uint256 i = 0; i < calls.length; ++i) { + if (calls[i].target == address(this)) { + bytes4 nestedSelector = bytes4(calls[i].data); + + if ( + nestedSelector == IStandardExecutor.execute.selector + || nestedSelector == IStandardExecutor.executeBatch.selector + ) { + // To prevent arbitrarily-deep recursive checking, we limit the depth of self-calls to one + // for the purposes of batching. + // This means that all self-calls must occur at the top level of the batch. + // Note that plugins of other contracts using `executeWithAuthorization` may still + // independently call into this account with a different validation function, allowing + // composition of multiple batches. + revert SelfCallRecursionDepthExceeded(); + } + + _checkIfValidationAppliesSelector(nestedSelector, validationFunction, isDefault); + } + } + } + } + + function _checkIfValidationAppliesSelector( + bytes4 selector, + FunctionReference validationFunction, + bool isDefault + ) internal view { AccountStorage storage _storage = getAccountStorage(); // Check that the provided validation function is applicable to the selector diff --git a/test/account/SelfCallAuthorization.t.sol b/test/account/SelfCallAuthorization.t.sol new file mode 100644 index 00000000..955cf112 --- /dev/null +++ b/test/account/SelfCallAuthorization.t.sol @@ -0,0 +1,239 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol"; +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; + +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {IStandardExecutor, Call} from "../../src/interfaces/IStandardExecutor.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; + +import {AccountTestBase} from "../utils/AccountTestBase.sol"; +import {DefaultValidationFactoryFixture} from "../mocks/DefaultValidationFactoryFixture.sol"; +import {ComprehensivePlugin} from "../mocks/plugins/ComprehensivePlugin.sol"; + +contract SelfCallAuthorizationTest is AccountTestBase { + DefaultValidationFactoryFixture public defaultValidationFactoryFixture; + + ComprehensivePlugin public comprehensivePlugin; + + FunctionReference public comprehensivePluginValidation; + + function setUp() public { + defaultValidationFactoryFixture = new DefaultValidationFactoryFixture(entryPoint, singleOwnerPlugin); + + account1 = UpgradeableModularAccount(payable(defaultValidationFactoryFixture.createAccount(owner1, 0))); + + vm.deal(address(account1), 100 ether); + + // install the comprehensive plugin to get new exec functions with different validations configured. + + comprehensivePlugin = new ComprehensivePlugin(); + + bytes32 manifestHash = keccak256(abi.encode(comprehensivePlugin.pluginManifest())); + vm.prank(address(entryPoint)); + account1.installPlugin(address(comprehensivePlugin), manifestHash, "", new FunctionReference[](0)); + + comprehensivePluginValidation = FunctionReferenceLib.pack( + address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.VALIDATION) + ); + } + + function test_selfCallFails_userOp() public { + // Uses default validation + _runUserOp( + abi.encodeCall(ComprehensivePlugin.foo, ()), + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector( + UpgradeableModularAccount.UserOpValidationFunctionMissing.selector, + ComprehensivePlugin.foo.selector + ) + ) + ); + } + + function test_selfCallFails_runtime() public { + // Uses default validation + _runtimeCall( + abi.encodeCall(ComprehensivePlugin.foo, ()), + abi.encodeWithSelector( + UpgradeableModularAccount.UserOpValidationFunctionMissing.selector, + ComprehensivePlugin.foo.selector + ) + ); + } + + function test_selfCallPrivilegeEscalation_prevented_userOp() public { + // Using default validation, self-call bypasses custom validation needed for ComprehensivePlugin.foo + _runUserOp( + abi.encodeCall( + UpgradeableModularAccount.execute, + (address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())) + ), + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector(UpgradeableModularAccount.SelfCallRecursionDepthExceeded.selector) + ) + ); + + Call[] memory calls = new Call[](1); + calls[0] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + + _runUserOp( + abi.encodeCall(IStandardExecutor.executeBatch, (calls)), + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector( + UpgradeableModularAccount.UserOpValidationFunctionMissing.selector, + ComprehensivePlugin.foo.selector + ) + ) + ); + } + + function test_selfCallPrivilegeEscalation_prevented_runtime() public { + // Using default validation, self-call bypasses custom validation needed for ComprehensivePlugin.foo + _runtimeCall( + abi.encodeCall( + UpgradeableModularAccount.execute, + (address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())) + ), + abi.encodeWithSelector(UpgradeableModularAccount.SelfCallRecursionDepthExceeded.selector) + ); + + Call[] memory calls = new Call[](1); + calls[0] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + + _runtimeExecBatchExpFail( + calls, + abi.encodeWithSelector( + UpgradeableModularAccount.UserOpValidationFunctionMissing.selector, + ComprehensivePlugin.foo.selector + ) + ); + } + + function test_batchAction_allowed_userOp() public { + _enableBatchValidation(); + + Call[] memory calls = new Call[](2); + calls[0] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + calls[1] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + + PackedUserOperation memory userOp = _generateUserOpWithComprehensivePluginValidation( + abi.encodeCall(IStandardExecutor.executeBatch, (calls)) + ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectCall(address(comprehensivePlugin), abi.encodeCall(ComprehensivePlugin.foo, ()), 2); + entryPoint.handleOps(userOps, beneficiary); + } + + function test_batchAction_allowed_runtime() public { + _enableBatchValidation(); + + Call[] memory calls = new Call[](2); + calls[0] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + calls[1] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + + vm.expectCall(address(comprehensivePlugin), abi.encodeCall(ComprehensivePlugin.foo, ()), 2); + account1.executeWithAuthorization( + abi.encodeCall(IStandardExecutor.executeBatch, (calls)), + _encodeSignature(comprehensivePluginValidation, SELECTOR_ASSOCIATED_VALIDATION, "") + ); + } + + function test_recursiveDepthCapped_userOp() public { + _enableBatchValidation(); + + Call[] memory innerCalls = new Call[](1); + innerCalls[0] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + + Call[] memory outerCalls = new Call[](1); + outerCalls[0] = Call(address(account1), 0, abi.encodeCall(IStandardExecutor.executeBatch, (innerCalls))); + + PackedUserOperation memory userOp = _generateUserOpWithComprehensivePluginValidation( + abi.encodeCall(IStandardExecutor.executeBatch, (outerCalls)) + ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectRevert( + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector(UpgradeableModularAccount.SelfCallRecursionDepthExceeded.selector) + ) + ); + entryPoint.handleOps(userOps, beneficiary); + } + + function test_recursiveDepthCapped_runtime() public { + _enableBatchValidation(); + + Call[] memory innerCalls = new Call[](1); + innerCalls[0] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + + Call[] memory outerCalls = new Call[](1); + outerCalls[0] = Call(address(account1), 0, abi.encodeCall(IStandardExecutor.executeBatch, (innerCalls))); + + vm.expectRevert(abi.encodeWithSelector(UpgradeableModularAccount.SelfCallRecursionDepthExceeded.selector)); + account1.executeWithAuthorization( + abi.encodeCall(IStandardExecutor.executeBatch, (outerCalls)), + _encodeSignature(comprehensivePluginValidation, SELECTOR_ASSOCIATED_VALIDATION, "") + ); + } + + function _enableBatchValidation() internal { + // Extend ComprehensivePlugin's validation function to also validate `executeBatch`, to allow the + // self-call. + + bytes4[] memory selectors = new bytes4[](1); + selectors[0] = IStandardExecutor.executeBatch.selector; + + vm.prank(owner1); + account1.executeWithAuthorization( + abi.encodeCall( + UpgradeableModularAccount.installValidation, + (comprehensivePluginValidation, false, selectors, "", "") + ), + _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, "") + ); + } + + function _generateUserOpWithComprehensivePluginValidation(bytes memory callData) + internal + view + returns (PackedUserOperation memory) + { + uint256 nonce = entryPoint.getNonce(address(account1), 0); + return PackedUserOperation({ + sender: address(account1), + nonce: nonce, + initCode: hex"", + callData: callData, + accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT), + preVerificationGas: 0, + gasFees: _encodeGas(1, 1), + paymasterAndData: hex"", + signature: _encodeSignature( + comprehensivePluginValidation, + SELECTOR_ASSOCIATED_VALIDATION, + // Comprehensive plugin's validation function doesn't actually check anything, so we don't need to + // sign anything. + "" + ) + }); + } +} From 509ba0391f2906004fccd90de14bb7f4b6e01d1d Mon Sep 17 00:00:00 2001 From: adam Date: Mon, 24 Jun 2024 15:05:38 -0400 Subject: [PATCH 4/5] refactor validation mapping --- src/account/AccountLoupe.sol | 14 +++++++++++--- src/account/AccountStorage.sol | 12 ++++++++++-- src/account/PluginManager2.sol | 15 +++++---------- src/account/PluginManagerInternals.sol | 8 ++------ src/account/UpgradeableModularAccount.sol | 5 ++--- src/helpers/KnownSelectors.sol | 3 +-- src/interfaces/IAccountLoupe.sol | 8 ++++---- src/interfaces/IPluginManager.sol | 2 -- test/account/AccountLoupe.t.sol | 20 ++++++-------------- test/account/MultiValidation.t.sol | 11 +++++++---- 10 files changed, 48 insertions(+), 50 deletions(-) diff --git a/src/account/AccountLoupe.sol b/src/account/AccountLoupe.sol index 89ffb04b..39ecc54b 100644 --- a/src/account/AccountLoupe.sol +++ b/src/account/AccountLoupe.sol @@ -7,7 +7,7 @@ import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet import {IAccountLoupe, ExecutionHook} from "../interfaces/IAccountLoupe.sol"; import {FunctionReference, IPluginManager} from "../interfaces/IPluginManager.sol"; import {IStandardExecutor} from "../interfaces/IStandardExecutor.sol"; -import {getAccountStorage, SelectorData, toFunctionReferenceArray, toExecutionHook} from "./AccountStorage.sol"; +import {getAccountStorage, SelectorData, toExecutionHook, toSelector} from "./AccountStorage.sol"; abstract contract AccountLoupe is IAccountLoupe { using EnumerableSet for EnumerableSet.Bytes32Set; @@ -28,8 +28,16 @@ abstract contract AccountLoupe is IAccountLoupe { } /// @inheritdoc IAccountLoupe - function getValidations(bytes4 selector) external view override returns (FunctionReference[] memory) { - return toFunctionReferenceArray(getAccountStorage().selectorData[selector].validations); + function getSelectors(FunctionReference validationFunction) external view returns (bytes4[] memory) { + uint256 length = getAccountStorage().validationData[validationFunction].selectors.length(); + + bytes4[] memory selectors = new bytes4[](length); + + for (uint256 i = 0; i < length; ++i) { + selectors[i] = toSelector(getAccountStorage().validationData[validationFunction].selectors.at(i)); + } + + return selectors; } /// @inheritdoc IAccountLoupe diff --git a/src/account/AccountStorage.sol b/src/account/AccountStorage.sol index ffdaff26..7f76e269 100644 --- a/src/account/AccountStorage.sol +++ b/src/account/AccountStorage.sol @@ -29,8 +29,6 @@ struct SelectorData { bool allowDefaultValidation; // The execution hooks for this function selector. EnumerableSet.Bytes32Set executionHooks; - // Which validation functions are associated with this function selector. - EnumerableSet.Bytes32Set validations; } struct ValidationData { @@ -40,6 +38,8 @@ struct ValidationData { bool isSignatureValidation; // The pre validation hooks for this function selector. FunctionReference[] preValidationHooks; + // The set of selectors that may be validated by this validation function. + EnumerableSet.Bytes32Set selectors; } struct AccountStorage { @@ -93,6 +93,14 @@ function toExecutionHook(bytes32 setValue) isPostHook = (uint256(setValue) >> 72) & 0xFF == 1; } +function toSetValue(bytes4 selector) pure returns (bytes32) { + return bytes32(selector); +} + +function toSelector(bytes32 setValue) pure returns (bytes4) { + return bytes4(setValue); +} + /// @dev Helper function to get all elements of a set into memory. function toFunctionReferenceArray(EnumerableSet.Bytes32Set storage set) view diff --git a/src/account/PluginManager2.sol b/src/account/PluginManager2.sol index 0e860848..aa09cffc 100644 --- a/src/account/PluginManager2.sol +++ b/src/account/PluginManager2.sol @@ -66,7 +66,7 @@ abstract contract PluginManager2 { for (uint256 i = 0; i < selectors.length; ++i) { bytes4 selector = selectors[i]; - if (!_storage.selectorData[selector].validations.add(toSetValue(validationFunction))) { + if (!_storage.validationData[validationFunction].selectors.add(toSetValue(selector))) { revert ValidationAlreadySet(selector, validationFunction); } } @@ -79,7 +79,6 @@ abstract contract PluginManager2 { function _uninstallValidation( FunctionReference validationFunction, - bytes4[] calldata selectors, bytes calldata uninstallData, bytes calldata preValidationHookUninstallData ) internal { @@ -102,14 +101,10 @@ abstract contract PluginManager2 { } delete _storage.validationData[validationFunction].preValidationHooks; - // Because this function also calls `onUninstall`, and removes the default flag from validation, we must - // assume these selectors passed in to be exhaustive. - // TODO: consider enforcing this from user-supplied install config. - for (uint256 i = 0; i < selectors.length; ++i) { - bytes4 selector = selectors[i]; - if (!_storage.selectorData[selector].validations.remove(toSetValue(validationFunction))) { - revert ValidationNotSet(selector, validationFunction); - } + // Clear selectors + while (_storage.validationData[validationFunction].selectors.length() > 0) { + bytes32 selector = _storage.validationData[validationFunction].selectors.at(0); + _storage.validationData[validationFunction].selectors.remove(selector); } if (uninstallData.length > 0) { diff --git a/src/account/PluginManagerInternals.sol b/src/account/PluginManagerInternals.sol index 2ec06f81..f2f25c95 100644 --- a/src/account/PluginManagerInternals.sol +++ b/src/account/PluginManagerInternals.sol @@ -103,12 +103,10 @@ abstract contract PluginManagerInternals is IPluginManager { internal notNullFunction(validationFunction) { - SelectorData storage _selectorData = getAccountStorage().selectorData[selector]; - // Fail on duplicate validation functions. Otherwise, dependency validation functions could shadow // non-depdency validation functions. Then, if a either plugin is uninstalled, it would cause a partial // uninstall of the other. - if (!_selectorData.validations.add(toSetValue(validationFunction))) { + if (!getAccountStorage().validationData[validationFunction].selectors.add(toSetValue(selector))) { revert ValidationFunctionAlreadySet(selector, validationFunction); } } @@ -117,11 +115,9 @@ abstract contract PluginManagerInternals is IPluginManager { internal notNullFunction(validationFunction) { - SelectorData storage _selectorData = getAccountStorage().selectorData[selector]; - // May ignore return value, as the manifest hash is validated to ensure that the validation function // exists. - _selectorData.validations.remove(toSetValue(validationFunction)); + getAccountStorage().validationData[validationFunction].selectors.remove(toSetValue(selector)); } function _addExecHooks( diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index ae53ed66..72a2a26f 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -271,11 +271,10 @@ contract UpgradeableModularAccount is /// @notice May be validated by a default validation. function uninstallValidation( FunctionReference validationFunction, - bytes4[] calldata selectors, bytes calldata uninstallData, bytes calldata preValidationHookUninstallData ) external wrapNativeFunction { - _uninstallValidation(validationFunction, selectors, uninstallData, preValidationHookUninstallData); + _uninstallValidation(validationFunction, uninstallData, preValidationHookUninstallData); } /// @notice ERC165 introspection @@ -623,7 +622,7 @@ contract UpgradeableModularAccount is } } else { // Not default validation, but per-selector - if (!getAccountStorage().selectorData[selector].validations.contains(toSetValue(validationFunction))) { + if (!getAccountStorage().validationData[validationFunction].selectors.contains(toSetValue(selector))) { revert UserOpValidationFunctionMissing(selector); } } diff --git a/src/helpers/KnownSelectors.sol b/src/helpers/KnownSelectors.sol index e5244d2c..1d02d2a3 100644 --- a/src/helpers/KnownSelectors.sol +++ b/src/helpers/KnownSelectors.sol @@ -34,8 +34,7 @@ library KnownSelectors { || selector == IStandardExecutor.executeWithAuthorization.selector // check against IAccountLoupe methods || selector == IAccountLoupe.getExecutionFunctionHandler.selector - || selector == IAccountLoupe.getValidations.selector - || selector == IAccountLoupe.getExecutionHooks.selector + || selector == IAccountLoupe.getSelectors.selector || selector == IAccountLoupe.getExecutionHooks.selector || selector == IAccountLoupe.getPreValidationHooks.selector || selector == IAccountLoupe.getInstalledPlugins.selector; } diff --git a/src/interfaces/IAccountLoupe.sol b/src/interfaces/IAccountLoupe.sol index 490b216c..a02d71a8 100644 --- a/src/interfaces/IAccountLoupe.sol +++ b/src/interfaces/IAccountLoupe.sol @@ -18,10 +18,10 @@ interface IAccountLoupe { /// @return plugin The plugin address for this selector. function getExecutionFunctionHandler(bytes4 selector) external view returns (address plugin); - /// @notice Get the validation functions for a selector. - /// @param selector The selector to get the validation functions for. - /// @return The validation functions for this selector. - function getValidations(bytes4 selector) external view returns (FunctionReference[] memory); + /// @notice Get the selectors for a validation function. + /// @param validationFunction The validation function to get the selectors for. + /// @return The allowed selectors for this validation function. + function getSelectors(FunctionReference validationFunction) external view returns (bytes4[] memory); /// @notice Get the pre and post execution hooks for a selector. /// @param selector The selector to get the hooks for. diff --git a/src/interfaces/IPluginManager.sol b/src/interfaces/IPluginManager.sol index 717e1fa0..d0b9400d 100644 --- a/src/interfaces/IPluginManager.sol +++ b/src/interfaces/IPluginManager.sol @@ -43,12 +43,10 @@ interface IPluginManager { /// @notice Uninstall a validation function from a set of execution selectors. /// TODO: remove or update. /// @param validationFunction The validation function to uninstall. - /// @param selectors The selectors to uninstall the validation function for. /// @param uninstallData Optional data to be decoded and used by the plugin to clear plugin data for the /// account. function uninstallValidation( FunctionReference validationFunction, - bytes4[] calldata selectors, bytes calldata uninstallData, bytes calldata preValidationHookUninstallData ) external; diff --git a/test/account/AccountLoupe.t.sol b/test/account/AccountLoupe.t.sol index c16ed1c6..9113c325 100644 --- a/test/account/AccountLoupe.t.sol +++ b/test/account/AccountLoupe.t.sol @@ -83,23 +83,15 @@ contract AccountLoupeTest is AccountTestBase { } } - function test_pluginLoupe_getValidationFunctions() public { - FunctionReference[] memory validations = account1.getValidations(comprehensivePlugin.foo.selector); - - assertEq(validations.length, 1); - assertEq( - FunctionReference.unwrap(validations[0]), - FunctionReference.unwrap( - FunctionReferenceLib.pack( - address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.VALIDATION) - ) - ) + function test_pluginLoupe_getSelectors() public { + FunctionReference comprehensivePluginValidation = FunctionReferenceLib.pack( + address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.VALIDATION) ); - validations = account1.getValidations(account1.execute.selector); + bytes4[] memory selectors = account1.getSelectors(comprehensivePluginValidation); - assertEq(validations.length, 1); - assertEq(FunctionReference.unwrap(validations[0]), FunctionReference.unwrap(_ownerValidation)); + assertEq(selectors.length, 1); + assertEq(selectors[0], comprehensivePlugin.foo.selector); } function test_pluginLoupe_getExecutionHooks() public { diff --git a/test/account/MultiValidation.t.sol b/test/account/MultiValidation.t.sol index 78867f55..9c79be9d 100644 --- a/test/account/MultiValidation.t.sol +++ b/test/account/MultiValidation.t.sol @@ -42,10 +42,13 @@ contract MultiValidationTest is AccountTestBase { ); validations[1] = FunctionReferenceLib.pack(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER)); - FunctionReference[] memory validations2 = account1.getValidations(IStandardExecutor.execute.selector); - assertEq(validations2.length, 2); - assertEq(FunctionReference.unwrap(validations2[0]), FunctionReference.unwrap(validations[0])); - assertEq(FunctionReference.unwrap(validations2[1]), FunctionReference.unwrap(validations[1])); + + bytes4[] memory selectors0 = account1.getSelectors(validations[0]); + bytes4[] memory selectors1 = account1.getSelectors(validations[1]); + assertEq(selectors0.length, selectors1.length); + for (uint256 i = 0; i < selectors0.length; i++) { + assertEq(selectors0[i], selectors1[i]); + } } function test_runtimeValidation_specify() public { From 2e49357cab5f47f07b36e1b1fd6e9eed89523e78 Mon Sep 17 00:00:00 2001 From: adam Date: Fri, 28 Jun 2024 15:49:01 -0400 Subject: [PATCH 5/5] composable validation example --- src/plugins/owner/ECDSAValidationPlugin.sol | 112 ++++++++ src/plugins/owner/MultisigPlugin.sol | 143 ++++++++++ test/account/ComposableValidation.t.sol | 273 ++++++++++++++++++++ 3 files changed, 528 insertions(+) create mode 100644 src/plugins/owner/ECDSAValidationPlugin.sol create mode 100644 src/plugins/owner/MultisigPlugin.sol create mode 100644 test/account/ComposableValidation.t.sol diff --git a/src/plugins/owner/ECDSAValidationPlugin.sol b/src/plugins/owner/ECDSAValidationPlugin.sol new file mode 100644 index 00000000..86b89bbb --- /dev/null +++ b/src/plugins/owner/ECDSAValidationPlugin.sol @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; +import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; + +import {IPlugin} from "../../interfaces/IPlugin.sol"; +import {IValidation} from "../../interfaces/IValidation.sol"; +import {BasePlugin} from "../BasePlugin.sol"; +import {PluginManifest, PluginMetadata} from "../../interfaces/IPlugin.sol"; + +contract ECDSAValidationPlugin is IValidation, BasePlugin { + using ECDSA for bytes32; + using MessageHashUtils for bytes32; + + uint256 internal constant _SIG_VALIDATION_PASSED = 0; + uint256 internal constant _SIG_VALIDATION_FAILED = 1; + + // bytes4(keccak256("isValidSignature(bytes32,bytes)")) + bytes4 internal constant _1271_MAGIC_VALUE = 0x1626ba7e; + bytes4 internal constant _1271_INVALID = 0xffffffff; + + mapping(uint8 id => mapping(address account => address)) public owners; + + error AlreadyInitialized(); + error NotAuthorized(); + error NotInitialized(); + + /// @inheritdoc IPlugin + function onInstall(bytes calldata data) external override { + uint8 id = uint8(bytes1(data[:1])); + + if (owners[id][msg.sender] != address(0)) { + revert AlreadyInitialized(); + } + + address owner = abi.decode(data[1:], (address)); + owners[id][msg.sender] = owner; + } + + /// @inheritdoc IPlugin + function onUninstall(bytes calldata data) external override { + uint8 id = uint8(bytes1(data[:1])); + + if (owners[id][msg.sender] == address(0)) { + revert NotInitialized(); + } + + delete owners[id][msg.sender]; + } + + /// @inheritdoc IValidation + function validateRuntime(uint8 functionId, address sender, uint256, bytes calldata, bytes calldata) + external + view + override + { + // TODO: not composable here, need to add a param to `validateRuntime` to pass in the account. + if (sender != owners[functionId][msg.sender]) { + revert NotAuthorized(); + } + return; + } + + /// @inheritdoc IValidation + function validateUserOp(uint8 functionId, PackedUserOperation calldata userOp, bytes32 userOpHash) + external + view + override + returns (uint256) + { + // Validate the user op signature against the owner. + (address signer,,) = (userOpHash.toEthSignedMessageHash()).tryRecover(userOp.signature); + if (signer == address(0) || signer != owners[functionId][userOp.sender]) { + return _SIG_VALIDATION_FAILED; + } + return _SIG_VALIDATION_PASSED; + } + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Execution view functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + /// @inheritdoc IValidation + /// @dev The signature is valid if it is signed by the owner's private key + /// (if the owner is an EOA) or if it is a valid ERC-1271 signature from the + /// owner (if the owner is a contract). Note that unlike the signature + /// validation used in `validateUserOp`, this does **not** wrap the digest in + /// an "Ethereum Signed Message" envelope before checking the signature in + /// the EOA-owner case. + function validateSignature(uint8 functionId, address, bytes32 digest, bytes calldata signature) + external + view + override + returns (bytes4) + { + // TODO: not composable here, need to add a param to `validateSignature` to pass in the account. + if (digest.recover(signature) == owners[functionId][msg.sender]) { + return _1271_MAGIC_VALUE; + } + return _1271_INVALID; + } + + /// @inheritdoc IPlugin + // solhint-disable-next-line no-empty-blocks + function pluginManifest() external pure override returns (PluginManifest memory) {} + + /// @inheritdoc IPlugin + // solhint-disable-next-line no-empty-blocks + function pluginMetadata() external pure virtual override returns (PluginMetadata memory) {} +} diff --git a/src/plugins/owner/MultisigPlugin.sol b/src/plugins/owner/MultisigPlugin.sol new file mode 100644 index 00000000..22cac637 --- /dev/null +++ b/src/plugins/owner/MultisigPlugin.sol @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; + +import {FunctionReference, FunctionReferenceLib} from "../../helpers/FunctionReferenceLib.sol"; +import {_coalescePreValidation} from "../../helpers/ValidationDataHelpers.sol"; +import {IPlugin} from "../../interfaces/IPlugin.sol"; +import {IValidation} from "../../interfaces/IValidation.sol"; +import {BasePlugin} from "../BasePlugin.sol"; +import {PluginManifest, PluginMetadata} from "../../interfaces/IPlugin.sol"; + +// Non-threshold based multisig plugin - all owners must sign. +// Supports up to 100 owners per id. +contract MultisigPlugin is IValidation, BasePlugin { + struct OwnerInfo { + uint256 length; + FunctionReference[100] validations; + } + + uint256 internal constant _SIG_VALIDATION_PASSED = 0; + uint256 internal constant _SIG_VALIDATION_FAILED = 1; + + mapping(uint8 id => mapping(address account => OwnerInfo)) public ownerInfo; + + error AlreadyInitialized(); + error ArrayLengthMismatch(); + error NotAuthorized(); + error NotInitialized(); + error InvalidOwners(); + error ValidationReturnedAuthorizer(); + + /// @inheritdoc IPlugin + function onInstall(bytes calldata data) external override { + uint8 id = uint8(bytes1(data[:1])); + + if (ownerInfo[id][msg.sender].length != 0) { + revert AlreadyInitialized(); + } + + FunctionReference[] memory validations = abi.decode(data[1:], (FunctionReference[])); + + if (validations.length == 0 || validations.length > 100) { + revert InvalidOwners(); + } + + ownerInfo[id][msg.sender].length = validations.length; + + for (uint256 i = 0; i < validations.length; i++) { + ownerInfo[id][msg.sender].validations[i] = validations[i]; + } + } + + /// @inheritdoc IPlugin + function onUninstall(bytes calldata data) external override { + uint8 id = uint8(bytes1(data[:1])); + + uint256 length = ownerInfo[id][msg.sender].length; + + if (length == 0) { + revert NotInitialized(); + } + + for (uint256 i = 0; i < length; i++) { + ownerInfo[id][msg.sender].validations[i] = FunctionReference.wrap(bytes21(0)); + } + + ownerInfo[id][msg.sender].length = 0; + } + + /// @inheritdoc IValidation + function validateUserOp(uint8 functionId, PackedUserOperation calldata userOp, bytes32 userOpHash) + external + override + returns (uint256) + { + OwnerInfo storage info = ownerInfo[functionId][userOp.sender]; + + if (info.length == 0) { + revert NotInitialized(); + } + + FunctionReference[] memory validations = new FunctionReference[](info.length); + + for (uint256 i = 0; i < info.length; i++) { + validations[i] = info.validations[i]; + } + + uint256 result = _SIG_VALIDATION_PASSED; + + //decode the inner signatures from the userOp + bytes[] memory innerSignatures = abi.decode(userOp.signature, (bytes[])); + + if (innerSignatures.length != validations.length) { + revert ArrayLengthMismatch(); + } + + PackedUserOperation memory innerUserOp = userOp; + + for (uint256 i = 0; i < validations.length; i++) { + innerUserOp.signature = innerSignatures[i]; + (address validationPlugin, uint8 validationId) = FunctionReferenceLib.unpack(validations[i]); + uint256 validationResult = + IValidation(validationPlugin).validateUserOp(validationId, innerUserOp, userOpHash); + // Ensure no authorizer is returned + if (uint160(validationResult) > 1) { + revert ValidationReturnedAuthorizer(); + } + + result = _coalescePreValidation(result, validationResult); + } + + return result; + } + + /// @inheritdoc IValidation + function validateRuntime(uint8, address, uint256, bytes calldata, bytes calldata) external pure override { + revert NotImplemented(); + } + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Execution view functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + /// @inheritdoc IValidation + /// @dev The signature is valid if it is signed by the owner's private key + /// (if the owner is an EOA) or if it is a valid ERC-1271 signature from the + /// owner (if the owner is a contract). Note that unlike the signature + /// validation used in `validateUserOp`, this does **not** wrap the digest in + /// an "Ethereum Signed Message" envelope before checking the signature in + /// the EOA-owner case. + function validateSignature(uint8, address, bytes32, bytes calldata) external pure override returns (bytes4) { + revert NotImplemented(); + } + + /// @inheritdoc IPlugin + // solhint-disable-next-line no-empty-blocks + function pluginManifest() external pure override returns (PluginManifest memory) {} + + /// @inheritdoc IPlugin + // solhint-disable-next-line no-empty-blocks + function pluginMetadata() external pure virtual override returns (PluginMetadata memory) {} +} diff --git a/test/account/ComposableValidation.t.sol b/test/account/ComposableValidation.t.sol new file mode 100644 index 00000000..903e4ccf --- /dev/null +++ b/test/account/ComposableValidation.t.sol @@ -0,0 +1,273 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; +import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; + +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {ECDSAValidationPlugin} from "../../src/plugins/owner/ECDSAValidationPlugin.sol"; +import {IStandardExecutor, Call} from "../../src/interfaces/IStandardExecutor.sol"; +import {MultisigPlugin} from "../../src/plugins/owner/MultisigPlugin.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; + +import {CustomValidationTestBase} from "../utils/CustomValidationTestBase.sol"; + +contract ComposableValidationTest is CustomValidationTestBase { + using MessageHashUtils for bytes32; + + ECDSAValidationPlugin public ecdsaValidationPlugin; + MultisigPlugin public multisigPlugin; + + function setUp() public { + ecdsaValidationPlugin = new ECDSAValidationPlugin(); + multisigPlugin = new MultisigPlugin(); + + _ownerValidation = FunctionReferenceLib.pack(address(ecdsaValidationPlugin), uint8(123)); + } + + function test_basicUserOp_withECDSAValidation() public { + _customValidationSetup(); + + // Now that the account is set up with the ECDSAValidationPlugin, we can test the basic user op + PackedUserOperation memory userOp = PackedUserOperation({ + sender: address(account1), + nonce: 0, + initCode: hex"", + callData: abi.encodeCall(IStandardExecutor.execute, (beneficiary, 0, hex"")), + accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT), + preVerificationGas: 0, + gasFees: _encodeGas(1, 1), + paymasterAndData: hex"", + signature: hex"" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + userOp.signature = _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, abi.encodePacked(r, s, v)); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + } + + function test_basicUserOp_withComposableMultisig_oneLayer() public { + (address owner2, uint256 owner2Key) = makeAddrAndKey("owner2"); + (address owner3, uint256 owner3Key) = makeAddrAndKey("owner3"); + + _customValidationSetup(); + + // Install the multisig plugin with signers 2 and 3 + + FunctionReference composableMultisigValidation = + FunctionReferenceLib.pack(address(multisigPlugin), uint8(0)); + FunctionReference owner2Validation = FunctionReferenceLib.pack(address(ecdsaValidationPlugin), uint8(2)); + FunctionReference owner3Validation = FunctionReferenceLib.pack(address(ecdsaValidationPlugin), uint8(3)); + + FunctionReference[] memory multisigSigners = new FunctionReference[](2); + multisigSigners[0] = owner2Validation; + multisigSigners[1] = owner3Validation; + + // Set up the composable MultisigPlugin + Call[] memory calls = new Call[](3); + calls[0] = Call( + address(ecdsaValidationPlugin), + 0, + abi.encodeCall(ECDSAValidationPlugin.onInstall, (abi.encodePacked(uint8(2), abi.encode(owner2)))) + ); + calls[1] = Call( + address(ecdsaValidationPlugin), + 0, + abi.encodeCall(ECDSAValidationPlugin.onInstall, (abi.encodePacked(uint8(3), abi.encode(owner3)))) + ); + calls[2] = Call( + address(account1), + 0, + abi.encodeCall( + UpgradeableModularAccount.installValidation, + ( + composableMultisigValidation, + true, + new bytes4[](0), + abi.encodePacked(uint8(0), abi.encode(multisigSigners)), + "" + ) + ) + ); + + vm.prank(owner1); + account1.executeWithAuthorization( + abi.encodeCall(IStandardExecutor.executeBatch, (calls)), + _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, "") + ); + + // test the multisig validation + + PackedUserOperation memory userOp = PackedUserOperation({ + sender: address(account1), + nonce: 0, + initCode: hex"", + callData: abi.encodeCall(IStandardExecutor.execute, (beneficiary, 0, hex"")), + accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT), + preVerificationGas: 0, + gasFees: _encodeGas(1, 1), + paymasterAndData: hex"", + signature: hex"" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); + bytes memory owner2Signature = abi.encodePacked(r, s, v); + + (v, r, s) = vm.sign(owner3Key, userOpHash.toEthSignedMessageHash()); + bytes memory owner3Signature = abi.encodePacked(r, s, v); + + bytes[] memory signatures = new bytes[](2); + signatures[0] = owner2Signature; + signatures[1] = owner3Signature; + + userOp.signature = + _encodeSignature(composableMultisigValidation, DEFAULT_VALIDATION, abi.encode(signatures)); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + } + + function test_basicUserOp_withComposableMultisig_twoLayers() public { + (address owner2, uint256 owner2Key) = makeAddrAndKey("owner2"); + (address owner3, uint256 owner3Key) = makeAddrAndKey("owner3"); + (address owner4, uint256 owner4Key) = makeAddrAndKey("owner4"); + + _customValidationSetup(); + + // create signers 2, 3, 4 + // Install the multisig plugin with [signer 2, another multisig [signer 3, signer 4]] + + // To prevent stack too deep, put it in memory. + // 0 = outerMultisigValidation + // 1 = owner2Validation + // 2 = innerMultisigValidation + // 3 = owner3Validation + // 4 = owner4Validation + + FunctionReference[5] memory validations; + + validations[0] = FunctionReferenceLib.pack(address(multisigPlugin), uint8(0)); + validations[1] = FunctionReferenceLib.pack(address(ecdsaValidationPlugin), uint8(2)); + validations[2] = FunctionReferenceLib.pack(address(multisigPlugin), uint8(1)); + validations[3] = FunctionReferenceLib.pack(address(ecdsaValidationPlugin), uint8(3)); + validations[4] = FunctionReferenceLib.pack(address(ecdsaValidationPlugin), uint8(4)); + + FunctionReference[] memory innerMultisigSigners = new FunctionReference[](2); + innerMultisigSigners[0] = validations[3]; + innerMultisigSigners[1] = validations[4]; + + FunctionReference[] memory outerMultisigSigners = new FunctionReference[](2); + outerMultisigSigners[0] = validations[1]; + outerMultisigSigners[1] = validations[2]; + + // Set up the ComposableMultisigPlugin + Call[] memory calls = new Call[](5); + calls[0] = Call( + address(ecdsaValidationPlugin), + 0, + abi.encodeCall(ECDSAValidationPlugin.onInstall, (abi.encodePacked(uint8(2), abi.encode(owner2)))) + ); + calls[1] = Call( + address(ecdsaValidationPlugin), + 0, + abi.encodeCall(ECDSAValidationPlugin.onInstall, (abi.encodePacked(uint8(3), abi.encode(owner3)))) + ); + calls[2] = Call( + address(ecdsaValidationPlugin), + 0, + abi.encodeCall(ECDSAValidationPlugin.onInstall, (abi.encodePacked(uint8(4), abi.encode(owner4)))) + ); + calls[3] = Call( + address(multisigPlugin), + 0, + abi.encodeCall( + ECDSAValidationPlugin.onInstall, (abi.encodePacked(uint8(1), abi.encode(innerMultisigSigners))) + ) + ); + calls[4] = Call( + address(account1), + 0, + abi.encodeCall( + UpgradeableModularAccount.installValidation, + ( + validations[0], + true, + new bytes4[](0), + abi.encodePacked(uint8(0), abi.encode(outerMultisigSigners)), + "" + ) + ) + ); + + vm.prank(owner1); + account1.executeWithAuthorization( + abi.encodeCall(IStandardExecutor.executeBatch, (calls)), + _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, "") + ); + + // test the multisig of multisigs validation + + PackedUserOperation memory userOp = PackedUserOperation({ + sender: address(account1), + nonce: 0, + initCode: hex"", + callData: abi.encodeCall(IStandardExecutor.execute, (beneficiary, 0, hex"")), + accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT), + preVerificationGas: 0, + gasFees: _encodeGas(1, 1), + paymasterAndData: hex"", + signature: hex"" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); + bytes memory owner2Signature = abi.encodePacked(r, s, v); + + (v, r, s) = vm.sign(owner3Key, userOpHash.toEthSignedMessageHash()); + bytes memory owner3Signature = abi.encodePacked(r, s, v); + + (v, r, s) = vm.sign(owner4Key, userOpHash.toEthSignedMessageHash()); + bytes memory owner4Signature = abi.encodePacked(r, s, v); + + bytes[] memory innerSignatures = new bytes[](2); + innerSignatures[0] = owner3Signature; + innerSignatures[1] = owner4Signature; + + bytes[] memory outerSignatures = new bytes[](2); + outerSignatures[0] = owner2Signature; + outerSignatures[1] = abi.encode(innerSignatures); + + userOp.signature = _encodeSignature(validations[0], DEFAULT_VALIDATION, abi.encode(outerSignatures)); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + } + + function _initialValidationConfig() + internal + virtual + override + returns (FunctionReference, bool, bytes4[] memory, bytes memory, bytes memory) + { + return ( + _ownerValidation, + true, + new bytes4[](0), + abi.encodePacked(uint8(123), abi.encode(owner1)), + abi.encodePacked("") + ); + } +}