From a3f4b808e16f67f7dcb52f21198aa899d5201c47 Mon Sep 17 00:00:00 2001 From: adam Date: Mon, 20 May 2024 18:59:41 -0700 Subject: [PATCH 1/6] refactor authorized caller and runtime validation flow --- src/account/UpgradeableModularAccount.sol | 107 ++++++++++++------ src/interfaces/IPluginExecutor.sol | 10 ++ src/interfaces/IValidation.sol | 8 +- src/plugins/owner/SingleOwnerPlugin.sol | 6 +- test/account/AccountExecHooks.t.sol | 4 + test/account/AccountLoupe.t.sol | 1 + test/account/AccountReturnData.t.sol | 21 +++- .../ExecuteFromPluginPermissions.t.sol | 3 + test/account/ManifestValidity.t.sol | 1 + test/account/UpgradeableModularAccount.t.sol | 22 ++-- test/account/ValidationIntersection.t.sol | 2 +- test/mocks/plugins/ComprehensivePlugin.sol | 6 +- test/mocks/plugins/ValidationPluginMocks.sol | 2 +- test/plugin/SingleOwnerPlugin.t.sol | 4 +- test/plugin/TokenReceiverPlugin.t.sol | 7 +- test/utils/AccountTestBase.sol | 13 ++- 16 files changed, 158 insertions(+), 59 deletions(-) diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index 1cd0d98e..3c8b3d80 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -81,7 +81,7 @@ contract UpgradeableModularAccount is // Wraps execution of a native function with runtime validation and hooks // Used for upgradeTo, upgradeToAndCall, execute, executeBatch, installPlugin, uninstallPlugin modifier wrapNativeFunction() { - _doRuntimeValidationIfNotFromEP(); + _checkPermittedCallerIfNotFromEP(); PostExecToRun[] memory postExecHooks = _doPreExecHooks(msg.sig, msg.data); @@ -133,7 +133,7 @@ contract UpgradeableModularAccount is revert UnrecognizedFunction(msg.sig); } - _doRuntimeValidationIfNotFromEP(); + _checkPermittedCallerIfNotFromEP(); PostExecToRun[] memory postExecHooks; // Cache post-exec hooks in memory @@ -262,6 +262,41 @@ contract UpgradeableModularAccount is return returnData; } + /// @inheritdoc IPluginExecutor + function executeWithAuthorization(bytes calldata data, bytes calldata authorization) + external + payable + returns (bytes memory) + { + bytes4 execSelector = bytes4(data[0:4]); + + FunctionReference runtimeValidationFunction = FunctionReference.wrap(bytes21(authorization[0:21])); + + AccountStorage storage _storage = getAccountStorage(); + + // check if that runtime validation function is allowed to be called + if (_storage.selectorData[execSelector].denyExecutionCount > 0) { + revert AlwaysDenyRule(); + } + if (_storage.selectorData[execSelector].validation.notEq(runtimeValidationFunction)) { + revert RuntimeValidationFunctionMissing(execSelector); + } + + _doRuntimeValidation(runtimeValidationFunction, data, authorization[21:]); + + // If runtime validation passes, execute the call + + (bool success, bytes memory returnData) = address(this).call(data); + + if (!success) { + assembly ("memory-safe") { + revert(add(returnData, 32), mload(returnData)) + } + } + + return returnData; + } + /// @inheritdoc IPluginManager function installPlugin( address plugin, @@ -412,50 +447,40 @@ contract UpgradeableModularAccount is } } - function _doRuntimeValidationIfNotFromEP() internal { - AccountStorage storage _storage = getAccountStorage(); - - if (_storage.selectorData[msg.sig].denyExecutionCount > 0) { - revert AlwaysDenyRule(); - } - - if (msg.sender == address(_ENTRY_POINT)) return; - - FunctionReference runtimeValidationFunction = _storage.selectorData[msg.sig].validation; + function _doRuntimeValidation( + FunctionReference runtimeValidationFunction, + bytes calldata callData, + bytes calldata authorizationData + ) internal { // run all preRuntimeValidation hooks EnumerableSet.Bytes32Set storage preRuntimeValidationHooks = - getAccountStorage().selectorData[msg.sig].preValidationHooks; + getAccountStorage().selectorData[bytes4(callData[0:4])].preValidationHooks; uint256 preRuntimeValidationHooksLength = preRuntimeValidationHooks.length(); for (uint256 i = 0; i < preRuntimeValidationHooksLength; ++i) { bytes32 key = preRuntimeValidationHooks.at(i); FunctionReference preRuntimeValidationHook = toFunctionReference(key); - (address plugin, uint8 functionId) = preRuntimeValidationHook.unpack(); + (address hookPlugin, uint8 hookFunctionId) = preRuntimeValidationHook.unpack(); + try IValidationHook(hookPlugin).preRuntimeValidationHook( + hookFunctionId, msg.sender, msg.value, callData + ) + // forgefmt: disable-start // solhint-disable-next-line no-empty-blocks - try IValidationHook(plugin).preRuntimeValidationHook(functionId, msg.sender, msg.value, msg.data) {} - catch (bytes memory revertReason) { - revert PreRuntimeValidationHookFailed(plugin, functionId, revertReason); + {} catch (bytes memory revertReason) { + // forgefmt: disable-end + revert PreRuntimeValidationHookFailed(hookPlugin, hookFunctionId, revertReason); } } - // Identifier scope limiting - { - if (_storage.selectorData[msg.sig].isPublic) { - // If the function is public, we don't need to check the runtime validation function. - return; - } - - if (runtimeValidationFunction.isEmpty()) { - revert RuntimeValidationFunctionMissing(msg.sig); - } + (address plugin, uint8 functionId) = runtimeValidationFunction.unpack(); - (address plugin, uint8 functionId) = runtimeValidationFunction.unpack(); - // solhint-disable-next-line no-empty-blocks - try IValidation(plugin).validateRuntime(functionId, msg.sender, msg.value, msg.data) {} - catch (bytes memory revertReason) { - revert RuntimeValidationFunctionReverted(plugin, functionId, revertReason); - } + try IValidation(plugin).validateRuntime(functionId, msg.sender, msg.value, callData, authorizationData) + // forgefmt: disable-start + // solhint-disable-next-line no-empty-blocks + {} catch (bytes memory revertReason) { + // forgefmt: disable-end + revert RuntimeValidationFunctionReverted(plugin, functionId, revertReason); } } @@ -536,4 +561,20 @@ contract UpgradeableModularAccount is // solhint-disable-next-line no-empty-blocks function _authorizeUpgrade(address newImplementation) internal override {} + + function _checkPermittedCallerIfNotFromEP() internal view { + AccountStorage storage _storage = getAccountStorage(); + + if (_storage.selectorData[msg.sig].denyExecutionCount > 0) { + revert AlwaysDenyRule(); + } + if ( + msg.sender == address(_ENTRY_POINT) || msg.sender == address(this) + || _storage.selectorData[msg.sig].isPublic + ) return; + + if (!_storage.callPermitted[msg.sender][msg.sig]) { + revert ExecFromPluginNotPermitted(msg.sender, msg.sig); + } + } } diff --git a/src/interfaces/IPluginExecutor.sol b/src/interfaces/IPluginExecutor.sol index e1989958..2ff5ed11 100644 --- a/src/interfaces/IPluginExecutor.sol +++ b/src/interfaces/IPluginExecutor.sol @@ -20,4 +20,14 @@ interface IPluginExecutor { external payable returns (bytes memory); + + /// @notice Execute a call using a specified runtime validation, as given in the first 21 bytes of + /// `authorization`. + /// @param data The calldata to send to the account. + /// @param authorization The authorization data to use for the call. The first 21 bytes specifies which runtime + /// validation to use, and the rest is sent as a parameter to runtime validation. + function executeWithAuthorization(bytes calldata data, bytes calldata authorization) + external + payable + returns (bytes memory); } diff --git a/src/interfaces/IValidation.sol b/src/interfaces/IValidation.sol index b4edddcc..b3adcd3d 100644 --- a/src/interfaces/IValidation.sol +++ b/src/interfaces/IValidation.sol @@ -23,7 +23,13 @@ interface IValidation is IPlugin { /// @param sender The caller address. /// @param value The call value. /// @param data The calldata sent. - function validateRuntime(uint8 functionId, address sender, uint256 value, bytes calldata data) external; + function validateRuntime( + uint8 functionId, + address sender, + uint256 value, + bytes calldata data, + bytes calldata authorization + ) external; /// @notice Validates a signature using ERC-1271. /// @dev To indicate the entire call should revert, the function MUST revert. diff --git a/src/plugins/owner/SingleOwnerPlugin.sol b/src/plugins/owner/SingleOwnerPlugin.sol index bebbad21..1018d908 100644 --- a/src/plugins/owner/SingleOwnerPlugin.sol +++ b/src/plugins/owner/SingleOwnerPlugin.sol @@ -79,7 +79,11 @@ contract SingleOwnerPlugin is ISingleOwnerPlugin, BasePlugin { } /// @inheritdoc IValidation - function validateRuntime(uint8 functionId, address sender, uint256, bytes calldata) external view override { + function validateRuntime(uint8 functionId, address sender, uint256, bytes calldata, bytes calldata) + external + view + override + { if (functionId == uint8(FunctionId.VALIDATION_OWNER_OR_SELF)) { // Validate that the sender is the owner of the account or self. if (sender != _owners[msg.sender] && sender != msg.sender) { diff --git a/test/account/AccountExecHooks.t.sol b/test/account/AccountExecHooks.t.sol index 3605ddd5..ee229e3e 100644 --- a/test/account/AccountExecHooks.t.sol +++ b/test/account/AccountExecHooks.t.sol @@ -228,6 +228,7 @@ contract AccountExecHooksTest is AccountTestBase { vm.expectEmit(true, true, true, true); emit PluginInstalled(address(mockPlugin1), manifestHash1, new FunctionReference[](0)); + vm.prank(address(entryPoint)); account1.installPlugin({ plugin: address(mockPlugin1), manifestHash: manifestHash1, @@ -251,6 +252,7 @@ contract AccountExecHooksTest is AccountTestBase { vm.expectEmit(true, true, true, true); emit PluginInstalled(address(mockPlugin1), manifestHash1, new FunctionReference[](0)); + vm.prank(address(entryPoint)); account1.installPlugin({ plugin: address(mockPlugin1), manifestHash: manifestHash1, @@ -274,6 +276,7 @@ contract AccountExecHooksTest is AccountTestBase { vm.expectEmit(true, true, true, true); emit PluginInstalled(address(mockPlugin2), manifestHash2, new FunctionReference[](0)); + vm.prank(address(entryPoint)); account1.installPlugin({ plugin: address(mockPlugin2), manifestHash: manifestHash2, @@ -288,6 +291,7 @@ contract AccountExecHooksTest is AccountTestBase { vm.expectEmit(true, true, true, true); emit PluginUninstalled(address(plugin), true); + vm.prank(address(entryPoint)); account1.uninstallPlugin(address(plugin), bytes(""), bytes("")); } } diff --git a/test/account/AccountLoupe.t.sol b/test/account/AccountLoupe.t.sol index 34c914a7..253691af 100644 --- a/test/account/AccountLoupe.t.sol +++ b/test/account/AccountLoupe.t.sol @@ -25,6 +25,7 @@ contract AccountLoupeTest is AccountTestBase { comprehensivePlugin = new ComprehensivePlugin(); bytes32 manifestHash = keccak256(abi.encode(comprehensivePlugin.pluginManifest())); + vm.prank(address(entryPoint)); account1.installPlugin(address(comprehensivePlugin), manifestHash, "", new FunctionReference[](0)); ownerValidation = FunctionReferenceLib.pack( diff --git a/test/account/AccountReturnData.t.sol b/test/account/AccountReturnData.t.sol index 46463ed1..35def6b1 100644 --- a/test/account/AccountReturnData.t.sol +++ b/test/account/AccountReturnData.t.sol @@ -3,6 +3,7 @@ pragma solidity ^0.8.19; import {FunctionReference} from "../../src/helpers/FunctionReferenceLib.sol"; import {Call} from "../../src/interfaces/IStandardExecutor.sol"; +import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol"; import { RegularResultContract, @@ -26,6 +27,7 @@ contract AccountReturnDataTest is AccountTestBase { // Add the result creator plugin to the account bytes32 resultCreatorManifestHash = keccak256(abi.encode(resultCreatorPlugin.pluginManifest())); + vm.prank(address(entryPoint)); account1.installPlugin({ plugin: address(resultCreatorPlugin), manifestHash: resultCreatorManifestHash, @@ -34,6 +36,7 @@ contract AccountReturnDataTest is AccountTestBase { }); // Add the result consumer plugin to the account bytes32 resultConsumerManifestHash = keccak256(abi.encode(resultConsumerPlugin.pluginManifest())); + vm.prank(address(entryPoint)); account1.installPlugin({ plugin: address(resultConsumerPlugin), manifestHash: resultConsumerManifestHash, @@ -51,10 +54,15 @@ contract AccountReturnDataTest is AccountTestBase { // Tests the ability to read the results of contracts called via IStandardExecutor.execute function test_returnData_singular_execute() public { - bytes memory returnData = - account1.execute(address(regularResultContract), 0, abi.encodeCall(RegularResultContract.foo, ())); + bytes memory returnData = account1.executeWithAuthorization( + abi.encodeCall( + account1.execute, + (address(regularResultContract), 0, abi.encodeCall(RegularResultContract.foo, ())) + ), + abi.encodePacked(singleOwnerPlugin, ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF) + ); - bytes32 result = abi.decode(returnData, (bytes32)); + bytes32 result = abi.decode(abi.decode(returnData, (bytes)), (bytes32)); assertEq(result, keccak256("bar")); } @@ -73,7 +81,12 @@ contract AccountReturnDataTest is AccountTestBase { data: abi.encodeCall(RegularResultContract.bar, ()) }); - bytes[] memory returnDatas = account1.executeBatch(calls); + bytes memory retData = account1.executeWithAuthorization( + abi.encodeCall(account1.executeBatch, (calls)), + abi.encodePacked(singleOwnerPlugin, ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF) + ); + + bytes[] memory returnDatas = abi.decode(retData, (bytes[])); bytes32 result1 = abi.decode(returnDatas[0], (bytes32)); bytes32 result2 = abi.decode(returnDatas[1], (bytes32)); diff --git a/test/account/ExecuteFromPluginPermissions.t.sol b/test/account/ExecuteFromPluginPermissions.t.sol index 1e9f17e2..6a95ba18 100644 --- a/test/account/ExecuteFromPluginPermissions.t.sol +++ b/test/account/ExecuteFromPluginPermissions.t.sol @@ -35,6 +35,7 @@ contract ExecuteFromPluginPermissionsTest is AccountTestBase { // Add the result creator plugin to the account bytes32 resultCreatorManifestHash = keccak256(abi.encode(resultCreatorPlugin.pluginManifest())); + vm.prank(address(entryPoint)); account1.installPlugin({ plugin: address(resultCreatorPlugin), manifestHash: resultCreatorManifestHash, @@ -43,6 +44,7 @@ contract ExecuteFromPluginPermissionsTest is AccountTestBase { }); // Add the EFP caller plugin to the account bytes32 efpCallerManifestHash = keccak256(abi.encode(efpCallerPlugin.pluginManifest())); + vm.prank(address(entryPoint)); account1.installPlugin({ plugin: address(efpCallerPlugin), manifestHash: efpCallerManifestHash, @@ -53,6 +55,7 @@ contract ExecuteFromPluginPermissionsTest is AccountTestBase { // Add the EFP caller plugin with any external permissions to the account bytes32 efpCallerAnyExternalManifestHash = keccak256(abi.encode(efpCallerPluginAnyExternal.pluginManifest())); + vm.prank(address(entryPoint)); account1.installPlugin({ plugin: address(efpCallerPluginAnyExternal), manifestHash: efpCallerAnyExternalManifestHash, diff --git a/test/account/ManifestValidity.t.sol b/test/account/ManifestValidity.t.sol index 8200a017..08c2609d 100644 --- a/test/account/ManifestValidity.t.sol +++ b/test/account/ManifestValidity.t.sol @@ -18,6 +18,7 @@ contract ManifestValidityTest is AccountTestBase { bytes32 manifestHash = keccak256(abi.encode(plugin.pluginManifest())); + vm.prank(address(entryPoint)); vm.expectRevert(abi.encodeWithSelector(PluginManagerInternals.InvalidPluginManifest.selector)); account1.installPlugin({ plugin: address(plugin), diff --git a/test/account/UpgradeableModularAccount.t.sol b/test/account/UpgradeableModularAccount.t.sol index c5a061e0..9aab77f6 100644 --- a/test/account/UpgradeableModularAccount.t.sol +++ b/test/account/UpgradeableModularAccount.t.sol @@ -233,7 +233,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { } function test_installPlugin() public { - vm.startPrank(owner1); + vm.startPrank(address(entryPoint)); bytes32 manifestHash = keccak256(abi.encode(tokenReceiverPlugin.pluginManifest())); @@ -253,7 +253,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { } function test_installPlugin_ExecuteFromPlugin_PermittedExecSelectorNotInstalled() public { - vm.startPrank(owner1); + vm.startPrank(address(entryPoint)); PluginManifest memory m; m.permittedExecutionSelectors = new bytes4[](1); @@ -271,7 +271,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { } function test_installPlugin_invalidManifest() public { - vm.startPrank(owner1); + vm.startPrank(address(entryPoint)); vm.expectRevert(abi.encodeWithSelector(PluginManagerInternals.InvalidPluginManifest.selector)); IPluginManager(account1).installPlugin({ @@ -283,7 +283,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { } function test_installPlugin_interfaceNotSupported() public { - vm.startPrank(owner1); + vm.startPrank(address(entryPoint)); address badPlugin = address(1); vm.expectRevert( @@ -298,7 +298,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { } function test_installPlugin_alreadyInstalled() public { - vm.startPrank(owner1); + vm.startPrank(address(entryPoint)); bytes32 manifestHash = keccak256(abi.encode(tokenReceiverPlugin.pluginManifest())); IPluginManager(account1).installPlugin({ @@ -322,7 +322,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { } function test_uninstallPlugin_default() public { - vm.startPrank(owner1); + vm.startPrank(address(entryPoint)); ComprehensivePlugin plugin = new ComprehensivePlugin(); bytes32 manifestHash = keccak256(abi.encode(plugin.pluginManifest())); @@ -342,7 +342,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { } function test_uninstallPlugin_manifestParameter() public { - vm.startPrank(owner1); + vm.startPrank(address(entryPoint)); ComprehensivePlugin plugin = new ComprehensivePlugin(); bytes memory serializedManifest = abi.encode(plugin.pluginManifest()); @@ -367,7 +367,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { } function test_uninstallPlugin_invalidManifestFails() public { - vm.startPrank(owner1); + vm.startPrank(address(entryPoint)); ComprehensivePlugin plugin = new ComprehensivePlugin(); bytes memory serializedManifest = abi.encode(plugin.pluginManifest()); @@ -395,7 +395,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { } function _installPluginWithExecHooks() internal returns (MockPlugin plugin) { - vm.startPrank(owner2); + vm.startPrank(address(entryPoint)); plugin = new MockPlugin(manifest); bytes32 manifestHash = keccak256(abi.encode(plugin.pluginManifest())); @@ -411,7 +411,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { } function test_upgradeToAndCall() public { - vm.startPrank(owner1); + vm.startPrank(address(entryPoint)); UpgradeableModularAccount account3 = new UpgradeableModularAccount(entryPoint); bytes32 slot = account3.proxiableUUID(); @@ -427,7 +427,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { function test_transferOwnership() public { assertEq(singleOwnerPlugin.ownerOf(address(account1)), owner1); - vm.prank(owner1); + vm.prank(address(entryPoint)); account1.execute( address(singleOwnerPlugin), 0, abi.encodeCall(SingleOwnerPlugin.transferOwnership, (owner2)) ); diff --git a/test/account/ValidationIntersection.t.sol b/test/account/ValidationIntersection.t.sol index 7d451730..9315d7e6 100644 --- a/test/account/ValidationIntersection.t.sol +++ b/test/account/ValidationIntersection.t.sol @@ -26,7 +26,7 @@ contract ValidationIntersectionTest is AccountTestBase { oneHookPlugin = new MockUserOpValidation1HookPlugin(); twoHookPlugin = new MockUserOpValidation2HookPlugin(); - vm.startPrank(address(owner1)); + vm.startPrank(address(entryPoint)); account1.installPlugin({ plugin: address(noHookPlugin), manifestHash: keccak256(abi.encode(noHookPlugin.pluginManifest())), diff --git a/test/mocks/plugins/ComprehensivePlugin.sol b/test/mocks/plugins/ComprehensivePlugin.sol index 9aac5545..ec40368f 100644 --- a/test/mocks/plugins/ComprehensivePlugin.sol +++ b/test/mocks/plugins/ComprehensivePlugin.sol @@ -84,7 +84,11 @@ contract ComprehensivePlugin is IValidation, IValidationHook, IExecutionHook, Ba revert NotImplemented(); } - function validateRuntime(uint8 functionId, address, uint256, bytes calldata) external pure override { + function validateRuntime(uint8 functionId, address, uint256, bytes calldata, bytes calldata) + external + pure + override + { if (functionId == uint8(FunctionId.VALIDATION)) { return; } diff --git a/test/mocks/plugins/ValidationPluginMocks.sol b/test/mocks/plugins/ValidationPluginMocks.sol index 554f589f..443ee0bb 100644 --- a/test/mocks/plugins/ValidationPluginMocks.sol +++ b/test/mocks/plugins/ValidationPluginMocks.sol @@ -71,7 +71,7 @@ abstract contract MockBaseUserOpValidationPlugin is IValidation, IValidationHook revert NotImplemented(); } - function validateRuntime(uint8, address, uint256, bytes calldata) external pure override { + function validateRuntime(uint8, address, uint256, bytes calldata, bytes calldata) external pure override { revert NotImplemented(); } } diff --git a/test/plugin/SingleOwnerPlugin.t.sol b/test/plugin/SingleOwnerPlugin.t.sol index 71afbdbf..a6a1900d 100644 --- a/test/plugin/SingleOwnerPlugin.t.sol +++ b/test/plugin/SingleOwnerPlugin.t.sol @@ -114,11 +114,11 @@ contract SingleOwnerPluginTest is OptimizedTest { assertEq(address(0), plugin.owner()); plugin.transferOwnership(owner1); assertEq(owner1, plugin.owner()); - plugin.validateRuntime(uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF), owner1, 0, ""); + plugin.validateRuntime(uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF), owner1, 0, "", ""); vm.startPrank(b); vm.expectRevert(ISingleOwnerPlugin.NotAuthorized.selector); - plugin.validateRuntime(uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF), owner1, 0, ""); + plugin.validateRuntime(uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF), owner1, 0, "", ""); } function testFuzz_validateUserOpSig(string memory salt, PackedUserOperation memory userOp) public { diff --git a/test/plugin/TokenReceiverPlugin.t.sol b/test/plugin/TokenReceiverPlugin.t.sol index 7a7433af..0e111020 100644 --- a/test/plugin/TokenReceiverPlugin.t.sol +++ b/test/plugin/TokenReceiverPlugin.t.sol @@ -1,7 +1,7 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.19; -import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol"; +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; import {IERC721Receiver} from "@openzeppelin/contracts/token/ERC721/IERC721Receiver.sol"; import {IERC1155Receiver} from "@openzeppelin/contracts/token/ERC1155/IERC1155Receiver.sol"; @@ -15,6 +15,7 @@ import {MockERC1155} from "../mocks/MockERC1155.sol"; import {OptimizedTest} from "../utils/OptimizedTest.sol"; contract TokenReceiverPluginTest is OptimizedTest, IERC1155Receiver { + EntryPoint public entryPoint; UpgradeableModularAccount public acct; TokenReceiverPlugin public plugin; @@ -32,7 +33,8 @@ contract TokenReceiverPluginTest is OptimizedTest, IERC1155Receiver { uint256 internal constant _BATCH_TOKEN_IDS = 5; function setUp() public { - MSCAFactoryFixture factory = new MSCAFactoryFixture(IEntryPoint(address(0)), _deploySingleOwnerPlugin()); + entryPoint = new EntryPoint(); + MSCAFactoryFixture factory = new MSCAFactoryFixture(entryPoint, _deploySingleOwnerPlugin()); acct = factory.createAccount(address(this), 0); plugin = _deployTokenReceiverPlugin(); @@ -53,6 +55,7 @@ contract TokenReceiverPluginTest is OptimizedTest, IERC1155Receiver { function _initPlugin() internal { bytes32 manifestHash = keccak256(abi.encode(plugin.pluginManifest())); + vm.prank(address(entryPoint)); acct.installPlugin(address(plugin), manifestHash, "", new FunctionReference[](0)); } diff --git a/test/utils/AccountTestBase.sol b/test/utils/AccountTestBase.sol index d8c890fe..de9d56c2 100644 --- a/test/utils/AccountTestBase.sol +++ b/test/utils/AccountTestBase.sol @@ -4,6 +4,7 @@ pragma solidity ^0.8.19; import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol"; import {SingleOwnerPlugin} from "../../src/plugins/owner/SingleOwnerPlugin.sol"; import {OptimizedTest} from "./OptimizedTest.sol"; @@ -37,8 +38,16 @@ abstract contract AccountTestBase is OptimizedTest { function _transferOwnershipToTest() internal { // Transfer ownership to test contract for easier invocation. vm.prank(owner1); - account1.execute( - address(singleOwnerPlugin), 0, abi.encodeCall(SingleOwnerPlugin.transferOwnership, (address(this))) + account1.executeWithAuthorization( + abi.encodeCall( + account1.execute, + ( + address(singleOwnerPlugin), + 0, + abi.encodeCall(SingleOwnerPlugin.transferOwnership, (address(this))) + ) + ), + abi.encodePacked(address(singleOwnerPlugin), ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF) ); } From e39a3da63d39227284b18d1b490b3831e0576568 Mon Sep 17 00:00:00 2001 From: adam Date: Tue, 21 May 2024 10:21:19 -0700 Subject: [PATCH 2/6] add multi-validation --- src/account/AccountLoupe.sol | 17 +++---- src/account/AccountStorage.sol | 2 +- src/account/PluginManagerInternals.sol | 10 ++-- src/account/UpgradeableModularAccount.sol | 17 +++++-- src/interfaces/IAccountLoupe.sol | 17 +++---- test/account/AccountLoupe.t.sol | 51 ++++++++++---------- test/account/UpgradeableModularAccount.t.sol | 20 +++++--- test/account/ValidationIntersection.t.sol | 31 +++++++++++- 8 files changed, 104 insertions(+), 61 deletions(-) diff --git a/src/account/AccountLoupe.sol b/src/account/AccountLoupe.sol index ca4da1dd..44870e90 100644 --- a/src/account/AccountLoupe.sol +++ b/src/account/AccountLoupe.sol @@ -20,11 +20,7 @@ abstract contract AccountLoupe is IAccountLoupe { using EnumerableSet for EnumerableSet.AddressSet; /// @inheritdoc IAccountLoupe - function getExecutionFunctionConfig(bytes4 selector) - external - view - returns (ExecutionFunctionConfig memory config) - { + function getExecutionFunctionHandler(bytes4 selector) external view returns (address plugin) { AccountStorage storage _storage = getAccountStorage(); if ( @@ -33,12 +29,15 @@ abstract contract AccountLoupe is IAccountLoupe { || selector == IPluginManager.installPlugin.selector || selector == IPluginManager.uninstallPlugin.selector ) { - config.plugin = address(this); - } else { - config.plugin = _storage.selectorData[selector].plugin; + return address(this); } - config.validationFunction = _storage.selectorData[selector].validation; + return _storage.selectorData[selector].plugin; + } + + /// @inheritdoc IAccountLoupe + function getValidationFunctions(bytes4 selector) external view returns (FunctionReference[] memory) { + return toFunctionReferenceArray(getAccountStorage().selectorData[selector].validations); } /// @inheritdoc IAccountLoupe diff --git a/src/account/AccountStorage.sol b/src/account/AccountStorage.sol index 205af107..4319f470 100644 --- a/src/account/AccountStorage.sol +++ b/src/account/AccountStorage.sol @@ -45,7 +45,7 @@ struct SelectorData { // but it packs alongside `plugin` while still leaving some other space in the slot for future packing. uint48 denyExecutionCount; // User operation validation and runtime validation share a function reference. - FunctionReference validation; + EnumerableSet.Bytes32Set validations; // The pre validation hooks for this function selector. EnumerableSet.Bytes32Set preValidationHooks; // The execution hooks for this function selector. diff --git a/src/account/PluginManagerInternals.sol b/src/account/PluginManagerInternals.sol index 7ded3176..7f8337e4 100644 --- a/src/account/PluginManagerInternals.sol +++ b/src/account/PluginManagerInternals.sol @@ -86,11 +86,11 @@ abstract contract PluginManagerInternals is IPluginManager { { SelectorData storage _selectorData = getAccountStorage().selectorData[selector]; - if (_selectorData.validation.notEmpty()) { + // Fail on duplicate definitions - otherwise dependencies could shadow non-depdency + // validation functions, leading to partial uninstalls. + if (!_selectorData.validations.add(toSetValue(validationFunction))) { revert ValidationFunctionAlreadySet(selector, validationFunction); } - - _selectorData.validation = validationFunction; } function _removeValidationFunction(bytes4 selector, FunctionReference validationFunction) @@ -99,7 +99,9 @@ abstract contract PluginManagerInternals is IPluginManager { { SelectorData storage _selectorData = getAccountStorage().selectorData[selector]; - _selectorData.validation = FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE; + // May ignore return value, as the manifest hash is validated to ensure that the validation function + // exists. + _selectorData.validations.remove(toSetValue(validationFunction)); } function _addExecHooks( diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index 3c8b3d80..84a73cd3 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -278,7 +278,7 @@ contract UpgradeableModularAccount is if (_storage.selectorData[execSelector].denyExecutionCount > 0) { revert AlwaysDenyRule(); } - if (_storage.selectorData[execSelector].validation.notEq(runtimeValidationFunction)) { + if (!_storage.selectorData[execSelector].validations.contains(toSetValue(runtimeValidationFunction))) { revert RuntimeValidationFunctionMissing(execSelector); } @@ -395,18 +395,27 @@ contract UpgradeableModularAccount is revert AlwaysDenyRule(); } - FunctionReference userOpValidationFunction = getAccountStorage().selectorData[selector].validation; + FunctionReference userOpValidationFunction = FunctionReference.wrap(bytes21(userOp.signature[:21])); - validationData = _doUserOpValidation(selector, userOpValidationFunction, userOp, userOpHash); + if (!getAccountStorage().selectorData[selector].validations.contains(toSetValue(userOpValidationFunction))) + { + revert UserOpValidationFunctionMissing(selector); + } + + validationData = + _doUserOpValidation(selector, userOpValidationFunction, userOp, userOp.signature[21:], userOpHash); } // To support gas estimation, we don't fail early when the failure is caused by a signature failure function _doUserOpValidation( bytes4 selector, FunctionReference userOpValidationFunction, - PackedUserOperation calldata userOp, + PackedUserOperation memory userOp, + bytes calldata signature, bytes32 userOpHash ) internal returns (uint256 validationData) { + userOp.signature = signature; + if (userOpValidationFunction.isEmpty()) { // If the validation function is empty, then the call cannot proceed. revert UserOpValidationFunctionMissing(selector); diff --git a/src/interfaces/IAccountLoupe.sol b/src/interfaces/IAccountLoupe.sol index a1b3c15f..91a648f1 100644 --- a/src/interfaces/IAccountLoupe.sol +++ b/src/interfaces/IAccountLoupe.sol @@ -12,17 +12,16 @@ struct ExecutionHook { } interface IAccountLoupe { - /// @notice Config for an execution function, given a selector. - struct ExecutionFunctionConfig { - address plugin; - FunctionReference validationFunction; - } - - /// @notice Get the validation functions and plugin address for a selector. + /// @notice Get the plugin address for a selector. /// @dev If the selector is a native function, the plugin address will be the address of the account. /// @param selector The selector to get the configuration for. - /// @return The configuration for this selector. - function getExecutionFunctionConfig(bytes4 selector) external view returns (ExecutionFunctionConfig memory); + /// @return plugin The plugin address for this selector. + function getExecutionFunctionHandler(bytes4 selector) external view returns (address plugin); + + /// @notice Get the validation functions for a selector. + /// @param selector The selector to get the validation functions for. + /// @return The validation functions for this selector. + function getValidationFunctions(bytes4 selector) external view returns (FunctionReference[] memory); /// @notice Get the pre and post execution hooks for a selector. /// @param selector The selector to get the hooks for. diff --git a/test/account/AccountLoupe.t.sol b/test/account/AccountLoupe.t.sol index 253691af..846037df 100644 --- a/test/account/AccountLoupe.t.sol +++ b/test/account/AccountLoupe.t.sol @@ -4,7 +4,7 @@ pragma solidity ^0.8.19; import {UUPSUpgradeable} from "@openzeppelin/contracts/proxy/utils/UUPSUpgradeable.sol"; import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; -import {IAccountLoupe, ExecutionHook} from "../../src/interfaces/IAccountLoupe.sol"; +import {ExecutionHook} from "../../src/interfaces/IAccountLoupe.sol"; import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; import {IStandardExecutor} from "../../src/interfaces/IStandardExecutor.sol"; import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol"; @@ -42,60 +42,59 @@ contract AccountLoupeTest is AccountTestBase { assertEq(plugins[1], address(comprehensivePlugin)); } - function test_pluginLoupe_getExecutionFunctionConfig_native() public { + function test_pluginLoupe_getExecutionFunctionHandler_native() public { bytes4[] memory selectorsToCheck = new bytes4[](5); - FunctionReference[] memory expectedValidations = new FunctionReference[](5); selectorsToCheck[0] = IStandardExecutor.execute.selector; - expectedValidations[0] = ownerValidation; selectorsToCheck[1] = IStandardExecutor.executeBatch.selector; - expectedValidations[1] = ownerValidation; selectorsToCheck[2] = UUPSUpgradeable.upgradeToAndCall.selector; - expectedValidations[2] = ownerValidation; selectorsToCheck[3] = IPluginManager.installPlugin.selector; - expectedValidations[3] = ownerValidation; selectorsToCheck[4] = IPluginManager.uninstallPlugin.selector; - expectedValidations[4] = ownerValidation; for (uint256 i = 0; i < selectorsToCheck.length; i++) { - IAccountLoupe.ExecutionFunctionConfig memory config = - account1.getExecutionFunctionConfig(selectorsToCheck[i]); + address plugin = account1.getExecutionFunctionHandler(selectorsToCheck[i]); - assertEq(config.plugin, address(account1)); - assertEq( - FunctionReference.unwrap(config.validationFunction), - FunctionReference.unwrap(expectedValidations[i]) - ); + assertEq(plugin, address(account1)); } } function test_pluginLoupe_getExecutionFunctionConfig_plugin() public { bytes4[] memory selectorsToCheck = new bytes4[](1); address[] memory expectedPluginAddress = new address[](1); - FunctionReference[] memory expectedValidations = new FunctionReference[](1); selectorsToCheck[0] = comprehensivePlugin.foo.selector; expectedPluginAddress[0] = address(comprehensivePlugin); - expectedValidations[0] = FunctionReferenceLib.pack( - address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.VALIDATION) - ); for (uint256 i = 0; i < selectorsToCheck.length; i++) { - IAccountLoupe.ExecutionFunctionConfig memory config = - account1.getExecutionFunctionConfig(selectorsToCheck[i]); + address plugin = account1.getExecutionFunctionHandler(selectorsToCheck[i]); - assertEq(config.plugin, expectedPluginAddress[i]); - assertEq( - FunctionReference.unwrap(config.validationFunction), - FunctionReference.unwrap(expectedValidations[i]) - ); + assertEq(plugin, expectedPluginAddress[i]); } } + function test_pluginLoupe_getValidationFunctions() public { + FunctionReference[] memory validations = account1.getValidationFunctions(comprehensivePlugin.foo.selector); + + assertEq(validations.length, 1); + assertEq( + FunctionReference.unwrap(validations[0]), + FunctionReference.unwrap( + FunctionReferenceLib.pack( + address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.VALIDATION) + ) + ) + ); + + validations = account1.getValidationFunctions(account1.execute.selector); + + assertEq(validations.length, 1); + assertEq(FunctionReference.unwrap(validations[0]), FunctionReference.unwrap(ownerValidation)); + } + function test_pluginLoupe_getExecutionHooks() public { ExecutionHook[] memory hooks = account1.getExecutionHooks(comprehensivePlugin.foo.selector); ExecutionHook[3] memory expectedHooks = [ diff --git a/test/account/UpgradeableModularAccount.t.sol b/test/account/UpgradeableModularAccount.t.sol index 9aab77f6..01337691 100644 --- a/test/account/UpgradeableModularAccount.t.sol +++ b/test/account/UpgradeableModularAccount.t.sol @@ -10,7 +10,7 @@ import {IERC1271} from "@openzeppelin/contracts/interfaces/IERC1271.sol"; import {PluginManagerInternals} from "../../src/account/PluginManagerInternals.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; -import {FunctionReference} from "../../src/helpers/FunctionReferenceLib.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; import {IPlugin, PluginManifest} from "../../src/interfaces/IPlugin.sol"; import {IAccountLoupe} from "../../src/interfaces/IAccountLoupe.sol"; import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; @@ -39,6 +39,8 @@ contract UpgradeableModularAccountTest is AccountTestBase { Counter public counter; PluginManifest public manifest; + FunctionReference public ownerValidation; + uint256 public constant CALL_GAS_LIMIT = 50000; uint256 public constant VERIFICATION_GAS_LIMIT = 1200000; @@ -59,6 +61,10 @@ contract UpgradeableModularAccountTest is AccountTestBase { vm.deal(ethRecipient, 1 wei); counter = new Counter(); counter.increment(); // amoritze away gas cost of zero->nonzero transition + + ownerValidation = FunctionReferenceLib.pack( + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF) + ); } function test_deployAccount() public { @@ -81,7 +87,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(r, s, v); + userOp.signature = abi.encodePacked(ownerValidation, r, s, v); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -110,7 +116,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(r, s, v); + userOp.signature = abi.encodePacked(ownerValidation, r, s, v); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -136,7 +142,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(r, s, v); + userOp.signature = abi.encodePacked(ownerValidation, r, s, v); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -162,7 +168,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(r, s, v); + userOp.signature = abi.encodePacked(ownerValidation, r, s, v); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -190,7 +196,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(r, s, v); + userOp.signature = abi.encodePacked(ownerValidation, r, s, v); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -221,7 +227,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(r, s, v); + userOp.signature = abi.encodePacked(ownerValidation, r, s, v); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; diff --git a/test/account/ValidationIntersection.t.sol b/test/account/ValidationIntersection.t.sol index 9315d7e6..9c54e7b0 100644 --- a/test/account/ValidationIntersection.t.sol +++ b/test/account/ValidationIntersection.t.sol @@ -4,7 +4,7 @@ pragma solidity ^0.8.19; import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; -import {FunctionReference} from "../../src/helpers/FunctionReferenceLib.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; import { MockBaseUserOpValidationPlugin, @@ -21,11 +21,30 @@ contract ValidationIntersectionTest is AccountTestBase { MockUserOpValidation1HookPlugin public oneHookPlugin; MockUserOpValidation2HookPlugin public twoHookPlugin; + FunctionReference public noHookValidation; + FunctionReference public oneHookValidation; + FunctionReference public twoHookValidation; + function setUp() public { noHookPlugin = new MockUserOpValidationPlugin(); oneHookPlugin = new MockUserOpValidation1HookPlugin(); twoHookPlugin = new MockUserOpValidation2HookPlugin(); + noHookValidation = FunctionReferenceLib.pack({ + addr: address(noHookPlugin), + functionId: uint8(MockBaseUserOpValidationPlugin.FunctionId.USER_OP_VALIDATION) + }); + + oneHookValidation = FunctionReferenceLib.pack({ + addr: address(oneHookPlugin), + functionId: uint8(MockBaseUserOpValidationPlugin.FunctionId.USER_OP_VALIDATION) + }); + + twoHookValidation = FunctionReferenceLib.pack({ + addr: address(twoHookPlugin), + functionId: uint8(MockBaseUserOpValidationPlugin.FunctionId.USER_OP_VALIDATION) + }); + vm.startPrank(address(entryPoint)); account1.installPlugin({ plugin: address(noHookPlugin), @@ -53,6 +72,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(noHookPlugin.foo.selector); + userOp.signature = abi.encodePacked(noHookValidation); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -69,6 +89,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); + userOp.signature = abi.encodePacked(oneHookValidation); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -86,6 +107,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); + userOp.signature = abi.encodePacked(oneHookValidation); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -108,6 +130,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); + userOp.signature = abi.encodePacked(oneHookValidation); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -129,6 +152,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); + userOp.signature = abi.encodePacked(oneHookValidation); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -148,6 +172,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); + userOp.signature = abi.encodePacked(oneHookValidation); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -172,6 +197,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); + userOp.signature = abi.encodePacked(oneHookValidation); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -195,6 +221,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); + userOp.signature = abi.encodePacked(oneHookValidation); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -218,6 +245,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(twoHookPlugin.baz.selector); + userOp.signature = abi.encodePacked(twoHookValidation); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -236,6 +264,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(twoHookPlugin.baz.selector); + userOp.signature = abi.encodePacked(twoHookValidation); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); From 8b282cf4b89d6e82f7d115794bb2ca686f9aa329 Mon Sep 17 00:00:00 2001 From: adam Date: Tue, 21 May 2024 10:33:10 -0700 Subject: [PATCH 3/6] Add multi validation test --- test/account/MultiValidation.t.sol | 123 +++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 test/account/MultiValidation.t.sol diff --git a/test/account/MultiValidation.t.sol b/test/account/MultiValidation.t.sol new file mode 100644 index 00000000..4102932d --- /dev/null +++ b/test/account/MultiValidation.t.sol @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; +import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; + +import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol"; + +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {FunctionReference} from "../../src/interfaces/IPluginManager.sol"; +import {IStandardExecutor} from "../../src/interfaces/IStandardExecutor.sol"; +import {FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; +import {SingleOwnerPlugin} from "../../src/plugins/owner/SingleOwnerPlugin.sol"; +import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol"; + +import {AccountTestBase} from "../utils/AccountTestBase.sol"; + +contract MultiValidationTest is AccountTestBase { + using ECDSA for bytes32; + using MessageHashUtils for bytes32; + + SingleOwnerPlugin public validator2; + + address public owner2; + uint256 public owner2Key; + + uint256 public constant CALL_GAS_LIMIT = 50000; + uint256 public constant VERIFICATION_GAS_LIMIT = 1200000; + + function setUp() public { + validator2 = new SingleOwnerPlugin(); + + (owner2, owner2Key) = makeAddrAndKey("owner2"); + } + + function test_overlappingValidationInstall() public { + bytes32 manifestHash = keccak256(abi.encode(validator2.pluginManifest())); + vm.prank(address(entryPoint)); + account1.installPlugin(address(validator2), manifestHash, abi.encode(owner2), new FunctionReference[](0)); + + FunctionReference[] memory validations = new FunctionReference[](2); + validations[0] = FunctionReferenceLib.pack( + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF) + ); + validations[1] = FunctionReferenceLib.pack( + address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF) + ); + FunctionReference[] memory validations2 = + account1.getValidationFunctions(IStandardExecutor.execute.selector); + assertEq(validations2.length, 2); + assertEq(FunctionReference.unwrap(validations2[0]), FunctionReference.unwrap(validations[0])); + assertEq(FunctionReference.unwrap(validations2[1]), FunctionReference.unwrap(validations[1])); + } + + function test_runtimeValidation_specify() public { + test_overlappingValidationInstall(); + + // Assert that the runtime validation can be specified. + + vm.prank(owner1); + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.RuntimeValidationFunctionReverted.selector, + address(validator2), + 0, + abi.encodeWithSignature("NotAuthorized()") + ) + ); + account1.executeWithAuthorization( + abi.encodeCall(IStandardExecutor.execute, (address(0), 0, "")), + abi.encodePacked( + address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF) + ) + ); + + vm.prank(owner2); + account1.executeWithAuthorization( + abi.encodeCall(IStandardExecutor.execute, (address(0), 0, "")), + abi.encodePacked( + address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF) + ) + ); + } + + function test_userOpValidation_specify() public { + test_overlappingValidationInstall(); + + // Assert that the userOp validation can be specified. + + PackedUserOperation memory userOp = PackedUserOperation({ + sender: address(account1), + nonce: 0, + initCode: "", + callData: abi.encodeCall(UpgradeableModularAccount.execute, (address(0), 0, "")), + accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT), + preVerificationGas: 0, + gasFees: _encodeGas(1, 1), + paymasterAndData: "", + signature: "" + }); + + // Generate signature + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF), r, s, v); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + + // Sign with owner 1, expect fail + + userOp.nonce = 1; + (v, r, s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF), r, s, v); + + userOps[0] = userOp; + vm.expectRevert(abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA24 signature error")); + entryPoint.handleOps(userOps, beneficiary); + } +} \ No newline at end of file From c8d2b456d903ad9e820c9679cf188a6e505747bc Mon Sep 17 00:00:00 2001 From: adam Date: Tue, 21 May 2024 10:36:34 -0700 Subject: [PATCH 4/6] Update FunctionId naming --- src/plugins/owner/ISingleOwnerPlugin.sol | 2 +- src/plugins/owner/SingleOwnerPlugin.sol | 6 ++--- test/account/AccountLoupe.t.sol | 2 +- test/account/AccountReturnData.t.sol | 4 ++-- test/account/MultiValidation.t.sol | 23 +++++++++----------- test/account/UpgradeableModularAccount.t.sol | 2 +- test/plugin/SingleOwnerPlugin.t.sol | 13 +++++------ test/utils/AccountTestBase.sol | 2 +- 8 files changed, 24 insertions(+), 30 deletions(-) diff --git a/src/plugins/owner/ISingleOwnerPlugin.sol b/src/plugins/owner/ISingleOwnerPlugin.sol index 6d80eb50..57bcac80 100644 --- a/src/plugins/owner/ISingleOwnerPlugin.sol +++ b/src/plugins/owner/ISingleOwnerPlugin.sol @@ -5,7 +5,7 @@ import {IValidation} from "../../interfaces/IValidation.sol"; interface ISingleOwnerPlugin is IValidation { enum FunctionId { - VALIDATION_OWNER_OR_SELF, + VALIDATION_OWNER, SIG_VALIDATION } diff --git a/src/plugins/owner/SingleOwnerPlugin.sol b/src/plugins/owner/SingleOwnerPlugin.sol index 1018d908..dbdd41b2 100644 --- a/src/plugins/owner/SingleOwnerPlugin.sol +++ b/src/plugins/owner/SingleOwnerPlugin.sol @@ -84,7 +84,7 @@ contract SingleOwnerPlugin is ISingleOwnerPlugin, BasePlugin { view override { - if (functionId == uint8(FunctionId.VALIDATION_OWNER_OR_SELF)) { + if (functionId == uint8(FunctionId.VALIDATION_OWNER)) { // Validate that the sender is the owner of the account or self. if (sender != _owners[msg.sender] && sender != msg.sender) { revert NotAuthorized(); @@ -101,7 +101,7 @@ contract SingleOwnerPlugin is ISingleOwnerPlugin, BasePlugin { override returns (uint256) { - if (functionId == uint8(FunctionId.VALIDATION_OWNER_OR_SELF)) { + if (functionId == uint8(FunctionId.VALIDATION_OWNER)) { // Validate the user op signature against the owner. (address signer,,) = (userOpHash.toEthSignedMessageHash()).tryRecover(userOp.signature); if (signer == address(0) || signer != _owners[msg.sender]) { @@ -158,7 +158,7 @@ contract SingleOwnerPlugin is ISingleOwnerPlugin, BasePlugin { ManifestFunction memory ownerValidationFunction = ManifestFunction({ functionType: ManifestAssociatedFunctionType.SELF, - functionId: uint8(FunctionId.VALIDATION_OWNER_OR_SELF), + functionId: uint8(FunctionId.VALIDATION_OWNER), dependencyIndex: 0 // Unused. }); manifest.validationFunctions = new ManifestAssociatedFunction[](5); diff --git a/test/account/AccountLoupe.t.sol b/test/account/AccountLoupe.t.sol index 846037df..43c7187f 100644 --- a/test/account/AccountLoupe.t.sol +++ b/test/account/AccountLoupe.t.sol @@ -29,7 +29,7 @@ contract AccountLoupeTest is AccountTestBase { account1.installPlugin(address(comprehensivePlugin), manifestHash, "", new FunctionReference[](0)); ownerValidation = FunctionReferenceLib.pack( - address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF) + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) ); } diff --git a/test/account/AccountReturnData.t.sol b/test/account/AccountReturnData.t.sol index 35def6b1..085fa4bf 100644 --- a/test/account/AccountReturnData.t.sol +++ b/test/account/AccountReturnData.t.sol @@ -59,7 +59,7 @@ contract AccountReturnDataTest is AccountTestBase { account1.execute, (address(regularResultContract), 0, abi.encodeCall(RegularResultContract.foo, ())) ), - abi.encodePacked(singleOwnerPlugin, ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF) + abi.encodePacked(singleOwnerPlugin, ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) ); bytes32 result = abi.decode(abi.decode(returnData, (bytes)), (bytes32)); @@ -83,7 +83,7 @@ contract AccountReturnDataTest is AccountTestBase { bytes memory retData = account1.executeWithAuthorization( abi.encodeCall(account1.executeBatch, (calls)), - abi.encodePacked(singleOwnerPlugin, ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF) + abi.encodePacked(singleOwnerPlugin, ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) ); bytes[] memory returnDatas = abi.decode(retData, (bytes[])); diff --git a/test/account/MultiValidation.t.sol b/test/account/MultiValidation.t.sol index 4102932d..8d552d4e 100644 --- a/test/account/MultiValidation.t.sol +++ b/test/account/MultiValidation.t.sol @@ -41,11 +41,10 @@ contract MultiValidationTest is AccountTestBase { FunctionReference[] memory validations = new FunctionReference[](2); validations[0] = FunctionReferenceLib.pack( - address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF) - ); - validations[1] = FunctionReferenceLib.pack( - address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF) + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) ); + validations[1] = + FunctionReferenceLib.pack(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER)); FunctionReference[] memory validations2 = account1.getValidationFunctions(IStandardExecutor.execute.selector); assertEq(validations2.length, 2); @@ -69,17 +68,13 @@ contract MultiValidationTest is AccountTestBase { ); account1.executeWithAuthorization( abi.encodeCall(IStandardExecutor.execute, (address(0), 0, "")), - abi.encodePacked( - address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF) - ) + abi.encodePacked(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER)) ); vm.prank(owner2); account1.executeWithAuthorization( abi.encodeCall(IStandardExecutor.execute, (address(0), 0, "")), - abi.encodePacked( - address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF) - ) + abi.encodePacked(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER)) ); } @@ -103,7 +98,8 @@ contract MultiValidationTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF), r, s, v); + userOp.signature = + abi.encodePacked(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER), r, s, v); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -114,10 +110,11 @@ contract MultiValidationTest is AccountTestBase { userOp.nonce = 1; (v, r, s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF), r, s, v); + userOp.signature = + abi.encodePacked(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER), r, s, v); userOps[0] = userOp; vm.expectRevert(abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA24 signature error")); entryPoint.handleOps(userOps, beneficiary); } -} \ No newline at end of file +} diff --git a/test/account/UpgradeableModularAccount.t.sol b/test/account/UpgradeableModularAccount.t.sol index 01337691..1cbcb78d 100644 --- a/test/account/UpgradeableModularAccount.t.sol +++ b/test/account/UpgradeableModularAccount.t.sol @@ -63,7 +63,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { counter.increment(); // amoritze away gas cost of zero->nonzero transition ownerValidation = FunctionReferenceLib.pack( - address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF) + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) ); } diff --git a/test/plugin/SingleOwnerPlugin.t.sol b/test/plugin/SingleOwnerPlugin.t.sol index a6a1900d..41997591 100644 --- a/test/plugin/SingleOwnerPlugin.t.sol +++ b/test/plugin/SingleOwnerPlugin.t.sol @@ -114,11 +114,11 @@ contract SingleOwnerPluginTest is OptimizedTest { assertEq(address(0), plugin.owner()); plugin.transferOwnership(owner1); assertEq(owner1, plugin.owner()); - plugin.validateRuntime(uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF), owner1, 0, "", ""); + plugin.validateRuntime(uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER), owner1, 0, "", ""); vm.startPrank(b); vm.expectRevert(ISingleOwnerPlugin.NotAuthorized.selector); - plugin.validateRuntime(uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF), owner1, 0, "", ""); + plugin.validateRuntime(uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER), owner1, 0, "", ""); } function testFuzz_validateUserOpSig(string memory salt, PackedUserOperation memory userOp) public { @@ -133,9 +133,8 @@ contract SingleOwnerPluginTest is OptimizedTest { userOp.signature = abi.encodePacked(r, s, v); // sig check should fail - uint256 success = plugin.validateUserOp( - uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF), userOp, userOpHash - ); + uint256 success = + plugin.validateUserOp(uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER), userOp, userOpHash); assertEq(success, 1); // transfer ownership to signer @@ -143,9 +142,7 @@ contract SingleOwnerPluginTest is OptimizedTest { assertEq(signer, plugin.owner()); // sig check should pass - success = plugin.validateUserOp( - uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF), userOp, userOpHash - ); + success = plugin.validateUserOp(uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER), userOp, userOpHash); assertEq(success, 0); } diff --git a/test/utils/AccountTestBase.sol b/test/utils/AccountTestBase.sol index de9d56c2..6eca6626 100644 --- a/test/utils/AccountTestBase.sol +++ b/test/utils/AccountTestBase.sol @@ -47,7 +47,7 @@ abstract contract AccountTestBase is OptimizedTest { abi.encodeCall(SingleOwnerPlugin.transferOwnership, (address(this))) ) ), - abi.encodePacked(address(singleOwnerPlugin), ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER_OR_SELF) + abi.encodePacked(address(singleOwnerPlugin), ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) ); } From 6d48a68021bff6f3a50a508873595af9fb1939a6 Mon Sep 17 00:00:00 2001 From: adam Date: Thu, 30 May 2024 15:53:40 -0400 Subject: [PATCH 5/6] rename loupe function --- src/account/AccountLoupe.sol | 14 ++++++++++---- src/interfaces/IAccountLoupe.sol | 2 +- test/account/AccountLoupe.t.sol | 4 ++-- test/account/MultiValidation.t.sol | 3 +-- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/account/AccountLoupe.sol b/src/account/AccountLoupe.sol index 44870e90..0593298a 100644 --- a/src/account/AccountLoupe.sol +++ b/src/account/AccountLoupe.sol @@ -20,7 +20,7 @@ abstract contract AccountLoupe is IAccountLoupe { using EnumerableSet for EnumerableSet.AddressSet; /// @inheritdoc IAccountLoupe - function getExecutionFunctionHandler(bytes4 selector) external view returns (address plugin) { + function getExecutionFunctionHandler(bytes4 selector) external view override returns (address plugin) { AccountStorage storage _storage = getAccountStorage(); if ( @@ -36,12 +36,17 @@ abstract contract AccountLoupe is IAccountLoupe { } /// @inheritdoc IAccountLoupe - function getValidationFunctions(bytes4 selector) external view returns (FunctionReference[] memory) { + function getValidations(bytes4 selector) external view override returns (FunctionReference[] memory) { return toFunctionReferenceArray(getAccountStorage().selectorData[selector].validations); } /// @inheritdoc IAccountLoupe - function getExecutionHooks(bytes4 selector) external view returns (ExecutionHook[] memory execHooks) { + function getExecutionHooks(bytes4 selector) + external + view + override + returns (ExecutionHook[] memory execHooks) + { SelectorData storage selectorData = getAccountStorage().selectorData[selector]; uint256 executionHooksLength = selectorData.executionHooks.length(); @@ -58,6 +63,7 @@ abstract contract AccountLoupe is IAccountLoupe { function getPreValidationHooks(bytes4 selector) external view + override returns (FunctionReference[] memory preValidationHooks) { preValidationHooks = @@ -65,7 +71,7 @@ abstract contract AccountLoupe is IAccountLoupe { } /// @inheritdoc IAccountLoupe - function getInstalledPlugins() external view returns (address[] memory pluginAddresses) { + function getInstalledPlugins() external view override returns (address[] memory pluginAddresses) { pluginAddresses = getAccountStorage().plugins.values(); } } diff --git a/src/interfaces/IAccountLoupe.sol b/src/interfaces/IAccountLoupe.sol index 91a648f1..b474149c 100644 --- a/src/interfaces/IAccountLoupe.sol +++ b/src/interfaces/IAccountLoupe.sol @@ -21,7 +21,7 @@ interface IAccountLoupe { /// @notice Get the validation functions for a selector. /// @param selector The selector to get the validation functions for. /// @return The validation functions for this selector. - function getValidationFunctions(bytes4 selector) external view returns (FunctionReference[] memory); + function getValidations(bytes4 selector) external view returns (FunctionReference[] memory); /// @notice Get the pre and post execution hooks for a selector. /// @param selector The selector to get the hooks for. diff --git a/test/account/AccountLoupe.t.sol b/test/account/AccountLoupe.t.sol index 43c7187f..a6ed44eb 100644 --- a/test/account/AccountLoupe.t.sol +++ b/test/account/AccountLoupe.t.sol @@ -77,7 +77,7 @@ contract AccountLoupeTest is AccountTestBase { } function test_pluginLoupe_getValidationFunctions() public { - FunctionReference[] memory validations = account1.getValidationFunctions(comprehensivePlugin.foo.selector); + FunctionReference[] memory validations = account1.getValidations(comprehensivePlugin.foo.selector); assertEq(validations.length, 1); assertEq( @@ -89,7 +89,7 @@ contract AccountLoupeTest is AccountTestBase { ) ); - validations = account1.getValidationFunctions(account1.execute.selector); + validations = account1.getValidations(account1.execute.selector); assertEq(validations.length, 1); assertEq(FunctionReference.unwrap(validations[0]), FunctionReference.unwrap(ownerValidation)); diff --git a/test/account/MultiValidation.t.sol b/test/account/MultiValidation.t.sol index 8d552d4e..9ca70857 100644 --- a/test/account/MultiValidation.t.sol +++ b/test/account/MultiValidation.t.sol @@ -45,8 +45,7 @@ contract MultiValidationTest is AccountTestBase { ); validations[1] = FunctionReferenceLib.pack(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER)); - FunctionReference[] memory validations2 = - account1.getValidationFunctions(IStandardExecutor.execute.selector); + FunctionReference[] memory validations2 = account1.getValidations(IStandardExecutor.execute.selector); assertEq(validations2.length, 2); assertEq(FunctionReference.unwrap(validations2[0]), FunctionReference.unwrap(validations[0])); assertEq(FunctionReference.unwrap(validations2[1]), FunctionReference.unwrap(validations[1])); From def2267177625552f3f09f01761714e650c0edde Mon Sep 17 00:00:00 2001 From: adam Date: Mon, 10 Jun 2024 11:01:42 -0400 Subject: [PATCH 6/6] review fixes --- src/account/AccountLoupe.sol | 12 ++---------- src/account/PluginManagerInternals.sol | 5 +++-- src/account/UpgradeableModularAccount.sol | 8 +++++--- 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/account/AccountLoupe.sol b/src/account/AccountLoupe.sol index 0593298a..9e9053fe 100644 --- a/src/account/AccountLoupe.sol +++ b/src/account/AccountLoupe.sol @@ -7,13 +7,7 @@ import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet import {IAccountLoupe, ExecutionHook} from "../interfaces/IAccountLoupe.sol"; import {FunctionReference, IPluginManager} from "../interfaces/IPluginManager.sol"; import {IStandardExecutor} from "../interfaces/IStandardExecutor.sol"; -import { - AccountStorage, - getAccountStorage, - SelectorData, - toFunctionReferenceArray, - toExecutionHook -} from "./AccountStorage.sol"; +import {getAccountStorage, SelectorData, toFunctionReferenceArray, toExecutionHook} from "./AccountStorage.sol"; abstract contract AccountLoupe is IAccountLoupe { using EnumerableSet for EnumerableSet.Bytes32Set; @@ -21,8 +15,6 @@ abstract contract AccountLoupe is IAccountLoupe { /// @inheritdoc IAccountLoupe function getExecutionFunctionHandler(bytes4 selector) external view override returns (address plugin) { - AccountStorage storage _storage = getAccountStorage(); - if ( selector == IStandardExecutor.execute.selector || selector == IStandardExecutor.executeBatch.selector || selector == UUPSUpgradeable.upgradeToAndCall.selector @@ -32,7 +24,7 @@ abstract contract AccountLoupe is IAccountLoupe { return address(this); } - return _storage.selectorData[selector].plugin; + return getAccountStorage().selectorData[selector].plugin; } /// @inheritdoc IAccountLoupe diff --git a/src/account/PluginManagerInternals.sol b/src/account/PluginManagerInternals.sol index 7f8337e4..cfdbef16 100644 --- a/src/account/PluginManagerInternals.sol +++ b/src/account/PluginManagerInternals.sol @@ -86,8 +86,9 @@ abstract contract PluginManagerInternals is IPluginManager { { SelectorData storage _selectorData = getAccountStorage().selectorData[selector]; - // Fail on duplicate definitions - otherwise dependencies could shadow non-depdency - // validation functions, leading to partial uninstalls. + // Fail on duplicate validation functions. Otherwise, dependency validation functions could shadow + // non-depdency validation functions. Then, if a either plugin is uninstall, it would cause a partial + // uninstall of the other. if (!_selectorData.validations.add(toSetValue(validationFunction))) { revert ValidationFunctionAlreadySet(selector, validationFunction); } diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index 84a73cd3..69a11cbd 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -268,9 +268,10 @@ contract UpgradeableModularAccount is payable returns (bytes memory) { - bytes4 execSelector = bytes4(data[0:4]); + bytes4 execSelector = bytes4(data[:4]); - FunctionReference runtimeValidationFunction = FunctionReference.wrap(bytes21(authorization[0:21])); + // Revert if the provided `authorization` less than 21 bytes long, rather than right-padding. + FunctionReference runtimeValidationFunction = FunctionReference.wrap(bytes21(authorization[:21])); AccountStorage storage _storage = getAccountStorage(); @@ -395,6 +396,7 @@ contract UpgradeableModularAccount is revert AlwaysDenyRule(); } + // Revert if the provided `authorization` less than 21 bytes long, rather than right-padding. FunctionReference userOpValidationFunction = FunctionReference.wrap(bytes21(userOp.signature[:21])); if (!getAccountStorage().selectorData[selector].validations.contains(toSetValue(userOpValidationFunction))) @@ -463,7 +465,7 @@ contract UpgradeableModularAccount is ) internal { // run all preRuntimeValidation hooks EnumerableSet.Bytes32Set storage preRuntimeValidationHooks = - getAccountStorage().selectorData[bytes4(callData[0:4])].preValidationHooks; + getAccountStorage().selectorData[bytes4(callData[:4])].preValidationHooks; uint256 preRuntimeValidationHooksLength = preRuntimeValidationHooks.length(); for (uint256 i = 0; i < preRuntimeValidationHooksLength; ++i) {