From 7434da3994997b61853c39e4338e285e8d38a8a2 Mon Sep 17 00:00:00 2001 From: Jay Paik Date: Wed, 10 Jan 2024 03:00:36 -0500 Subject: [PATCH] feat: add PluginStorageLib, AssociatedLinkedListSetLib --- src/libraries/AssociatedLinkedListSetLib.sol | 507 ++++++++++++++++++ src/libraries/PluginStorageLib.sol | 59 ++ .../AssociatedLinkedListSetLib.t.sol | 220 ++++++++ test/libraries/PluginStorageLib.t.sol | 88 +++ 4 files changed, 874 insertions(+) create mode 100644 src/libraries/AssociatedLinkedListSetLib.sol create mode 100644 src/libraries/PluginStorageLib.sol create mode 100644 test/libraries/AssociatedLinkedListSetLib.t.sol create mode 100644 test/libraries/PluginStorageLib.t.sol diff --git a/src/libraries/AssociatedLinkedListSetLib.sol b/src/libraries/AssociatedLinkedListSetLib.sol new file mode 100644 index 00000000..56a34289 --- /dev/null +++ b/src/libraries/AssociatedLinkedListSetLib.sol @@ -0,0 +1,507 @@ +// SPDX-License-Identifier: GPL-3.0 +pragma solidity ^0.8.19; + +type SetValue is bytes30; + +/// @dev The sentinel value is used to indicate the head and tail of the list. +bytes32 constant SENTINEL_VALUE = bytes32(uint256(1)); + +/// @dev Removing the last element will result in this flag not being set correctly, but all operations will +/// function normally, albeit with one extra sload for getAll. +bytes32 constant HAS_NEXT_FLAG = bytes32(uint256(2)); + +/// @dev Type representing the set, which is just a storage slot placeholder like the solidity mapping type. +struct AssociatedLinkedListSet { + bytes32 placeholder; +} + +/// @title Associated Linked List Set Library +/// @notice Provides a set data structure that is enumerable and held in address-associated storage (per the +/// ERC-4337 spec) +library AssociatedLinkedListSetLib { + // Mapping Entry Byte Layout + // | value | 0xAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA____ | + // | meta | 0x____________________________________________________________BBBB | + + // Bit-layout of the meta bytes (2 bytes) + // | user flags | 11111111 11111100 | + // | has next | 00000000 00000010 | + // | sentinel | 00000000 00000001 | + + // Mapping keys exclude the upper 15 bits of the meta bytes, which allows keys to be either a value or the + // sentinel. + + bytes4 internal constant _ASSOCIATED_STORAGE_PREFIX = 0x9cc6c923; // bytes4(keccak256("AssociatedLinkedListSet")) + + // A custom type representing the index of a storage slot + type StoragePointer is bytes32; + + // A custom type representing a pointer to a location in memory beyond the current free memory pointer. + // Holds a fixed-size buffer similar to "bytes memory", but without a length field. + // Care must be taken when using these, as they may be overwritten if ANY memory is allocated after allocating + // a TempBytesMemory. + type TempBytesMemory is bytes32; + + // INTERNAL METHODS + + /// @notice Adds a value to a set. + /// @param set The set to add the value to. + /// @param associated The address the set is associated with. + /// @param value The value to add. + /// @return True if the value was added, false if the value cannot be added (already exists or is zero). + function tryAdd(AssociatedLinkedListSet storage set, address associated, SetValue value) + internal + returns (bool) + { + bytes32 unwrappedKey = bytes32(SetValue.unwrap(value)); + if (unwrappedKey == bytes32(0)) { + // Cannot add the zero value + return false; + } + + TempBytesMemory keyBuffer = _allocateTempKeyBuffer(set, associated); + + StoragePointer valueSlot = _mapLookup(keyBuffer, unwrappedKey); + if (_load(valueSlot) != bytes32(0)) { + // Entry already exists + return false; + } + + // Load the head of the set + StoragePointer sentinelSlot = _mapLookup(keyBuffer, SENTINEL_VALUE); + bytes32 prev = _load(sentinelSlot); + if (prev == bytes32(0) || isSentinel(prev)) { + // set is empty, need to do: + // map[SENTINEL_VALUE] = unwrappedKey; + // map[unwrappedKey] = SENTINEL_VALUE; + _store(sentinelSlot, unwrappedKey); + _store(valueSlot, SENTINEL_VALUE); + } else { + // set is not empty, need to do: + // map[SENTINEL_VALUE] = unwrappedKey | HAS_NEXT_FLAG; + // map[unwrappedKey] = prev; + _store(sentinelSlot, unwrappedKey | HAS_NEXT_FLAG); + _store(valueSlot, prev); + } + + return true; + } + + /// @notice Removes a value from a set. + /// @dev This is an O(n) operation, where n is the number of elements in the set. + /// @param set The set to remove the value from + /// @param associated The address the set is associated with + /// @param value The value to remove + /// @return True if the value was removed, false if the value does not exist + function tryRemove(AssociatedLinkedListSet storage set, address associated, SetValue value) + internal + returns (bool) + { + bytes32 unwrappedKey = bytes32(SetValue.unwrap(value)); + TempBytesMemory keyBuffer = _allocateTempKeyBuffer(set, associated); + + StoragePointer valueSlot = _mapLookup(keyBuffer, unwrappedKey); + bytes32 nextValue = _load(valueSlot); + if (unwrappedKey == bytes32(0) || nextValue == bytes32(0)) { + // Entry does not exist + return false; + } + + bytes32 prevKey = SENTINEL_VALUE; + bytes32 currentVal; + do { + // Load the current entry + StoragePointer prevSlot = _mapLookup(keyBuffer, prevKey); + currentVal = _load(prevSlot); + bytes32 currentKey = clearFlags(currentVal); + if (currentKey == unwrappedKey) { + // Found the entry + // Set the previous value's next value to the next value, + // and the flags to the current value's flags. + // and the next value's `hasNext` flag to determine whether or not the next value is (or points to) + // the sentinel value. + + // Need to do: + // map[prevKey] = clearFlags(nextValue) | getUserFlags(currentVal) | (nextValue & HAS_NEXT_FLAG); + // map[currentKey] = bytes32(0); + + _store(prevSlot, clearFlags(nextValue) | getUserFlags(currentVal) | (nextValue & HAS_NEXT_FLAG)); + _store(valueSlot, bytes32(0)); + + return true; + } + prevKey = currentKey; + } while (!isSentinel(currentVal) && currentVal != bytes32(0)); + return false; + } + + /// @notice Removes a value from a set, given the previous value in the set. + /// @dev This is an O(1) operation but requires additional knowledge. + /// @param set The set to remove the value from + /// @param associated The address the set is associated with + /// @param value The value to remove + /// @param prev The previous value in the set + /// @return True if the value was removed, false if the value does not exist + function tryRemoveKnown(AssociatedLinkedListSet storage set, address associated, SetValue value, bytes32 prev) + internal + returns (bool) + { + bytes32 unwrappedKey = bytes32(SetValue.unwrap(value)); + TempBytesMemory keyBuffer = _allocateTempKeyBuffer(set, associated); + + prev = clearFlags(prev); + + if (prev == bytes32(0) || unwrappedKey == bytes32(0)) { + return false; + } + + // assert that the previous key's next value is the value to be removed + StoragePointer prevSlot = _mapLookup(keyBuffer, prev); + bytes32 currentValue = _load(prevSlot); + if (clearFlags(currentValue) != unwrappedKey) { + return false; + } + + StoragePointer valueSlot = _mapLookup(keyBuffer, unwrappedKey); + bytes32 next = _load(valueSlot); + if (next == bytes32(0)) { + // The set didn't actually contain the value + return false; + } + + // Need to do: + // map[prev] = clearFlags(next) | getUserFlags(currentValue) | (next & HAS_NEXT_FLAG); + // map[unwrappedKey] = bytes32(0); + _store(prevSlot, clearFlags(next) | getUserFlags(currentValue) | (next & HAS_NEXT_FLAG)); + _store(valueSlot, bytes32(0)); + + return true; + } + + /// @notice Removes all values from a set. + /// @dev This is an O(n) operation, where n is the number of elements in the set. + /// @param set The set to remove the values from + /// @param associated The address the set is associated with + function clear(AssociatedLinkedListSet storage set, address associated) internal { + TempBytesMemory keyBuffer = _allocateTempKeyBuffer(set, associated); + + bytes32 cursor = SENTINEL_VALUE; + + do { + bytes32 cleared = clearFlags(cursor); + StoragePointer cursorSlot = _mapLookup(keyBuffer, cleared); + bytes32 next = _load(cursorSlot); + _store(cursorSlot, bytes32(0)); + cursor = next; + } while (!isSentinel(cursor) && cursor != bytes32(0)); + + StoragePointer sentinelSlot = _mapLookup(keyBuffer, SENTINEL_VALUE); + _store(sentinelSlot, bytes32(0)); + } + + /// @notice Set the flags on a value in the set. + /// @dev The user flags can only be set on the upper 14 bits, because the lower two are reserved for the + /// sentinel and has next bit. + /// @param set The set containing the value. + /// @param associated The address the set is associated with. + /// @param value The value to set the flags on. + /// @param flags The flags to set. + /// @return True if the set contains the value and the operation succeeds, false otherwise. + function trySetFlags(AssociatedLinkedListSet storage set, address associated, SetValue value, uint16 flags) + internal + returns (bool) + { + bytes32 unwrappedKey = SetValue.unwrap(value); + TempBytesMemory keyBuffer = _allocateTempKeyBuffer(set, associated); + + // Ignore the lower 2 bits. + flags &= 0xFFFC; + + // If the set doesn't actually contain the value, return false; + StoragePointer valueSlot = _mapLookup(keyBuffer, unwrappedKey); + bytes32 next = _load(valueSlot); + if (next == bytes32(0)) { + return false; + } + + // Set the flags + _store(valueSlot, clearUserFlags(next) | bytes32(uint256(flags))); + + return true; + } + + /// @notice Set the given flags on a value in the set, preserving the values of other flags. + /// @dev The user flags can only be set on the upper 14 bits, because the lower two are reserved for the + /// sentinel and has next bit. + /// Short-circuits if the flags are already enabled, returning true. + /// @param set The set containing the value. + /// @param associated The address the set is associated with. + /// @param value The value to enable the flags on. + /// @param flags The flags to enable. + /// @return True if the operation succeeds or short-circuits due to the flags already being enabled. False + /// otherwise. + function tryEnableFlags(AssociatedLinkedListSet storage set, address associated, SetValue value, uint16 flags) + internal + returns (bool) + { + flags &= 0xFFFC; // Allow short-circuit if lower bits are accidentally set + uint16 currFlags = getFlags(set, associated, value); + if (currFlags & flags == flags) return true; // flags are already enabled + return trySetFlags(set, associated, value, currFlags | flags); + } + + /// @notice Clear the given flags on a value in the set, preserving the values of other flags. + /// @notice If the value is not in the set, this function will still return true. + /// @dev The user flags can only be set on the upper 14 bits, because the lower two are reserved for the + /// sentinel and has next bit. + /// Short-circuits if the flags are already disabled, or if set does not contain the value. Short-circuits + /// return true. + /// @param set The set containing the value. + /// @param associated The address the set is associated with. + /// @param value The value to disable the flags on. + /// @param flags The flags to disable. + /// @return True if the operation succeeds, or short-circuits due to the flags already being disabled or if the + /// set does not contain the value. False otherwise. + function tryDisableFlags(AssociatedLinkedListSet storage set, address associated, SetValue value, uint16 flags) + internal + returns (bool) + { + flags &= 0xFFFC; // Allow short-circuit if lower bits are accidentally set + uint16 currFlags = getFlags(set, associated, value); + if (currFlags & flags == 0) return true; // flags are already disabled + return trySetFlags(set, associated, value, currFlags & ~flags); + } + + /// @notice Checks if a set contains a value + /// @dev This method does not clear the upper bits of `value`, that is expected to be done as part of casting + /// to the correct type. If this function is provided the sentinel value by using the upper bits, this function + /// may returns `true`. + /// @param set The set to check + /// @param associated The address the set is associated with + /// @param value The value to check for + /// @return True if the set contains the value, false otherwise + function contains(AssociatedLinkedListSet storage set, address associated, SetValue value) + internal + view + returns (bool) + { + bytes32 unwrappedKey = bytes32(SetValue.unwrap(value)); + TempBytesMemory keyBuffer = _allocateTempKeyBuffer(set, associated); + + StoragePointer slot = _mapLookup(keyBuffer, unwrappedKey); + return _load(slot) != bytes32(0); + } + + /// @notice Checks if a set is empty + /// @param set The set to check + /// @param associated The address the set is associated with + /// @return True if the set is empty, false otherwise + function isEmpty(AssociatedLinkedListSet storage set, address associated) internal view returns (bool) { + TempBytesMemory keyBuffer = _allocateTempKeyBuffer(set, associated); + + StoragePointer sentinelSlot = _mapLookup(keyBuffer, SENTINEL_VALUE); + bytes32 val = _load(sentinelSlot); + return val == bytes32(0) || isSentinel(val); // either the sentinel is unset, or points to itself + } + + /// @notice Get the flags on a value in the set. + /// @dev The reserved lower 2 bits will not be returned, as those are reserved for the sentinel and has next + /// bit. + /// @param set The set containing the value. + /// @param associated The address the set is associated with. + /// @param value The value to get the flags from. + /// @return The flags set on the value. + function getFlags(AssociatedLinkedListSet storage set, address associated, SetValue value) + internal + view + returns (uint16) + { + bytes32 unwrappedKey = SetValue.unwrap(value); + TempBytesMemory keyBuffer = _allocateTempKeyBuffer(set, associated); + return uint16(uint256(_load(_mapLookup(keyBuffer, unwrappedKey))) & 0xFFFC); + } + + /// @notice Check if the flags on a value are enabled. + /// @dev The reserved lower 2 bits will be ignored, as those are reserved for the sentinel and has next bit. + /// @param set The set containing the value. + /// @param associated The address the set is associated with. + /// @param value The value to check the flags on. + /// @param flags The flags to check. + /// @return True if all of the flags are enabled, false otherwise. + function flagsEnabled(AssociatedLinkedListSet storage set, address associated, SetValue value, uint16 flags) + internal + view + returns (bool) + { + flags &= 0xFFFC; + return getFlags(set, associated, value) & flags == flags; + } + + /// @notice Check if the flags on a value are disabled. + /// @dev The reserved lower 2 bits will be ignored, as those are reserved for the sentinel and has next bit. + /// @param set The set containing the value. + /// @param associated The address the set is associated with. + /// @param value The value to check the flags on. + /// @param flags The flags to check. + /// @return True if all of the flags are disabled, false otherwise. + function flagsDisabled(AssociatedLinkedListSet storage set, address associated, SetValue value, uint16 flags) + internal + view + returns (bool) + { + flags &= 0xFFFC; + return ~(getFlags(set, associated, value)) & flags == flags; + } + + /// @notice Gets all elements in a set. + /// @dev This is an O(n) operation, where n is the number of elements in the set. + /// @param set The set to get the elements of. + /// @return res An array of all elements in the set. + function getAll(AssociatedLinkedListSet storage set, address associated) + internal + view + returns (SetValue[] memory res) + { + TempBytesMemory keyBuffer = _allocateTempKeyBuffer(set, associated); + + StoragePointer sentinelSlot = _mapLookup(keyBuffer, SENTINEL_VALUE); + bytes32 cursor = _load(sentinelSlot); + + uint256 count; + while (!isSentinel(cursor) && cursor != bytes32(0)) { + unchecked { + ++count; + } + bytes32 cleared = clearFlags(cursor); + + if (hasNext(cursor)) { + StoragePointer cursorSlot = _mapLookup(keyBuffer, cleared); + cursor = _load(cursorSlot); + } else { + cursor = bytes32(0); + } + } + + res = new SetValue[](count); + + if (count == 0) { + return res; + } + + // Re-allocate the key buffer because we just overwrote it! + keyBuffer = _allocateTempKeyBuffer(set, associated); + + cursor = SENTINEL_VALUE; + for (uint256 i = 0; i < count;) { + StoragePointer cursorSlot = _mapLookup(keyBuffer, cursor); + bytes32 cursorValue = _load(cursorSlot); + bytes32 cleared = clearFlags(cursorValue); + res[i] = SetValue.wrap(bytes30(cleared)); + cursor = cleared; + + unchecked { + ++i; + } + } + } + + function isSentinel(bytes32 value) internal pure returns (bool ret) { + assembly ("memory-safe") { + ret := and(value, 1) + } + } + + function hasNext(bytes32 value) internal pure returns (bool) { + return value & HAS_NEXT_FLAG != 0; + } + + function clearFlags(bytes32 val) internal pure returns (bytes32) { + return val & 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF0001; + } + + /// @dev Preserves the lower two bits + function clearUserFlags(bytes32 val) internal pure returns (bytes32) { + return val & 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF0003; + } + + function getUserFlags(bytes32 val) internal pure returns (bytes32) { + return val & bytes32(uint256(0xFFFC)); + } + + // PRIVATE METHODS + + /// @notice Given an allocated key buffer, returns the storage slot for a given key + function _mapLookup(TempBytesMemory keyBuffer, bytes32 value) private pure returns (StoragePointer slot) { + assembly ("memory-safe") { + // Store the value in the last word. + let keyWord2 := value + mstore(add(keyBuffer, 0x60), keyWord2) + slot := keccak256(keyBuffer, 0x80) + } + } + + /// @notice Allocates a key buffer for a given ID and associated address into scratch space memory. + /// @dev The returned buffer must not be used if any additional memory is allocated after calling this + /// function. + /// @param set The set to allocate the key buffer for. + /// @param associated The address the set is associated with. + /// @return key A key buffer that can be used to lookup values in the set + function _allocateTempKeyBuffer(AssociatedLinkedListSet storage set, address associated) + private + pure + returns (TempBytesMemory key) + { + // Key derivation for an entry + // associated addr (left-padded) || prefix || uint224(0) batchIndex || set storage slot || entry + // Word 1: + // | zeros | 0x000000000000000000000000________________________________________ | + // | address | 0x________________________AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA | + // Word 2: + // | prefix | 0xPPPPPPPP________________________________________________________ | + // | batch index (zero) | 0x________00000000000000000000000000000000000000000000000000000000 | + // Word 3: + // | set storage slot | 0xSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSS | + // Word 4: + // | entry value | 0xVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVV____ | + // | entry meta | 0x____________________________________________________________MMMM | + + // The batch index is for consistency with PluginStorageLib, and the prefix in front of it is + // to prevent any potential crafted collisions where the batch index may be equal to storage slot + // of the ALLS. The prefix is set to the upper bits of the batch index to make it infeasible to + // reach from just incrementing the value. + + // This segment is memory-safe because it only uses the scratch space memory after the value of the free + // memory pointer. + // See https://docs.soliditylang.org/en/v0.8.21/assembly.html#memory-safety + assembly ("memory-safe") { + // Clean upper bits of arguments + associated := and(associated, 0xffffffffffffffffffffffffffffffffffffffff) + + // Use memory past-the-free-memory-pointer without updating it, as this is just scratch space + key := mload(0x40) + // Store the associated address in the first word, left-padded with zeroes + mstore(key, associated) + // Store the prefix and a batch index of 0 + mstore(add(key, 0x20), _ASSOCIATED_STORAGE_PREFIX) + // Store the list's storage slot in the third word + mstore(add(key, 0x40), set.slot) + // Leaves the last word open for the value entry + } + + return key; + } + + /// @dev Loads a value from storage + function _load(StoragePointer ptr) private view returns (bytes32 val) { + assembly ("memory-safe") { + val := sload(ptr) + } + } + + /// @dev Writes a value into storage + function _store(StoragePointer ptr, bytes32 val) private { + assembly ("memory-safe") { + sstore(ptr, val) + } + } +} diff --git a/src/libraries/PluginStorageLib.sol b/src/libraries/PluginStorageLib.sol new file mode 100644 index 00000000..8ad17ed4 --- /dev/null +++ b/src/libraries/PluginStorageLib.sol @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: GPL-3.0 +pragma solidity ^0.8.19; + +type StoragePointer is bytes32; + +/// @title Plugin Storage Library +/// @notice Library for allocating and accessing ERC-4337 address-associated storage within plugins. +library PluginStorageLib { + /// @notice Allocates a memory buffer for an associated storage key, and sets the associated address and batch + /// index. + /// @param addr The address to associate with the storage key. + /// @param batchIndex The batch index to associate with the storage key. + /// @param keySize The size of the key in words, where each word is 32 bytes. Not inclusive of the address and + /// batch index. + /// @return key The allocated memory buffer. + function allocateAssociatedStorageKey(address addr, uint256 batchIndex, uint8 keySize) + internal + pure + returns (bytes memory key) + { + assembly ("memory-safe") { + // Clear any dirty upper bits of keySize to prevent overflow + keySize := and(keySize, 0xff) + + // compute the total size of the buffer, include the address and batch index + let totalSize := add(64, mul(32, keySize)) + + // Allocate memory for the key + key := mload(0x40) + mstore(0x40, add(add(key, totalSize), 32)) + mstore(key, totalSize) + + // Clear any dirty upper bits of address + addr := and(addr, 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF) + // Store the address and batch index in the key buffer + mstore(add(key, 32), addr) + mstore(add(key, 64), batchIndex) + } + } + + function associatedStorageLookup(bytes memory key, bytes32 input) internal pure returns (StoragePointer ptr) { + assembly ("memory-safe") { + mstore(add(key, 96), input) + ptr := keccak256(add(key, 32), mload(key)) + } + } + + function associatedStorageLookup(bytes memory key, bytes32 input1, bytes32 input2) + internal + pure + returns (StoragePointer ptr) + { + assembly ("memory-safe") { + mstore(add(key, 96), input1) + mstore(add(key, 128), input2) + ptr := keccak256(add(key, 32), mload(key)) + } + } +} diff --git a/test/libraries/AssociatedLinkedListSetLib.t.sol b/test/libraries/AssociatedLinkedListSetLib.t.sol new file mode 100644 index 00000000..21dd59ef --- /dev/null +++ b/test/libraries/AssociatedLinkedListSetLib.t.sol @@ -0,0 +1,220 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {Test} from "forge-std/Test.sol"; +import { + AssociatedLinkedListSet, + AssociatedLinkedListSetLib, + SENTINEL_VALUE, + SetValue +} from "../../src/libraries/AssociatedLinkedListSetLib.sol"; + +contract AssociatedLinkedListSetLibTest is Test { + using AssociatedLinkedListSetLib for AssociatedLinkedListSet; + + AssociatedLinkedListSet internal _set1; + AssociatedLinkedListSet internal _set2; + + address internal _associated = address(this); + + // User-defined function for wrapping from bytes30 (uint240) to SetValue + // Can define a custom one for addresses, uints, etc. + function _getListValue(uint240 value) internal pure returns (SetValue) { + return SetValue.wrap(bytes30(value)); + } + + function test_add_contains() public { + SetValue value = _getListValue(12); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + } + + function test_empty() public { + SetValue value = _getListValue(12); + + assertFalse(_set1.contains(_associated, value)); + assertTrue(_set1.isEmpty(_associated)); + } + + function test_remove() public { + SetValue value = _getListValue(12); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + + assertTrue(_set1.tryRemove(_associated, value)); + assertFalse(_set1.contains(_associated, value)); + } + + function test_remove_empty() public { + SetValue value = _getListValue(12); + + assertFalse(_set1.tryRemove(_associated, value)); + } + + function test_remove_nonexistent() public { + SetValue value = _getListValue(12); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + + SetValue value2 = _getListValue(13); + assertFalse(_set1.tryRemove(_associated, value2)); + assertTrue(_set1.contains(_associated, value)); + } + + function test_remove_nonexistent_empty() public { + SetValue value = _getListValue(12); + + assertFalse(_set1.tryRemove(_associated, value)); + } + + function test_remove_nonexistent_empty2() public { + SetValue value = _getListValue(12); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + + SetValue value2 = _getListValue(13); + assertFalse(_set1.tryRemove(_associated, value2)); + assertTrue(_set1.contains(_associated, value)); + } + + function test_add_remove_add() public { + SetValue value = _getListValue(12); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + + assertTrue(_set1.tryRemove(_associated, value)); + assertFalse(_set1.contains(_associated, value)); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + } + + function test_add_remove_add_empty() public { + SetValue value = _getListValue(12); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + + assertTrue(_set1.tryRemove(_associated, value)); + assertFalse(_set1.contains(_associated, value)); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + } + + function test_no_address_collision() public { + SetValue value = _getListValue(12); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + assertFalse(_set2.contains(_associated, value)); + } + + function test_clear() public { + SetValue value = _getListValue(12); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + + _set1.clear(_associated); + + assertFalse(_set1.contains(_associated, value)); + assertTrue(_set1.isEmpty(_associated)); + } + + function test_getAll() public { + SetValue value = _getListValue(12); + SetValue value2 = _getListValue(13); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.tryAdd(_associated, value2)); + + SetValue[] memory values = _set1.getAll(_associated); + assertEq(values.length, 2); + // Returned set will be in reverse order of added elements + assertEq(SetValue.unwrap(values[1]), SetValue.unwrap(value)); + assertEq(SetValue.unwrap(values[0]), SetValue.unwrap(value2)); + } + + function test_getAll2() public { + SetValue value = _getListValue(12); + SetValue value2 = _getListValue(13); + SetValue value3 = _getListValue(14); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.tryAdd(_associated, value2)); + assertTrue(_set1.tryAdd(_associated, value3)); + + SetValue[] memory values = _set1.getAll(_associated); + assertEq(values.length, 3); + // Returned set will be in reverse order of added elements + assertEq(SetValue.unwrap(values[2]), SetValue.unwrap(value)); + assertEq(SetValue.unwrap(values[1]), SetValue.unwrap(value2)); + assertEq(SetValue.unwrap(values[0]), SetValue.unwrap(value3)); + } + + function test_getAll_empty() public { + SetValue[] memory values = _set1.getAll(_associated); + assertEq(values.length, 0); + } + + function test_tryRemoveKnown1() public { + SetValue value = _getListValue(12); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + + assertTrue(_set1.tryRemoveKnown(_associated, value, SENTINEL_VALUE)); + assertFalse(_set1.contains(_associated, value)); + assertTrue(_set1.isEmpty(_associated)); + } + + function test_tryRemoveKnown2() public { + SetValue value1 = _getListValue(12); + SetValue value2 = _getListValue(13); + + assertTrue(_set1.tryAdd(_associated, value1)); + assertTrue(_set1.tryAdd(_associated, value2)); + assertTrue(_set1.contains(_associated, value1)); + assertTrue(_set1.contains(_associated, value2)); + + // Assert that getAll returns the correct values + SetValue[] memory values = _set1.getAll(_associated); + assertEq(values.length, 2); + assertEq(SetValue.unwrap(values[1]), SetValue.unwrap(value1)); + assertEq(SetValue.unwrap(values[0]), SetValue.unwrap(value2)); + + assertTrue(_set1.tryRemoveKnown(_associated, value1, bytes32(SetValue.unwrap(value2)))); + assertFalse(_set1.contains(_associated, value1)); + assertTrue(_set1.contains(_associated, value2)); + + // Assert that getAll returns the correct values + values = _set1.getAll(_associated); + assertEq(values.length, 1); + assertEq(SetValue.unwrap(values[0]), SetValue.unwrap(value2)); + + assertTrue(_set1.tryRemoveKnown(_associated, value2, SENTINEL_VALUE)); + assertFalse(_set1.contains(_associated, value1)); + + assertTrue(_set1.isEmpty(_associated)); + } + + function test_tryRemoveKnown_invalid1() public { + SetValue value1 = _getListValue(12); + SetValue value2 = _getListValue(13); + + assertTrue(_set1.tryAdd(_associated, value1)); + assertTrue(_set1.tryAdd(_associated, value2)); + + assertFalse(_set1.tryRemoveKnown(_associated, value1, bytes32(SetValue.unwrap(value1)))); + assertTrue(_set1.contains(_associated, value1)); + + assertFalse(_set1.tryRemoveKnown(_associated, value2, bytes32(SetValue.unwrap(value2)))); + assertTrue(_set1.contains(_associated, value2)); + } +} diff --git a/test/libraries/PluginStorageLib.t.sol b/test/libraries/PluginStorageLib.t.sol new file mode 100644 index 00000000..2dc58e1f --- /dev/null +++ b/test/libraries/PluginStorageLib.t.sol @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {Test} from "forge-std/Test.sol"; +import {PluginStorageLib, StoragePointer} from "../../src/libraries/PluginStorageLib.sol"; + +contract PluginStorageLibTest is Test { + using PluginStorageLib for bytes; + using PluginStorageLib for bytes32; + + uint256 public constant FUZZ_ARR_SIZE = 32; + + address public account1; + + struct TestStruct { + uint256 a; + uint256 b; + } + + function setUp() public { + account1 = makeAddr("account1"); + } + + function test_storagePointer() public { + bytes memory key = PluginStorageLib.allocateAssociatedStorageKey(account1, 0, 1); + + StoragePointer ptr = PluginStorageLib.associatedStorageLookup( + key, hex"00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff" + ); + TestStruct storage val = _castPtrToStruct(ptr); + + vm.record(); + val.a = 0xdeadbeef; + val.b = 123; + (, bytes32[] memory accountWrites) = vm.accesses(address(this)); + + assertEq(accountWrites.length, 2); + bytes32 expectedKey = keccak256( + abi.encodePacked( + uint256(uint160(account1)), + uint256(0), + hex"00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff" + ) + ); + assertEq(accountWrites[0], expectedKey); + assertEq(vm.load(address(this), expectedKey), bytes32(uint256(0xdeadbeef))); + assertEq(accountWrites[1], bytes32(uint256(expectedKey) + 1)); + assertEq(vm.load(address(this), bytes32(uint256(expectedKey) + 1)), bytes32(uint256(123))); + } + + function testFuzz_storagePointer( + address account, + uint256 batchIndex, + bytes32 inputKey, + uint256[FUZZ_ARR_SIZE] calldata values + ) public { + bytes memory key = PluginStorageLib.allocateAssociatedStorageKey(account, batchIndex, 1); + uint256[FUZZ_ARR_SIZE] storage val = + _castPtrToArray(PluginStorageLib.associatedStorageLookup(key, inputKey)); + // Write values to storage + vm.record(); + for (uint256 i = 0; i < FUZZ_ARR_SIZE; i++) { + val[i] = values[i]; + } + // Assert the writes took place in the right location, and the correct value is stored there + (, bytes32[] memory accountWrites) = vm.accesses(address(this)); + assertEq(accountWrites.length, FUZZ_ARR_SIZE); + for (uint256 i = 0; i < FUZZ_ARR_SIZE; i++) { + bytes32 expectedKey = bytes32( + uint256(keccak256(abi.encodePacked(uint256(uint160(account)), uint256(batchIndex), inputKey))) + i + ); + assertEq(accountWrites[i], expectedKey); + assertEq(vm.load(address(this), expectedKey), bytes32(uint256(values[i]))); + } + } + + function _castPtrToArray(StoragePointer ptr) internal pure returns (uint256[FUZZ_ARR_SIZE] storage val) { + assembly ("memory-safe") { + val.slot := ptr + } + } + + function _castPtrToStruct(StoragePointer ptr) internal pure returns (TestStruct storage val) { + assembly ("memory-safe") { + val.slot := ptr + } + } +}