From bd4433f0ee08a7fd3bcc646fe5bc745053197d95 Mon Sep 17 00:00:00 2001 From: Tu Pham Date: Mon, 25 May 2026 16:17:22 +0700 Subject: [PATCH] Fix Lido withdrawal claim access control --- src/contracts/LidoARM.sol | 12 ++++-- .../LidoARM/ClaimStETHWithdrawalForWETH.t.sol | 37 +++++++++++++++++++ test/fork/utils/MockCall.sol | 2 +- 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/src/contracts/LidoARM.sol b/src/contracts/LidoARM.sol index 34bfbcac..6964424a 100644 --- a/src/contracts/LidoARM.sol +++ b/src/contracts/LidoARM.sol @@ -141,7 +141,12 @@ contract LidoARM is Initializable, AbstractARM { * @param hintIds The hint IDs of the withdrawal requests. * Call `findCheckpointHints` on the Lido withdrawal queue contract to get the hint IDs. */ - function claimLidoWithdrawals(uint256[] calldata requestIds, uint256[] calldata hintIds) external { + function claimLidoWithdrawals(uint256[] calldata requestIds, uint256[] calldata hintIds) + external + onlyOperatorOrOwner + { + uint256 ethBalanceBefore = address(this).balance; + // Claim the NFTs for ETH. lidoWithdrawalQueue.claimWithdrawals(requestIds, hintIds); @@ -165,8 +170,9 @@ contract LidoARM is Initializable, AbstractARM { // this subtraction should never underflow. lidoWithdrawalQueueAmount -= totalAmountRequested; - // Wrap all the received ETH to WETH. This can be less than the requested amount in the event of slashing. - weth.deposit{value: address(this).balance}(); + // Wrap only the ETH received from this claim. This can be less than the requested amount in the event of slashing. + uint256 ethReceived = address(this).balance - ethBalanceBefore; + if (ethReceived > 0) weth.deposit{value: ethReceived}(); emit ClaimLidoWithdrawals(requestIds); } diff --git a/test/fork/LidoARM/ClaimStETHWithdrawalForWETH.t.sol b/test/fork/LidoARM/ClaimStETHWithdrawalForWETH.t.sol index 00dff276..9d1cd372 100644 --- a/test/fork/LidoARM/ClaimStETHWithdrawalForWETH.t.sol +++ b/test/fork/LidoARM/ClaimStETHWithdrawalForWETH.t.sol @@ -34,6 +34,16 @@ contract Fork_Concrete_LidoARM_ClaimLidoWithdrawals_Test_ is Fork_Shared_Test_ { amounts2[1] = DEFAULT_AMOUNT; } + ////////////////////////////////////////////////////// + /// --- REVERTING TESTS + ////////////////////////////////////////////////////// + function test_RevertWhen_ClaimLidoWithdrawals_NotOperatorOrOwner() public asRandomAddress { + uint256[] memory emptyList = new uint256[](0); + + vm.expectRevert("ARM: Only operator or owner can call this function."); + lidoARM.claimLidoWithdrawals(emptyList, emptyList); + } + ////////////////////////////////////////////////////// /// --- PASSING TESTS ////////////////////////////////////////////////////// @@ -83,6 +93,33 @@ contract Fork_Concrete_LidoARM_ClaimLidoWithdrawals_Test_ is Fork_Shared_Test_ { assertEq(weth.balanceOf(address(lidoARM)), balanceBefore + DEFAULT_AMOUNT); } + function test_ClaimLidoWithdrawals_OnlyWrapsClaimedETH() + public + asOperator + requestLidoWithdrawalsOnLidoARM(amounts1) + mockFunctionClaimWithdrawOnLidoARM(DEFAULT_AMOUNT) + { + uint256 donatedETH = 0.5 ether; + vm.deal(address(lidoARM), donatedETH); + + uint256 balanceBefore = weth.balanceOf(address(lidoARM)); + assertEq(address(lidoARM).balance, donatedETH); + assertEq(lidoARM.lidoWithdrawalQueueAmount(), DEFAULT_AMOUNT); + + stETHWithdrawal.getLastRequestId(); + uint256[] memory requests = new uint256[](1); + requests[0] = stETHWithdrawal.getLastRequestId(); + + uint256 lastIndex = stETHWithdrawal.getLastCheckpointIndex(); + uint256[] memory hintIds = stETHWithdrawal.findCheckpointHints(requests, 1, lastIndex); + + lidoARM.claimLidoWithdrawals(requests, hintIds); + + assertEq(address(lidoARM).balance, donatedETH); + assertEq(lidoARM.lidoWithdrawalQueueAmount(), 0); + assertEq(weth.balanceOf(address(lidoARM)), balanceBefore + DEFAULT_AMOUNT); + } + function test_ClaimLidoWithdrawals_MultiRequest() public asOperator diff --git a/test/fork/utils/MockCall.sol b/test/fork/utils/MockCall.sol index 1d6bba40..ecb86edf 100644 --- a/test/fork/utils/MockCall.sol +++ b/test/fork/utils/MockCall.sol @@ -71,6 +71,6 @@ contract ETHSender { Vm private constant vm = Vm(address(uint160(uint256(keccak256("hevm cheat code"))))); function sendETH(address target) external { - vm.deal(target, address(this).balance); + vm.deal(target, target.balance + address(this).balance); } }