Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions src/account/AccountLoupe.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@ pragma solidity ^0.8.25;
import {UUPSUpgradeable} from "@openzeppelin/contracts/proxy/utils/UUPSUpgradeable.sol";
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";

import {IAccountLoupe} from "../interfaces/IAccountLoupe.sol";
import {IAccountLoupe, ExecutionHook} from "../interfaces/IAccountLoupe.sol";
import {FunctionReference, IPluginManager} from "../interfaces/IPluginManager.sol";
import {IStandardExecutor} from "../interfaces/IStandardExecutor.sol";
import {AccountStorage, getAccountStorage, SelectorData, toFunctionReferenceArray} from "./AccountStorage.sol";
import {
AccountStorage,
getAccountStorage,
SelectorData,
toFunctionReferenceArray,
toExecutionHook
} from "./AccountStorage.sol";

abstract contract AccountLoupe is IAccountLoupe {
using EnumerableSet for EnumerableSet.Bytes32Set;
Expand Down Expand Up @@ -36,25 +42,16 @@ abstract contract AccountLoupe is IAccountLoupe {
}

/// @inheritdoc IAccountLoupe
function getExecutionHooks(bytes4 selector) external view returns (ExecutionHooks[] memory execHooks) {
function getExecutionHooks(bytes4 selector) external view returns (ExecutionHook[] memory execHooks) {
SelectorData storage selectorData = getAccountStorage().selectorData[selector];
uint256 preExecHooksLength = selectorData.preHooks.length();
uint256 postOnlyExecHooksLength = selectorData.postOnlyHooks.length();
uint256 executionHooksLength = selectorData.executionHooks.length();

execHooks = new ExecutionHooks[](preExecHooksLength + postOnlyExecHooksLength);
execHooks = new ExecutionHook[](executionHooksLength);

for (uint256 i = 0; i < preExecHooksLength; ++i) {
bytes32 key = selectorData.preHooks.at(i);
FunctionReference preExecHook = FunctionReference.wrap(bytes21(key));
FunctionReference associatedPostExecHook = selectorData.associatedPostHooks[preExecHook];

execHooks[i].preExecHook = preExecHook;
execHooks[i].postExecHook = associatedPostExecHook;
}

for (uint256 i = 0; i < postOnlyExecHooksLength; ++i) {
bytes32 key = selectorData.postOnlyHooks.at(i);
execHooks[preExecHooksLength + i].postExecHook = FunctionReference.wrap(bytes21(key));
for (uint256 i = 0; i < executionHooksLength; ++i) {
bytes32 key = selectorData.executionHooks.at(i);
ExecutionHook memory execHook = execHooks[i];
(execHook.hookFunction, execHook.isPreHook, execHook.isPostHook) = toExecutionHook(key);
}
}

Expand Down
34 changes: 30 additions & 4 deletions src/account/AccountStorage.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pragma solidity ^0.8.25;
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";

import {IPlugin} from "../interfaces/IPlugin.sol";
import {ExecutionHook} from "../interfaces/IAccountLoupe.sol";
import {FunctionReference} from "../interfaces/IPluginManager.sol";

// bytes = keccak256("ERC6900.UpgradeableModularAccount.Storage")
Expand Down Expand Up @@ -44,10 +45,7 @@ struct SelectorData {
// The pre validation hooks for this function selector.
EnumerableSet.Bytes32Set preValidationHooks;
// The execution hooks for this function selector.
EnumerableSet.Bytes32Set preHooks;
// bytes21 key = pre hook function reference
mapping(FunctionReference => FunctionReference) associatedPostHooks;
EnumerableSet.Bytes32Set postOnlyHooks;
EnumerableSet.Bytes32Set executionHooks;
}

struct AccountStorage {
Expand Down Expand Up @@ -79,6 +77,34 @@ function getPermittedCallKey(address addr, bytes4 selector) pure returns (bytes2

using EnumerableSet for EnumerableSet.Bytes32Set;

function toSetValue(FunctionReference functionReference) pure returns (bytes32) {
return bytes32(FunctionReference.unwrap(functionReference));
}

function toFunctionReference(bytes32 setValue) pure returns (FunctionReference) {
return FunctionReference.wrap(bytes21(setValue));
}

// ExecutionHook layout:
// 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF______________________ Hook Function Reference
// 0x__________________________________________AA____________________ is pre hook
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we could precede the hook function with the type identifier or implement a sorting byte. This way, in storage, all prehooks could be organized before posthooks, allowing the loop to terminate early.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this would be useful in practice. Currently though, the spec doesn't mandate any ordering for hooks, so can we address this problem in a different PR, for this issue? erc6900/resources#5

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 thanks!

// 0x____________________________________________BB__________________ is post hook

function toSetValue(ExecutionHook memory executionHook) pure returns (bytes32) {
return bytes32(FunctionReference.unwrap(executionHook.hookFunction))
| bytes32(executionHook.isPreHook ? uint256(1) << 80 : 0)
| bytes32(executionHook.isPostHook ? uint256(1) << 72 : 0);
}

function toExecutionHook(bytes32 setValue)
pure
returns (FunctionReference hookFunction, bool isPreHook, bool isPostHook)
{
hookFunction = FunctionReference.wrap(bytes21(setValue));
isPreHook = (uint256(setValue) >> 80) & 0xFF == 1;
isPostHook = (uint256(setValue) >> 72) & 0xFF == 1;
Comment on lines +104 to +105
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why 0xFF over 0x01?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used 0xFF because I wanted to just compare the least significant byte, and I could have also used 0x01 here. It would behave the same if you only use toHookData and toSetValue to convert between the types, it could differ if someone manually dirties the upper bits of either of the two bytes though.

}

/// @dev Helper function to get all elements of a set into memory.
function toFunctionReferenceArray(EnumerableSet.Bytes32Set storage set)
view
Expand Down
99 changes: 30 additions & 69 deletions src/account/PluginManagerInternals.sol
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ import {
ManifestExternalCallPermission,
PluginManifest
} from "../interfaces/IPlugin.sol";
import {ExecutionHook} from "../interfaces/IAccountLoupe.sol";
import {FunctionReference, IPluginManager} from "../interfaces/IPluginManager.sol";
import {
AccountStorage,
getAccountStorage,
SelectorData,
toSetValue,
getPermittedCallKey,
PermittedExternalCallData
} from "./AccountStorage.sol";
Expand Down Expand Up @@ -96,49 +98,30 @@ abstract contract PluginManagerInternals is IPluginManager {
_selectorData.validation = FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE;
}

function _addExecHooks(bytes4 selector, FunctionReference preExecHook, FunctionReference postExecHook)
internal
{
SelectorData storage _selectorData = getAccountStorage().selectorData[selector];

if (preExecHook.notEmpty()) {
// Don't need to check for duplicates, as the hook can be run at most once.
_selectorData.preHooks.add(_toSetValue(preExecHook));

if (postExecHook.notEmpty()) {
_selectorData.associatedPostHooks[preExecHook] = postExecHook;
}

return;
}

if (postExecHook.isEmpty()) {
// both pre and post hooks cannot be null
revert NullFunctionReference();
}

_selectorData.postOnlyHooks.add(_toSetValue(postExecHook));
function _addExecHooks(
bytes4 selector,
FunctionReference hookFunction,
bool isPreExecHook,
bool isPostExecHook
) internal {
getAccountStorage().selectorData[selector].executionHooks.add(
toSetValue(
ExecutionHook({hookFunction: hookFunction, isPreHook: isPreExecHook, isPostHook: isPostExecHook})
)
);
}

function _removeExecHooks(bytes4 selector, FunctionReference preExecHook, FunctionReference postExecHook)
internal
{
SelectorData storage _selectorData = getAccountStorage().selectorData[selector];

if (preExecHook.notEmpty()) {
_selectorData.preHooks.remove(_toSetValue(preExecHook));

if (postExecHook.notEmpty()) {
_selectorData.associatedPostHooks[preExecHook] = FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE;
}

return;
}

// The case where both pre and post hooks are null was checked during installation.

// May ignore return value, as the manifest hash is validated to ensure that the hook exists.
_selectorData.postOnlyHooks.remove(_toSetValue(postExecHook));
function _removeExecHooks(
bytes4 selector,
FunctionReference hookFunction,
bool isPreExecHook,
bool isPostExecHook
) internal {
getAccountStorage().selectorData[selector].executionHooks.remove(
toSetValue(
ExecutionHook({hookFunction: hookFunction, isPreHook: isPreExecHook, isPostHook: isPostExecHook})
)
);
}

function _addPreValidationHook(bytes4 selector, FunctionReference preValidationHook)
Expand All @@ -151,7 +134,7 @@ abstract contract PluginManagerInternals is IPluginManager {
_selectorData.denyExecutionCount += 1;
return;
}
_selectorData.preValidationHooks.add(_toSetValue(preValidationHook));
_selectorData.preValidationHooks.add(toSetValue(preValidationHook));
}

function _removePreValidationHook(bytes4 selector, FunctionReference preValidationHook)
Expand All @@ -165,7 +148,7 @@ abstract contract PluginManagerInternals is IPluginManager {
return;
}
// May ignore return value, as the manifest hash is validated to ensure that the hook exists.
_selectorData.preValidationHooks.remove(_toSetValue(preValidationHook));
_selectorData.preValidationHooks.remove(toSetValue(preValidationHook));
}

function _installPlugin(
Expand Down Expand Up @@ -299,15 +282,8 @@ abstract contract PluginManagerInternals is IPluginManager {
length = manifest.executionHooks.length;
for (uint256 i = 0; i < length; ++i) {
ManifestExecutionHook memory mh = manifest.executionHooks[i];
_addExecHooks(
mh.executionSelector,
_resolveManifestFunction(
mh.preExecHook, plugin, emptyDependencies, ManifestAssociatedFunctionType.NONE
),
_resolveManifestFunction(
mh.postExecHook, plugin, emptyDependencies, ManifestAssociatedFunctionType.NONE
)
);
FunctionReference hookFunction = FunctionReferenceLib.pack(plugin, mh.functionId);
_addExecHooks(mh.executionSelector, hookFunction, mh.isPreHook, mh.isPostHook);
}

length = manifest.interfaceIds.length;
Expand Down Expand Up @@ -365,15 +341,8 @@ abstract contract PluginManagerInternals is IPluginManager {
length = manifest.executionHooks.length;
for (uint256 i = 0; i < length; ++i) {
ManifestExecutionHook memory mh = manifest.executionHooks[i];
_removeExecHooks(
mh.executionSelector,
_resolveManifestFunction(
mh.preExecHook, plugin, emptyDependencies, ManifestAssociatedFunctionType.NONE
),
_resolveManifestFunction(
mh.postExecHook, plugin, emptyDependencies, ManifestAssociatedFunctionType.NONE
)
);
FunctionReference hookFunction = FunctionReferenceLib.pack(plugin, mh.functionId);
_removeExecHooks(mh.executionSelector, hookFunction, mh.isPreHook, mh.isPostHook);
}

length = manifest.preValidationHooks.length;
Expand Down Expand Up @@ -461,14 +430,6 @@ abstract contract PluginManagerInternals is IPluginManager {
emit PluginUninstalled(plugin, onUninstallSuccess);
}

function _toSetValue(FunctionReference functionReference) internal pure returns (bytes32) {
return bytes32(FunctionReference.unwrap(functionReference));
}

function _toFunctionReference(bytes32 setValue) internal pure returns (FunctionReference) {
return FunctionReference.wrap(bytes21(setValue));
}

function _isValidPluginManifest(PluginManifest memory manifest, bytes32 manifestHash)
internal
pure
Expand Down
56 changes: 26 additions & 30 deletions src/account/UpgradeableModularAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@ import {FunctionReference, IPluginManager} from "../interfaces/IPluginManager.so
import {IStandardExecutor, Call} from "../interfaces/IStandardExecutor.sol";
import {AccountExecutor} from "./AccountExecutor.sol";
import {AccountLoupe} from "./AccountLoupe.sol";
import {AccountStorage, getAccountStorage, getPermittedCallKey, SelectorData} from "./AccountStorage.sol";
import {
AccountStorage,
getAccountStorage,
getPermittedCallKey,
SelectorData,
toFunctionReference,
toExecutionHook
} from "./AccountStorage.sol";
import {AccountStorageInitializable} from "./AccountStorageInitializable.sol";
import {PluginManagerInternals} from "./PluginManagerInternals.sol";

Expand Down Expand Up @@ -51,7 +58,6 @@ contract UpgradeableModularAccount is
error AuthorizeUpgradeReverted(bytes revertReason);
error ExecFromPluginNotPermitted(address plugin, bytes4 selector);
error ExecFromPluginExternalNotPermitted(address plugin, address target, uint256 value, bytes data);
error InvalidConfiguration();
error NativeTokenSpendingNotPermitted(address plugin);
error PostExecHookReverted(address plugin, uint8 functionId, bytes revertReason);
error PreExecHookReverted(address plugin, uint8 functionId, bytes revertReason);
Expand Down Expand Up @@ -355,7 +361,7 @@ contract UpgradeableModularAccount is
uint256 preUserOpValidationHooksLength = preUserOpValidationHooks.length();
for (uint256 i = 0; i < preUserOpValidationHooksLength; ++i) {
bytes32 key = preUserOpValidationHooks.at(i);
FunctionReference preUserOpValidationHook = _toFunctionReference(key);
FunctionReference preUserOpValidationHook = toFunctionReference(key);

(address plugin, uint8 functionId) = preUserOpValidationHook.unpack();
currentValidationData = IPlugin(plugin).preUserOpValidationHook(functionId, userOp, userOpHash);
Expand Down Expand Up @@ -398,7 +404,7 @@ contract UpgradeableModularAccount is
uint256 preRuntimeValidationHooksLength = preRuntimeValidationHooks.length();
for (uint256 i = 0; i < preRuntimeValidationHooksLength; ++i) {
bytes32 key = preRuntimeValidationHooks.at(i);
FunctionReference preRuntimeValidationHook = _toFunctionReference(key);
FunctionReference preRuntimeValidationHook = toFunctionReference(key);

(address plugin, uint8 functionId) = preRuntimeValidationHook.unpack();
// solhint-disable-next-line no-empty-blocks
Expand Down Expand Up @@ -432,44 +438,34 @@ contract UpgradeableModularAccount is
{
SelectorData storage selectorData = getAccountStorage().selectorData[selector];

uint256 preExecHooksLength = selectorData.preHooks.length();
uint256 postOnlyHooksLength = selectorData.postOnlyHooks.length();
uint256 hooksLength = selectorData.executionHooks.length();

// Overallocate on length - not all of this may get filled up. We set the correct length later.
postHooksToRun = new PostExecToRun[](preExecHooksLength + postOnlyHooksLength);
postHooksToRun = new PostExecToRun[](hooksLength);

// Copy all post hooks to the array. This happens before any pre hooks are run, so we can
// be sure that the set of hooks to run will not be affected by state changes mid-execution.

// Copy post-only hooks.
for (uint256 i = 0; i < postOnlyHooksLength; ++i) {
bytes32 key = selectorData.postOnlyHooks.at(i);
postHooksToRun[i].postExecHook = _toFunctionReference(key);
}

// Copy associated post hooks to the array.
for (uint256 i = 0; i < preExecHooksLength; ++i) {
FunctionReference preExecHook = _toFunctionReference(selectorData.preHooks.at(i));

FunctionReference associatedPostExecHook = selectorData.associatedPostHooks[preExecHook];

if (associatedPostExecHook.notEmpty()) {
postHooksToRun[i + postOnlyHooksLength].postExecHook = associatedPostExecHook;
for (uint256 i = 0; i < hooksLength; ++i) {
bytes32 key = selectorData.executionHooks.at(i);
(FunctionReference hookFunction,, bool isPostHook) = toExecutionHook(key);
if (isPostHook) {
postHooksToRun[i].postExecHook = hookFunction;
}
}

// Run the pre hooks and copy their return data to the post hooks array, if an associated post-exec hook
// exists.
for (uint256 i = 0; i < preExecHooksLength; ++i) {
bytes32 key = selectorData.preHooks.at(i);
FunctionReference preExecHook = _toFunctionReference(key);
for (uint256 i = 0; i < hooksLength; ++i) {
bytes32 key = selectorData.executionHooks.at(i);
(FunctionReference hookFunction, bool isPreHook, bool isPostHook) = toExecutionHook(key);

bytes memory preExecHookReturnData = _runPreExecHook(preExecHook, data);
if (isPreHook) {
bytes memory preExecHookReturnData = _runPreExecHook(hookFunction, data);

// If there is an associated post-exec hook, save the return data.
PostExecToRun memory postExecToRun = postHooksToRun[i + postOnlyHooksLength];
if (postExecToRun.postExecHook.notEmpty()) {
postExecToRun.preExecHookReturnData = preExecHookReturnData;
// If there is an associated post-exec hook, save the return data.
if (isPostHook) {
postHooksToRun[i].preExecHookReturnData = preExecHookReturnData;
}
}
}
}
Expand Down
Loading