From f618d2b9f90686a4b6056b22ec3cf0dbb8233acc Mon Sep 17 00:00:00 2001 From: Ameesha Agrawal Date: Tue, 29 Jul 2025 16:50:14 +0530 Subject: [PATCH 1/6] doc: commments --- contracts/protocol/Socket.sol | 3 +-- contracts/protocol/SocketConfig.sol | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/contracts/protocol/Socket.sol b/contracts/protocol/Socket.sol index ad3f97dd..d4c6e944 100644 --- a/contracts/protocol/Socket.sol +++ b/contracts/protocol/Socket.sol @@ -85,7 +85,6 @@ contract Socket is SocketUtils { // verify the digest _verify(payloadId, plugConfig, executeParams_, transmissionParams_.transmitterProof); - return _execute(payloadId, executeParams_, transmissionParams_); } @@ -101,6 +100,7 @@ contract Socket is SocketUtils { if (isValidSwitchboard[plugConfig_.switchboardId] != SwitchboardStatus.REGISTERED) revert InvalidSwitchboard(); + // NOTE: the first un-trusted call in the system address transmitter = ISwitchboard(switchboardAddresses[plugConfig_.switchboardId]) .getTransmitter(msg.sender, payloadId_, transmitterProof_); @@ -114,7 +114,6 @@ contract Socket is SocketUtils { ); payloadIdToDigest[payloadId_] = digest; - // NOTE: is the the first un-trusted call in the system, another one is Plug.call if ( !ISwitchboard(switchboardAddresses[plugConfig_.switchboardId]).allowPayload( digest, diff --git a/contracts/protocol/SocketConfig.sol b/contracts/protocol/SocketConfig.sol index 2a701b4a..a9c0e2ed 100644 --- a/contracts/protocol/SocketConfig.sol +++ b/contracts/protocol/SocketConfig.sol @@ -79,21 +79,25 @@ abstract contract SocketConfig is ISocket, AccessControl { emit SwitchboardDisabled(switchboardId_); } - // @notice function to enable a switchboard + // @notice function to enable a switchboard if disabled // @dev only callable by governance role function enableSwitchboard(uint64 switchboardId_) external onlyRole(GOVERNANCE_ROLE) { isValidSwitchboard[switchboardId_] = SwitchboardStatus.REGISTERED; emit SwitchboardEnabled(switchboardId_); } + // @notice function to set the socket fee manager + // @dev only callable by governance role + // @param socketFeeManager_ address of the socket fee manager function setSocketFeeManager(address socketFeeManager_) external onlyRole(GOVERNANCE_ROLE) { - emit SocketFeeManagerUpdated(address(socketFeeManager), socketFeeManager_); socketFeeManager = ISocketFeeManager(socketFeeManager_); + emit SocketFeeManagerUpdated(address(socketFeeManager), socketFeeManager_); } - /** - * @notice connects Plug to Socket and sets the config for given `siblingChainSlug_` - */ + // @notice function to connect a plug to socket + // @dev only callable by plug + // @param appGatewayId_ app gateway id + // @param switchboardId_ switchboard id function connect(bytes32 appGatewayId_, uint64 switchboardId_) external override { if (isValidSwitchboard[switchboardId_] != SwitchboardStatus.REGISTERED) revert InvalidSwitchboard(); @@ -105,9 +109,8 @@ abstract contract SocketConfig is ISocket, AccessControl { emit PlugConnected(msg.sender, appGatewayId_, switchboardId_); } - /** - * @notice disconnects Plug from Socket - */ + // @notice function to disconnect a plug from socket + // @dev only callable by plug function disconnect() external override { PlugConfigEvm storage _plugConfig = _plugConfigs[msg.sender]; if (_plugConfig.appGatewayId == bytes32(0)) revert PlugNotConnected(); From 1cd0c98c40b4822324c4df05378bc15ca254d4d6 Mon Sep 17 00:00:00 2001 From: Ameesha Agrawal Date: Thu, 31 Jul 2025 14:05:19 +0530 Subject: [PATCH 2/6] feat: test mocks --- test/mock/MockERC721.sol | 48 +++++++++++ test/mock/MockFastSwitchboard.sol | 30 +++++-- test/mock/MockFeesManager.sol | 44 +++++++++++ test/mock/MockPlug.sol | 127 ++++++++++++++++++++++++++++++ 4 files changed, 241 insertions(+), 8 deletions(-) create mode 100644 test/mock/MockERC721.sol create mode 100644 test/mock/MockFeesManager.sol create mode 100644 test/mock/MockPlug.sol diff --git a/test/mock/MockERC721.sol b/test/mock/MockERC721.sol new file mode 100644 index 00000000..08461c4e --- /dev/null +++ b/test/mock/MockERC721.sol @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.21; + +import {ERC721} from "solady/tokens/ERC721.sol"; + +contract MockERC721 is ERC721 { + string private _name = "MockERC721"; + string private _symbol = "MOCK721"; + + // Simple tokenId tracker for minting + uint256 private _nextTokenId = 1; + + function name() public view override returns (string memory) { + return _name; + } + + function symbol() public view override returns (string memory) { + return _symbol; + } + + /// @notice Mint a new token to `to` with a specific `tokenId` + function mint(address to, uint256 tokenId) public { + _mint(to, tokenId); + } + + /// @notice Mint a new token to `to` with auto-incremented tokenId + function mint(address to) public returns (uint256 tokenId) { + tokenId = _nextTokenId++; + _mint(to, tokenId); + } + + /// @notice Burn a token + function burn(uint256 tokenId) public { + // Only owner or approved can burn + if ( + msg.sender != ownerOf(tokenId) && + !isApprovedForAll(ownerOf(tokenId), msg.sender) && + getApproved(tokenId) != msg.sender + ) { + revert NotOwnerNorApproved(); + } + _burn(tokenId); + } + + function tokenURI(uint256 id) public view virtual override returns (string memory) { + return ""; + } +} diff --git a/test/mock/MockFastSwitchboard.sol b/test/mock/MockFastSwitchboard.sol index 24b31cd3..6ed61324 100644 --- a/test/mock/MockFastSwitchboard.sol +++ b/test/mock/MockFastSwitchboard.sol @@ -11,6 +11,8 @@ contract MockFastSwitchboard is ISwitchboard { // chain slug of deployed chain uint32 public immutable chainSlug; uint64 public switchboardId; + bool public attestCalled; + bool public isPayloadAllowed; /** * @dev Constructor of SwitchboardBase @@ -21,17 +23,26 @@ contract MockFastSwitchboard is ISwitchboard { chainSlug = chainSlug_; socket__ = ISocket(socket_); owner = owner_; + + isPayloadAllowed = true; } - function attest(bytes32, bytes calldata) external {} + function attest(bytes32, bytes calldata) external { + attestCalled = true; + } - function allowPayload(bytes32, bytes32) external pure returns (bool) { + function allowPayload(bytes32, bytes32) external view returns (bool) { // digest has enough attestations - return true; + return isPayloadAllowed; } - function registerSwitchboard() external { + function setIsPayloadAllowed(bool isPayloadAllowed_) external { + isPayloadAllowed = isPayloadAllowed_; + } + + function registerSwitchboard() external returns (uint64) { switchboardId = socket__.registerSwitchboard(); + return switchboardId; } function processTrigger( @@ -39,13 +50,16 @@ contract MockFastSwitchboard is ISwitchboard { bytes32 triggerId_, bytes calldata payload_, bytes calldata overrides_ - ) external payable override {} + ) external payable override { + // Simple implementation that just accepts the trigger + // In a real switchboard, this would process the trigger + } function getTransmitter( address sender_, - bytes32 payloadId_, - bytes calldata transmitterSignature_ - ) external view returns (address) { + bytes32, + bytes calldata + ) external pure returns (address) { return sender_; } } diff --git a/test/mock/MockFeesManager.sol b/test/mock/MockFeesManager.sol new file mode 100644 index 00000000..a63bffea --- /dev/null +++ b/test/mock/MockFeesManager.sol @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: GPL-3.0-only +pragma solidity ^0.8.21; + +import "../../contracts/protocol/interfaces/ISocketFeeManager.sol"; +import "../../contracts/utils/common/Structs.sol"; + +/** + * @title MockFeeManager + * @dev Mock fee manager for testing + */ +contract MockFeeManager is ISocketFeeManager { + bool public payAndCheckFeesCalled = false; + ExecuteParams public lastExecuteParams; + TransmissionParams public lastTransmissionParams; + + function payAndCheckFees( + ExecuteParams memory executeParams_, + TransmissionParams memory transmissionParams_ + ) external payable override { + payAndCheckFeesCalled = true; + lastExecuteParams = executeParams_; + lastTransmissionParams = transmissionParams_; + } + + function getMinSocketFees() external pure override returns (uint256) { + return 0.001 ether; + } + + function setSocketFees(uint256) external override { + // Mock implementation - do nothing + } + + function socketFees() external pure override returns (uint256) { + return 0.001 ether; + } + + function getLastTransmissionParams() external view returns (TransmissionParams memory) { + return lastTransmissionParams; + } + + function reset() external { + payAndCheckFeesCalled = false; + } +} diff --git a/test/mock/MockPlug.sol b/test/mock/MockPlug.sol new file mode 100644 index 00000000..13b9dc9f --- /dev/null +++ b/test/mock/MockPlug.sol @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: GPL-3.0-only +pragma solidity ^0.8.21; + +import "../../contracts/protocol/interfaces/IPlug.sol"; +import "../../contracts/protocol/interfaces/ISocket.sol"; + +/** + * @title MockPlug + * @dev Mock plug for testing + */ +contract MockPlug is IPlug { + ISocket public socket__; + bytes32 public appGatewayId; + uint64 public switchboardId; + bytes public overridesData = hex"1234"; + bool public initSocketCalled = false; + + error CallFailed(); + + function initSocket( + bytes32 appGatewayId_, + address socket_, + uint64 switchboardId_ + ) external override { + appGatewayId = appGatewayId_; + socket__ = ISocket(socket_); + initSocketCalled = true; + + socket__.connect(appGatewayId_, switchboardId_); + } + + function disconnect() external { + socket__.disconnect(); + } + + function overrides() external view override returns (bytes memory) { + return overridesData; + } + + function setOverrides(bytes calldata newOverrides) external { + overridesData = newOverrides; + } + + function setSocket(address newSocket) external { + socket__ = ISocket(newSocket); + } + + // Test helper function to trigger socket calls + function triggerSocket(bytes calldata data) external payable returns (bytes memory) { + (bool success, bytes memory result) = address(socket__).call{value: msg.value}(data); + require(success, "Socket call failed"); + return result; + } + + // Test helper functions for simulation testing + function processPayload(bytes calldata payload) external payable returns (bool) { + return payload.length > 0; + } + + function failingFunction() external pure { + revert("Intentional failure"); + } + + function returnLargeData() external pure returns (bytes memory) { + // Return 3KB of data to test maxCopyBytes truncation (maxCopyBytes is 2KB) + return new bytes(3072); + } + + // Function to call mockTarget + function callMockTarget(address target, bytes calldata data) external payable returns (bool) { + (bool success, ) = target.call{value: msg.value}(data); + if (!success) revert CallFailed(); + return success; + } +} + +/** + * @title MockTarget + * @dev Mock target contract for execution testing + */ +contract MockTarget { + uint256 public counter = 0; + bytes public lastCallData; + uint256 public lastValue; + bool public shouldRevert = false; + + event Called(address caller, uint256 value, bytes data); + + function increment() external payable { + if (shouldRevert) { + revert(); + } + counter++; + lastCallData = msg.data; + lastValue = msg.value; + emit Called(msg.sender, msg.value, msg.data); + } + + function setShouldRevert(bool _shouldRevert) external { + shouldRevert = _shouldRevert; + } + + function reset() external { + counter = 0; + lastCallData = ""; + lastValue = 0; + shouldRevert = false; + } + + // Fallback to handle any call + fallback() external payable { + if (shouldRevert) { + revert("Mock revert"); + } + lastCallData = msg.data; + lastValue = msg.value; + emit Called(msg.sender, msg.value, msg.data); + } + + receive() external payable { + if (shouldRevert) { + revert("Mock revert"); + } + lastValue = msg.value; + emit Called(msg.sender, msg.value, ""); + } +} From c13b51d442f0619cc9f4786bc29161e20b672a2d Mon Sep 17 00:00:00 2001 From: Ameesha Agrawal Date: Thu, 31 Jul 2025 14:05:35 +0530 Subject: [PATCH 3/6] test: unit socket tests --- contracts/protocol/Socket.sol | 1 - contracts/protocol/SocketUtils.sol | 15 - test/Socket.t.sol | 952 +++++++++++++++++++++++++++++ 3 files changed, 952 insertions(+), 16 deletions(-) create mode 100644 test/Socket.t.sol diff --git a/contracts/protocol/Socket.sol b/contracts/protocol/Socket.sol index d4c6e944..f08e44cc 100644 --- a/contracts/protocol/Socket.sol +++ b/contracts/protocol/Socket.sol @@ -229,7 +229,6 @@ contract Socket is SocketUtils { /// @notice Receive function that forwards all calls to Socket's callAppGateway receive() external payable { - // todo: need a fn to increase trigger fees revert("Socket does not accept ETH"); } } diff --git a/contracts/protocol/SocketUtils.sol b/contracts/protocol/SocketUtils.sol index 3e23ede4..ac652670 100644 --- a/contracts/protocol/SocketUtils.sol +++ b/contracts/protocol/SocketUtils.sol @@ -91,21 +91,6 @@ abstract contract SocketUtils is SocketConfig { ); } - /** - * @notice recovers the signer from the signature - * @param digest_ The digest of the payload - * @param signature_ The signature of the payload - * @return signer The address of the signer - */ - function _recoverSigner( - bytes32 digest_, - bytes calldata signature_ - ) internal view returns (address signer) { - bytes32 digest = keccak256(abi.encodePacked("\x19Ethereum Signed Message:\n32", digest_)); - // recovered signer is checked for the valid roles later - signer = ECDSA.recover(digest, signature_); - } - /** * @notice Encodes the trigger ID with the chain slug, socket address and nonce * @return The trigger ID diff --git a/test/Socket.t.sol b/test/Socket.t.sol new file mode 100644 index 00000000..190cdcb3 --- /dev/null +++ b/test/Socket.t.sol @@ -0,0 +1,952 @@ +// SPDX-License-Identifier: GPL-3.0-only +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; +import {console} from "forge-std/console.sol"; + +import {MockFastSwitchboard} from "./mock/MockFastSwitchboard.sol"; +import {MockPlug, MockTarget} from "./mock/MockPlug.sol"; +import {MockFeeManager} from "./mock/MockFeesManager.sol"; +import {SuperToken} from "./apps/app-gateways/super-token/SuperToken.sol"; +import {MockERC721} from "./mock/MockERC721.sol"; + +import "../contracts/protocol/Socket.sol"; +import "../contracts/protocol/SocketUtils.sol"; +import "../contracts/utils/common/Errors.sol"; +import "../contracts/utils/common/Constants.sol"; +import "../contracts/utils/common/Structs.sol"; + +/** + * @title SocketTestWrapper + * @dev Wrapper contract to expose internal functions for testing + */ +contract SocketTestWrapper is Socket { + constructor( + uint32 chainSlug_, + address owner_, + string memory version_ + ) Socket(chainSlug_, owner_, version_) {} + + // Expose internal functions for testing + function createDigest( + address transmitter_, + bytes32 payloadId_, + bytes32 appGatewayId_, + ExecuteParams calldata executeParams_ + ) external view returns (bytes32) { + return _createDigest(transmitter_, payloadId_, appGatewayId_, executeParams_); + } + + function encodeTriggerId() external returns (bytes32) { + return _encodeTriggerId(); + } + + function executeInternal( + bytes32 payloadId_, + ExecuteParams calldata executeParams_, + TransmissionParams calldata transmissionParams_ + ) external payable returns (bool, bytes memory) { + return _execute(payloadId_, executeParams_, transmissionParams_); + } +} + +/** + * @title SocketTestBase + * @dev Base contract for Socket protocol unit tests + * Provides common setup, utilities, and mock contracts + */ +contract SocketTestBase is Test { + uint256 c = 1; + string constant VERSION = "1.0.0"; + address public socketOwner = address(uint160(c++)); + address public transmitter = address(uint160(c++)); + address public testUser = address(uint160(c++)); + + uint32 constant TEST_CHAIN_SLUG = 1; + bytes32 constant TEST_APP_GATEWAY_ID = keccak256("TEST_APP_GATEWAY"); + bytes constant TEST_PAYLOAD = hex"1234567890abcdef"; + bytes constant TEST_OVERRIDES = hex"abcdef"; + + // Contracts + Socket public socket; + MockFastSwitchboard public mockSwitchboard; + MockPlug public mockPlug; + MockFeeManager public mockFeeManager; + SuperToken public mockToken; + MockTarget public mockTarget; + SocketTestWrapper public socketWrapper; + + uint64 public switchboardId; + ExecuteParams public executeParams; + TransmissionParams public transmissionParams; + + event ExecutionSuccess(bytes32 payloadId, bool exceededMaxCopy, bytes returnData); + event ExecutionFailed(bytes32 payloadId, bool exceededMaxCopy, bytes returnData); + + function setUp() public virtual { + socket = new Socket(TEST_CHAIN_SLUG, socketOwner, VERSION); + mockSwitchboard = new MockFastSwitchboard(TEST_CHAIN_SLUG, address(socket), socketOwner); + mockPlug = new MockPlug(); + mockFeeManager = new MockFeeManager(); + mockToken = new SuperToken("Test Token", "TEST", 18, testUser, 1000000000000000000); + mockTarget = new MockTarget(); + socketWrapper = new SocketTestWrapper(TEST_CHAIN_SLUG, socketOwner, VERSION); + mockToken.setOwner(socketOwner); + + // Set up initial state + vm.startPrank(socketOwner); + socket.grantRole(GOVERNANCE_ROLE, socketOwner); + socket.grantRole(RESCUE_ROLE, socketOwner); + socket.grantRole(SWITCHBOARD_DISABLER_ROLE, socketOwner); + + socket.setSocketFeeManager(address(mockFeeManager)); + mockToken.setSocket(address(socket)); + vm.stopPrank(); + + switchboardId = mockSwitchboard.registerSwitchboard(); + mockPlug.initSocket(TEST_APP_GATEWAY_ID, address(socket), switchboardId); + + executeParams = _createExecuteParams(); + transmissionParams = _createTransmissionParams(); + + vm.deal(transmitter, 100 ether); + vm.deal(testUser, 100 ether); + } + + function _createExecuteParams() internal view returns (ExecuteParams memory) { + return + ExecuteParams({ + callType: WRITE, + payloadPointer: 1, + deadline: block.timestamp + 1 hours, + gasLimit: 100000, + value: 0, + prevBatchDigestHash: bytes32(0), + target: address(mockPlug), + payload: TEST_PAYLOAD, + extraData: bytes("") + }); + } + + function _createTransmissionParams() internal view returns (TransmissionParams memory) { + return + TransmissionParams({ + socketFees: 0, + refundAddress: testUser, + extraData: bytes(""), + transmitterProof: bytes("") + }); + } + + function createSignature( + bytes32 digest_, + uint256 privateKey_ + ) public pure returns (bytes memory sig) { + bytes32 digest = keccak256(abi.encodePacked("\x19Ethereum Signed Message:\n32", digest_)); + (uint8 sigV, bytes32 sigR, bytes32 sigS) = vm.sign(privateKey_, digest); + sig = new bytes(65); + bytes1 v32 = bytes1(sigV); + assembly { + mstore(add(sig, 96), v32) + mstore(add(sig, 32), sigR) + mstore(add(sig, 64), sigS) + } + } +} + +/** + * @title SocketExecuteTest + * @dev Tests for Socket execute function + */ +contract SocketExecuteTest is SocketTestBase { + function testConstructorWithValidParameters() public view { + assertEq(socket.chainSlug(), TEST_CHAIN_SLUG, "Chain slug should match"); + assertEq(socket.owner(), socketOwner, "Owner should match"); + assertEq(socket.version(), keccak256(bytes(VERSION)), "Version should match"); + assertEq(socket.gasLimitBuffer(), 105, "Gas limit buffer should be 105"); + } + + function testExecuteDeadlinePassed() public { + executeParams.deadline = block.timestamp - 1; // Past deadline + vm.expectRevert(DeadlinePassed.selector); + socket.execute{value: 1 ether}(executeParams, transmissionParams); + } + + function testExecutePlugNotFound() public { + executeParams.target = address(0x999); // Non-existent plug + vm.expectRevert(PlugNotFound.selector); + socket.execute{value: 1 ether}(executeParams, transmissionParams); + } + + function testExecuteInsufficientValue() public { + executeParams.value = 1 ether; + transmissionParams.socketFees = 0.5 ether; + + vm.expectRevert(Socket.InsufficientMsgValue.selector); + socket.execute{value: 0.5 ether}(executeParams, transmissionParams); + } + + function testExecuteInvalidSwitchboardDisabled() public { + hoax(socketOwner); + socket.disableSwitchboard(switchboardId); + + vm.expectRevert(InvalidSwitchboard.selector); + socket.execute{value: 1 ether}(executeParams, transmissionParams); + } + + function testExecuteWithInvalidCallType() public { + executeParams.callType = bytes4(0x12345678); // Invalid call type + + vm.expectRevert(InvalidCallType.selector); + socket.execute{value: 1 ether}(executeParams, transmissionParams); + } + + function testExecuteRefundIfExecutionFails() public { + executeParams.target = address(mockPlug); + executeParams.value = 0.5 ether; + executeParams.payload = abi.encodeWithSelector( + mockPlug.callMockTarget.selector, + address(mockTarget), + abi.encodeWithSelector(mockTarget.increment.selector) + ); + + // Set up mock target to revert + mockTarget.setShouldRevert(true); + + uint256 userBalance = testUser.balance; + transmissionParams.refundAddress = testUser; + + hoax(transmitter); + (bool success, ) = socket.execute{value: 1 ether}(executeParams, transmissionParams); + assertFalse(success, "Execution should fail"); + + // Check that refund was sent + assertEq(testUser.balance, userBalance + 1 ether, "Refund should be sent to user"); + + // Set up mock target to revert + mockTarget.setShouldRevert(false); + userBalance = testUser.balance; + + hoax(transmitter); + (success, ) = socket.execute{value: 1 ether}(executeParams, transmissionParams); + assertTrue(success, "Execution should succeed"); + + // Check that refund was sent + assertEq(testUser.balance, userBalance, "Refund should not be sent to user"); + } + + function testExecuteWithValidParameters() public { + // Use mockPlug as target since it's connected to socket + executeParams.target = address(mockPlug); + executeParams.payload = abi.encodeWithSelector( + mockPlug.processPayload.selector, + TEST_PAYLOAD + ); + + bytes32 payloadId = createPayloadId( + executeParams.payloadPointer, + switchboardId, + TEST_CHAIN_SLUG + ); + + vm.expectEmit(true, true, true, true, address(socket)); + emit ExecutionSuccess(payloadId, false, bytes(abi.encode(true))); + (bool success, ) = socket.execute{value: 1 ether}(executeParams, transmissionParams); + assertTrue(success, "Execution should succeed"); + } + + function testExecuteWithPayloadAlreadyExecuted() public { + executeParams.payload = abi.encodeWithSelector( + mockPlug.processPayload.selector, + TEST_PAYLOAD + ); + + bytes32 payloadId = createPayloadId( + executeParams.payloadPointer, + switchboardId, + TEST_CHAIN_SLUG + ); + + vm.expectEmit(true, true, true, true, address(socket)); + emit ExecutionSuccess(payloadId, false, bytes(abi.encode(true))); + socket.execute{value: 1 ether}(executeParams, transmissionParams); + + vm.expectRevert( + abi.encodeWithSelector(Socket.PayloadAlreadyExecuted.selector, ExecutionStatus.Executed) + ); + socket.execute{value: 1 ether}(executeParams, transmissionParams); + } + + function testExecuteWithVerificationFailed() public { + // Override the allowPayload function to return false + mockSwitchboard.setIsPayloadAllowed(false); + + vm.expectRevert(Socket.VerificationFailed.selector); + socket.execute{value: 1 ether}(executeParams, transmissionParams); + } + + function testExecuteWithFailedExecution() public { + executeParams.target = address(mockPlug); + executeParams.payload = abi.encodeWithSelector( + mockPlug.callMockTarget.selector, + address(mockTarget), + abi.encodeWithSelector(mockTarget.increment.selector) + ); + + // Set up mock target to revert + mockTarget.setShouldRevert(true); + + bytes32 payloadId = createPayloadId( + executeParams.payloadPointer, + switchboardId, + TEST_CHAIN_SLUG + ); + + vm.expectEmit(true, true, true, true, address(socket)); + emit ExecutionFailed( + payloadId, + false, + abi.encodeWithSelector(MockPlug.CallFailed.selector) + ); + (bool success, ) = socket.execute{value: 1 ether}(executeParams, transmissionParams); + assertFalse(success, "Execution should fail"); + assertEq(mockTarget.counter(), 0, "Target should not be called"); + } + + function testExecuteWithExceededMaxCopyBytes() public { + // Set up mock target to return large data + executeParams.target = address(mockPlug); + executeParams.payload = abi.encodeWithSelector(mockPlug.returnLargeData.selector); + + (bool success, bytes memory returnData) = socket.execute{value: 1 ether}( + executeParams, + transmissionParams + ); + + // The return data should be truncated to maxCopyBytes (2048 bytes) + assertEq(returnData.length, 2048, "Return data should be exactly maxCopyBytes"); + assertLt(returnData.length, 3072, "Return data should be truncated"); + assertTrue(success, "Execution should succeed even with large return data"); + } + + function testExecutionRetryIfFailing() public { + executeParams.target = address(mockPlug); + executeParams.payload = abi.encodeWithSelector( + mockPlug.callMockTarget.selector, + address(mockTarget), + abi.encodeWithSelector(mockTarget.increment.selector) + ); + + bytes32 payloadId = createPayloadId( + executeParams.payloadPointer, + switchboardId, + TEST_CHAIN_SLUG + ); + + mockTarget.setShouldRevert(true); + vm.expectEmit(true, true, true, true, address(socket)); + emit ExecutionFailed( + payloadId, + false, + abi.encodeWithSelector(MockPlug.CallFailed.selector) + ); + (bool success, ) = socket.execute{value: 1 ether}(executeParams, transmissionParams); + assertFalse(success, "First execution should fail"); + + mockTarget.setShouldRevert(false); + vm.expectEmit(true, true, true, true, address(socket)); + emit ExecutionSuccess(payloadId, false, bytes(abi.encode(true))); + (success, ) = socket.execute{value: 1 ether}(executeParams, transmissionParams); + assertTrue(success, "Second execution should succeed"); + } + + function testGasUsageForExecute() public { + executeParams.target = address(mockPlug); + executeParams.payload = abi.encodeWithSelector( + mockPlug.callMockTarget.selector, + address(mockTarget), + abi.encodeWithSelector(mockTarget.increment.selector) + ); + + uint256 gasBefore = gasleft(); + socket.execute{value: 1 ether}(executeParams, transmissionParams); + uint256 gasUsed = gasBefore - gasleft(); + console.log("gasUsed", gasUsed); + } + + function testExecuteLowGasLimit() public { + // set high gas limit + executeParams.gasLimit = 10000000; + + vm.expectRevert(Socket.LowGasLimit.selector); + socket.execute{value: 1 ether, gas: 100000}(executeParams, transmissionParams); + } +} + +/** + * @title SocketTriggerTest + * @dev Tests for Socket triggerAppGateway function + */ +contract SocketTriggerTest is SocketTestBase { + function testTriggerAppGatewayPlugNotFound() public { + bytes memory triggerData = abi.encodeWithSelector( + mockPlug.processPayload.selector, + TEST_PAYLOAD + ); + + hoax(testUser); + vm.expectRevert(PlugNotFound.selector); + socket.triggerAppGateway{value: 1 ether}(triggerData); + } + + function testTriggerAppGatewayWithInvalidSwitchboard() public { + // Give mockPlug some funds + vm.deal(address(mockPlug), 10 ether); + + hoax(socketOwner); + socket.disableSwitchboard(switchboardId); + + bytes memory triggerData = abi.encodeWithSelector( + mockPlug.processPayload.selector, + TEST_PAYLOAD + ); + + vm.expectRevert(InvalidSwitchboard.selector); + hoax(address(mockPlug)); + socket.triggerAppGateway{value: 1 ether}(triggerData); + } + + function testGasUsageForTriggerAppGateway() public { + // Give mockPlug some funds + vm.deal(address(mockPlug), 10 ether); + + bytes memory triggerData = abi.encodeWithSelector( + mockPlug.processPayload.selector, + TEST_PAYLOAD + ); + + uint256 gasBefore = gasleft(); + hoax(address(mockPlug)); + socket.triggerAppGateway{value: 1 ether}(triggerData); + uint256 gasUsed = gasBefore - gasleft(); + + // Gas usage should be reasonable + console.log("gasUsed", gasUsed); + } +} + +/** + * @title SocketConnectDisconnectTest + * @dev Tests for Socket connect and disconnect functions + */ +contract SocketConnectDisconnectTest is SocketTestBase { + function testConnectWithInvalidSwitchboard() public { + vm.expectRevert(InvalidSwitchboard.selector); + mockPlug.initSocket(TEST_APP_GATEWAY_ID, address(socket), switchboardId + 1); + } + + function testConnectWithNewSwitchboard() public { + // Create a new switchboard + MockFastSwitchboard newSwitchboard = new MockFastSwitchboard( + TEST_CHAIN_SLUG, + address(socket), + socketOwner + ); + uint64 newSwitchboardId = newSwitchboard.registerSwitchboard(); + mockPlug.initSocket(TEST_APP_GATEWAY_ID, address(socket), newSwitchboardId); + + (bytes32 appGatewayId, uint64 switchboardId) = socket.getPlugConfig(address(mockPlug)); + assertEq(appGatewayId, TEST_APP_GATEWAY_ID, "App gateway ID should match"); + assertEq(switchboardId, newSwitchboardId, "Switchboard ID should match"); + } + + // Try to disconnect a plug that was never connected + function testDisconnectWithPlugNotConnected() public { + MockPlug newPlug = new MockPlug(); + newPlug.setSocket(address(socket)); + + vm.expectRevert(SocketConfig.PlugNotConnected.selector); + newPlug.disconnect(); + } +} + +/** + * @title SocketSwitchboardManagementTest + * @dev Tests for Socket switchboard management functions + */ +contract SocketSwitchboardManagementTest is SocketTestBase { + function testRegisterSwitchboardWithExistingSwitchboard() public { + assertEq(mockSwitchboard.switchboardId(), socket.switchboardIds(address(mockSwitchboard))); + + // Try to register the same switchboard again + vm.expectRevert(SocketConfig.SwitchboardExists.selector); + mockSwitchboard.registerSwitchboard(); + } + + function testDisableSwitchboardWithGovernanceRole() public { + hoax(socketOwner); + socket.disableSwitchboard(switchboardId); + + // Try to register the same switchboard again + vm.expectRevert(SocketConfig.SwitchboardExists.selector); + mockSwitchboard.registerSwitchboard(); + + assertEq( + uint256(socket.isValidSwitchboard(switchboardId)), + uint256(SwitchboardStatus.DISABLED), + "Switchboard should be disabled" + ); + } + + function testEnableSwitchboardWithGovernanceRole() public { + // First disable the switchboard + hoax(socketOwner); + socket.disableSwitchboard(switchboardId); + assertEq( + uint256(socket.isValidSwitchboard(switchboardId)), + uint256(SwitchboardStatus.DISABLED), + "Switchboard should be disabled" + ); + + // Then enable it + hoax(socketOwner); + socket.enableSwitchboard(switchboardId); + assertEq( + uint256(socket.isValidSwitchboard(switchboardId)), + uint256(SwitchboardStatus.REGISTERED), + "Switchboard should be enabled" + ); + } + + function testSwitchboardStatusValidation() public { + // Test initial status + assertEq( + uint256(socket.isValidSwitchboard(switchboardId)), + uint256(SwitchboardStatus.REGISTERED), + "Switchboard should be registered initially" + ); + + // Test disabled status + hoax(socketOwner); + socket.disableSwitchboard(switchboardId); + assertEq( + uint256(socket.isValidSwitchboard(switchboardId)), + uint256(SwitchboardStatus.DISABLED), + "Switchboard should be disabled" + ); + + // Test enabled status + hoax(socketOwner); + socket.enableSwitchboard(switchboardId); + + assertEq( + uint256(socket.isValidSwitchboard(switchboardId)), + uint256(SwitchboardStatus.REGISTERED), + "Switchboard should be enabled" + ); + } +} + +/** + * @title SocketSetterTest + * @dev Tests for Socket setter functions + */ +contract SocketSetterTest is SocketTestBase { + function testSetGasLimitBuffer() public { + uint256 newBuffer = 110; + + hoax(testUser); + vm.expectRevert(abi.encodeWithSelector(AccessControl.NoPermit.selector, GOVERNANCE_ROLE)); + socket.setGasLimitBuffer(newBuffer); + + hoax(socketOwner); + socket.setGasLimitBuffer(newBuffer); + + assertEq(socket.gasLimitBuffer(), newBuffer, "Gas limit buffer should be updated"); + } + + function testSetMaxCopyBytes() public { + uint16 newMaxCopyBytes = 4096; + + hoax(testUser); + vm.expectRevert(abi.encodeWithSelector(AccessControl.NoPermit.selector, GOVERNANCE_ROLE)); + socket.setMaxCopyBytes(newMaxCopyBytes); + + hoax(socketOwner); + socket.setMaxCopyBytes(newMaxCopyBytes); + assertEq(socket.maxCopyBytes(), newMaxCopyBytes, "Max copy bytes should be updated"); + } + + function testSetSocketFeeManager() public { + address newFeeManager = address(0x123); + + hoax(testUser); + vm.expectRevert(abi.encodeWithSelector(AccessControl.NoPermit.selector, GOVERNANCE_ROLE)); + socket.setSocketFeeManager(newFeeManager); + + hoax(socketOwner); + socket.setSocketFeeManager(newFeeManager); + assertEq( + address(socket.socketFeeManager()), + newFeeManager, + "Socket fee manager should be updated" + ); + } + + function testConfigurationValidation() public { + // Test initial configuration + assertEq(socket.gasLimitBuffer(), 105, "Initial gas limit buffer should be 105"); + assertEq(socket.maxCopyBytes(), 2048, "Initial max copy bytes should be 2048"); + assertEq( + address(socket.socketFeeManager()), + address(mockFeeManager), + "Initial fee manager should be mockFeeManager" + ); + + MockFeeManager newFeeManager = new MockFeeManager(); + // Test configuration updates + vm.startPrank(socketOwner); + socket.setGasLimitBuffer(120); + socket.setMaxCopyBytes(4096); + socket.setSocketFeeManager(address(newFeeManager)); + vm.stopPrank(); + + assertEq(socket.gasLimitBuffer(), 120, "Gas limit buffer should be updated"); + assertEq(socket.maxCopyBytes(), 4096, "Max copy bytes should be updated"); + assertEq( + address(socket.socketFeeManager()), + address(newFeeManager), + "Fee manager should be updated" + ); + } +} + +/** + * @title SocketDigestTest + * @dev Tests for digest creation functionality + */ +contract SocketDigestTest is SocketTestBase { + function testCreateDigestWithBasicPayload() public view { + ExecuteParams memory params = ExecuteParams({ + target: address(mockTarget), + value: 1 ether, + gasLimit: 100000, + deadline: block.timestamp + 3600, + callType: WRITE, + payload: TEST_PAYLOAD, + payloadPointer: uint160(1), + prevBatchDigestHash: bytes32(uint256(0)), + extraData: bytes("") + }); + + bytes32 digest = socketWrapper.createDigest( + transmitter, + bytes32(uint256(0x123)), + TEST_APP_GATEWAY_ID, + params + ); + + assertTrue(digest != bytes32(0), "Digest should not be zero"); + } + + function testCreateDigestWithLargePayload() public view { + bytes memory largePayload = new bytes(1000); + for (uint256 i = 0; i < 1000; i++) { + largePayload[i] = bytes1(uint8(i % 256)); + } + + ExecuteParams memory params = ExecuteParams({ + target: address(mockTarget), + value: 2 ether, + gasLimit: 200000, + deadline: block.timestamp + 7200, + callType: WRITE, + payload: largePayload, + payloadPointer: uint160(1), + prevBatchDigestHash: bytes32(uint256(0)), + extraData: bytes("") + }); + + bytes32 digest = socketWrapper.createDigest( + transmitter, + bytes32(uint256(0xdef)), + TEST_APP_GATEWAY_ID, + params + ); + + assertTrue(digest != bytes32(0), "Digest should not be zero"); + } + + function testCreateDigestWithDifferentTransmitters() public view { + ExecuteParams memory params = ExecuteParams({ + target: address(mockTarget), + value: 1 ether, + gasLimit: 100000, + deadline: block.timestamp + 3600, + callType: WRITE, + payload: TEST_PAYLOAD, + payloadPointer: uint160(1), + prevBatchDigestHash: bytes32(0), + extraData: bytes("") + }); + + bytes32 digest1 = socketWrapper.createDigest( + transmitter, + bytes32(uint256(0x123)), + TEST_APP_GATEWAY_ID, + params + ); + + bytes32 digest2 = socketWrapper.createDigest( + address(uint160(0x456)), + bytes32(uint256(0x123)), + TEST_APP_GATEWAY_ID, + params + ); + + assertTrue(digest1 != digest2, "Digests should be different for different transmitters"); + } + + function testCreateDigestWithDifferentPayloads() public view { + ExecuteParams memory params1 = ExecuteParams({ + target: address(mockTarget), + value: 1 ether, + gasLimit: 100000, + deadline: block.timestamp + 3600, + callType: WRITE, + payload: TEST_PAYLOAD, + payloadPointer: uint160(1), + prevBatchDigestHash: bytes32(0), + extraData: bytes("") + }); + + ExecuteParams memory params2 = ExecuteParams({ + target: address(mockTarget), + value: 1 ether, + gasLimit: 100000, + deadline: block.timestamp + 3600, + callType: WRITE, + payload: hex"abcdef", + payloadPointer: uint160(1), + prevBatchDigestHash: bytes32(0), + extraData: bytes("") + }); + + bytes32 digest1 = socketWrapper.createDigest( + transmitter, + bytes32(uint256(0x123)), + TEST_APP_GATEWAY_ID, + params1 + ); + + bytes32 digest2 = socketWrapper.createDigest( + transmitter, + bytes32(uint256(0x123)), + TEST_APP_GATEWAY_ID, + params2 + ); + + assertTrue(digest1 != digest2, "Digests should be different for different payloads"); + } +} + +/** + * @title SocketTriggerIdTest + * @dev Tests for trigger ID encoding and uniqueness + */ +contract SocketTriggerIdTest is SocketTestBase { + function testEncodeTriggerIdIncrementsCounter() public { + bytes32 triggerId1 = socketWrapper.encodeTriggerId(); + bytes32 triggerId2 = socketWrapper.encodeTriggerId(); + bytes32 triggerId3 = socketWrapper.encodeTriggerId(); + + assertTrue(triggerId1 != triggerId2, "Trigger IDs should be different"); + assertTrue(triggerId2 != triggerId3, "Trigger IDs should be different"); + assertTrue(triggerId1 != triggerId3, "Trigger IDs should be different"); + } + + function testTriggerIdUniqueness() public { + bytes32[] memory triggerIds = new bytes32[](100); + + for (uint256 i = 0; i < 100; i++) { + triggerIds[i] = socketWrapper.encodeTriggerId(); + } + + // Check for uniqueness + for (uint256 i = 0; i < 100; i++) { + for (uint256 j = i + 1; j < 100; j++) { + assertTrue(triggerIds[i] != triggerIds[j], "Trigger IDs should be unique"); + assertEq( + uint32(uint256(triggerIds[i]) >> 224), + TEST_CHAIN_SLUG, + "Chain slug should match" + ); + assertEq( + address(uint160(uint256(triggerIds[i]) >> 64)), + address(socketWrapper), + "Socket address should match" + ); + assertEq(uint64(uint256(triggerIds[i])), i, "Counter should increment"); + } + } + } + + function testTriggerIdFormatFuzz(uint64 counter) public { + vm.assume(counter < type(uint64).max - 1000); + + // Set the counter to a specific value + uint256 counterSlot = uint256(57); + vm.store(address(socketWrapper), bytes32(counterSlot), bytes32(uint256(counter))); + + bytes32 triggerId = socketWrapper.encodeTriggerId(); + uint32 chainSlugFromId = uint32(uint256(triggerId) >> 224); + address socketAddressFromId = address(uint160(uint256(triggerId) >> 64)); + uint64 counterFromId = uint64(uint256(triggerId)); + + assertEq(chainSlugFromId, TEST_CHAIN_SLUG, "Chain slug should match"); + assertEq(socketAddressFromId, address(socketWrapper), "Socket address should match"); + assertEq(counterFromId, counter, "Counter should match"); + } +} + +/** + * @title SocketSimulationTest + * @dev Tests for simulation functionality + */ +contract SocketSimulationTest is SocketTestBase { + function testSimulationModifier() public { + SocketUtils.SimulateParams[] memory params = new SocketUtils.SimulateParams[](1); + params[0] = SocketUtils.SimulateParams({ + target: address(mockTarget), + value: 1 ether, + gasLimit: 100000, + payload: TEST_PAYLOAD + }); + + // Should revert when called by non-off-chain caller + vm.expectRevert(SocketUtils.OnlyOffChain.selector); + socket.simulate(params); + } +} + +/** + * @title SocketRescueTest + * @dev Tests for rescue functionality + */ +contract SocketRescueTest is SocketTestBase { + function testRescueFunds() public { + // Send some ETH to the socket + vm.deal(address(socket), 10 ether); + uint256 balanceBefore = testUser.balance; + + hoax(socketOwner); + socket.rescueFunds(ETH_ADDRESS, testUser, 5 ether); + + uint256 balanceAfter = testUser.balance; + assertEq(balanceAfter - balanceBefore, 5 ether, "User should receive rescued ETH"); + } + + function testRescueFundsWithNonRescueRole() public { + vm.deal(address(socket), 10 ether); + + hoax(testUser); + vm.expectRevert(abi.encodeWithSelector(AccessControl.NoPermit.selector, RESCUE_ROLE)); + socket.rescueFunds(address(0), testUser, 5 ether); + } + + function testRescueTokenFunds() public { + // Transfer some tokens to the socket + hoax(address(socket)); + mockToken.mint(address(socket), 1000); + uint256 balanceBefore = mockToken.balanceOf(testUser); + + hoax(socketOwner); + socket.rescueFunds(address(mockToken), testUser, 500); + + uint256 balanceAfter = mockToken.balanceOf(testUser); + assertEq(balanceAfter - balanceBefore, 500, "User should receive rescued tokens"); + } + + function testSendingEthToSocket() public { + vm.expectRevert("Socket does not accept ETH"); + address(socket).call{value: 1 ether}(""); + } + + function testRescueNFT() public { + // Deploy a mock ERC721 NFT and mint one to this test contract + MockERC721 mockNFT = new MockERC721(); + uint256 tokenId = 1; + mockNFT.mint(address(socket), tokenId); + assertEq(mockNFT.ownerOf(tokenId), address(socket), "Socket should own the NFT"); + + hoax(socketOwner); + socket.rescueFunds(address(mockNFT), testUser, tokenId); + + // Check that testUser received the NFT + assertEq(mockNFT.ownerOf(tokenId), testUser, "User should receive rescued NFT"); + } +} + +/** + * @title SocketFeeManagerTest + * @dev Tests for fee manager functionality + */ +contract SocketFeeManagerTest is SocketTestBase { + function testFeeCollectedIfExecutionSuccess() public { + // Set up execution parameters with fees + executeParams.gasLimit = 100000; + transmissionParams.socketFees = 0.1 ether; + executeParams.target = address(mockPlug); + executeParams.payload = abi.encodeWithSelector( + mockPlug.processPayload.selector, + TEST_PAYLOAD + ); + + // Execute with fees + socket.execute{value: 1.1 ether}(executeParams, transmissionParams); + + // Check that fees were collected + assertEq(address(mockFeeManager).balance, 0.1 ether, "Fee manager should receive fees"); + } + + function testGasUsage() public { + executeParams.gasLimit = 100000; + transmissionParams.socketFees = 0.1 ether; + + // mockSwitchboard.setTransmitter(transmitter); + + uint256 gasBefore = gasleft(); + + vm.deal(testUser, 2 ether); + hoax(testUser); + socket.execute{value: 1.1 ether}(executeParams, transmissionParams); + + uint256 gasUsed = gasBefore - gasleft(); + console.log("Gas used for execution with fees:", gasUsed); + + assertTrue(gasUsed > 0, "Gas should be used"); + } + + function testFeeManagerNotSet() public { + // Remove fee manager + hoax(socketOwner); + socket.setSocketFeeManager(address(0)); + + executeParams.gasLimit = 100000; + transmissionParams.socketFees = 0.1 ether; + + // mockSwitchboard.setTransmitter(transmitter); + + // Should still execute successfully without fee manager + vm.deal(testUser, 2 ether); + hoax(testUser); + socket.execute{value: 1.1 ether}(executeParams, transmissionParams); + + // No fees should be collected + assertEq( + address(mockFeeManager).balance, + 0, + "No fees should be collected when fee manager is not set" + ); + } +} From 918b71a8821c47f8e2116fc1ff87fc937f64a9d9 Mon Sep 17 00:00:00 2001 From: Ameesha Agrawal Date: Thu, 31 Jul 2025 14:20:52 +0530 Subject: [PATCH 4/6] test: utils --- test/Utils.t.sol | 149 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 test/Utils.t.sol diff --git a/test/Utils.t.sol b/test/Utils.t.sol new file mode 100644 index 00000000..7cba44a8 --- /dev/null +++ b/test/Utils.t.sol @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: GPL-3.0-only +pragma solidity ^0.8.21; + +import "forge-std/Test.sol"; +import "../contracts/utils/common/IdUtils.sol"; +import "../contracts/utils/common/Converters.sol"; + +/** + * @title IdUtilsTest + * @dev Tests for IdUtils utility functions + */ +contract IdUtilsTest is Test { + function testCreatePayloadId() public { + uint160 payloadPointer = 12345; + uint64 switchboardId = 67890; + uint32 chainSlug = 1; + + bytes32 payloadId = createPayloadId(payloadPointer, switchboardId, chainSlug); + + // Verify the structure + uint32 chainSlugFromId = uint32(uint256(payloadId) >> 224); + uint64 switchboardIdFromId = uint64(uint256(payloadId) >> 160); + uint160 payloadPointerFromId = uint160(uint256(payloadId)); + + assertEq(chainSlugFromId, chainSlug, "Chain slug should match"); + assertEq(switchboardIdFromId, switchboardId, "Switchboard ID should match"); + assertEq(payloadPointerFromId, payloadPointer, "Payload pointer should match"); + } + + function testCreatePayloadIdWithZeroValues() public { + bytes32 payloadId = createPayloadId(0, 0, 0); + + assertEq(payloadId, bytes32(0), "Payload ID should be zero for zero inputs"); + } + + function testCreatePayloadIdWithMaxValues() public { + uint160 maxPayloadPointer = type(uint160).max; + uint64 maxSwitchboardId = type(uint64).max; + uint32 maxChainSlug = type(uint32).max; + + bytes32 payloadId = createPayloadId(maxPayloadPointer, maxSwitchboardId, maxChainSlug); + + // Verify the structure + uint32 chainSlugFromId = uint32(uint256(payloadId) >> 224); + uint64 switchboardIdFromId = uint64(uint256(payloadId) >> 160); + uint160 payloadPointerFromId = uint160(uint256(payloadId)); + + assertEq(chainSlugFromId, maxChainSlug, "Chain slug should match"); + assertEq(switchboardIdFromId, maxSwitchboardId, "Switchboard ID should match"); + assertEq(payloadPointerFromId, maxPayloadPointer, "Payload pointer should match"); + } + + function testCreatePayloadIdFuzz( + uint160 payloadPointer, + uint64 switchboardId, + uint32 chainSlug + ) public { + bytes32 payloadId = createPayloadId(payloadPointer, switchboardId, chainSlug); + + // Verify the structure + uint32 chainSlugFromId = uint32(uint256(payloadId) >> 224); + uint64 switchboardIdFromId = uint64(uint256(payloadId) >> 160); + uint160 payloadPointerFromId = uint160(uint256(payloadId)); + + assertEq(chainSlugFromId, chainSlug, "Chain slug should match"); + assertEq(switchboardIdFromId, switchboardId, "Switchboard ID should match"); + assertEq(payloadPointerFromId, payloadPointer, "Payload pointer should match"); + } +} + +/** + * @title ConvertersTest + * @dev Tests for Converters utility functions + */ +contract ConvertersTest is Test { + function testToBytes32Format() public { + address testAddr = address(0x1234567890123456789012345678901234567890); + bytes32 result = toBytes32Format(testAddr); + + assertEq(result, bytes32(uint256(uint160(testAddr))), "Conversion should be correct"); + } + + function testToBytes32FormatWithZeroAddress() public { + bytes32 result = toBytes32Format(address(0)); + + assertEq(result, bytes32(0), "Zero address should convert to zero bytes32"); + } + + function testFromBytes32Format() public { + address originalAddr = address(0x1234567890123456789012345678901234567890); + bytes32 bytes32Format = toBytes32Format(originalAddr); + + address convertedAddr = fromBytes32Format(bytes32Format); + + assertEq(convertedAddr, originalAddr, "Conversion should be reversible"); + } + + function testFromBytes32FormatWithZeroAddress() public { + bytes32 zeroBytes32 = bytes32(0); + address convertedAddr = fromBytes32Format(zeroBytes32); + + assertEq(convertedAddr, address(0), "Zero bytes32 should convert to zero address"); + } + + function testFromBytes32FormatWithInvalidAddress() public { + // Create a bytes32 with non-zero upper bits + bytes32 invalidBytes32 = bytes32(uint256(1) << 160); + + try this.fromBytes32FormatWrapper(invalidBytes32) { + fail(); + } catch { + // Expected to revert + } + } + + function fromBytes32FormatWrapper( + bytes32 bytes32FormatAddress + ) external pure returns (address) { + return fromBytes32Format(bytes32FormatAddress); + } + + function testFromBytes32FormatWithMaxValidAddress() public { + address maxAddr = address(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF); + bytes32 bytes32Format = toBytes32Format(maxAddr); + + address convertedAddr = fromBytes32Format(bytes32Format); + + assertEq(convertedAddr, maxAddr, "Max address should convert correctly"); + } + + function testConvertersRoundTrip() public { + address originalAddr = address(0xabCDEF1234567890ABcDEF1234567890aBCDeF12); + + bytes32 bytes32Format = toBytes32Format(originalAddr); + address convertedAddr = fromBytes32Format(bytes32Format); + + assertEq(convertedAddr, originalAddr, "Round trip conversion should work"); + } + + function testConvertersFuzz(address addr) public { + // Skip addresses that would cause overflow + vm.assume(uint256(uint160(addr)) <= type(uint160).max); + + bytes32 bytes32Format = toBytes32Format(addr); + address convertedAddr = fromBytes32Format(bytes32Format); + + assertEq(convertedAddr, addr, "Fuzz test should pass"); + } +} From 8c00261fc8a2dcabe117575fa8c4f8ae96c96f73 Mon Sep 17 00:00:00 2001 From: Ameesha Agrawal Date: Thu, 31 Jul 2025 17:21:48 +0530 Subject: [PATCH 5/6] fix: rescue nft --- contracts/utils/RescueFundsLib.sol | 25 +++++++++- test/SetupTest.t.sol | 42 +++++++++++----- test/apps/ParallelCounter.t.sol | 78 ++++++++++++++++++------------ 3 files changed, 100 insertions(+), 45 deletions(-) diff --git a/contracts/utils/RescueFundsLib.sol b/contracts/utils/RescueFundsLib.sol index 738ec796..d2921df9 100644 --- a/contracts/utils/RescueFundsLib.sol +++ b/contracts/utils/RescueFundsLib.sol @@ -5,6 +5,19 @@ import "solady/utils/SafeTransferLib.sol"; import {ZeroAddress, InvalidTokenAddress} from "./common/Errors.sol"; import {ETH_ADDRESS} from "./common/Constants.sol"; +interface IERC721 { + function safeTransferFrom( + address from, + address to, + uint256 tokenId, + bytes calldata data + ) external; +} + +interface IERC165 { + function supportsInterface(bytes4 interfaceId) external view returns (bool); +} + /** * @title RescueFundsLib * @dev A library that provides a function to rescue funds from a contract. @@ -23,7 +36,17 @@ library RescueFundsLib { SafeTransferLib.forceSafeTransferETH(rescueTo_, amount_); } else { if (token_.code.length == 0) revert InvalidTokenAddress(); - SafeTransferLib.safeTransfer(token_, rescueTo_, amount_); + + // Identify if a token is an NFT (ERC721) by checking if it supports the ERC721 interface + try IERC165(token_).supportsInterface(0x80ac58cd) returns (bool isERC721) { + if (isERC721) + IERC721(token_).safeTransferFrom(address(this), rescueTo_, amount_, ""); + else { + SafeTransferLib.safeTransfer(token_, rescueTo_, amount_); + } + } catch { + SafeTransferLib.safeTransfer(token_, rescueTo_, amount_); + } } } } diff --git a/test/SetupTest.t.sol b/test/SetupTest.t.sol index 303935a2..434f6f46 100644 --- a/test/SetupTest.t.sol +++ b/test/SetupTest.t.sol @@ -1130,21 +1130,45 @@ contract WatcherSetup is AuctionSetup { ) internal { // Count valid plugs first. In some cases we might have contractIds such that oly a subset is // deployed on a chain. for ex, vault on source, and supertoken on destination. - uint256 validPlugCount = 0; + uint256 validPlugCount = _countValidPlugs(appGateway_, contractIds_, chainSlug_); + + // Create array with exact size needed + AppGatewayConfig[] memory configs = new AppGatewayConfig[](validPlugCount); + _populateConfigs(configs, appGateway_, contractIds_, chainSlug_); + + // Only call watcher if we have valid configs + if (validPlugCount > 0) { + watcherMultiCall( + address(configurations), + abi.encodeWithSelector(Configurations.setAppGatewayConfigs.selector, configs) + ); + } + } + + function _countValidPlugs( + IAppGateway appGateway_, + bytes32[] memory contractIds_, + uint32 chainSlug_ + ) internal view returns (uint256 validCount) { for (uint i = 0; i < contractIds_.length; i++) { bytes32 plug = appGateway_.getOnChainAddress(contractIds_[i], chainSlug_); if (plug != bytes32(0)) { - validPlugCount++; + validCount++; } } + } - // Create array with exact size needed - AppGatewayConfig[] memory configs = new AppGatewayConfig[](validPlugCount); + function _populateConfigs( + AppGatewayConfig[] memory configs, + IAppGateway appGateway_, + bytes32[] memory contractIds_, + uint32 chainSlug_ + ) internal view { uint256 configIndex = 0; + uint64 switchboardId = configurations.switchboards(chainSlug_, appGateway_.sbType()); for (uint i = 0; i < contractIds_.length; i++) { bytes32 plug = appGateway_.getOnChainAddress(contractIds_[i], chainSlug_); - uint64 switchboardId = configurations.switchboards(chainSlug_, appGateway_.sbType()); if (plug != bytes32(0)) { configs[configIndex] = AppGatewayConfig({ plug: plug, @@ -1157,14 +1181,6 @@ contract WatcherSetup is AuctionSetup { configIndex++; } } - - // Only call watcher if we have valid configs - if (validPlugCount > 0) { - watcherMultiCall( - address(configurations), - abi.encodeWithSelector(Configurations.setAppGatewayConfigs.selector, configs) - ); - } } } diff --git a/test/apps/ParallelCounter.t.sol b/test/apps/ParallelCounter.t.sol index 51763bb3..b71587fd 100644 --- a/test/apps/ParallelCounter.t.sol +++ b/test/apps/ParallelCounter.t.sol @@ -37,6 +37,11 @@ contract ParallelCounterTest is AppGatewayBaseSetup { chainSlugs[1] = optChainSlug; deployCounterApp(chainSlugs); + // Break down into smaller functions to avoid stack too deep + _testForwarderAddresses(); + } + + function _testForwarderAddresses() internal view { (bytes32 onChainArb1, address forwarderArb1) = getOnChainAndForwarderAddresses( arbChainSlug, counterId1, @@ -61,44 +66,55 @@ contract ParallelCounterTest is AppGatewayBaseSetup { parallelCounterGateway ); - assertEq( - IForwarder(forwarderArb1).getChainSlug(), - arbChainSlug, - "Forwarder chainSlug should be correct" - ); - assertEq( - IForwarder(forwarderArb1).getOnChainAddress(), + _assertForwarderAddresses( + forwarderArb1, onChainArb1, - "Forwarder onChainAddress should be correct" - ); - assertEq( - IForwarder(forwarderOpt1).getChainSlug(), - optChainSlug, - "Forwarder chainSlug should be correct" - ); - assertEq( - IForwarder(forwarderOpt1).getOnChainAddress(), - onChainOpt1, - "Forwarder onChainAddress should be correct" - ); - assertEq( - IForwarder(forwarderArb2).getChainSlug(), arbChainSlug, - "Forwarder chainSlug should be correct" - ); - assertEq( - IForwarder(forwarderArb2).getOnChainAddress(), + forwarderOpt1, + onChainOpt1, + optChainSlug, + forwarderArb2, onChainArb2, - "Forwarder onChainAddress should be correct" + arbChainSlug, + forwarderOpt2, + onChainOpt2, + optChainSlug ); + } + + function _assertForwarderAddresses( + address forwarderArb1, + bytes32 onChainArb1, + uint32 arbChainSlug, + address forwarderOpt1, + bytes32 onChainOpt1, + uint32 optChainSlug, + address forwarderArb2, + bytes32 onChainArb2, + uint32 arbChainSlug2, + address forwarderOpt2, + bytes32 onChainOpt2, + uint32 optChainSlug2 + ) internal view { + _assertForwarderAddress(forwarderArb1, onChainArb1, arbChainSlug); + _assertForwarderAddress(forwarderOpt1, onChainOpt1, optChainSlug); + _assertForwarderAddress(forwarderArb2, onChainArb2, arbChainSlug2); + _assertForwarderAddress(forwarderOpt2, onChainOpt2, optChainSlug2); + } + + function _assertForwarderAddress( + address forwarder, + bytes32 onChainAddress, + uint32 chainSlug + ) internal view { assertEq( - IForwarder(forwarderOpt2).getOnChainAddress(), - onChainOpt2, - "Forwarder onChainAddress should be correct" + IForwarder(forwarder).getChainSlug(), + chainSlug, + "Forwarder chainSlug should be correct" ); assertEq( - IForwarder(forwarderOpt2).getOnChainAddress(), - onChainOpt2, + IForwarder(forwarder).getOnChainAddress(), + onChainAddress, "Forwarder onChainAddress should be correct" ); } From c6b8b8506a7de80a4b0e9471b7463a7690c2fc56 Mon Sep 17 00:00:00 2001 From: Ameesha Agrawal Date: Mon, 4 Aug 2025 15:03:37 +0530 Subject: [PATCH 6/6] test: memory to storage --- contracts/protocol/Socket.sol | 7 +++---- test/Socket.t.sol | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/contracts/protocol/Socket.sol b/contracts/protocol/Socket.sol index f08e44cc..57e5e05e 100644 --- a/contracts/protocol/Socket.sol +++ b/contracts/protocol/Socket.sol @@ -67,7 +67,7 @@ contract Socket is SocketUtils { // check if the call type is valid if (executeParams_.callType != WRITE) revert InvalidCallType(); - PlugConfigEvm memory plugConfig = _plugConfigs[executeParams_.target]; + PlugConfigEvm storage plugConfig = _plugConfigs[executeParams_.target]; // check if the plug is disconnected if (plugConfig.appGatewayId == bytes32(0)) revert PlugNotFound(); @@ -157,9 +157,8 @@ contract Socket is SocketUtils { } else { payloadExecuted[payloadId_] = ExecutionStatus.Reverted; - address receiver = transmissionParams_.refundAddress == address(0) - ? msg.sender - : transmissionParams_.refundAddress; + address receiver = transmissionParams_.refundAddress; + if (receiver == address(0)) receiver = msg.sender; SafeTransferLib.forceSafeTransferETH(receiver, msg.value); emit ExecutionFailed(payloadId_, exceededMaxCopy, returnData); } diff --git a/test/Socket.t.sol b/test/Socket.t.sol index 190cdcb3..10bcf4e0 100644 --- a/test/Socket.t.sol +++ b/test/Socket.t.sol @@ -869,7 +869,7 @@ contract SocketRescueTest is SocketTestBase { function testSendingEthToSocket() public { vm.expectRevert("Socket does not accept ETH"); - address(socket).call{value: 1 ether}(""); + (bool success, ) = address(socket).call{value: 1 ether}(""); } function testRescueNFT() public {