diff --git a/minichain/mempool.py b/minichain/mempool.py index cc3b4f7..4b71e08 100644 --- a/minichain/mempool.py +++ b/minichain/mempool.py @@ -3,82 +3,76 @@ logger = logging.getLogger(__name__) - class Mempool: - TRANSACTIONS_PER_BLOCK = 100 - - def __init__(self, max_size=1000, transactions_per_block=TRANSACTIONS_PER_BLOCK): - self._pending_txs = [] - self._seen_tx_ids = set() + def __init__(self, max_size=1000, transactions_per_block=100): + self._pool = {} + self._size = 0 self._lock = threading.Lock() self.max_size = max_size self.transactions_per_block = transactions_per_block - def _get_tx_id(self, tx): - return tx.tx_id - def add_transaction(self, tx): - """ - Adds a transaction to the pool if: - - Signature is valid - - Transaction is not a duplicate - - Mempool is not full - """ - tx_id = self._get_tx_id(tx) - if not tx.verify(): logger.warning("Mempool: Invalid signature rejected") return False with self._lock: - if tx_id in self._seen_tx_ids: - logger.warning("Mempool: Duplicate transaction rejected %s", tx_id) - return False - - replacement_index = None - for index, pending_tx in enumerate(self._pending_txs): - if pending_tx.sender == tx.sender and pending_tx.nonce == tx.nonce: - replacement_index = index - break + existing = self._pool.get(tx.sender, {}).get(tx.nonce) - if replacement_index is None and len(self._pending_txs) >= self.max_size: - logger.warning("Mempool: Full, rejecting transaction") - return False - - if replacement_index is not None: - old_tx = self._pending_txs[replacement_index] - self._seen_tx_ids.discard(self._get_tx_id(old_tx)) - self._pending_txs[replacement_index] = tx + if existing: + if existing.tx_id == tx.tx_id: + logger.warning("Mempool: Duplicate transaction rejected %s", tx.tx_id) + return False + # Fix: Guard against older replacements (e.g. rejected block restore) + # Only allow overwrite if it's a genuinely newer replacement + if tx.timestamp <= existing.timestamp: + logger.warning("Mempool: Ignoring older replacement %s", tx.tx_id) + return False + else: - self._pending_txs.append(tx) - - self._seen_tx_ids.add(tx_id) + if self._size >= self.max_size: + logger.warning("Mempool: Full, rejecting transaction") + return False + self._size += 1 + self._pool.setdefault(tx.sender, {})[tx.nonce] = tx return True def get_transactions_for_block(self): - """ - Returns transactions in deterministic sorted queue order. - This is read-only; transactions are removed only after block acceptance. - """ with self._lock: - selected = list(self._pending_txs) - selected.sort(key=lambda tx: (tx.timestamp, tx.sender, tx.nonce)) - return selected[: self.transactions_per_block] + snapshot = {s: list(pool.values()) for s, pool in self._pool.items()} + + for txs in snapshot.values(): + txs.sort(key=lambda t: t.nonce) + + selected = [] + while len(selected) < self.transactions_per_block: + best_tx = None + best_sender = None + + for sender, txs in snapshot.items(): + if txs: + if best_tx is None or (txs[0].timestamp, sender, txs[0].nonce) < (best_tx.timestamp, best_sender, best_tx.nonce): + best_tx = txs[0] + best_sender = sender + + if not best_tx: + break + + selected.append(best_tx) + snapshot[best_sender].pop(0) + + return selected def remove_transactions(self, transactions): with self._lock: - remove_ids = {self._get_tx_id(tx) for tx in transactions} - remove_sender_nonces = {(tx.sender, tx.nonce) for tx in transactions} - if not remove_ids: - return - self._pending_txs = [ - tx - for tx in self._pending_txs - if self._get_tx_id(tx) not in remove_ids - and (tx.sender, tx.nonce) not in remove_sender_nonces - ] - self._seen_tx_ids = {self._get_tx_id(tx) for tx in self._pending_txs} + for tx in transactions: + pool = self._pool.get(tx.sender) + if pool and tx.nonce in pool: + del pool[tx.nonce] + self._size -= 1 + if not pool: + del self._pool[tx.sender] def __len__(self): with self._lock: - return len(self._pending_txs) + return self._size diff --git a/tests/test_protocol_hardening.py b/tests/test_protocol_hardening.py index 60aee4a..6b169e7 100644 --- a/tests/test_protocol_hardening.py +++ b/tests/test_protocol_hardening.py @@ -45,7 +45,7 @@ def _signed_tx(self, nonce, amount=1, timestamp=None) -> Transaction: def test_transactions_for_block_are_sorted_and_capped(self): mempool = Mempool() for nonce in range(mempool.transactions_per_block + 5): - self.assertTrue(mempool.add_transaction(self._signed_tx(nonce, timestamp=5000 - nonce))) + self.assertTrue(mempool.add_transaction(self._signed_tx(nonce, timestamp=5000 + nonce))) selected = mempool.get_transactions_for_block()