diff --git a/contracts/contracts/interfaces/IRewardManager.sol b/contracts/contracts/interfaces/IRewardManager.sol index 29e08a989..8e27b11d3 100644 --- a/contracts/contracts/interfaces/IRewardManager.sol +++ b/contracts/contracts/interfaces/IRewardManager.sol @@ -48,8 +48,7 @@ interface IRewardManager { function overrideReceiver(address overrideAddress, bool migrateExistingRewards) external; /// @dev Removes the override address for a receiver. - /// @param migrateExistingRewards If true, existing rewards for the overridden address will be migrated atomically to the msg.sender. - function removeOverrideAddress(bool migrateExistingRewards) external; + function removeOverrideAddress() external; /// @dev Allows a reward recipient to claim their rewards. function claimRewards() external; diff --git a/contracts/contracts/validator-registry/rewards/RewardManager.sol b/contracts/contracts/validator-registry/rewards/RewardManager.sol index fb79cba8e..8d697a398 100644 --- a/contracts/contracts/validator-registry/rewards/RewardManager.sol +++ b/contracts/contracts/validator-registry/rewards/RewardManager.sol @@ -121,11 +121,9 @@ contract RewardManager is IRewardManager, RewardManagerStorage, } /// @dev Removes the override address for a receiver. - /// @param migrateExistingRewards If true, existing rewards for the overridden address will be migrated atomically to the msg.sender. - function removeOverrideAddress(bool migrateExistingRewards) external whenNotPaused nonReentrant { + function removeOverrideAddress() external whenNotPaused nonReentrant { address toBeRemoved = overrideAddresses[msg.sender]; require(toBeRemoved != address(0), NoOverriddenAddressToRemove()); - if (migrateExistingRewards) { _migrateRewards(toBeRemoved, msg.sender); } overrideAddresses[msg.sender] = address(0); emit OverrideAddressRemoved(msg.sender); } @@ -190,7 +188,7 @@ contract RewardManager is IRewardManager, RewardManagerStorage, emit RewardsClaimed(msg.sender, amount); } - /// @dev DANGER: This function should ONLY be called from overrideClaimAddress or removeOverriddenClaimAddress + /// @dev DANGER: This function should ONLY be called from overrideReceiver /// with careful attention to parameter order. function _migrateRewards(address from, address to) internal { uint256 amount = unclaimedRewards[from]; diff --git a/contracts/test/validator-registry/rewards/RewardManagerTest.sol b/contracts/test/validator-registry/rewards/RewardManagerTest.sol index 781338c68..f757c018b 100644 --- a/contracts/test/validator-registry/rewards/RewardManagerTest.sol +++ b/contracts/test/validator-registry/rewards/RewardManagerTest.sol @@ -158,7 +158,7 @@ contract RewardManagerTest is Test { vm.prank(user1); vm.expectRevert(PausableUpgradeable.EnforcedPause.selector); - rewardManager.removeOverrideAddress(false); + rewardManager.removeOverrideAddress(); vm.prank(user1); vm.expectRevert(PausableUpgradeable.EnforcedPause.selector); @@ -335,7 +335,7 @@ contract RewardManagerTest is Test { vm.prank(operatorFromMiddlewareTest); vm.expectEmit(); emit OverrideAddressRemoved(operatorFromMiddlewareTest); - rewardManager.removeOverrideAddress(false); + rewardManager.removeOverrideAddress(); vm.deal(user3, 4 ether); vm.expectEmit(); @@ -394,17 +394,18 @@ contract RewardManagerTest is Test { vm.expectEmit(); emit OverrideAddressRemoved(vanillaTestUser); vm.prank(vanillaTestUser); - rewardManager.removeOverrideAddress(true); + rewardManager.removeOverrideAddress(); - assertEq(rewardManager.unclaimedRewards(user4), 0 ether); - assertEq(rewardManager.unclaimedRewards(vanillaTestUser), 9 ether); + assertEq(rewardManager.unclaimedRewards(user4), 9 ether); + assertEq(rewardManager.unclaimedRewards(vanillaTestUser), 0 ether); - uint256 balanceBefore = vanillaTestUser.balance; - vm.prank(vanillaTestUser); + // Rewards must be claimed manually from the override address, even if that override address is removed + uint256 balanceBefore = user4.balance; + vm.prank(user4); vm.expectEmit(); - emit RewardsClaimed(vanillaTestUser, 9 ether); + emit RewardsClaimed(user4, 9 ether); rewardManager.claimRewards(); - assertEq(vanillaTestUser.balance, balanceBefore + 9 ether); + assertEq(user4.balance, balanceBefore + 9 ether); } function testAutoClaim() public {