Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 50 additions & 56 deletions minichain/mempool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,82 +3,76 @@

logger = logging.getLogger(__name__)


class Mempool:
TRANSACTIONS_PER_BLOCK = 100

def __init__(self, max_size=1000, transactions_per_block=TRANSACTIONS_PER_BLOCK):
self._pending_txs = []
self._seen_tx_ids = set()
def __init__(self, max_size=1000, transactions_per_block=100):
self._pool = {}
self._size = 0
self._lock = threading.Lock()
self.max_size = max_size
self.transactions_per_block = transactions_per_block

def _get_tx_id(self, tx):
return tx.tx_id

def add_transaction(self, tx):
"""
Adds a transaction to the pool if:
- Signature is valid
- Transaction is not a duplicate
- Mempool is not full
"""
tx_id = self._get_tx_id(tx)

if not tx.verify():
logger.warning("Mempool: Invalid signature rejected")
return False

with self._lock:
if tx_id in self._seen_tx_ids:
logger.warning("Mempool: Duplicate transaction rejected %s", tx_id)
return False

replacement_index = None
for index, pending_tx in enumerate(self._pending_txs):
if pending_tx.sender == tx.sender and pending_tx.nonce == tx.nonce:
replacement_index = index
break
existing = self._pool.get(tx.sender, {}).get(tx.nonce)

if replacement_index is None and len(self._pending_txs) >= self.max_size:
logger.warning("Mempool: Full, rejecting transaction")
return False

if replacement_index is not None:
old_tx = self._pending_txs[replacement_index]
self._seen_tx_ids.discard(self._get_tx_id(old_tx))
self._pending_txs[replacement_index] = tx
if existing:
if existing.tx_id == tx.tx_id:
logger.warning("Mempool: Duplicate transaction rejected %s", tx.tx_id)
return False
# Fix: Guard against older replacements (e.g. rejected block restore)
# Only allow overwrite if it's a genuinely newer replacement
if tx.timestamp <= existing.timestamp:
logger.warning("Mempool: Ignoring older replacement %s", tx.tx_id)
return False

else:
self._pending_txs.append(tx)

self._seen_tx_ids.add(tx_id)
if self._size >= self.max_size:
logger.warning("Mempool: Full, rejecting transaction")
return False
self._size += 1
self._pool.setdefault(tx.sender, {})[tx.nonce] = tx
return True

def get_transactions_for_block(self):
"""
Returns transactions in deterministic sorted queue order.
This is read-only; transactions are removed only after block acceptance.
"""
with self._lock:
selected = list(self._pending_txs)
selected.sort(key=lambda tx: (tx.timestamp, tx.sender, tx.nonce))
return selected[: self.transactions_per_block]
snapshot = {s: list(pool.values()) for s, pool in self._pool.items()}

for txs in snapshot.values():
txs.sort(key=lambda t: t.nonce)

selected = []
while len(selected) < self.transactions_per_block:
best_tx = None
best_sender = None

for sender, txs in snapshot.items():
if txs:
if best_tx is None or (txs[0].timestamp, sender, txs[0].nonce) < (best_tx.timestamp, best_sender, best_tx.nonce):
best_tx = txs[0]
best_sender = sender

if not best_tx:
break

selected.append(best_tx)
snapshot[best_sender].pop(0)

return selected

def remove_transactions(self, transactions):
with self._lock:
remove_ids = {self._get_tx_id(tx) for tx in transactions}
remove_sender_nonces = {(tx.sender, tx.nonce) for tx in transactions}
if not remove_ids:
return
self._pending_txs = [
tx
for tx in self._pending_txs
if self._get_tx_id(tx) not in remove_ids
and (tx.sender, tx.nonce) not in remove_sender_nonces
]
self._seen_tx_ids = {self._get_tx_id(tx) for tx in self._pending_txs}
for tx in transactions:
pool = self._pool.get(tx.sender)
if pool and tx.nonce in pool:
del pool[tx.nonce]
self._size -= 1
if not pool:
del self._pool[tx.sender]

def __len__(self):
with self._lock:
return len(self._pending_txs)
return self._size
2 changes: 1 addition & 1 deletion tests/test_protocol_hardening.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _signed_tx(self, nonce, amount=1, timestamp=None) -> Transaction:
def test_transactions_for_block_are_sorted_and_capped(self):
mempool = Mempool()
for nonce in range(mempool.transactions_per_block + 5):
self.assertTrue(mempool.add_transaction(self._signed_tx(nonce, timestamp=5000 - nonce)))
self.assertTrue(mempool.add_transaction(self._signed_tx(nonce, timestamp=5000 + nonce)))

selected = mempool.get_transactions_for_block()

Expand Down
Loading