diff --git a/src/account/AccountLoupe.sol b/src/account/AccountLoupe.sol index 007036da..d84b6d20 100644 --- a/src/account/AccountLoupe.sol +++ b/src/account/AccountLoupe.sol @@ -2,7 +2,6 @@ pragma solidity ^0.8.25; import {UUPSUpgradeable} from "@openzeppelin/contracts/proxy/utils/UUPSUpgradeable.sol"; -import {EnumerableMap} from "@openzeppelin/contracts/utils/structs/EnumerableMap.sol"; import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import {IAccountLoupe} from "../interfaces/IAccountLoupe.sol"; @@ -11,7 +10,7 @@ import {IStandardExecutor} from "../interfaces/IStandardExecutor.sol"; import {AccountStorage, getAccountStorage, SelectorData, toFunctionReferenceArray} from "./AccountStorage.sol"; abstract contract AccountLoupe is IAccountLoupe { - using EnumerableMap for EnumerableMap.Bytes32ToUintMap; + using EnumerableSet for EnumerableSet.Bytes32Set; using EnumerableSet for EnumerableSet.AddressSet; /// @inheritdoc IAccountLoupe @@ -41,56 +40,21 @@ abstract contract AccountLoupe is IAccountLoupe { SelectorData storage selectorData = getAccountStorage().selectorData[selector]; uint256 preExecHooksLength = selectorData.preHooks.length(); uint256 postOnlyExecHooksLength = selectorData.postOnlyHooks.length(); - uint256 maxExecHooksLength = postOnlyExecHooksLength; - // There can only be as many associated post hooks to run as there are pre hooks. - for (uint256 i = 0; i < preExecHooksLength; ++i) { - (, uint256 count) = selectorData.preHooks.at(i); - unchecked { - maxExecHooksLength += (count + 1); - } - } - - // Overallocate on length - not all of this may get filled up. We set the correct length later. - execHooks = new ExecutionHooks[](maxExecHooksLength); - uint256 actualExecHooksLength; + execHooks = new ExecutionHooks[](preExecHooksLength + postOnlyExecHooksLength); for (uint256 i = 0; i < preExecHooksLength; ++i) { - (bytes32 key,) = selectorData.preHooks.at(i); + bytes32 key = selectorData.preHooks.at(i); FunctionReference preExecHook = FunctionReference.wrap(bytes21(key)); + FunctionReference associatedPostExecHook = selectorData.associatedPostHooks[preExecHook]; - uint256 associatedPostExecHooksLength = selectorData.associatedPostHooks[preExecHook].length(); - if (associatedPostExecHooksLength > 0) { - for (uint256 j = 0; j < associatedPostExecHooksLength; ++j) { - execHooks[actualExecHooksLength].preExecHook = preExecHook; - (key,) = selectorData.associatedPostHooks[preExecHook].at(j); - execHooks[actualExecHooksLength].postExecHook = FunctionReference.wrap(bytes21(key)); - - unchecked { - ++actualExecHooksLength; - } - } - } else { - execHooks[actualExecHooksLength].preExecHook = preExecHook; - - unchecked { - ++actualExecHooksLength; - } - } + execHooks[i].preExecHook = preExecHook; + execHooks[i].postExecHook = associatedPostExecHook; } for (uint256 i = 0; i < postOnlyExecHooksLength; ++i) { - (bytes32 key,) = selectorData.postOnlyHooks.at(i); - execHooks[actualExecHooksLength].postExecHook = FunctionReference.wrap(bytes21(key)); - - unchecked { - ++actualExecHooksLength; - } - } - - // Trim the exec hooks array to the actual length, since we may have overallocated. - assembly ("memory-safe") { - mstore(execHooks, actualExecHooksLength) + bytes32 key = selectorData.postOnlyHooks.at(i); + execHooks[preExecHooksLength + i].postExecHook = FunctionReference.wrap(bytes21(key)); } } diff --git a/src/account/AccountStorage.sol b/src/account/AccountStorage.sol index 35f4dc3d..728f7f5c 100644 --- a/src/account/AccountStorage.sol +++ b/src/account/AccountStorage.sol @@ -1,7 +1,6 @@ // SPDX-License-Identifier: GPL-3.0 pragma solidity ^0.8.25; -import {EnumerableMap} from "@openzeppelin/contracts/utils/structs/EnumerableMap.sol"; import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import {IPlugin} from "../interfaces/IPlugin.sol"; @@ -35,15 +34,20 @@ struct SelectorData { // The plugin that implements this execution function. // If this is a native function, the address must remain address(0). address plugin; + // How many times a `PRE_HOOK_ALWAYS_DENY` has been added for this function. + // Since that is the only type of hook that may overlap, we can use this to track the number of times it has + // been applied, and whether or not the deny should apply. The size `uint48` was chosen somewhat arbitrarily, + // but it packs alongside `plugin` while still leaving some other space in the slot for future packing. + uint48 denyExecutionCount; // User operation validation and runtime validation share a function reference. FunctionReference validation; // The pre validation hooks for this function selector. - EnumerableMap.Bytes32ToUintMap preValidationHooks; + EnumerableSet.Bytes32Set preValidationHooks; // The execution hooks for this function selector. - EnumerableMap.Bytes32ToUintMap preHooks; + EnumerableSet.Bytes32Set preHooks; // bytes21 key = pre hook function reference - mapping(FunctionReference => EnumerableMap.Bytes32ToUintMap) associatedPostHooks; - EnumerableMap.Bytes32ToUintMap postOnlyHooks; + mapping(FunctionReference => FunctionReference) associatedPostHooks; + EnumerableSet.Bytes32Set postOnlyHooks; } struct AccountStorage { @@ -73,17 +77,17 @@ function getPermittedCallKey(address addr, bytes4 selector) pure returns (bytes2 return bytes24(bytes20(addr)) | (bytes24(selector) >> 160); } -// Helper function to get all elements of a set into memory. -using EnumerableMap for EnumerableMap.Bytes32ToUintMap; +using EnumerableSet for EnumerableSet.Bytes32Set; -function toFunctionReferenceArray(EnumerableMap.Bytes32ToUintMap storage map) +/// @dev Helper function to get all elements of a set into memory. +function toFunctionReferenceArray(EnumerableSet.Bytes32Set storage set) view returns (FunctionReference[] memory) { - uint256 length = map.length(); + uint256 length = set.length(); FunctionReference[] memory result = new FunctionReference[](length); for (uint256 i = 0; i < length; ++i) { - (bytes32 key,) = map.at(i); + bytes32 key = set.at(i); result[i] = FunctionReference.wrap(bytes21(key)); } return result; diff --git a/src/account/PluginManagerInternals.sol b/src/account/PluginManagerInternals.sol index a3734c4c..1d2ee5b8 100644 --- a/src/account/PluginManagerInternals.sol +++ b/src/account/PluginManagerInternals.sol @@ -2,7 +2,6 @@ pragma solidity ^0.8.25; import {ERC165Checker} from "@openzeppelin/contracts/utils/introspection/ERC165Checker.sol"; -import {EnumerableMap} from "@openzeppelin/contracts/utils/structs/EnumerableMap.sol"; import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import {FunctionReferenceLib} from "../helpers/FunctionReferenceLib.sol"; @@ -25,7 +24,7 @@ import { } from "./AccountStorage.sol"; abstract contract PluginManagerInternals is IPluginManager { - using EnumerableMap for EnumerableMap.Bytes32ToUintMap; + using EnumerableSet for EnumerableSet.Bytes32Set; using EnumerableSet for EnumerableSet.AddressSet; using FunctionReferenceLib for FunctionReference; @@ -81,7 +80,7 @@ abstract contract PluginManagerInternals is IPluginManager { { SelectorData storage _selectorData = getAccountStorage().selectorData[selector]; - if (!_selectorData.validation.isEmpty()) { + if (_selectorData.validation.notEmpty()) { revert ValidationFunctionAlreadySet(selector, validationFunction); } @@ -102,20 +101,23 @@ abstract contract PluginManagerInternals is IPluginManager { { SelectorData storage _selectorData = getAccountStorage().selectorData[selector]; - if (!preExecHook.isEmpty()) { - _addOrIncrement(_selectorData.preHooks, _toSetValue(preExecHook)); + if (preExecHook.notEmpty()) { + // Don't need to check for duplicates, as the hook can be run at most once. + _selectorData.preHooks.add(_toSetValue(preExecHook)); - if (!postExecHook.isEmpty()) { - _addOrIncrement(_selectorData.associatedPostHooks[preExecHook], _toSetValue(postExecHook)); - } - } else { - if (postExecHook.isEmpty()) { - // both pre and post hooks cannot be null - revert NullFunctionReference(); + if (postExecHook.notEmpty()) { + _selectorData.associatedPostHooks[preExecHook] = postExecHook; } - _addOrIncrement(_selectorData.postOnlyHooks, _toSetValue(postExecHook)); + return; + } + + if (postExecHook.isEmpty()) { + // both pre and post hooks cannot be null + revert NullFunctionReference(); } + + _selectorData.postOnlyHooks.add(_toSetValue(postExecHook)); } function _removeExecHooks(bytes4 selector, FunctionReference preExecHook, FunctionReference postExecHook) @@ -123,37 +125,47 @@ abstract contract PluginManagerInternals is IPluginManager { { SelectorData storage _selectorData = getAccountStorage().selectorData[selector]; - if (!preExecHook.isEmpty()) { - _removeOrDecrement(_selectorData.preHooks, _toSetValue(preExecHook)); + if (preExecHook.notEmpty()) { + _selectorData.preHooks.remove(_toSetValue(preExecHook)); - if (!postExecHook.isEmpty()) { - _removeOrDecrement(_selectorData.associatedPostHooks[preExecHook], _toSetValue(postExecHook)); + if (postExecHook.notEmpty()) { + _selectorData.associatedPostHooks[preExecHook] = FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE; } - } else { - // The case where both pre and post hooks are null was checked during installation. - // May ignore return value, as the manifest hash is validated to ensure that the hook exists. - _removeOrDecrement(_selectorData.postOnlyHooks, _toSetValue(postExecHook)); + return; } + + // The case where both pre and post hooks are null was checked during installation. + + // May ignore return value, as the manifest hash is validated to ensure that the hook exists. + _selectorData.postOnlyHooks.remove(_toSetValue(postExecHook)); } function _addPreValidationHook(bytes4 selector, FunctionReference preValidationHook) internal notNullFunction(preValidationHook) { - _addOrIncrement( - getAccountStorage().selectorData[selector].preValidationHooks, _toSetValue(preValidationHook) - ); + SelectorData storage _selectorData = getAccountStorage().selectorData[selector]; + if (preValidationHook.eq(FunctionReferenceLib._PRE_HOOK_ALWAYS_DENY)) { + // Increment `denyExecutionCount`, because this pre validation hook may be applied multiple times. + _selectorData.denyExecutionCount += 1; + return; + } + _selectorData.preValidationHooks.add(_toSetValue(preValidationHook)); } function _removePreValidationHook(bytes4 selector, FunctionReference preValidationHook) internal notNullFunction(preValidationHook) { + SelectorData storage _selectorData = getAccountStorage().selectorData[selector]; + if (preValidationHook.eq(FunctionReferenceLib._PRE_HOOK_ALWAYS_DENY)) { + // Decrement `denyExecutionCount`, because this pre exec hook may be applied multiple times. + _selectorData.denyExecutionCount -= 1; + return; + } // May ignore return value, as the manifest hash is validated to ensure that the hook exists. - _removeOrDecrement( - getAccountStorage().selectorData[selector].preValidationHooks, _toSetValue(preValidationHook) - ); + _selectorData.preValidationHooks.remove(_toSetValue(preValidationHook)); } function _installPlugin( @@ -290,7 +302,7 @@ abstract contract PluginManagerInternals is IPluginManager { _addExecHooks( mh.executionSelector, _resolveManifestFunction( - mh.preExecHook, plugin, emptyDependencies, ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY + mh.preExecHook, plugin, emptyDependencies, ManifestAssociatedFunctionType.NONE ), _resolveManifestFunction( mh.postExecHook, plugin, emptyDependencies, ManifestAssociatedFunctionType.NONE @@ -356,7 +368,7 @@ abstract contract PluginManagerInternals is IPluginManager { _removeExecHooks( mh.executionSelector, _resolveManifestFunction( - mh.preExecHook, plugin, emptyDependencies, ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY + mh.preExecHook, plugin, emptyDependencies, ManifestAssociatedFunctionType.NONE ), _resolveManifestFunction( mh.postExecHook, plugin, emptyDependencies, ManifestAssociatedFunctionType.NONE @@ -449,25 +461,6 @@ abstract contract PluginManagerInternals is IPluginManager { emit PluginUninstalled(plugin, onUninstallSuccess); } - function _addOrIncrement(EnumerableMap.Bytes32ToUintMap storage map, bytes32 key) internal { - (bool success, uint256 value) = map.tryGet(key); - map.set(key, success ? value + 1 : 0); - } - - /// @return True if the key was removed or its value was decremented, false if the key was not found. - function _removeOrDecrement(EnumerableMap.Bytes32ToUintMap storage map, bytes32 key) internal returns (bool) { - (bool success, uint256 value) = map.tryGet(key); - if (!success) { - return false; - } - if (value == 0) { - map.remove(key); - } else { - map.set(key, value - 1); - } - return true; - } - function _toSetValue(FunctionReference functionReference) internal pure returns (bytes32) { return bytes32(FunctionReference.unwrap(functionReference)); } diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index cf9af40f..3161eb45 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -6,7 +6,6 @@ import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntry import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; import {UUPSUpgradeable} from "@openzeppelin/contracts/proxy/utils/UUPSUpgradeable.sol"; import {IERC165} from "@openzeppelin/contracts/utils/introspection/IERC165.sol"; -import {EnumerableMap} from "@openzeppelin/contracts/utils/structs/EnumerableMap.sol"; import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import {FunctionReferenceLib} from "../helpers/FunctionReferenceLib.sol"; @@ -32,7 +31,6 @@ contract UpgradeableModularAccount is PluginManagerInternals, UUPSUpgradeable { - using EnumerableMap for EnumerableMap.Bytes32ToUintMap; using EnumerableSet for EnumerableSet.Bytes32Set; using FunctionReferenceLib for FunctionReference; @@ -323,6 +321,12 @@ contract UpgradeableModularAccount is } bytes4 selector = bytes4(userOp.callData); + AccountStorage storage _storage = getAccountStorage(); + + if (_storage.selectorData[selector].denyExecutionCount > 0) { + revert AlwaysDenyRule(); + } + FunctionReference userOpValidationFunction = getAccountStorage().selectorData[selector].validation; validationData = _doUserOpValidation(selector, userOpValidationFunction, userOp, userOpHash); @@ -335,83 +339,72 @@ contract UpgradeableModularAccount is PackedUserOperation calldata userOp, bytes32 userOpHash ) internal returns (uint256 validationData) { - if (userOpValidationFunction.isEmpty()) { + if (userOpValidationFunction.isEmptyOrMagicValue()) { + // If the validation function is empty, then the call cannot proceed. + // Alternatively, the validation function may be set to the RUNTIME_VALIDATION_ALWAYS_ALLOW magic + // value, in which case we also revert. revert UserOpValidationFunctionMissing(selector); } uint256 currentValidationData; // Do preUserOpValidation hooks - EnumerableMap.Bytes32ToUintMap storage preUserOpValidationHooks = + EnumerableSet.Bytes32Set storage preUserOpValidationHooks = getAccountStorage().selectorData[selector].preValidationHooks; uint256 preUserOpValidationHooksLength = preUserOpValidationHooks.length(); for (uint256 i = 0; i < preUserOpValidationHooksLength; ++i) { - (bytes32 key,) = preUserOpValidationHooks.at(i); + bytes32 key = preUserOpValidationHooks.at(i); FunctionReference preUserOpValidationHook = _toFunctionReference(key); - if (!preUserOpValidationHook.isEmptyOrMagicValue()) { - (address plugin, uint8 functionId) = preUserOpValidationHook.unpack(); - currentValidationData = IPlugin(plugin).preUserOpValidationHook(functionId, userOp, userOpHash); + (address plugin, uint8 functionId) = preUserOpValidationHook.unpack(); + currentValidationData = IPlugin(plugin).preUserOpValidationHook(functionId, userOp, userOpHash); - if (uint160(currentValidationData) > 1) { - // If the aggregator is not 0 or 1, it is an unexpected value - revert UnexpectedAggregator(plugin, functionId, address(uint160(currentValidationData))); - } - validationData = _coalescePreValidation(validationData, currentValidationData); - } else { - // Function reference cannot be 0 and _RUNTIME_VALIDATION_ALWAYS_ALLOW is not permitted here. - revert InvalidConfiguration(); + if (uint160(currentValidationData) > 1) { + // If the aggregator is not 0 or 1, it is an unexpected value + revert UnexpectedAggregator(plugin, functionId, address(uint160(currentValidationData))); } + validationData = _coalescePreValidation(validationData, currentValidationData); } // Run the user op validationFunction { - if (!userOpValidationFunction.isEmptyOrMagicValue()) { - (address plugin, uint8 functionId) = userOpValidationFunction.unpack(); - currentValidationData = IPlugin(plugin).userOpValidationFunction(functionId, userOp, userOpHash); - - if (preUserOpValidationHooksLength != 0) { - // If we have other validation data we need to coalesce with - validationData = _coalesceValidation(validationData, currentValidationData); - } else { - validationData = currentValidationData; - } + (address plugin, uint8 functionId) = userOpValidationFunction.unpack(); + currentValidationData = IPlugin(plugin).userOpValidationFunction(functionId, userOp, userOpHash); + + if (preUserOpValidationHooksLength != 0) { + // If we have other validation data we need to coalesce with + validationData = _coalesceValidation(validationData, currentValidationData); } else { - // _PRE_HOOK_ALWAYS_DENY is not permitted here. - // If this is _RUNTIME_VALIDATION_ALWAYS_ALLOW, the call should revert. - revert InvalidConfiguration(); + validationData = currentValidationData; } } } function _doRuntimeValidationIfNotFromEP() internal { + AccountStorage storage _storage = getAccountStorage(); + + if (_storage.selectorData[msg.sig].denyExecutionCount > 0) { + revert AlwaysDenyRule(); + } + if (msg.sender == address(_ENTRY_POINT)) return; - AccountStorage storage _storage = getAccountStorage(); FunctionReference runtimeValidationFunction = _storage.selectorData[msg.sig].validation; // run all preRuntimeValidation hooks - EnumerableMap.Bytes32ToUintMap storage preRuntimeValidationHooks = + EnumerableSet.Bytes32Set storage preRuntimeValidationHooks = getAccountStorage().selectorData[msg.sig].preValidationHooks; uint256 preRuntimeValidationHooksLength = preRuntimeValidationHooks.length(); for (uint256 i = 0; i < preRuntimeValidationHooksLength; ++i) { - (bytes32 key,) = preRuntimeValidationHooks.at(i); + bytes32 key = preRuntimeValidationHooks.at(i); FunctionReference preRuntimeValidationHook = _toFunctionReference(key); - if (!preRuntimeValidationHook.isEmptyOrMagicValue()) { - (address plugin, uint8 functionId) = preRuntimeValidationHook.unpack(); - // solhint-disable-next-line no-empty-blocks - try IPlugin(plugin).preRuntimeValidationHook(functionId, msg.sender, msg.value, msg.data) {} - catch (bytes memory revertReason) { - revert PreRuntimeValidationHookFailed(plugin, functionId, revertReason); - } - } else { - if (preRuntimeValidationHook.eq(FunctionReferenceLib._PRE_HOOK_ALWAYS_DENY)) { - revert AlwaysDenyRule(); - } - // Function reference cannot be 0 or _RUNTIME_VALIDATION_ALWAYS_ALLOW. - revert InvalidConfiguration(); + (address plugin, uint8 functionId) = preRuntimeValidationHook.unpack(); + // solhint-disable-next-line no-empty-blocks + try IPlugin(plugin).preRuntimeValidationHook(functionId, msg.sender, msg.value, msg.data) {} + catch (bytes memory revertReason) { + revert PreRuntimeValidationHookFailed(plugin, functionId, revertReason); } } @@ -427,8 +420,6 @@ contract UpgradeableModularAccount is } else { if (runtimeValidationFunction.isEmpty()) { revert RuntimeValidationFunctionMissing(msg.sig); - } else if (runtimeValidationFunction.eq(FunctionReferenceLib._PRE_HOOK_ALWAYS_DENY)) { - revert InvalidConfiguration(); } // If _RUNTIME_VALIDATION_ALWAYS_ALLOW, just let the function finish. } @@ -440,63 +431,47 @@ contract UpgradeableModularAccount is returns (PostExecToRun[] memory postHooksToRun) { SelectorData storage selectorData = getAccountStorage().selectorData[selector]; + uint256 preExecHooksLength = selectorData.preHooks.length(); uint256 postOnlyHooksLength = selectorData.postOnlyHooks.length(); - uint256 maxPostExecHooksLength = postOnlyHooksLength; - - // There can only be as many associated post hooks to run as there are pre hooks. - for (uint256 i = 0; i < preExecHooksLength; ++i) { - (, uint256 count) = selectorData.preHooks.at(i); - unchecked { - maxPostExecHooksLength += (count + 1); - } - } // Overallocate on length - not all of this may get filled up. We set the correct length later. - postHooksToRun = new PostExecToRun[](maxPostExecHooksLength); - uint256 actualPostHooksToRunLength; + postHooksToRun = new PostExecToRun[](preExecHooksLength + postOnlyHooksLength); - // Copy post-only hooks to the array. + // Copy all post hooks to the array. This happens before any pre hooks are run, so we can + // be sure that the set of hooks to run will not be affected by state changes mid-execution. + + // Copy post-only hooks. for (uint256 i = 0; i < postOnlyHooksLength; ++i) { - (bytes32 key,) = selectorData.postOnlyHooks.at(i); - postHooksToRun[actualPostHooksToRunLength].postExecHook = _toFunctionReference(key); - unchecked { - ++actualPostHooksToRunLength; - } + bytes32 key = selectorData.postOnlyHooks.at(i); + postHooksToRun[i].postExecHook = _toFunctionReference(key); } - // Then run the pre hooks and copy the associated post hooks (along with their pre hook's return data) to - // the array. + // Copy associated post hooks to the array. for (uint256 i = 0; i < preExecHooksLength; ++i) { - (bytes32 key,) = selectorData.preHooks.at(i); - FunctionReference preExecHook = _toFunctionReference(key); + FunctionReference preExecHook = _toFunctionReference(selectorData.preHooks.at(i)); - if (preExecHook.isEmptyOrMagicValue()) { - // The function reference must be PRE_HOOK_ALWAYS_DENY in this case, because zero and any other - // magic value is unassignable here. - revert AlwaysDenyRule(); + FunctionReference associatedPostExecHook = selectorData.associatedPostHooks[preExecHook]; + + if (associatedPostExecHook.notEmpty()) { + postHooksToRun[i + postOnlyHooksLength].postExecHook = associatedPostExecHook; } + } - bytes memory preExecHookReturnData = _runPreExecHook(preExecHook, data); + // Run the pre hooks and copy their return data to the post hooks array, if an associated post-exec hook + // exists. + for (uint256 i = 0; i < preExecHooksLength; ++i) { + bytes32 key = selectorData.preHooks.at(i); + FunctionReference preExecHook = _toFunctionReference(key); - uint256 associatedPostExecHooksLength = selectorData.associatedPostHooks[preExecHook].length(); - if (associatedPostExecHooksLength > 0) { - for (uint256 j = 0; j < associatedPostExecHooksLength; ++j) { - (key,) = selectorData.associatedPostHooks[preExecHook].at(j); - postHooksToRun[actualPostHooksToRunLength].postExecHook = _toFunctionReference(key); - postHooksToRun[actualPostHooksToRunLength].preExecHookReturnData = preExecHookReturnData; + bytes memory preExecHookReturnData = _runPreExecHook(preExecHook, data); - unchecked { - ++actualPostHooksToRunLength; - } - } + // If there is an associated post-exec hook, save the return data. + PostExecToRun memory postExecToRun = postHooksToRun[i + postOnlyHooksLength]; + if (postExecToRun.postExecHook.notEmpty()) { + postExecToRun.preExecHookReturnData = preExecHookReturnData; } } - - // Trim the post hook array to the actual length, since we may have overallocated. - assembly ("memory-safe") { - mstore(postHooksToRun, actualPostHooksToRunLength) - } } function _runPreExecHook(FunctionReference preExecHook, bytes calldata data) @@ -521,6 +496,12 @@ contract UpgradeableModularAccount is --i; PostExecToRun memory postHookToRun = postHooksToRun[i]; + + if (postHookToRun.postExecHook.isEmpty()) { + // This is an empty post hook, from a pre-only hook, so we skip it. + continue; + } + (address plugin, uint8 functionId) = postHookToRun.postExecHook.unpack(); // solhint-disable-next-line no-empty-blocks try IPlugin(plugin).postExecutionHook(functionId, postHookToRun.preExecHookReturnData) {} diff --git a/src/helpers/FunctionReferenceLib.sol b/src/helpers/FunctionReferenceLib.sol index a938eef5..e80992c9 100644 --- a/src/helpers/FunctionReferenceLib.sol +++ b/src/helpers/FunctionReferenceLib.sol @@ -26,6 +26,10 @@ library FunctionReferenceLib { return FunctionReference.unwrap(fr) == bytes21(0); } + function notEmpty(FunctionReference fr) internal pure returns (bool) { + return FunctionReference.unwrap(fr) != bytes21(0); + } + function isEmptyOrMagicValue(FunctionReference fr) internal pure returns (bool) { return FunctionReference.unwrap(fr) <= bytes21(uint168(2)); } diff --git a/src/interfaces/IPlugin.sol b/src/interfaces/IPlugin.sol index 5049d5cd..ca5c7ff7 100644 --- a/src/interfaces/IPlugin.sol +++ b/src/interfaces/IPlugin.sol @@ -19,10 +19,9 @@ enum ManifestAssociatedFunctionType { // setting a hook and is therefore disallowed. RUNTIME_VALIDATION_ALWAYS_ALLOW, // Resolves to a magic value to always fail in a hook for a given function. - // This is only assignable to pre hooks (pre validation and pre execution). It should not be used on - // validation functions themselves, because this is equivalent to leaving the validation functions unset. - // It should not be used in post-exec hooks, because if it is known to always revert, that should happen - // as early as possible to save gas. + // This is only assignable to pre execution hooks. It should not be used on validation functions themselves, because + // this is equivalent to leaving the validation functions unset. It should not be used in post-exec hooks, because + // if it is known to always revert, that should happen as early as possible to save gas. PRE_HOOK_ALWAYS_DENY } // forgefmt: disable-end diff --git a/standard/ERCs/erc-6900.md b/standard/ERCs/erc-6900.md index a9c86dcd..f4be1e0d 100644 --- a/standard/ERCs/erc-6900.md +++ b/standard/ERCs/erc-6900.md @@ -348,10 +348,9 @@ enum ManifestAssociatedFunctionType { // setting a hook and is therefore disallowed. RUNTIME_VALIDATION_ALWAYS_ALLOW, // Resolves to a magic value to always fail in a hook for a given function. - // This is only assignable to pre hooks (pre validation and pre execution). It should not be used on - // validation functions themselves, because this is equivalent to leaving the validation functions unset. - // It should not be used in post-exec hooks, because if it is known to always revert, that should happen - // as early as possible to save gas. + // This is only assignable to pre execution hooks. It should not be used on validation functions themselves, because + // this is equivalent to leaving the validation functions unset. It should not be used in post-exec hooks, because + // if it is known to always revert, that should happen as early as possible to save gas. PRE_HOOK_ALWAYS_DENY } @@ -498,7 +497,7 @@ Finally, the function MUST emit the event `PluginUninstalled` with the plugin's When the function `validateUserOp` is called on modular account by the `EntryPoint`, it MUST find the user operation validation function associated to the function selector in the first four bytes of `userOp.callData`. If there is no function defined for the selector, or if `userOp.callData.length < 4`, then execution MUST revert. -If the function selector has associated pre user operation validation hooks, then those hooks MUST be run sequentially. If any revert, the outer call MUST revert. If any are set to `PRE_HOOK_ALWAYS_DENY`, the call MUST revert. If any return an `authorizer` value other than 0 or 1, execution MUST revert. If any return an `authorizer` value of 1, indicating an invalid signature, the returned validation data of the outer call MUST also be 1. If any return time-bounded validation by specifying either a `validUntil` or `validBefore` value, the resulting validation data MUST be the intersection of all time bounds provided. +If the function selector has associated pre user operation validation hooks, then those hooks MUST be run sequentially. If any revert, the outer call MUST revert. If the selector has any pre execution hooks set to `PRE_HOOK_ALWAYS_DENY`, the call MUST revert. If any return an `authorizer` value other than 0 or 1, execution MUST revert. If any return an `authorizer` value of 1, indicating an invalid signature, the returned validation data of the outer call MUST also be 1. If any return time-bounded validation by specifying either a `validUntil` or `validBefore` value, the resulting validation data MUST be the intersection of all time bounds provided. Then, the modular account MUST execute the validation function with the user operation and its hash as parameters using the `call` opcode. The returned validation data from the user operation validation function MUST be updated, if necessary, by the return values of any pre user operation validation hooks, then returned by `validateUserOp`. @@ -510,7 +509,7 @@ Additionally, when the modular account natively implements functions in `IPlugin The steps to perform are: -- If the call is not from the `EntryPoint`, then find an associated runtime validation function. If one does not exist, execution MUST revert. The modular account MUST execute all pre runtime validation hooks, then the runtime validation function, with the `call` opcode. All of these functions MUST receive the caller, value, and execution function's calldata as parameters. If any of these functions revert, execution MUST revert. If any pre runtime validation hooks are set to `PRE_HOOK_ALWAYS_DENY`, execution MUST revert. If the runtime validation function is set to `RUNTIME_VALIDATION_ALWAYS_ALLOW`, the validation function MUST be bypassed. +- If the call is not from the `EntryPoint`, then find an associated runtime validation function. If one does not exist, execution MUST revert. The modular account MUST execute all pre runtime validation hooks, then the runtime validation function, with the `call` opcode. All of these functions MUST receive the caller, value, and execution function's calldata as parameters. If any of these functions revert, execution MUST revert. If any pre execution hooks are set to `PRE_HOOK_ALWAYS_DENY`, execution MUST revert. If the validation function is set to `RUNTIME_VALIDATION_ALWAYS_ALLOW`, the runtime validation function MUST be bypassed. - If there are pre execution hooks defined for the execution function, execute those hooks with the caller, value, and execution function's calldata as parameters. If any of these hooks returns data, it MUST be preserved until the call to the post execution hook. The operation MUST be done with the `call` opcode. If there are duplicate pre execution hooks (i.e., hooks with identical `FunctionReference`s), run the hook only once. If any of these functions revert, execution MUST revert. - Run the execution function. - If any post execution hooks are defined, run the functions. If a pre execution hook returned data to the account, that data MUST be passed as a parameter to the associated post execution hook. The operation MUST be done with the `call` opcode. If there are duplicate post execution hooks, run them once for each unique associated pre execution hook. For post execution hooks without an associated pre execution hook, run the hook only once. If any of these functions revert, execution MUST revert. diff --git a/test/account/AccountExecHooks.t.sol b/test/account/AccountExecHooks.t.sol index db5df7eb..8e1fef94 100644 --- a/test/account/AccountExecHooks.t.sol +++ b/test/account/AccountExecHooks.t.sol @@ -11,6 +11,7 @@ import { } from "../../src/interfaces/IPlugin.sol"; import {PluginManagerInternals} from "../../src/account/PluginManagerInternals.sol"; import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; import {MockPlugin} from "../mocks/MockPlugin.sol"; import {AccountTestBase} from "../utils/AccountTestBase.sol"; @@ -177,54 +178,62 @@ contract AccountExecHooksTest is AccountTestBase { _uninstallPlugin(mockPlugin1); } - function test_overlappingPreExecHooks_install() public { + function test_overlappingPreValidationHooks_install() public { // Install the first plugin. - _installPlugin1WithHooks( + _installPlugin1WithPreValidationHook( _EXEC_SELECTOR, ManifestFunction({ functionType: ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY, functionId: 0, dependencyIndex: 0 - }), - ManifestFunction(ManifestAssociatedFunctionType.NONE, 0, 0) + }) ); + // Expect the call to fail due to the "always deny" pre hook. + vm.breakpoint("a"); + (bool success, bytes memory retData) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertFalse(success); + assertEq(retData, abi.encodeWithSelector(UpgradeableModularAccount.AlwaysDenyRule.selector)); + // Install a second plugin that applies the same pre hook on the same selector. - _installPlugin2WithHooksExpectSuccess( + _installPlugin2WithPreValidationHook( _EXEC_SELECTOR, ManifestFunction({ functionType: ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY, functionId: 0, dependencyIndex: 0 - }), - ManifestFunction(ManifestAssociatedFunctionType.NONE, 0, 0), - new FunctionReference[](0) + }) ); - vm.stopPrank(); - } - - function test_overlappingPreExecHooks_run() public { - (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + // Still expect the call to fail. + (success, retData) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); assertFalse(success); + assertEq(retData, abi.encodeWithSelector(UpgradeableModularAccount.AlwaysDenyRule.selector)); + + vm.stopPrank(); } - function test_overlappingPreExecHooks_uninstall() public { - test_overlappingPreExecHooks_install(); + function test_overlappingPreValidationHooks_uninstall() public { + test_overlappingPreValidationHooks_install(); // Uninstall the second plugin. _uninstallPlugin(mockPlugin2); - // Expect the pre hook to still exist. - (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + // Expect the pre validation hook of "always deny" to still exist. + (bool success, bytes memory retData) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); assertFalse(success); + assertEq(retData, abi.encodeWithSelector(UpgradeableModularAccount.AlwaysDenyRule.selector)); // Uninstall the first plugin. _uninstallPlugin(mockPlugin1); - // Execution selector should no longer exist. - (success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + // // Execution selector should no longer exist. + (success, retData) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); assertFalse(success); + assertEq( + retData, + abi.encodeWithSelector(UpgradeableModularAccount.UnrecognizedFunction.selector, _EXEC_SELECTOR) + ); } function test_execHookDependencies_failInstall() public { @@ -290,6 +299,29 @@ contract AccountExecHooksTest is AccountTestBase { }); } + function _installPlugin1WithPreValidationHook(bytes4 selector, ManifestFunction memory preValidationHook) + internal + { + m1.preValidationHooks.push( + ManifestAssociatedFunction({executionSelector: selector, associatedFunction: preValidationHook}) + ); + + mockPlugin1 = new MockPlugin(m1); + manifestHash1 = keccak256(abi.encode(mockPlugin1.pluginManifest())); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall(abi.encodeCall(IPlugin.onInstall, (bytes(""))), 0); + vm.expectEmit(true, true, true, true); + emit PluginInstalled(address(mockPlugin1), manifestHash1, new FunctionReference[](0)); + + account1.installPlugin({ + plugin: address(mockPlugin1), + manifestHash: manifestHash1, + pluginInstallData: bytes(""), + dependencies: new FunctionReference[](0) + }); + } + function _installPlugin2WithHooksExpectSuccess( bytes4 selector, ManifestFunction memory preHook, @@ -349,6 +381,29 @@ contract AccountExecHooksTest is AccountTestBase { }); } + function _installPlugin2WithPreValidationHook(bytes4 selector, ManifestFunction memory preValidationHook) + internal + { + m2.preValidationHooks.push( + ManifestAssociatedFunction({executionSelector: selector, associatedFunction: preValidationHook}) + ); + + mockPlugin2 = new MockPlugin(m2); + manifestHash2 = keccak256(abi.encode(mockPlugin2.pluginManifest())); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall(abi.encodeCall(IPlugin.onInstall, (bytes(""))), 0); + vm.expectEmit(true, true, true, true); + emit PluginInstalled(address(mockPlugin2), manifestHash2, new FunctionReference[](0)); + + account1.installPlugin({ + plugin: address(mockPlugin2), + manifestHash: manifestHash2, + pluginInstallData: bytes(""), + dependencies: new FunctionReference[](0) + }); + } + function _uninstallPlugin(MockPlugin plugin) internal { vm.expectEmit(true, true, true, true); emit ReceivedCall(abi.encodeCall(IPlugin.onUninstall, (bytes(""))), 0);