diff --git a/src/libraries/ModuleStorageLib.sol b/src/libraries/ModuleStorageLib.sol new file mode 100644 index 00000000..160b247e --- /dev/null +++ b/src/libraries/ModuleStorageLib.sol @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: GPL-3.0 +pragma solidity ^0.8.20; + +type StoragePointer is bytes32; + +/// @title Module Storage Library +/// @notice Library for allocating and accessing ERC-4337 address-associated storage within modules. +library ModuleStorageLib { + /// @notice Allocates a memory buffer for an associated storage key, and sets the associated address and batch + /// index. + /// @param addr The address to associate with the storage key. + /// @param batchIndex The batch index to associate with the storage key. + /// @param keySize The size of the key in words, where each word is 32 bytes. Not inclusive of the address and + /// batch index. + /// @return key The allocated memory buffer. + function allocateAssociatedStorageKey(address addr, uint256 batchIndex, uint8 keySize) + internal + pure + returns (bytes memory key) + { + /// @solidity memory-safe-assembly + assembly { + // Clear any dirty upper bits of keySize to prevent overflow + keySize := and(keySize, 0xff) + + // compute the total size of the buffer, include the address and batch index + let totalSize := add(64, mul(32, keySize)) + + // Allocate memory for the key + key := mload(0x40) + mstore(0x40, add(add(key, totalSize), 32)) + mstore(key, totalSize) + + // Clear any dirty upper bits of address + addr := and(addr, 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF) + // Store the address and batch index in the key buffer + mstore(add(key, 32), addr) + mstore(add(key, 64), batchIndex) + } + } + + function associatedStorageLookup(bytes memory key, bytes32 input) internal pure returns (StoragePointer ptr) { + /// @solidity memory-safe-assembly + assembly { + mstore(add(key, 96), input) + ptr := keccak256(add(key, 32), mload(key)) + } + } + + function associatedStorageLookup(bytes memory key, bytes32 input1, bytes32 input2) + internal + pure + returns (StoragePointer ptr) + { + /// @solidity memory-safe-assembly + assembly { + mstore(add(key, 96), input1) + mstore(add(key, 128), input2) + ptr := keccak256(add(key, 32), mload(key)) + } + } +} diff --git a/test/libraries/ModuleStorageLib.t.sol b/test/libraries/ModuleStorageLib.t.sol new file mode 100644 index 00000000..e2992c1e --- /dev/null +++ b/test/libraries/ModuleStorageLib.t.sol @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: GPL-3.0 +pragma solidity ^0.8.20; + +import {Test} from "forge-std/Test.sol"; + +import {ModuleStorageLib, StoragePointer} from "../../src/libraries/ModuleStorageLib.sol"; + +contract ModuleStorageLibTest is Test { + using ModuleStorageLib for bytes; + using ModuleStorageLib for bytes32; + + uint256 public constant FUZZ_ARR_SIZE = 32; + + address public account1; + + struct TestStruct { + uint256 a; + uint256 b; + } + + function setUp() public { + account1 = makeAddr("account1"); + } + + function test_storagePointer() public { + bytes memory key = ModuleStorageLib.allocateAssociatedStorageKey(account1, 0, 1); + + StoragePointer ptr = ModuleStorageLib.associatedStorageLookup( + key, hex"00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff" + ); + TestStruct storage val = _castPtrToStruct(ptr); + + vm.record(); + val.a = 0xdeadbeef; + val.b = 123; + (, bytes32[] memory accountWrites) = vm.accesses(address(this)); + + // printStorageReadsAndWrites(address(this)); + + assertEq(accountWrites.length, 2); + bytes32 expectedKey = keccak256( + abi.encodePacked( + uint256(uint160(account1)), + uint256(0), + hex"00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff" + ) + ); + assertEq(accountWrites[0], expectedKey); + assertEq(vm.load(address(this), expectedKey), bytes32(uint256(0xdeadbeef))); + assertEq(accountWrites[1], bytes32(uint256(expectedKey) + 1)); + assertEq(vm.load(address(this), bytes32(uint256(expectedKey) + 1)), bytes32(uint256(123))); + } + + function testFuzz_storagePointer( + address account, + uint256 batchIndex, + bytes32 inputKey, + uint256[FUZZ_ARR_SIZE] calldata values + ) public { + bytes memory key = ModuleStorageLib.allocateAssociatedStorageKey(account, batchIndex, 1); + uint256[FUZZ_ARR_SIZE] storage val = + _castPtrToArray(ModuleStorageLib.associatedStorageLookup(key, inputKey)); + // Write values to storage + vm.record(); + for (uint256 i = 0; i < FUZZ_ARR_SIZE; i++) { + val[i] = values[i]; + } + // Assert the writes took place in the right location, and the correct value is stored there + (, bytes32[] memory accountWrites) = vm.accesses(address(this)); + assertEq(accountWrites.length, FUZZ_ARR_SIZE); + for (uint256 i = 0; i < FUZZ_ARR_SIZE; i++) { + bytes32 expectedKey = bytes32( + uint256(keccak256(abi.encodePacked(uint256(uint160(account)), uint256(batchIndex), inputKey))) + i + ); + assertEq(accountWrites[i], expectedKey); + assertEq(vm.load(address(this), expectedKey), bytes32(uint256(values[i]))); + } + } + + function _castPtrToArray(StoragePointer ptr) internal pure returns (uint256[FUZZ_ARR_SIZE] storage val) { + /// @solidity memory-safe-assembly + assembly { + val.slot := ptr + } + } + + function _castPtrToStruct(StoragePointer ptr) internal pure returns (TestStruct storage val) { + /// @solidity memory-safe-assembly + assembly { + val.slot := ptr + } + } +}