From 1bfd893170da6f1bc22b143098bd4fcc1009b5cc Mon Sep 17 00:00:00 2001 From: adam Date: Mon, 24 Jun 2024 15:05:38 -0400 Subject: [PATCH] refactor validation mapping --- src/account/AccountLoupe.sol | 14 +++++++++++--- src/account/AccountStorage.sol | 12 ++++++++++-- src/account/PluginManager2.sol | 15 +++++---------- src/account/PluginManagerInternals.sol | 8 ++------ src/account/UpgradeableModularAccount.sol | 9 ++------- src/helpers/KnownSelectors.sol | 3 +-- src/interfaces/IAccountLoupe.sol | 8 ++++---- src/interfaces/IPluginManager.sol | 2 -- test/account/AccountLoupe.t.sol | 20 ++++++-------------- test/account/MultiValidation.t.sol | 11 +++++++---- 10 files changed, 48 insertions(+), 54 deletions(-) diff --git a/src/account/AccountLoupe.sol b/src/account/AccountLoupe.sol index 3e62fe9a..32721550 100644 --- a/src/account/AccountLoupe.sol +++ b/src/account/AccountLoupe.sol @@ -7,7 +7,7 @@ import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet import {IAccountLoupe, ExecutionHook} from "../interfaces/IAccountLoupe.sol"; import {FunctionReference, IPluginManager} from "../interfaces/IPluginManager.sol"; import {IStandardExecutor} from "../interfaces/IStandardExecutor.sol"; -import {getAccountStorage, toFunctionReferenceArray, toExecutionHook} from "./AccountStorage.sol"; +import {getAccountStorage, toExecutionHook, toSelector} from "./AccountStorage.sol"; abstract contract AccountLoupe is IAccountLoupe { using EnumerableSet for EnumerableSet.Bytes32Set; @@ -28,8 +28,16 @@ abstract contract AccountLoupe is IAccountLoupe { } /// @inheritdoc IAccountLoupe - function getValidations(bytes4 selector) external view override returns (FunctionReference[] memory) { - return toFunctionReferenceArray(getAccountStorage().selectorData[selector].validations); + function getSelectors(FunctionReference validationFunction) external view returns (bytes4[] memory) { + uint256 length = getAccountStorage().validationData[validationFunction].selectors.length(); + + bytes4[] memory selectors = new bytes4[](length); + + for (uint256 i = 0; i < length; ++i) { + selectors[i] = toSelector(getAccountStorage().validationData[validationFunction].selectors.at(i)); + } + + return selectors; } /// @inheritdoc IAccountLoupe diff --git a/src/account/AccountStorage.sol b/src/account/AccountStorage.sol index ddd8f900..4f90169f 100644 --- a/src/account/AccountStorage.sol +++ b/src/account/AccountStorage.sol @@ -29,8 +29,6 @@ struct SelectorData { bool allowDefaultValidation; // The execution hooks for this function selector. EnumerableSet.Bytes32Set executionHooks; - // Which validation functions are associated with this function selector. - EnumerableSet.Bytes32Set validations; } struct ValidationData { @@ -44,6 +42,8 @@ struct ValidationData { FunctionReference[] preValidationHooks; // Permission hooks for this validation function. EnumerableSet.Bytes32Set permissionHooks; + // The set of selectors that may be validated by this validation function. + EnumerableSet.Bytes32Set selectors; } struct AccountStorage { @@ -96,6 +96,14 @@ function toExecutionHook(bytes32 setValue) isPostHook = (uint256(setValue) >> 72) & 0xFF == 1; } +function toSetValue(bytes4 selector) pure returns (bytes32) { + return bytes32(selector); +} + +function toSelector(bytes32 setValue) pure returns (bytes4) { + return bytes4(setValue); +} + /// @dev Helper function to get all elements of a set into memory. function toFunctionReferenceArray(EnumerableSet.Bytes32Set storage set) view diff --git a/src/account/PluginManager2.sol b/src/account/PluginManager2.sol index effa1a15..dbac17c1 100644 --- a/src/account/PluginManager2.sol +++ b/src/account/PluginManager2.sol @@ -89,7 +89,7 @@ abstract contract PluginManager2 { for (uint256 i = 0; i < selectors.length; ++i) { bytes4 selector = selectors[i]; - if (!_storage.selectorData[selector].validations.add(toSetValue(validationFunction))) { + if (!_storage.validationData[validationFunction].selectors.add(toSetValue(selector))) { revert ValidationAlreadySet(selector, validationFunction); } } @@ -102,7 +102,6 @@ abstract contract PluginManager2 { function _uninstallValidation( FunctionReference validationFunction, - bytes4[] calldata selectors, bytes calldata uninstallData, bytes calldata preValidationHookUninstallData, bytes calldata permissionHookUninstallData @@ -144,14 +143,10 @@ abstract contract PluginManager2 { } delete _storage.validationData[validationFunction].preValidationHooks; - // Because this function also calls `onUninstall`, and removes the default flag from validation, we must - // assume these selectors passed in to be exhaustive. - // TODO: consider enforcing this from user-supplied install config. - for (uint256 i = 0; i < selectors.length; ++i) { - bytes4 selector = selectors[i]; - if (!_storage.selectorData[selector].validations.remove(toSetValue(validationFunction))) { - revert ValidationNotSet(selector, validationFunction); - } + // Clear selectors + while (_storage.validationData[validationFunction].selectors.length() > 0) { + bytes32 selector = _storage.validationData[validationFunction].selectors.at(0); + _storage.validationData[validationFunction].selectors.remove(selector); } if (uninstallData.length > 0) { diff --git a/src/account/PluginManagerInternals.sol b/src/account/PluginManagerInternals.sol index a9b79802..42868585 100644 --- a/src/account/PluginManagerInternals.sol +++ b/src/account/PluginManagerInternals.sol @@ -103,12 +103,10 @@ abstract contract PluginManagerInternals is IPluginManager { internal notNullFunction(validationFunction) { - SelectorData storage _selectorData = getAccountStorage().selectorData[selector]; - // Fail on duplicate validation functions. Otherwise, dependency validation functions could shadow // non-depdency validation functions. Then, if a either plugin is uninstalled, it would cause a partial // uninstall of the other. - if (!_selectorData.validations.add(toSetValue(validationFunction))) { + if (!getAccountStorage().validationData[validationFunction].selectors.add(toSetValue(selector))) { revert ValidationFunctionAlreadySet(selector, validationFunction); } } @@ -117,11 +115,9 @@ abstract contract PluginManagerInternals is IPluginManager { internal notNullFunction(validationFunction) { - SelectorData storage _selectorData = getAccountStorage().selectorData[selector]; - // May ignore return value, as the manifest hash is validated to ensure that the validation function // exists. - _selectorData.validations.remove(toSetValue(validationFunction)); + getAccountStorage().validationData[validationFunction].selectors.remove(toSetValue(selector)); } function _addExecHooks( diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index d57fcb0f..aa6dfe19 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -309,17 +309,12 @@ contract UpgradeableModularAccount is /// @notice May be validated by a default validation. function uninstallValidation( FunctionReference validationFunction, - bytes4[] calldata selectors, bytes calldata uninstallData, bytes calldata preValidationHookUninstallData, bytes calldata permissionHookUninstallData ) external wrapNativeFunction { _uninstallValidation( - validationFunction, - selectors, - uninstallData, - preValidationHookUninstallData, - permissionHookUninstallData + validationFunction, uninstallData, preValidationHookUninstallData, permissionHookUninstallData ); } @@ -685,7 +680,7 @@ contract UpgradeableModularAccount is } } else { // Not default validation, but per-selector - if (!getAccountStorage().selectorData[selector].validations.contains(toSetValue(validationFunction))) { + if (!getAccountStorage().validationData[validationFunction].selectors.contains(toSetValue(selector))) { revert UserOpValidationFunctionMissing(selector); } } diff --git a/src/helpers/KnownSelectors.sol b/src/helpers/KnownSelectors.sol index e5244d2c..1d02d2a3 100644 --- a/src/helpers/KnownSelectors.sol +++ b/src/helpers/KnownSelectors.sol @@ -34,8 +34,7 @@ library KnownSelectors { || selector == IStandardExecutor.executeWithAuthorization.selector // check against IAccountLoupe methods || selector == IAccountLoupe.getExecutionFunctionHandler.selector - || selector == IAccountLoupe.getValidations.selector - || selector == IAccountLoupe.getExecutionHooks.selector + || selector == IAccountLoupe.getSelectors.selector || selector == IAccountLoupe.getExecutionHooks.selector || selector == IAccountLoupe.getPreValidationHooks.selector || selector == IAccountLoupe.getInstalledPlugins.selector; } diff --git a/src/interfaces/IAccountLoupe.sol b/src/interfaces/IAccountLoupe.sol index b172464a..d74c5940 100644 --- a/src/interfaces/IAccountLoupe.sol +++ b/src/interfaces/IAccountLoupe.sol @@ -18,10 +18,10 @@ interface IAccountLoupe { /// @return plugin The plugin address for this selector. function getExecutionFunctionHandler(bytes4 selector) external view returns (address plugin); - /// @notice Get the validation functions for a selector. - /// @param selector The selector to get the validation functions for. - /// @return The validation functions for this selector. - function getValidations(bytes4 selector) external view returns (FunctionReference[] memory); + /// @notice Get the selectors for a validation function. + /// @param validationFunction The validation function to get the selectors for. + /// @return The allowed selectors for this validation function. + function getSelectors(FunctionReference validationFunction) external view returns (bytes4[] memory); /// @notice Get the pre and post execution hooks for a selector. /// @param selector The selector to get the hooks for. diff --git a/src/interfaces/IPluginManager.sol b/src/interfaces/IPluginManager.sol index 32634e34..d98badbf 100644 --- a/src/interfaces/IPluginManager.sol +++ b/src/interfaces/IPluginManager.sol @@ -46,7 +46,6 @@ interface IPluginManager { /// @notice Uninstall a validation function from a set of execution selectors. /// TODO: remove or update. /// @param validationFunction The validation function to uninstall. - /// @param selectors The selectors to uninstall the validation function for. /// @param uninstallData Optional data to be decoded and used by the plugin to clear plugin data for the /// account. /// @param preValidationHookUninstallData Optional data to be decoded and used by the plugin to clear account @@ -54,7 +53,6 @@ interface IPluginManager { /// @param permissionHookUninstallData Optional data to be decoded and used by the plugin to clear account data function uninstallValidation( FunctionReference validationFunction, - bytes4[] calldata selectors, bytes calldata uninstallData, bytes calldata preValidationHookUninstallData, bytes calldata permissionHookUninstallData diff --git a/test/account/AccountLoupe.t.sol b/test/account/AccountLoupe.t.sol index a89d04cb..4c0ddb88 100644 --- a/test/account/AccountLoupe.t.sol +++ b/test/account/AccountLoupe.t.sol @@ -88,23 +88,15 @@ contract AccountLoupeTest is AccountTestBase { } } - function test_pluginLoupe_getValidationFunctions() public { - FunctionReference[] memory validations = account1.getValidations(comprehensivePlugin.foo.selector); - - assertEq(validations.length, 1); - assertEq( - FunctionReference.unwrap(validations[0]), - FunctionReference.unwrap( - FunctionReferenceLib.pack( - address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.VALIDATION) - ) - ) + function test_pluginLoupe_getSelectors() public { + FunctionReference comprehensivePluginValidation = FunctionReferenceLib.pack( + address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.VALIDATION) ); - validations = account1.getValidations(account1.execute.selector); + bytes4[] memory selectors = account1.getSelectors(comprehensivePluginValidation); - assertEq(validations.length, 1); - assertEq(FunctionReference.unwrap(validations[0]), FunctionReference.unwrap(_ownerValidation)); + assertEq(selectors.length, 1); + assertEq(selectors[0], comprehensivePlugin.foo.selector); } function test_pluginLoupe_getExecutionHooks() public { diff --git a/test/account/MultiValidation.t.sol b/test/account/MultiValidation.t.sol index 78867f55..9c79be9d 100644 --- a/test/account/MultiValidation.t.sol +++ b/test/account/MultiValidation.t.sol @@ -42,10 +42,13 @@ contract MultiValidationTest is AccountTestBase { ); validations[1] = FunctionReferenceLib.pack(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER)); - FunctionReference[] memory validations2 = account1.getValidations(IStandardExecutor.execute.selector); - assertEq(validations2.length, 2); - assertEq(FunctionReference.unwrap(validations2[0]), FunctionReference.unwrap(validations[0])); - assertEq(FunctionReference.unwrap(validations2[1]), FunctionReference.unwrap(validations[1])); + + bytes4[] memory selectors0 = account1.getSelectors(validations[0]); + bytes4[] memory selectors1 = account1.getSelectors(validations[1]); + assertEq(selectors0.length, selectors1.length); + for (uint256 i = 0; i < selectors0.length; i++) { + assertEq(selectors0[i], selectors1[i]); + } } function test_runtimeValidation_specify() public {