Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
import argparse
import asyncio
import logging
import re
import sys

from nacl.signing import SigningKey
from nacl.encoding import HexEncoder

from minichain import Transaction, Blockchain, Block, State, Mempool, P2PNetwork, mine_block
from minichain.validators import is_valid_receiver


logger = logging.getLogger(__name__)
Expand All @@ -49,21 +49,41 @@ def create_wallet():

def mine_and_process_block(chain, mempool, miner_pk):
"""Mine pending transactions into a new block."""
pending_txs = mempool.get_transactions_for_block(chain.state)
pending_txs = mempool.get_transactions_for_block()
if not pending_txs:
logger.info("Mempool is empty — nothing to mine.")
return None

# Filter queue candidates against a temporary state snapshot.
temp_state = chain.state.copy()
mineable_txs = []
stale_txs = []
for tx in pending_txs:
expected_nonce = temp_state.get_account(tx.sender).get("nonce", 0)
if tx.nonce < expected_nonce:
stale_txs.append(tx)
continue
if temp_state.validate_and_apply(tx):
mineable_txs.append(tx)

if stale_txs:
mempool.remove_transactions(stale_txs)

if not mineable_txs:
logger.info("No mineable transactions in current queue window.")
return None

block = Block(
index=chain.last_block.index + 1,
previous_hash=chain.last_block.hash,
transactions=pending_txs,
transactions=mineable_txs,
)

mined_block = mine_block(block)

if chain.add_block(mined_block):
logger.info("✅ Block #%d mined and added (%d txs)", mined_block.index, len(pending_txs))
logger.info("✅ Block #%d mined and added (%d txs)", mined_block.index, len(mineable_txs))
mempool.remove_transactions(mineable_txs)
chain.state.credit_mining_reward(miner_pk)
return mined_block
else:
Expand Down Expand Up @@ -97,8 +117,8 @@ async def handler(data):
logger.info("📥 Received tx from %s... (amount=%s)", tx.sender[:8], tx.amount)

elif msg_type == "block":
txs_raw = payload.pop("transactions", [])
block_hash = payload.pop("hash", None)
txs_raw = payload.get("transactions", [])
block_hash = payload.get("hash")
transactions = [Transaction(**t) for t in txs_raw]

block = Block(
Expand Down Expand Up @@ -147,7 +167,7 @@ async def handler(data):
"""


async def cli_loop(sk, pk, chain, mempool, network, nonce_counter):
async def cli_loop(sk, pk, chain, mempool, network):
"""Read commands from stdin asynchronously."""
loop = asyncio.get_event_loop()
print(HELP_TEXT)
Expand Down Expand Up @@ -179,18 +199,23 @@ async def cli_loop(sk, pk, chain, mempool, network, nonce_counter):
print(" Usage: send <receiver_address> <amount>")
continue
receiver = parts[1]
if not is_valid_receiver(receiver):
print(" Invalid receiver format. Expected 40 or 64 hex characters.")
continue
try:
amount = int(parts[2])
except ValueError:
print(" Amount must be an integer.")
continue
if amount <= 0:
print(" Amount must be greater than 0.")
continue

nonce = nonce_counter[0]
nonce = chain.state.get_account(pk).get("nonce", 0)
tx = Transaction(sender=pk, receiver=receiver, amount=amount, nonce=nonce)
tx.sign(sk)

if mempool.add_transaction(tx):
nonce_counter[0] += 1
await network.broadcast_transaction(tx)
print(f" ✅ Tx sent: {amount} coins → {receiver[:12]}...")
else:
Expand All @@ -201,9 +226,6 @@ async def cli_loop(sk, pk, chain, mempool, network, nonce_counter):
mined = mine_and_process_block(chain, mempool, pk)
if mined:
await network.broadcast_block(mined, miner=pk)
# Sync local nonce from chain state
acc = chain.state.get_account(pk)
nonce_counter[0] = acc.get("nonce", 0)

# ── peers ──
elif cmd == "peers":
Expand Down Expand Up @@ -288,11 +310,8 @@ async def on_peer_connected(writer):
except ValueError:
logger.error("Invalid --connect format. Use host:port")

# Nonce counter kept as a mutable list so the CLI closure can mutate it
nonce_counter = [0]

try:
await cli_loop(sk, pk, chain, mempool, network, nonce_counter)
await cli_loop(sk, pk, chain, mempool, network)
finally:
await network.stop()

Expand Down
90 changes: 36 additions & 54 deletions minichain/mempool.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,28 @@
from collections import defaultdict
import logging
import threading

logger = logging.getLogger(__name__)


class Mempool:
def __init__(self, max_size=1000):
self._pending_by_sender = defaultdict(dict)
TRANSACTIONS_PER_BLOCK = 100

def __init__(self, max_size=1000, transactions_per_block=TRANSACTIONS_PER_BLOCK):
self._pending_txs = []
self._seen_tx_ids = set()
self._lock = threading.Lock()
self.max_size = max_size
self.transactions_per_block = transactions_per_block

def _get_tx_id(self, tx):
return tx.tx_id

def _count_transactions_unlocked(self):
return sum(len(sender_queue) for sender_queue in self._pending_by_sender.values())

def _expected_nonce_for_sender(self, sender, state):
if state is not None:
return state.get_account(sender)["nonce"]

sender_queue = self._pending_by_sender.get(sender, {})
if not sender_queue:
return 0
return min(sender_queue)

def add_transaction(self, tx):
"""
Adds a transaction to the pool if:
- Signature is valid
- Transaction is not a duplicate
- Sender nonce is not already present in the pool
- Mempool is not full
"""
tx_id = self._get_tx_id(tx)

Expand All @@ -45,58 +35,50 @@ def add_transaction(self, tx):
logger.warning("Mempool: Duplicate transaction rejected %s", tx_id)
return False

sender_queue = self._pending_by_sender[tx.sender]
if tx.nonce in sender_queue:
logger.warning(
"Mempool: Duplicate sender nonce rejected sender=%s nonce=%s",
tx.sender[:8],
tx.nonce,
)
return False
replacement_index = None
for index, pending_tx in enumerate(self._pending_txs):
if pending_tx.sender == tx.sender and pending_tx.nonce == tx.nonce:
replacement_index = index
break

if self._count_transactions_unlocked() >= self.max_size:
if replacement_index is None and len(self._pending_txs) >= self.max_size:
logger.warning("Mempool: Full, rejecting transaction")
return False

sender_queue[tx.nonce] = tx
if replacement_index is not None:
old_tx = self._pending_txs[replacement_index]
self._seen_tx_ids.discard(self._get_tx_id(old_tx))
self._pending_txs[replacement_index] = tx
else:
self._pending_txs.append(tx)

self._seen_tx_ids.add(tx_id)
return True

def get_transactions_for_block(self, state=None):
def get_transactions_for_block(self):
"""
Returns ready transactions only.

Transactions for the same sender are included in nonce order starting
from the sender's current account nonce. Later nonces stay queued until
earlier ones are confirmed.
Returns transactions in deterministic sorted queue order.
This is read-only; transactions are removed only after block acceptance.
"""
with self._lock:
selected = []

for sender, sender_queue in self._pending_by_sender.items():
expected_nonce = self._expected_nonce_for_sender(sender, state)
while expected_nonce in sender_queue:
selected.append(sender_queue[expected_nonce])
expected_nonce += 1

selected = list(self._pending_txs)
selected.sort(key=lambda tx: (tx.timestamp, tx.sender, tx.nonce))
self._remove_transactions_unlocked(selected)
return selected
return selected[: self.transactions_per_block]

def remove_transactions(self, transactions):
with self._lock:
self._remove_transactions_unlocked(transactions)

def _remove_transactions_unlocked(self, transactions):
for tx in transactions:
tx_id = self._get_tx_id(tx)
sender_queue = self._pending_by_sender.get(tx.sender)
if sender_queue and tx.nonce in sender_queue:
del sender_queue[tx.nonce]
if not sender_queue:
del self._pending_by_sender[tx.sender]
self._seen_tx_ids.discard(tx_id)
remove_ids = {self._get_tx_id(tx) for tx in transactions}
remove_sender_nonces = {(tx.sender, tx.nonce) for tx in transactions}
if not remove_ids:
return
self._pending_txs = [
tx
for tx in self._pending_txs
if self._get_tx_id(tx) not in remove_ids
and (tx.sender, tx.nonce) not in remove_sender_nonces
]
self._seen_tx_ids = {self._get_tx_id(tx) for tx in self._pending_txs}

def __len__(self):
with self._lock:
return self._count_transactions_unlocked()
return len(self._pending_txs)
8 changes: 8 additions & 0 deletions minichain/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging

from .serialization import canonical_json_hash
from .validators import is_valid_receiver

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -133,6 +134,13 @@ def _validate_transaction_payload(self, payload):
if not isinstance(payload.get(field), expected_type):
return False

if payload["amount"] <= 0:
return False

receiver = payload.get("receiver")
if receiver is not None and not is_valid_receiver(receiver):
return False

return True

def _validate_sync_payload(self, payload):
Expand Down
5 changes: 5 additions & 0 deletions minichain/validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import re


def is_valid_receiver(receiver):
return bool(re.fullmatch(r"[0-9a-fA-F]{40}|[0-9a-fA-F]{64}", receiver))
60 changes: 33 additions & 27 deletions tests/test_protocol_hardening.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ def test_block_hash_matches_compute_hash(self):
self.assertEqual(block.compute_hash(), calculate_hash(block.to_header_dict()))


class TestMempoolNonceQueues(unittest.TestCase):
class TestMempoolQueue(unittest.TestCase):
def setUp(self):
self.state = State()
self.sender_sk = SigningKey.generate()
self.sender_pk = self.sender_sk.verify_key.encode(encoder=HexEncoder).decode()
self.receiver_pk = SigningKey.generate().verify_key.encode(encoder=HexEncoder).decode()
self.state.credit_mining_reward(self.sender_pk, 100)

def _signed_tx(self, nonce, amount=1, timestamp=None):
def _signed_tx(self, nonce, amount=1, timestamp=None) -> Transaction:
tx = Transaction(
sender=self.sender_pk,
receiver=self.receiver_pk,
Expand All @@ -42,38 +42,31 @@ def _signed_tx(self, nonce, amount=1, timestamp=None):
tx.sign(self.sender_sk)
return tx

def test_ready_transactions_preserve_sender_nonce_order(self):
def test_transactions_for_block_are_sorted_and_capped(self):
mempool = Mempool()
late_tx = self._signed_tx(1, timestamp=2000)
early_tx = self._signed_tx(0, timestamp=1000)
for nonce in range(mempool.transactions_per_block + 5):
self.assertTrue(mempool.add_transaction(self._signed_tx(nonce, timestamp=5000 - nonce)))

self.assertTrue(mempool.add_transaction(late_tx))
self.assertTrue(mempool.add_transaction(early_tx))
selected = mempool.get_transactions_for_block()

selected = mempool.get_transactions_for_block(self.state)

self.assertEqual([tx.nonce for tx in selected], [0, 1])
self.assertEqual(len(mempool), 0)
self.assertEqual(len(selected), mempool.transactions_per_block)
self.assertEqual(len(mempool), mempool.transactions_per_block + 5)
self.assertEqual(
[tx.timestamp for tx in selected],
sorted(tx.timestamp for tx in selected),
)

def test_gap_transactions_stay_waiting(self):
def test_same_nonce_replaces_pending_transaction(self):
mempool = Mempool()
ready_tx = self._signed_tx(0, timestamp=1000)
waiting_tx = self._signed_tx(2, timestamp=3000)

self.assertTrue(mempool.add_transaction(ready_tx))
self.assertTrue(mempool.add_transaction(waiting_tx))

selected = mempool.get_transactions_for_block(self.state)

self.assertEqual([tx.nonce for tx in selected], [0])
self.assertEqual(len(mempool), 1)
original_tx = self._signed_tx(0, amount=1, timestamp=1000)
replacement_tx = self._signed_tx(0, amount=2, timestamp=2000)

self.state.apply_transaction(ready_tx)
middle_tx = self._signed_tx(1, timestamp=2000)
self.assertTrue(mempool.add_transaction(middle_tx))
self.assertTrue(mempool.add_transaction(original_tx))
self.assertTrue(mempool.add_transaction(replacement_tx))

selected = mempool.get_transactions_for_block(self.state)
self.assertEqual([tx.nonce for tx in selected], [1, 2])
selected = mempool.get_transactions_for_block()
self.assertEqual(len(selected), 1)
self.assertEqual(selected[0].amount, 2)

def test_remove_transactions_keeps_other_pending(self):
mempool = Mempool()
Expand All @@ -83,8 +76,21 @@ def test_remove_transactions_keeps_other_pending(self):
self.assertTrue(mempool.add_transaction(tx0))
self.assertTrue(mempool.add_transaction(tx1))
mempool.remove_transactions([tx0])
selected = mempool.get_transactions_for_block()

self.assertEqual(len(mempool), 1)
self.assertEqual(len(selected), 1)
self.assertEqual(selected[0].tx_id, tx1.tx_id)

def test_remove_transactions_by_sender_nonce_when_tx_id_differs(self):
mempool = Mempool()
local_tx = self._signed_tx(0, amount=1, timestamp=1000)
remote_confirmed_tx = self._signed_tx(0, amount=2, timestamp=2000)

self.assertTrue(mempool.add_transaction(local_tx))
mempool.remove_transactions([remote_confirmed_tx])

self.assertEqual(len(mempool), 0)


class TestP2PValidationAndDedup(unittest.IsolatedAsyncioTestCase):
Expand Down