From cf4c09a7d533d2c7959e790ce5882815b13929bd Mon Sep 17 00:00:00 2001 From: siddhant Date: Sat, 14 Mar 2026 19:26:11 +0530 Subject: [PATCH 1/3] refactor: simplify mempool to sorted queue and fix tx removal semantics --- main.py | 3 +- minichain/mempool.py | 90 +++++++++++++------------------- tests/test_protocol_hardening.py | 60 +++++++++++---------- 3 files changed, 71 insertions(+), 82 deletions(-) diff --git a/main.py b/main.py index 67681be..183824a 100644 --- a/main.py +++ b/main.py @@ -49,7 +49,7 @@ def create_wallet(): def mine_and_process_block(chain, mempool, miner_pk): """Mine pending transactions into a new block.""" - pending_txs = mempool.get_transactions_for_block(chain.state) + pending_txs = mempool.get_transactions_for_block() if not pending_txs: logger.info("Mempool is empty — nothing to mine.") return None @@ -64,6 +64,7 @@ def mine_and_process_block(chain, mempool, miner_pk): if chain.add_block(mined_block): logger.info("✅ Block #%d mined and added (%d txs)", mined_block.index, len(pending_txs)) + mempool.remove_transactions(pending_txs) chain.state.credit_mining_reward(miner_pk) return mined_block else: diff --git a/minichain/mempool.py b/minichain/mempool.py index 1e4f082..cc3b4f7 100644 --- a/minichain/mempool.py +++ b/minichain/mempool.py @@ -1,4 +1,3 @@ -from collections import defaultdict import logging import threading @@ -6,33 +5,24 @@ class Mempool: - def __init__(self, max_size=1000): - self._pending_by_sender = defaultdict(dict) + TRANSACTIONS_PER_BLOCK = 100 + + def __init__(self, max_size=1000, transactions_per_block=TRANSACTIONS_PER_BLOCK): + self._pending_txs = [] self._seen_tx_ids = set() self._lock = threading.Lock() self.max_size = max_size + self.transactions_per_block = transactions_per_block def _get_tx_id(self, tx): return tx.tx_id - def _count_transactions_unlocked(self): - return sum(len(sender_queue) for sender_queue in self._pending_by_sender.values()) - - def _expected_nonce_for_sender(self, sender, state): - if state is not None: - return state.get_account(sender)["nonce"] - - sender_queue = self._pending_by_sender.get(sender, {}) - if not sender_queue: - return 0 - return min(sender_queue) - def add_transaction(self, tx): """ Adds a transaction to the pool if: - Signature is valid - Transaction is not a duplicate - - Sender nonce is not already present in the pool + - Mempool is not full """ tx_id = self._get_tx_id(tx) @@ -45,58 +35,50 @@ def add_transaction(self, tx): logger.warning("Mempool: Duplicate transaction rejected %s", tx_id) return False - sender_queue = self._pending_by_sender[tx.sender] - if tx.nonce in sender_queue: - logger.warning( - "Mempool: Duplicate sender nonce rejected sender=%s nonce=%s", - tx.sender[:8], - tx.nonce, - ) - return False + replacement_index = None + for index, pending_tx in enumerate(self._pending_txs): + if pending_tx.sender == tx.sender and pending_tx.nonce == tx.nonce: + replacement_index = index + break - if self._count_transactions_unlocked() >= self.max_size: + if replacement_index is None and len(self._pending_txs) >= self.max_size: logger.warning("Mempool: Full, rejecting transaction") return False - sender_queue[tx.nonce] = tx + if replacement_index is not None: + old_tx = self._pending_txs[replacement_index] + self._seen_tx_ids.discard(self._get_tx_id(old_tx)) + self._pending_txs[replacement_index] = tx + else: + self._pending_txs.append(tx) + self._seen_tx_ids.add(tx_id) return True - def get_transactions_for_block(self, state=None): + def get_transactions_for_block(self): """ - Returns ready transactions only. - - Transactions for the same sender are included in nonce order starting - from the sender's current account nonce. Later nonces stay queued until - earlier ones are confirmed. + Returns transactions in deterministic sorted queue order. + This is read-only; transactions are removed only after block acceptance. """ with self._lock: - selected = [] - - for sender, sender_queue in self._pending_by_sender.items(): - expected_nonce = self._expected_nonce_for_sender(sender, state) - while expected_nonce in sender_queue: - selected.append(sender_queue[expected_nonce]) - expected_nonce += 1 - + selected = list(self._pending_txs) selected.sort(key=lambda tx: (tx.timestamp, tx.sender, tx.nonce)) - self._remove_transactions_unlocked(selected) - return selected + return selected[: self.transactions_per_block] def remove_transactions(self, transactions): with self._lock: - self._remove_transactions_unlocked(transactions) - - def _remove_transactions_unlocked(self, transactions): - for tx in transactions: - tx_id = self._get_tx_id(tx) - sender_queue = self._pending_by_sender.get(tx.sender) - if sender_queue and tx.nonce in sender_queue: - del sender_queue[tx.nonce] - if not sender_queue: - del self._pending_by_sender[tx.sender] - self._seen_tx_ids.discard(tx_id) + remove_ids = {self._get_tx_id(tx) for tx in transactions} + remove_sender_nonces = {(tx.sender, tx.nonce) for tx in transactions} + if not remove_ids: + return + self._pending_txs = [ + tx + for tx in self._pending_txs + if self._get_tx_id(tx) not in remove_ids + and (tx.sender, tx.nonce) not in remove_sender_nonces + ] + self._seen_tx_ids = {self._get_tx_id(tx) for tx in self._pending_txs} def __len__(self): with self._lock: - return self._count_transactions_unlocked() + return len(self._pending_txs) diff --git a/tests/test_protocol_hardening.py b/tests/test_protocol_hardening.py index 2c19710..60aee4a 100644 --- a/tests/test_protocol_hardening.py +++ b/tests/test_protocol_hardening.py @@ -23,7 +23,7 @@ def test_block_hash_matches_compute_hash(self): self.assertEqual(block.compute_hash(), calculate_hash(block.to_header_dict())) -class TestMempoolNonceQueues(unittest.TestCase): +class TestMempoolQueue(unittest.TestCase): def setUp(self): self.state = State() self.sender_sk = SigningKey.generate() @@ -31,7 +31,7 @@ def setUp(self): self.receiver_pk = SigningKey.generate().verify_key.encode(encoder=HexEncoder).decode() self.state.credit_mining_reward(self.sender_pk, 100) - def _signed_tx(self, nonce, amount=1, timestamp=None): + def _signed_tx(self, nonce, amount=1, timestamp=None) -> Transaction: tx = Transaction( sender=self.sender_pk, receiver=self.receiver_pk, @@ -42,38 +42,31 @@ def _signed_tx(self, nonce, amount=1, timestamp=None): tx.sign(self.sender_sk) return tx - def test_ready_transactions_preserve_sender_nonce_order(self): + def test_transactions_for_block_are_sorted_and_capped(self): mempool = Mempool() - late_tx = self._signed_tx(1, timestamp=2000) - early_tx = self._signed_tx(0, timestamp=1000) + for nonce in range(mempool.transactions_per_block + 5): + self.assertTrue(mempool.add_transaction(self._signed_tx(nonce, timestamp=5000 - nonce))) - self.assertTrue(mempool.add_transaction(late_tx)) - self.assertTrue(mempool.add_transaction(early_tx)) + selected = mempool.get_transactions_for_block() - selected = mempool.get_transactions_for_block(self.state) - - self.assertEqual([tx.nonce for tx in selected], [0, 1]) - self.assertEqual(len(mempool), 0) + self.assertEqual(len(selected), mempool.transactions_per_block) + self.assertEqual(len(mempool), mempool.transactions_per_block + 5) + self.assertEqual( + [tx.timestamp for tx in selected], + sorted(tx.timestamp for tx in selected), + ) - def test_gap_transactions_stay_waiting(self): + def test_same_nonce_replaces_pending_transaction(self): mempool = Mempool() - ready_tx = self._signed_tx(0, timestamp=1000) - waiting_tx = self._signed_tx(2, timestamp=3000) - - self.assertTrue(mempool.add_transaction(ready_tx)) - self.assertTrue(mempool.add_transaction(waiting_tx)) - - selected = mempool.get_transactions_for_block(self.state) - - self.assertEqual([tx.nonce for tx in selected], [0]) - self.assertEqual(len(mempool), 1) + original_tx = self._signed_tx(0, amount=1, timestamp=1000) + replacement_tx = self._signed_tx(0, amount=2, timestamp=2000) - self.state.apply_transaction(ready_tx) - middle_tx = self._signed_tx(1, timestamp=2000) - self.assertTrue(mempool.add_transaction(middle_tx)) + self.assertTrue(mempool.add_transaction(original_tx)) + self.assertTrue(mempool.add_transaction(replacement_tx)) - selected = mempool.get_transactions_for_block(self.state) - self.assertEqual([tx.nonce for tx in selected], [1, 2]) + selected = mempool.get_transactions_for_block() + self.assertEqual(len(selected), 1) + self.assertEqual(selected[0].amount, 2) def test_remove_transactions_keeps_other_pending(self): mempool = Mempool() @@ -83,8 +76,21 @@ def test_remove_transactions_keeps_other_pending(self): self.assertTrue(mempool.add_transaction(tx0)) self.assertTrue(mempool.add_transaction(tx1)) mempool.remove_transactions([tx0]) + selected = mempool.get_transactions_for_block() self.assertEqual(len(mempool), 1) + self.assertEqual(len(selected), 1) + self.assertEqual(selected[0].tx_id, tx1.tx_id) + + def test_remove_transactions_by_sender_nonce_when_tx_id_differs(self): + mempool = Mempool() + local_tx = self._signed_tx(0, amount=1, timestamp=1000) + remote_confirmed_tx = self._signed_tx(0, amount=2, timestamp=2000) + + self.assertTrue(mempool.add_transaction(local_tx)) + mempool.remove_transactions([remote_confirmed_tx]) + + self.assertEqual(len(mempool), 0) class TestP2PValidationAndDedup(unittest.IsolatedAsyncioTestCase): From a8e8718059d969e3ef6c7c32e2be29be50bcd618 Mon Sep 17 00:00:00 2001 From: siddhant Date: Sat, 14 Mar 2026 20:49:54 +0530 Subject: [PATCH 2/3] use chain nonce and add simple send/p2p safety checks --- main.py | 27 +++++++++++++++------------ minichain/p2p.py | 12 ++++++++++++ 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index 183824a..246276d 100644 --- a/main.py +++ b/main.py @@ -98,8 +98,8 @@ async def handler(data): logger.info("📥 Received tx from %s... (amount=%s)", tx.sender[:8], tx.amount) elif msg_type == "block": - txs_raw = payload.pop("transactions", []) - block_hash = payload.pop("hash", None) + txs_raw = payload.get("transactions", []) + block_hash = payload.get("hash") transactions = [Transaction(**t) for t in txs_raw] block = Block( @@ -148,7 +148,11 @@ async def handler(data): """ -async def cli_loop(sk, pk, chain, mempool, network, nonce_counter): +def _is_valid_receiver(receiver): + return bool(re.fullmatch(r"[0-9a-fA-F]{40}|[0-9a-fA-F]{64}", receiver)) + + +async def cli_loop(sk, pk, chain, mempool, network): """Read commands from stdin asynchronously.""" loop = asyncio.get_event_loop() print(HELP_TEXT) @@ -180,18 +184,23 @@ async def cli_loop(sk, pk, chain, mempool, network, nonce_counter): print(" Usage: send ") continue receiver = parts[1] + if not _is_valid_receiver(receiver): + print(" Invalid receiver format. Expected 40 or 64 hex characters.") + continue try: amount = int(parts[2]) except ValueError: print(" Amount must be an integer.") continue + if amount <= 0: + print(" Amount must be greater than 0.") + continue - nonce = nonce_counter[0] + nonce = chain.state.get_account(pk).get("nonce", 0) tx = Transaction(sender=pk, receiver=receiver, amount=amount, nonce=nonce) tx.sign(sk) if mempool.add_transaction(tx): - nonce_counter[0] += 1 await network.broadcast_transaction(tx) print(f" ✅ Tx sent: {amount} coins → {receiver[:12]}...") else: @@ -202,9 +211,6 @@ async def cli_loop(sk, pk, chain, mempool, network, nonce_counter): mined = mine_and_process_block(chain, mempool, pk) if mined: await network.broadcast_block(mined, miner=pk) - # Sync local nonce from chain state - acc = chain.state.get_account(pk) - nonce_counter[0] = acc.get("nonce", 0) # ── peers ── elif cmd == "peers": @@ -289,11 +295,8 @@ async def on_peer_connected(writer): except ValueError: logger.error("Invalid --connect format. Use host:port") - # Nonce counter kept as a mutable list so the CLI closure can mutate it - nonce_counter = [0] - try: - await cli_loop(sk, pk, chain, mempool, network, nonce_counter) + await cli_loop(sk, pk, chain, mempool, network) finally: await network.stop() diff --git a/minichain/p2p.py b/minichain/p2p.py index f4aa97e..88a87f0 100644 --- a/minichain/p2p.py +++ b/minichain/p2p.py @@ -8,6 +8,7 @@ import asyncio import json import logging +import re from .serialization import canonical_json_hash @@ -17,6 +18,10 @@ SUPPORTED_MESSAGE_TYPES = {"sync", "tx", "block"} +def _is_valid_receiver(receiver): + return bool(re.fullmatch(r"[0-9a-fA-F]{40}|[0-9a-fA-F]{64}", receiver)) + + class P2PNetwork: """ Lightweight peer-to-peer networking using asyncio TCP streams. @@ -133,6 +138,13 @@ def _validate_transaction_payload(self, payload): if not isinstance(payload.get(field), expected_type): return False + if payload["amount"] <= 0: + return False + + receiver = payload.get("receiver") + if receiver is not None and not _is_valid_receiver(receiver): + return False + return True def _validate_sync_payload(self, payload): From 38bbde745948149c21f958fac53b62b41be510e5 Mon Sep 17 00:00:00 2001 From: siddhant Date: Sat, 14 Mar 2026 21:15:27 +0530 Subject: [PATCH 3/3] prevent stale tx mining loops --- main.py | 33 ++++++++++++++++++++++++--------- minichain/p2p.py | 8 ++------ minichain/validators.py | 5 +++++ 3 files changed, 31 insertions(+), 15 deletions(-) create mode 100644 minichain/validators.py diff --git a/main.py b/main.py index 246276d..383b5fb 100644 --- a/main.py +++ b/main.py @@ -19,13 +19,13 @@ import argparse import asyncio import logging -import re import sys from nacl.signing import SigningKey from nacl.encoding import HexEncoder from minichain import Transaction, Blockchain, Block, State, Mempool, P2PNetwork, mine_block +from minichain.validators import is_valid_receiver logger = logging.getLogger(__name__) @@ -54,17 +54,36 @@ def mine_and_process_block(chain, mempool, miner_pk): logger.info("Mempool is empty — nothing to mine.") return None + # Filter queue candidates against a temporary state snapshot. + temp_state = chain.state.copy() + mineable_txs = [] + stale_txs = [] + for tx in pending_txs: + expected_nonce = temp_state.get_account(tx.sender).get("nonce", 0) + if tx.nonce < expected_nonce: + stale_txs.append(tx) + continue + if temp_state.validate_and_apply(tx): + mineable_txs.append(tx) + + if stale_txs: + mempool.remove_transactions(stale_txs) + + if not mineable_txs: + logger.info("No mineable transactions in current queue window.") + return None + block = Block( index=chain.last_block.index + 1, previous_hash=chain.last_block.hash, - transactions=pending_txs, + transactions=mineable_txs, ) mined_block = mine_block(block) if chain.add_block(mined_block): - logger.info("✅ Block #%d mined and added (%d txs)", mined_block.index, len(pending_txs)) - mempool.remove_transactions(pending_txs) + logger.info("✅ Block #%d mined and added (%d txs)", mined_block.index, len(mineable_txs)) + mempool.remove_transactions(mineable_txs) chain.state.credit_mining_reward(miner_pk) return mined_block else: @@ -148,10 +167,6 @@ async def handler(data): """ -def _is_valid_receiver(receiver): - return bool(re.fullmatch(r"[0-9a-fA-F]{40}|[0-9a-fA-F]{64}", receiver)) - - async def cli_loop(sk, pk, chain, mempool, network): """Read commands from stdin asynchronously.""" loop = asyncio.get_event_loop() @@ -184,7 +199,7 @@ async def cli_loop(sk, pk, chain, mempool, network): print(" Usage: send ") continue receiver = parts[1] - if not _is_valid_receiver(receiver): + if not is_valid_receiver(receiver): print(" Invalid receiver format. Expected 40 or 64 hex characters.") continue try: diff --git a/minichain/p2p.py b/minichain/p2p.py index 88a87f0..c11d897 100644 --- a/minichain/p2p.py +++ b/minichain/p2p.py @@ -8,9 +8,9 @@ import asyncio import json import logging -import re from .serialization import canonical_json_hash +from .validators import is_valid_receiver logger = logging.getLogger(__name__) @@ -18,10 +18,6 @@ SUPPORTED_MESSAGE_TYPES = {"sync", "tx", "block"} -def _is_valid_receiver(receiver): - return bool(re.fullmatch(r"[0-9a-fA-F]{40}|[0-9a-fA-F]{64}", receiver)) - - class P2PNetwork: """ Lightweight peer-to-peer networking using asyncio TCP streams. @@ -142,7 +138,7 @@ def _validate_transaction_payload(self, payload): return False receiver = payload.get("receiver") - if receiver is not None and not _is_valid_receiver(receiver): + if receiver is not None and not is_valid_receiver(receiver): return False return True diff --git a/minichain/validators.py b/minichain/validators.py new file mode 100644 index 0000000..b813df4 --- /dev/null +++ b/minichain/validators.py @@ -0,0 +1,5 @@ +import re + + +def is_valid_receiver(receiver): + return bool(re.fullmatch(r"[0-9a-fA-F]{40}|[0-9a-fA-F]{64}", receiver))