From ce90f557ffa0397c19efa42680938d7a3b5f29d0 Mon Sep 17 00:00:00 2001 From: adam Date: Wed, 19 Jun 2024 11:00:12 -0400 Subject: [PATCH] self-call protection --- src/account/UpgradeableModularAccount.sol | 73 ++++- test/account/SelfCallAuthorization.t.sol | 343 ++++++++++++++++++++++ test/utils/AccountTestBase.sol | 2 +- 3 files changed, 405 insertions(+), 13 deletions(-) create mode 100644 test/account/SelfCallAuthorization.t.sol diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index 99f92c80..d57fcb0f 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -72,6 +72,7 @@ contract UpgradeableModularAccount is error RequireUserOperationContext(); error RuntimeValidationFunctionMissing(bytes4 selector); error RuntimeValidationFunctionReverted(address plugin, uint8 functionId, bytes revertReason); + error SelfCallRecursionDepthExceeded(); error SignatureValidationInvalid(address plugin, uint8 functionId); error UnexpectedAggregator(address plugin, uint8 functionId, address aggregator); error UnrecognizedFunction(bytes4 selector); @@ -216,14 +217,12 @@ contract UpgradeableModularAccount is payable returns (bytes memory) { - bytes4 execSelector = bytes4(data[:4]); - // Revert if the provided `authorization` less than 21 bytes long, rather than right-padding. FunctionReference runtimeValidationFunction = FunctionReference.wrap(bytes21(authorization[:21])); // Check if the runtime validation function is allowed to be called bool isDefaultValidation = uint8(authorization[21]) == 1; - _checkIfValidationApplies(execSelector, runtimeValidationFunction, isDefaultValidation); + _checkIfValidationAppliesCallData(data, runtimeValidationFunction, isDefaultValidation); _doRuntimeValidation(runtimeValidationFunction, data, authorization[22:]); @@ -388,16 +387,12 @@ contract UpgradeableModularAccount is if (userOp.callData.length < 4) { revert UnrecognizedFunction(bytes4(userOp.callData)); } - bytes4 selector = bytes4(userOp.callData); - if (selector == this.executeUserOp.selector) { - selector = bytes4(userOp.callData[4:8]); - } // Revert if the provided `authorization` less than 21 bytes long, rather than right-padding. FunctionReference userOpValidationFunction = FunctionReference.wrap(bytes21(userOp.signature[:21])); bool isDefaultValidation = uint8(userOp.signature[21]) == 1; - _checkIfValidationApplies(selector, userOpValidationFunction, isDefaultValidation); + _checkIfValidationAppliesCallData(userOp.callData, userOpValidationFunction, isDefaultValidation); // Check if there are permission hooks associated with the validator, and revert if the call isn't to // `executeUserOp` @@ -623,10 +618,64 @@ contract UpgradeableModularAccount is // solhint-disable-next-line no-empty-blocks function _authorizeUpgrade(address newImplementation) internal override {} - function _checkIfValidationApplies(bytes4 selector, FunctionReference validationFunction, bool isDefault) - internal - view - { + function _checkIfValidationAppliesCallData( + bytes calldata callData, + FunctionReference validationFunction, + bool isDefault + ) internal view { + bytes4 outerSelector = bytes4(callData[:4]); + if (outerSelector == this.executeUserOp.selector) { + // If the selector is executeUserOp, pull the actual selector from the following data, + // and trim the calldata to ensure the self-call decoding is still accurate. + callData = callData[4:]; + outerSelector = bytes4(callData[:4]); + } + + _checkIfValidationAppliesSelector(outerSelector, validationFunction, isDefault); + + if (outerSelector == IStandardExecutor.execute.selector) { + (address target,,) = abi.decode(callData[4:], (address, uint256, bytes)); + + if (target == address(this)) { + // There is no point to call `execute` to recurse exactly once - this is equivalent to just having + // the calldata as a top-level call. + revert SelfCallRecursionDepthExceeded(); + } + } else if (outerSelector == IStandardExecutor.executeBatch.selector) { + // executeBatch may be used to batch account actions together, by targetting the account itself. + // If this is done, we must ensure all of the inner calls are allowed by the provided validation + // function. + + (Call[] memory calls) = abi.decode(callData[4:], (Call[])); + + for (uint256 i = 0; i < calls.length; ++i) { + if (calls[i].target == address(this)) { + bytes4 nestedSelector = bytes4(calls[i].data); + + if ( + nestedSelector == IStandardExecutor.execute.selector + || nestedSelector == IStandardExecutor.executeBatch.selector + ) { + // To prevent arbitrarily-deep recursive checking, we limit the depth of self-calls to one + // for the purposes of batching. + // This means that all self-calls must occur at the top level of the batch. + // Note that plugins of other contracts using `executeWithAuthorization` may still + // independently call into this account with a different validation function, allowing + // composition of multiple batches. + revert SelfCallRecursionDepthExceeded(); + } + + _checkIfValidationAppliesSelector(nestedSelector, validationFunction, isDefault); + } + } + } + } + + function _checkIfValidationAppliesSelector( + bytes4 selector, + FunctionReference validationFunction, + bool isDefault + ) internal view { AccountStorage storage _storage = getAccountStorage(); // Check that the provided validation function is applicable to the selector diff --git a/test/account/SelfCallAuthorization.t.sol b/test/account/SelfCallAuthorization.t.sol new file mode 100644 index 00000000..840f268a --- /dev/null +++ b/test/account/SelfCallAuthorization.t.sol @@ -0,0 +1,343 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {IAccountExecute} from "@eth-infinitism/account-abstraction/interfaces/IAccountExecute.sol"; +import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol"; +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; + +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {IStandardExecutor, Call} from "../../src/interfaces/IStandardExecutor.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; + +import {AccountTestBase} from "../utils/AccountTestBase.sol"; +import {DefaultValidationFactoryFixture} from "../mocks/DefaultValidationFactoryFixture.sol"; +import {ComprehensivePlugin} from "../mocks/plugins/ComprehensivePlugin.sol"; + +contract SelfCallAuthorizationTest is AccountTestBase { + DefaultValidationFactoryFixture public defaultValidationFactoryFixture; + + ComprehensivePlugin public comprehensivePlugin; + + FunctionReference public comprehensivePluginValidation; + + function setUp() public { + defaultValidationFactoryFixture = new DefaultValidationFactoryFixture(entryPoint, singleOwnerPlugin); + + account1 = UpgradeableModularAccount(payable(defaultValidationFactoryFixture.createAccount(owner1, 0))); + + vm.deal(address(account1), 100 ether); + + // install the comprehensive plugin to get new exec functions with different validations configured. + + comprehensivePlugin = new ComprehensivePlugin(); + + bytes32 manifestHash = keccak256(abi.encode(comprehensivePlugin.pluginManifest())); + vm.prank(address(entryPoint)); + account1.installPlugin(address(comprehensivePlugin), manifestHash, "", new FunctionReference[](0)); + + comprehensivePluginValidation = FunctionReferenceLib.pack( + address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.VALIDATION) + ); + } + + function test_selfCallFails_userOp() public { + // Uses default validation + _runUserOp( + abi.encodeCall(ComprehensivePlugin.foo, ()), + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector( + UpgradeableModularAccount.UserOpValidationFunctionMissing.selector, + ComprehensivePlugin.foo.selector + ) + ) + ); + } + + function test_selfCallFails_execUserOp() public { + // Uses default validation + _runUserOp( + abi.encodePacked(IAccountExecute.executeUserOp.selector, abi.encodeCall(ComprehensivePlugin.foo, ())), + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector( + UpgradeableModularAccount.UserOpValidationFunctionMissing.selector, + ComprehensivePlugin.foo.selector + ) + ) + ); + } + + function test_selfCallFails_runtime() public { + // Uses default validation + _runtimeCall( + abi.encodeCall(ComprehensivePlugin.foo, ()), + abi.encodeWithSelector( + UpgradeableModularAccount.UserOpValidationFunctionMissing.selector, + ComprehensivePlugin.foo.selector + ) + ); + } + + function test_selfCallPrivilegeEscalation_prevented_userOp() public { + // Using default validation, self-call bypasses custom validation needed for ComprehensivePlugin.foo + _runUserOp( + abi.encodeCall( + UpgradeableModularAccount.execute, + (address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())) + ), + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector(UpgradeableModularAccount.SelfCallRecursionDepthExceeded.selector) + ) + ); + + Call[] memory calls = new Call[](1); + calls[0] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + + _runUserOp( + abi.encodeCall(IStandardExecutor.executeBatch, (calls)), + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector( + UpgradeableModularAccount.UserOpValidationFunctionMissing.selector, + ComprehensivePlugin.foo.selector + ) + ) + ); + } + + function test_selfCallPrivilegeEscalation_prevented_execUserOp() public { + // Using default validation, self-call bypasses custom validation needed for ComprehensivePlugin.foo + _runUserOp( + abi.encodePacked( + IAccountExecute.executeUserOp.selector, + abi.encodeCall( + UpgradeableModularAccount.execute, + (address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())) + ) + ), + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector(UpgradeableModularAccount.SelfCallRecursionDepthExceeded.selector) + ) + ); + + Call[] memory calls = new Call[](1); + calls[0] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + + _runUserOp( + abi.encodePacked( + IAccountExecute.executeUserOp.selector, abi.encodeCall(IStandardExecutor.executeBatch, (calls)) + ), + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector( + UpgradeableModularAccount.UserOpValidationFunctionMissing.selector, + ComprehensivePlugin.foo.selector + ) + ) + ); + } + + function test_selfCallPrivilegeEscalation_prevented_runtime() public { + // Using default validation, self-call bypasses custom validation needed for ComprehensivePlugin.foo + _runtimeCall( + abi.encodeCall( + UpgradeableModularAccount.execute, + (address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())) + ), + abi.encodeWithSelector(UpgradeableModularAccount.SelfCallRecursionDepthExceeded.selector) + ); + + Call[] memory calls = new Call[](1); + calls[0] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + + _runtimeExecBatchExpFail( + calls, + abi.encodeWithSelector( + UpgradeableModularAccount.UserOpValidationFunctionMissing.selector, + ComprehensivePlugin.foo.selector + ) + ); + } + + function test_batchAction_allowed_userOp() public { + _enableBatchValidation(); + + Call[] memory calls = new Call[](2); + calls[0] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + calls[1] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + + PackedUserOperation memory userOp = _generateUserOpWithComprehensivePluginValidation( + abi.encodeCall(IStandardExecutor.executeBatch, (calls)) + ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectCall(address(comprehensivePlugin), abi.encodeCall(ComprehensivePlugin.foo, ()), 2); + entryPoint.handleOps(userOps, beneficiary); + } + + function test_batchAction_allowed_execUserOp() public { + _enableBatchValidation(); + + Call[] memory calls = new Call[](2); + calls[0] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + calls[1] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + + PackedUserOperation memory userOp = _generateUserOpWithComprehensivePluginValidation( + abi.encodePacked( + IAccountExecute.executeUserOp.selector, abi.encodeCall(IStandardExecutor.executeBatch, (calls)) + ) + ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectCall(address(comprehensivePlugin), abi.encodeCall(ComprehensivePlugin.foo, ()), 2); + entryPoint.handleOps(userOps, beneficiary); + } + + function test_batchAction_allowed_runtime() public { + _enableBatchValidation(); + + Call[] memory calls = new Call[](2); + calls[0] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + calls[1] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + + vm.expectCall(address(comprehensivePlugin), abi.encodeCall(ComprehensivePlugin.foo, ()), 2); + account1.executeWithAuthorization( + abi.encodeCall(IStandardExecutor.executeBatch, (calls)), + _encodeSignature(comprehensivePluginValidation, SELECTOR_ASSOCIATED_VALIDATION, "") + ); + } + + function test_recursiveDepthCapped_userOp() public { + _enableBatchValidation(); + + Call[] memory innerCalls = new Call[](1); + innerCalls[0] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + + Call[] memory outerCalls = new Call[](1); + outerCalls[0] = Call(address(account1), 0, abi.encodeCall(IStandardExecutor.executeBatch, (innerCalls))); + + PackedUserOperation memory userOp = _generateUserOpWithComprehensivePluginValidation( + abi.encodeCall(IStandardExecutor.executeBatch, (outerCalls)) + ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectRevert( + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector(UpgradeableModularAccount.SelfCallRecursionDepthExceeded.selector) + ) + ); + entryPoint.handleOps(userOps, beneficiary); + } + + function test_recursiveDepthCapped_execUserOp() public { + _enableBatchValidation(); + + Call[] memory innerCalls = new Call[](1); + innerCalls[0] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + + Call[] memory outerCalls = new Call[](1); + outerCalls[0] = Call(address(account1), 0, abi.encodeCall(IStandardExecutor.executeBatch, (innerCalls))); + + PackedUserOperation memory userOp = _generateUserOpWithComprehensivePluginValidation( + abi.encodePacked( + IAccountExecute.executeUserOp.selector, + abi.encodeCall(IStandardExecutor.executeBatch, (outerCalls)) + ) + ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectRevert( + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector(UpgradeableModularAccount.SelfCallRecursionDepthExceeded.selector) + ) + ); + entryPoint.handleOps(userOps, beneficiary); + } + + function test_recursiveDepthCapped_runtime() public { + _enableBatchValidation(); + + Call[] memory innerCalls = new Call[](1); + innerCalls[0] = Call(address(account1), 0, abi.encodeCall(ComprehensivePlugin.foo, ())); + + Call[] memory outerCalls = new Call[](1); + outerCalls[0] = Call(address(account1), 0, abi.encodeCall(IStandardExecutor.executeBatch, (innerCalls))); + + vm.expectRevert(abi.encodeWithSelector(UpgradeableModularAccount.SelfCallRecursionDepthExceeded.selector)); + account1.executeWithAuthorization( + abi.encodeCall(IStandardExecutor.executeBatch, (outerCalls)), + _encodeSignature(comprehensivePluginValidation, SELECTOR_ASSOCIATED_VALIDATION, "") + ); + } + + function _enableBatchValidation() internal { + // Extend ComprehensivePlugin's validation function to also validate `executeBatch`, to allow the + // self-call. + + bytes4[] memory selectors = new bytes4[](1); + selectors[0] = IStandardExecutor.executeBatch.selector; + + vm.prank(owner1); + account1.executeWithAuthorization( + abi.encodeCall( + UpgradeableModularAccount.installValidation, + (comprehensivePluginValidation, false, selectors, "", "", "") + ), + _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, "") + ); + } + + function _generateUserOpWithComprehensivePluginValidation(bytes memory callData) + internal + view + returns (PackedUserOperation memory) + { + uint256 nonce = entryPoint.getNonce(address(account1), 0); + return PackedUserOperation({ + sender: address(account1), + nonce: nonce, + initCode: hex"", + callData: callData, + accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT), + preVerificationGas: 0, + gasFees: _encodeGas(1, 1), + paymasterAndData: hex"", + signature: _encodeSignature( + comprehensivePluginValidation, + SELECTOR_ASSOCIATED_VALIDATION, + // Comprehensive plugin's validation function doesn't actually check anything, so we don't need to + // sign anything. + "" + ) + }); + } +} diff --git a/test/utils/AccountTestBase.sol b/test/utils/AccountTestBase.sol index f5fe033b..cc11334d 100644 --- a/test/utils/AccountTestBase.sol +++ b/test/utils/AccountTestBase.sol @@ -35,7 +35,7 @@ abstract contract AccountTestBase is OptimizedTest { uint8 public constant SELECTOR_ASSOCIATED_VALIDATION = 0; uint8 public constant DEFAULT_VALIDATION = 1; - uint256 public constant CALL_GAS_LIMIT = 50000; + uint256 public constant CALL_GAS_LIMIT = 100000; uint256 public constant VERIFICATION_GAS_LIMIT = 1200000; struct PreValidationHookData {