Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@
path = lib/openzeppelin-contracts
url = https://github.com/OpenZeppelin/openzeppelin-contracts
branch = v4.8.2
[submodule "lib/solady"]
path = lib/solady
url = https://github.com/vectorized/solady
1 change: 1 addition & 0 deletions lib/solady
Submodule solady added at 50cbe1
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"name": "@zerodevapp/contracts",
"description": "ZeroDev Account Abstraction (EIP 4337) contracts",
"main": "./dist/index.js",
"version": "4.0.0-beta.9",
"version": "4.0.0-beta.13",
"scripts": {
"prepack": "./scripts/prepack-contracts-package.sh",
"postpack": "./scripts/postpack-contracts-package.sh"
Expand Down
25 changes: 25 additions & 0 deletions src/executor/KillSwitchAction.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import "src/validator/IValidator.sol";
import "src/abstract/KernelStorage.sol";

contract KillSwitchAction {
IKernelValidator public immutable killSwitchValidator;

constructor(IKernelValidator _killswitchValidator) {
killSwitchValidator = _killswitchValidator;
}

// Function to get the wallet kernel storage
function getKernelStorage() internal pure returns (WalletKernelStorage storage ws) {
bytes32 storagePosition = bytes32(uint256(keccak256("zerodev.kernel")) - 1);
assembly {
ws.slot := storagePosition
}
}

function activateKillSwitch() external {
WalletKernelStorage storage ws = getKernelStorage();
ws.defaultValidator = killSwitchValidator;
getKernelStorage().disabledMode = bytes4(0xffffffff);
getKernelStorage().lastDisabledTime = uint48(block.timestamp);
}
}
63 changes: 36 additions & 27 deletions src/validator/KillSwitchValidator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,21 @@ import "openzeppelin-contracts/contracts/utils/cryptography/EIP712.sol";
import "src/utils/KernelHelper.sol";
import "account-abstraction/core/Helpers.sol";
import "src/Kernel.sol";
import { WalletKernelStorage, ExecutionDetail} from "src/abstract/KernelStorage.sol";
import "./ECDSAValidator.sol";


struct KillSwitchValidatorStorage {
address owner;
address guardian;
IKernelValidator validator;
uint48 pausedUntil;
}

contract KillSwitchValidator is IKernelValidator {
mapping(address => KillSwitchValidatorStorage) public killSwitchValidatorStorage;

function enable(bytes calldata enableData) external override {
killSwitchValidatorStorage[msg.sender].owner = address(bytes20(enableData[0:20]));
killSwitchValidatorStorage[msg.sender].guardian = address(bytes20(enableData[20:40]));
killSwitchValidatorStorage[msg.sender].guardian = address(bytes20(enableData[0:20]));
}

function disable(bytes calldata) external override {
Expand All @@ -29,40 +30,48 @@ contract KillSwitchValidator is IKernelValidator {

function validateSignature(bytes32 hash, bytes calldata signature) external view override returns (uint256) {
KillSwitchValidatorStorage storage validatorStorage = killSwitchValidatorStorage[msg.sender];
return _packValidationData(
validatorStorage.owner != ECDSA.recover(hash, signature), 0, validatorStorage.pausedUntil
);
uint256 res = validatorStorage.validator.validateSignature(hash,signature);
uint48 pausedUntil = validatorStorage.pausedUntil;
ValidationData memory validationData = _parseValidationData(res);
if(validationData.aggregator != address(1)) { // if signature verification has not been failed, return with the result
uint256 delayedData = _packValidationData(false, 0, pausedUntil);
return _packValidationData(_intersectTimeRange(res, delayedData));
}
}

function validateUserOp(UserOperation calldata _userOp, bytes32 _userOpHash, uint256)
external
override
returns (uint256)
{
address signer;
bytes calldata signature;
KillSwitchValidatorStorage storage validatorStorage = killSwitchValidatorStorage[_userOp.sender];
if (_userOp.signature.length == 6 + 65) {
require(bytes4(_userOp.callData[0:4]) != KernelStorage.disableMode.selector);
signer = validatorStorage.guardian;
uint48 pausedUntil = uint48(bytes6(_userOp.signature[0:6]));
require(pausedUntil > validatorStorage.pausedUntil, "KillSwitchValidator: invalid pausedUntil");
killSwitchValidatorStorage[_userOp.sender].pausedUntil = pausedUntil;
signature = _userOp.signature[6:71];
} else {
signer = killSwitchValidatorStorage[_userOp.sender].owner;
signature = _userOp.signature;
uint48 pausedUntil = validatorStorage.pausedUntil;
uint256 validationResult = 0;
if(address(validatorStorage.validator) != address(0)){
// check for validator at first
try validatorStorage.validator.validateUserOp(_userOp, _userOpHash, pausedUntil) returns (uint256 res) {
validationResult = res;
} catch {
validationResult = SIG_VALIDATION_FAILED;
}
ValidationData memory validationData = _parseValidationData(validationResult);
if(validationData.aggregator != address(1)) { // if signature verification has not been failed, return with the result
uint256 delayedData = _packValidationData(false, 0, pausedUntil);
return _packValidationData(_intersectTimeRange(validationResult, delayedData));
}
}
if (signer == ECDSA.recover(_userOpHash, signature)) {
// address(0) attack has been resolved in ECDSA library
return _packValidationData(false, 0, validatorStorage.pausedUntil);
}

bytes32 hash = ECDSA.toEthSignedMessageHash(_userOpHash);
address recovered = ECDSA.recover(hash, signature);
if (signer != recovered) {
if(_userOp.signature.length == 71) {
// save data to this storage
validatorStorage.pausedUntil = uint48(bytes6(_userOp.signature[0:6]));
validatorStorage.validator = KernelStorage(msg.sender).getDefaultValidator();
bytes32 hash = ECDSA.toEthSignedMessageHash(keccak256(bytes.concat(_userOp.signature[0:6],_userOpHash)));
address recovered = ECDSA.recover(hash, _userOp.signature[6:]);
if (validatorStorage.guardian != recovered) {
return SIG_VALIDATION_FAILED;
}
return _packValidationData(false, 0, pausedUntil);
} else {
return SIG_VALIDATION_FAILED;
}
return _packValidationData(false, 0, validatorStorage.pausedUntil);
}
}
188 changes: 188 additions & 0 deletions test/foundry/KillSwitch.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;

import "src/factory/KernelFactory.sol";
import "src/factory/TempKernel.sol";
import "src/validator/ECDSAValidator.sol";
import "src/factory/ECDSAKernelFactory.sol";
import "src/Kernel.sol";
import "src/validator/KillSwitchValidator.sol";
import "src/executor/KillSwitchAction.sol";
import "src/factory/EIP1967Proxy.sol";
// test utils
import "forge-std/Test.sol";
import {ERC4337Utils} from "./ERC4337Utils.sol";

using ERC4337Utils for EntryPoint;

contract KernelExecutionTest is Test {
Kernel kernel;
KernelFactory factory;
ECDSAKernelFactory ecdsaFactory;
EntryPoint entryPoint;
ECDSAValidator validator;

KillSwitchValidator killSwitch;
KillSwitchAction action;
address owner;
uint256 ownerKey;
address payable beneficiary;

function setUp() public {
(owner, ownerKey) = makeAddrAndKey("owner");
entryPoint = new EntryPoint();
factory = new KernelFactory(entryPoint);

validator = new ECDSAValidator();
ecdsaFactory = new ECDSAKernelFactory(factory, validator, entryPoint);

kernel = Kernel(payable(address(ecdsaFactory.createAccount(owner, 0))));
vm.deal(address(kernel), 1e30);
beneficiary = payable(address(makeAddr("beneficiary")));
killSwitch = new KillSwitchValidator();
action = new KillSwitchAction(killSwitch);
}

function test_mode_2() external {
UserOperation memory op = entryPoint.fillUserOp(
address(kernel),
abi.encodeWithSelector(Kernel.execute.selector, owner, 0, "", Operation.Call)
);

op.signature = bytes.concat(bytes4(0), entryPoint.signUserOpHash(vm, ownerKey, op));
UserOperation[] memory ops = new UserOperation[](1);
ops[0] = op;
entryPoint.handleOps(ops, beneficiary);


op = entryPoint.fillUserOp(
address(kernel),
abi.encodeWithSelector(KillSwitchAction.activateKillSwitch.selector)
);
address guardianKeyAddr;
uint256 guardianKeyPriv;
(guardianKeyAddr, guardianKeyPriv) = makeAddrAndKey("guardianKey");
bytes memory enableData = abi.encodePacked(
guardianKeyAddr
);
{
bytes32 digest = getTypedDataHash(
address(kernel),
KillSwitchAction.activateKillSwitch.selector,
0,
0,
address(killSwitch),
address(action),
enableData
);
(uint8 v, bytes32 r, bytes32 s) = vm.sign(ownerKey, digest);

op.signature = abi.encodePacked(
bytes4(0x00000002),
uint48(0),
uint48(0),
address(killSwitch),
address(action),
uint256(enableData.length),
enableData,
uint256(65),
r,
s,
v
);
}

uint256 pausedUntil = block.timestamp + 1000;

bytes32 hash = entryPoint.getUserOpHash(op);
{
(uint8 v, bytes32 r, bytes32 s) = vm.sign(guardianKeyPriv, ECDSA.toEthSignedMessageHash(keccak256(bytes.concat(bytes6(uint48(pausedUntil)),hash))));
bytes memory sig = abi.encodePacked(r, s, v);

op.signature = bytes.concat(op.signature, bytes6(uint48(pausedUntil)), sig);
}

ops[0] = op;
logGas(op);
entryPoint.handleOps(ops, beneficiary);
assertEq(address(kernel.getDefaultValidator()), address(killSwitch));
op = entryPoint.fillUserOp(
address(kernel),
abi.encodeWithSelector(Kernel.execute.selector, owner, 0, "", Operation.Call)
);

op.signature = bytes.concat(bytes4(0), entryPoint.signUserOpHash(vm, ownerKey, op));
ops[0] = op;
vm.expectRevert();
entryPoint.handleOps(ops, beneficiary); // should revert because kill switch is active
vm.warp(pausedUntil + 1);
entryPoint.handleOps(ops, beneficiary); // should not revert because pausedUntil has been passed
}

function logGas(UserOperation memory op) internal returns (uint256 used) {
try this.consoleGasUsage(op) {
revert("should revert");
} catch Error(string memory reason) {
used = abi.decode(bytes(reason), (uint256));
console.log("validation gas usage :", used);
}
}

function consoleGasUsage(UserOperation memory op) external {
uint256 gas = gasleft();
vm.startPrank(address(entryPoint));
kernel.validateUserOp(op, entryPoint.getUserOpHash(op), 0);
vm.stopPrank();
revert(string(abi.encodePacked(gas - gasleft())));
}
}

// computes the hash of a permit
function getStructHash(
bytes4 sig,
uint48 validUntil,
uint48 validAfter,
address validator,
address executor,
bytes memory enableData
) pure returns (bytes32) {
return keccak256(
abi.encode(
keccak256("ValidatorApproved(bytes4 sig,uint256 validatorData,address executor,bytes enableData)"),
bytes4(sig),
uint256(uint256(uint160(validator)) | (uint256(validAfter) << 160) | (uint256(validUntil) << (48 + 160))),
executor,
keccak256(enableData)
)
);
}

// computes the hash of the fully encoded EIP-712 message for the domain, which can be used to recover the signer
function getTypedDataHash(
address sender,
bytes4 sig,
uint48 validUntil,
uint48 validAfter,
address validator,
address executor,
bytes memory enableData
) view returns (bytes32) {
return keccak256(
abi.encodePacked(
"\x19\x01",
_buildDomainSeparator("Kernel", "0.0.2", sender),
getStructHash(sig, validUntil, validAfter, validator, executor, enableData)
)
);
}

function _buildDomainSeparator(string memory name, string memory version, address verifyingContract)
view
returns (bytes32)
{
bytes32 hashedName = keccak256(bytes(name));
bytes32 hashedVersion = keccak256(bytes(version));
bytes32 typeHash = keccak256("EIP712Domain(string name,string version,uint256 chainId,address verifyingContract)");

return keccak256(abi.encode(typeHash, hashedName, hashedVersion, block.chainid, address(verifyingContract)));
}