diff --git a/main.py b/main.py index 12577dd..c3776ea 100644 --- a/main.py +++ b/main.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) -BURN_ADDRESS = "0" * 40 +MAX_TRANSACTIONS_PER_BLOCK = 100 # ────────────────────────────────────────────── @@ -49,7 +49,9 @@ def create_wallet(): def mine_and_process_block(chain, mempool, miner_pk): """Mine pending transactions into a new block.""" - pending_txs = mempool.get_transactions_for_block() + pending_txs = mempool.get_transactions_for_block( + chain.state, max_count=MAX_TRANSACTIONS_PER_BLOCK + ) if not pending_txs: logger.info("Mempool is empty — nothing to mine.") return None @@ -58,11 +60,13 @@ def mine_and_process_block(chain, mempool, miner_pk): index=chain.last_block.index + 1, previous_hash=chain.last_block.hash, transactions=pending_txs, + miner=miner_pk, ) mined_block = mine_block(block) if chain.add_block(mined_block): + mempool.remove_transactions(pending_txs) logger.info("✅ Block #%d mined and added (%d txs)", mined_block.index, len(pending_txs)) chain.state.credit_mining_reward(miner_pk) return mined_block @@ -83,6 +87,9 @@ async def handler(data): payload = data.get("data") if msg_type == "sync": + if not isinstance(payload, dict): + logger.warning("Received malformed sync payload") + return False # Merge remote state into local state (for accounts we don't have yet) remote_accounts = payload.get("accounts", {}) for addr, acc in remote_accounts.items(): @@ -90,38 +97,56 @@ async def handler(data): chain.state.accounts[addr] = acc logger.info("🔄 Synced account %s... (balance=%d)", addr[:12], acc.get("balance", 0)) logger.info("🔄 State sync complete — %d accounts", len(chain.state.accounts)) + return True elif msg_type == "tx": + if not isinstance(payload, dict): + logger.warning("Received malformed tx payload") + return False tx = Transaction(**payload) if mempool.add_transaction(tx): logger.info("📥 Received tx from %s... (amount=%s)", tx.sender[:8], tx.amount) + return True + return False elif msg_type == "block": - txs_raw = payload.pop("transactions", []) - block_hash = payload.pop("hash", None) + if not isinstance(payload, dict): + logger.warning("Received malformed block payload") + return False + payload_data = dict(payload) + + txs_raw = payload_data.get("transactions", []) + block_hash = payload_data.get("hash") transactions = [Transaction(**t) for t in txs_raw] block = Block( - index=payload["index"], - previous_hash=payload["previous_hash"], + index=payload_data["index"], + previous_hash=payload_data["previous_hash"], transactions=transactions, - timestamp=payload.get("timestamp"), - difficulty=payload.get("difficulty"), + timestamp=payload_data.get("timestamp"), + difficulty=payload_data.get("difficulty"), + miner=payload_data.get("miner"), ) - block.nonce = payload.get("nonce", 0) + block.nonce = payload_data.get("nonce", 0) block.hash = block_hash if chain.add_block(block): logger.info("📥 Received Block #%d — added to chain", block.index) - # Apply mining reward for the remote miner (burn address as placeholder) - miner = payload.get("miner", BURN_ADDRESS) - chain.state.credit_mining_reward(miner) + # Reward only when miner is authenticated as part of hashed block data. + if block.miner: + chain.state.credit_mining_reward(block.miner) + else: + logger.warning("Received block without authenticated miner; reward not credited") - # Drain matching txs from mempool so they aren't re-mined - mempool.get_transactions_for_block() + # Drop only confirmed transactions so higher nonces can remain queued. + mempool.remove_transactions(block.transactions) + return True else: logger.warning("📥 Received Block #%s — rejected", block.index) + return False + + return False return handler @@ -147,7 +172,7 @@ async def handler(data): """ -async def cli_loop(sk, pk, chain, mempool, network, nonce_counter): +async def cli_loop(sk, pk, chain, mempool, network): """Read commands from stdin asynchronously.""" loop = asyncio.get_event_loop() print(HELP_TEXT) @@ -179,18 +204,23 @@ async def cli_loop(sk, pk, chain, mempool, network, nonce_counter): print(" Usage: send ") continue receiver = parts[1] + if len(receiver) != len(pk) or not re.fullmatch(r"[0-9a-fA-F]+", receiver): + print(" Receiver address must be a valid hex public key.") + continue try: amount = int(parts[2]) except ValueError: print(" Amount must be an integer.") continue + if amount <= 0: + print(" Amount must be greater than 0.") + continue - nonce = nonce_counter[0] + nonce = chain.state.get_account(pk)["nonce"] tx = Transaction(sender=pk, receiver=receiver, amount=amount, nonce=nonce) tx.sign(sk) if mempool.add_transaction(tx): - nonce_counter[0] += 1 await network.broadcast_transaction(tx) print(f" ✅ Tx sent: {amount} coins → {receiver[:12]}...") else: @@ -201,9 +231,6 @@ async def cli_loop(sk, pk, chain, mempool, network, nonce_counter): mined = mine_and_process_block(chain, mempool, pk) if mined: await network.broadcast_block(mined) - # Sync local nonce from chain state - acc = chain.state.get_account(pk) - nonce_counter[0] = acc.get("nonce", 0) # ── peers ── elif cmd == "peers": @@ -288,11 +315,8 @@ async def on_peer_connected(writer): except ValueError: logger.error("Invalid --connect format. Use host:port") - # Nonce counter kept as a mutable list so the CLI closure can mutate it - nonce_counter = [0] - try: - await cli_loop(sk, pk, chain, mempool, network, nonce_counter) + await cli_loop(sk, pk, chain, mempool, network) finally: await network.stop() diff --git a/minichain/block.py b/minichain/block.py index 859a32a..d7b6e9c 100644 --- a/minichain/block.py +++ b/minichain/block.py @@ -3,6 +3,7 @@ import json from typing import List, Optional from .transaction import Transaction +from .serialization import canonical_json_hash def _sha256(data: str) -> str: return hashlib.sha256(data.encode()).hexdigest() @@ -12,11 +13,8 @@ def _calculate_merkle_root(transactions: List[Transaction]) -> Optional[str]: if not transactions: return None - # Hash each transaction deterministically - tx_hashes = [ - _sha256(json.dumps(tx.to_dict(), sort_keys=True)) - for tx in transactions - ] + # Keep legacy leaf format for compatibility with existing blocks. + tx_hashes = [_transaction_leaf(tx) for tx in transactions] # Build Merkle tree while len(tx_hashes) > 1: @@ -33,6 +31,26 @@ def _calculate_merkle_root(transactions: List[Transaction]) -> Optional[str]: return tx_hashes[0] +def _transaction_leaf(tx: Transaction) -> str: + """Return a deterministic transaction leaf hash with compatibility fallback.""" + # Prefer an explicit legacy-compatible leaf method if present. + if hasattr(tx, "get_leaf_digest") and callable(getattr(tx, "get_leaf_digest")): + value = tx.get_leaf_digest() + if isinstance(value, str): + return value + if hasattr(tx, "digest"): + value = getattr(tx, "digest") + if isinstance(value, str): + return value + + # Legacy default used in prior versions. + if hasattr(tx, "to_dict") and callable(getattr(tx, "to_dict")): + return _sha256(json.dumps(tx.to_dict(), sort_keys=True)) + + # Final fallback for newer transaction shapes. + return tx.tx_id + + class Block: def __init__( self, @@ -41,6 +59,7 @@ def __init__( transactions: Optional[List[Transaction]] = None, timestamp: Optional[float] = None, difficulty: Optional[int] = None, + miner: Optional[str] = None, ): self.index = index self.previous_hash = previous_hash @@ -54,6 +73,7 @@ def __init__( ) self.difficulty: Optional[int] = difficulty + self.miner: Optional[str] = miner self.nonce: int = 0 self.hash: Optional[str] = None @@ -64,7 +84,7 @@ def __init__( # HEADER (used for mining) # ------------------------- def to_header_dict(self): - return { + header = { "index": self.index, "previous_hash": self.previous_hash, "merkle_root": self.merkle_root, @@ -72,6 +92,10 @@ def to_header_dict(self): "difficulty": self.difficulty, "nonce": self.nonce, } + # Include miner only when present so old-format headers stay valid. + if self.miner is not None: + header["miner"] = self.miner + return header # ------------------------- # BODY (transactions only) @@ -97,8 +121,4 @@ def to_dict(self): # HASH CALCULATION # ------------------------- def compute_hash(self) -> str: - header_string = json.dumps( - self.to_header_dict(), - sort_keys=True - ) - return _sha256(header_string) + return canonical_json_hash(self.to_header_dict()) diff --git a/minichain/contract.py b/minichain/contract.py index c12fddd..d286b7d 100644 --- a/minichain/contract.py +++ b/minichain/contract.py @@ -1,10 +1,16 @@ import logging import multiprocessing import ast +import os import json # Moved to module-level import logger = logging.getLogger(__name__) + +def _allow_unrestricted_contracts() -> bool: + value = os.getenv("MINICHAIN_ALLOW_UNRESTRICTED_CONTRACTS", "") + return value.lower() in {"1", "true", "yes", "on"} + def _safe_exec_worker(code, globals_dict, context_dict, result_queue): """ Worker function to execute contract code in a separate process. @@ -19,8 +25,16 @@ def _safe_exec_worker(code, globals_dict, context_dict, result_queue): except ImportError: logger.warning("Resource module not available. Contract will run without OS-level resource limits.") except (OSError, ValueError) as e: - logger.error(f"Failed to set resource limits: {e}") - raise RuntimeError(f"Failed to set resource limits: {e}") + if _allow_unrestricted_contracts(): + logger.warning( + "Failed to set resource limits but unsafe mode is enabled: %s", + e, + ) + else: + raise RuntimeError( + "Failed to set contract resource limits; refusing to execute " + "without explicit MINICHAIN_ALLOW_UNRESTRICTED_CONTRACTS=1" + ) from e exec(code, globals_dict, context_dict) # Return the updated storage diff --git a/minichain/mempool.py b/minichain/mempool.py index 06a60d0..7e98e7f 100644 --- a/minichain/mempool.py +++ b/minichain/mempool.py @@ -1,30 +1,34 @@ -from .pow import calculate_hash +from collections import defaultdict import logging import threading logger = logging.getLogger(__name__) + class Mempool: def __init__(self, max_size=1000): - self._pending_txs = [] - self._seen_tx_ids = set() # Dedup tracking + self._pending_by_sender = defaultdict(dict) + self._seen_tx_ids = set() self._lock = threading.Lock() self.max_size = max_size def _get_tx_id(self, tx): - """ - Compute a unique deterministic ID for a transaction. - Uses full serialized tx (payload + signature). - """ - return calculate_hash(tx.to_dict()) + return tx.tx_id + + def _count_transactions_unlocked(self): + return sum(len(sender_queue) for sender_queue in self._pending_by_sender.values()) + + def _expected_nonce_for_sender(self, sender: str, state) -> int: + """Return the sender nonce from required chain state.""" + return state.get_account(sender)["nonce"] def add_transaction(self, tx): """ Adds a transaction to the pool if: - Signature is valid - Transaction is not a duplicate + - Sender nonce is not already present in the pool """ - tx_id = self._get_tx_id(tx) if not tx.verify(): @@ -33,30 +37,64 @@ def add_transaction(self, tx): with self._lock: if tx_id in self._seen_tx_ids: - logger.warning(f"Mempool: Duplicate transaction rejected {tx_id}") + logger.warning("Mempool: Duplicate transaction rejected %s", tx_id) + return False + + sender_queue = self._pending_by_sender[tx.sender] + if tx.nonce in sender_queue: + logger.warning( + "Mempool: Duplicate sender nonce rejected sender=%s nonce=%s", + tx.sender[:8], + tx.nonce, + ) return False - if len(self._pending_txs) >= self.max_size: - # Simple eviction: drop oldest or reject. Here we reject. + if self._count_transactions_unlocked() >= self.max_size: logger.warning("Mempool: Full, rejecting transaction") return False - self._pending_txs.append(tx) + sender_queue[tx.nonce] = tx self._seen_tx_ids.add(tx_id) - return True - def get_transactions_for_block(self): + def get_transactions_for_block(self, state, max_count=None): """ - Returns pending transactions and clears the pool. + Returns ready transactions only. + + Transactions for the same sender are included in nonce order starting + from the sender's current account nonce. Later nonces stay queued until + earlier ones are confirmed. """ + with self._lock: + selected = [] + for sender, sender_queue in self._pending_by_sender.items(): + expected_nonce = self._expected_nonce_for_sender(sender, state) + while expected_nonce in sender_queue: + selected.append(sender_queue[expected_nonce]) + expected_nonce += 1 + + selected.sort(key=lambda tx: (tx.timestamp, tx.sender, tx.nonce)) + if max_count is not None: + if not isinstance(max_count, int) or max_count <= 0: + return [] + selected = selected[:max_count] + return selected + + def remove_transactions(self, transactions): with self._lock: - txs = self._pending_txs[:] + self._remove_transactions_unlocked(transactions) - # Clear both list and dedup set to stay in sync - self._pending_txs = [] - confirmed_ids = {self._get_tx_id(tx) for tx in txs} - self._seen_tx_ids.difference_update(confirmed_ids) + def _remove_transactions_unlocked(self, transactions): + for tx in transactions: + tx_id = self._get_tx_id(tx) + sender_queue = self._pending_by_sender.get(tx.sender) + if sender_queue and tx.nonce in sender_queue: + del sender_queue[tx.nonce] + if not sender_queue: + del self._pending_by_sender[tx.sender] + self._seen_tx_ids.discard(tx_id) - return txs + def __len__(self): + with self._lock: + return self._count_transactions_unlocked() diff --git a/minichain/p2p.py b/minichain/p2p.py index 81ff100..9fc874d 100644 --- a/minichain/p2p.py +++ b/minichain/p2p.py @@ -9,9 +9,12 @@ import json import logging +from .serialization import canonical_json_hash + logger = logging.getLogger(__name__) TOPIC = "minichain-global" +SUPPORTED_MESSAGE_TYPES = {"sync", "tx", "block"} class P2PNetwork: @@ -19,7 +22,7 @@ class P2PNetwork: Lightweight peer-to-peer networking using asyncio TCP streams. JSON wire format (one JSON object per line): - {"type": "tx" | "block", "data": {...}} + {"type": "sync" | "tx" | "block", "data": {...}} """ def __init__(self, handler_callback=None): @@ -30,17 +33,15 @@ def __init__(self, handler_callback=None): self._server: asyncio.Server | None = None self._port: int = 0 self._listen_tasks: list[asyncio.Task] = [] - self._on_peer_connected = None # callback(writer) called when a new peer connects + self._on_peer_connected = None + self._seen_tx_ids = set() + self._seen_block_hashes = set() def register_handler(self, handler_callback): if not callable(handler_callback): raise ValueError("handler_callback must be callable") self._handler_callback = handler_callback - # ------------------------------------------------------------------ - # Server lifecycle - # ------------------------------------------------------------------ - async def start(self, port: int = 9000): """Start listening for incoming peer connections on the given port.""" self._port = port @@ -66,24 +67,31 @@ async def stop(self): await self._server.wait_closed() self._server = None - # ------------------------------------------------------------------ - # Peer connections - # ------------------------------------------------------------------ - async def connect_to_peer(self, host: str, port: int) -> bool: """Actively connect to another MiniChain node.""" try: reader, writer = await asyncio.open_connection(host, port) self._peers.append((reader, writer)) - task = asyncio.create_task(self._listen_to_peer(reader, writer, f"{host}:{port}")) + task = asyncio.create_task( + self._listen_to_peer(reader, writer, f"{host}:{port}") + ) self._listen_tasks.append(task) + if self._on_peer_connected: + try: + await self._on_peer_connected(writer) + except Exception: + logger.exception("Network: Error during outbound peer sync") logger.info("Network: Connected to peer %s:%d", host, port) return True - except Exception as e: - logger.error("Network: Failed to connect to %s:%d — %s", host, port, e) + except Exception as exc: + logger.error("Network: Failed to connect to %s:%d — %s", host, port, exc) return False - async def _handle_incoming(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + async def _handle_incoming( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ): """Accept an incoming peer connection.""" peername = writer.get_extra_info("peername") addr = f"{peername[0]}:{peername[1]}" if peername else "unknown" @@ -91,14 +99,164 @@ async def _handle_incoming(self, reader: asyncio.StreamReader, writer: asyncio.S self._peers.append((reader, writer)) task = asyncio.create_task(self._listen_to_peer(reader, writer, addr)) self._listen_tasks.append(task) - # Send current state to the new peer if self._on_peer_connected: try: await self._on_peer_connected(writer) except Exception: logger.exception("Network: Error during peer sync") - async def _listen_to_peer(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, addr: str): + def _validate_transaction_payload(self, payload): + if not isinstance(payload, dict): + return False + + required_fields = { + "sender": str, + "amount": int, + "nonce": int, + "timestamp": int, + "signature": str, + } + optional_fields = { + "receiver": (str, type(None)), + "data": (str, type(None)), + } + allowed_fields = set(required_fields) | set(optional_fields) + + if set(payload) != allowed_fields: + return False + + for field, expected_type in required_fields.items(): + if not isinstance(payload.get(field), expected_type): + return False + + for field, expected_type in optional_fields.items(): + if not isinstance(payload.get(field), expected_type): + return False + + if payload["amount"] <= 0: + return False + + receiver = payload.get("receiver") + if receiver is not None: + if not isinstance(receiver, str): + return False + if not receiver: + return False + if len(receiver) != 64: + return False + if any(ch not in "0123456789abcdefABCDEF" for ch in receiver): + return False + + return True + + def _validate_sync_payload(self, payload): + if not isinstance(payload, dict) or set(payload) != {"accounts"}: + return False + + accounts = payload["accounts"] + if not isinstance(accounts, dict): + return False + + for address, account in accounts.items(): + if not isinstance(address, str) or not isinstance(account, dict): + return False + required = {"balance", "nonce", "code", "storage"} + if set(account) != required: + return False + if not isinstance(account["balance"], int): + return False + if not isinstance(account["nonce"], int): + return False + if not isinstance(account["code"], (str, type(None))): + return False + if not isinstance(account["storage"], dict): + return False + + return True + + def _validate_block_payload(self, payload): + if not isinstance(payload, dict): + return False + + required_fields = { + "index": int, + "previous_hash": str, + "merkle_root": (str, type(None)), + "transactions": list, + "timestamp": int, + "difficulty": (int, type(None)), + "nonce": int, + "hash": str, + } + optional_fields = {"miner": str} + allowed_fields = set(required_fields) | set(optional_fields) + + if not set(payload).issubset(allowed_fields): + return False + + for field, expected_type in required_fields.items(): + if not isinstance(payload.get(field), expected_type): + return False + + if "miner" in payload and not isinstance(payload["miner"], str): + return False + + return all( + self._validate_transaction_payload(tx_payload) + for tx_payload in payload["transactions"] + ) + + def _validate_message(self, message): + if not isinstance(message, dict): + return False + if set(message) != {"type", "data"}: + return False + + msg_type = message.get("type") + payload = message.get("data") + + if msg_type not in SUPPORTED_MESSAGE_TYPES: + return False + + validators = { + "sync": self._validate_sync_payload, + "tx": self._validate_transaction_payload, + "block": self._validate_block_payload, + } + return validators[msg_type](payload) + + def _message_id(self, msg_type, payload): + if msg_type == "tx": + return canonical_json_hash(payload) + if msg_type == "block": + return payload["hash"] + return None + + def _mark_seen(self, msg_type, payload): + message_id = self._message_id(msg_type, payload) + if message_id is None: + return + if msg_type == "tx": + self._seen_tx_ids.add(message_id) + elif msg_type == "block": + self._seen_block_hashes.add(message_id) + + def _is_duplicate(self, msg_type, payload): + message_id = self._message_id(msg_type, payload) + if message_id is None: + return False + if msg_type == "tx": + return message_id in self._seen_tx_ids + if msg_type == "block": + return message_id in self._seen_block_hashes + return False + + async def _listen_to_peer( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + addr: str, + ): """Read newline-delimited JSON messages from a peer.""" try: while True: @@ -111,11 +269,28 @@ async def _listen_to_peer(self, reader: asyncio.StreamReader, writer: asyncio.St logger.warning("Network: Malformed message from %s", addr) continue + if not self._validate_message(data): + logger.warning("Network: Invalid message schema from %s", addr) + continue + + msg_type = data["type"] + payload = data["data"] + if self._is_duplicate(msg_type, payload): + logger.info("Network: Duplicate %s ignored from %s", msg_type, addr) + continue + + accepted = False if self._handler_callback: try: - await self._handler_callback(data) + accepted = bool(await self._handler_callback(data)) except Exception: - logger.exception("Network: Handler error for message from %s", addr) + logger.exception( + "Network: Handler error for message from %s", addr + ) + accepted = False + + if accepted: + self._mark_seen(msg_type, payload) except asyncio.CancelledError: pass except ConnectionResetError: @@ -130,10 +305,6 @@ async def _listen_to_peer(self, reader: asyncio.StreamReader, writer: asyncio.St if (reader, writer) in self._peers: self._peers.remove((reader, writer)) - # ------------------------------------------------------------------ - # Broadcasting - # ------------------------------------------------------------------ - async def _broadcast_raw(self, payload: dict): """Send a JSON message to every connected peer.""" line = (json.dumps(payload) + "\n").encode() @@ -153,14 +324,17 @@ async def broadcast_transaction(self, tx): logger.info("Network: Broadcasting Tx from %s...", sender[:8]) try: payload = {"type": "tx", "data": tx.to_dict()} - except (TypeError, ValueError) as e: - logger.error("Network: Failed to serialize tx: %s", e) + except (TypeError, ValueError) as exc: + logger.error("Network: Failed to serialize tx: %s", exc) return + self._mark_seen("tx", payload["data"]) await self._broadcast_raw(payload) async def broadcast_block(self, block): logger.info("Network: Broadcasting Block #%d", block.index) - await self._broadcast_raw({"type": "block", "data": block.to_dict()}) + payload = {"type": "block", "data": block.to_dict()} + self._mark_seen("block", payload["data"]) + await self._broadcast_raw(payload) @property def peer_count(self) -> int: diff --git a/minichain/pow.py b/minichain/pow.py index b8484b1..40503a5 100644 --- a/minichain/pow.py +++ b/minichain/pow.py @@ -1,6 +1,5 @@ -import json import time -import hashlib +from .serialization import canonical_json_hash class MiningExceededError(Exception): @@ -9,8 +8,7 @@ class MiningExceededError(Exception): def calculate_hash(block_dict): """Calculates SHA256 hash of a block header.""" - block_string = json.dumps(block_dict, sort_keys=True).encode("utf-8") - return hashlib.sha256(block_string).hexdigest() + return canonical_json_hash(block_dict) def mine_block( diff --git a/minichain/serialization.py b/minichain/serialization.py new file mode 100644 index 0000000..46741f0 --- /dev/null +++ b/minichain/serialization.py @@ -0,0 +1,15 @@ +import hashlib +import json + + +def canonical_json_dumps(payload) -> str: + """Serialize payloads deterministically for signing and hashing.""" + return json.dumps(payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False) + + +def canonical_json_bytes(payload) -> bytes: + return canonical_json_dumps(payload).encode("utf-8") + + +def canonical_json_hash(payload) -> str: + return hashlib.sha256(canonical_json_bytes(payload)).hexdigest() diff --git a/minichain/transaction.py b/minichain/transaction.py index 8625f7a..c17137e 100644 --- a/minichain/transaction.py +++ b/minichain/transaction.py @@ -1,8 +1,8 @@ -import json import time from nacl.signing import SigningKey, VerifyKey from nacl.encoding import HexEncoder from nacl.exceptions import BadSignatureError, CryptoError +from .serialization import canonical_json_bytes, canonical_json_hash class Transaction: @@ -31,18 +31,25 @@ def to_dict(self): "signature": self.signature, } - @property - def hash_payload(self): - """Returns the bytes to be signed.""" - payload = { + def to_signing_dict(self): + return { "sender": self.sender, "receiver": self.receiver, "amount": self.amount, "nonce": self.nonce, "data": self.data, - "timestamp": self.timestamp, # Already integer milliseconds + "timestamp": self.timestamp, } - return json.dumps(payload, sort_keys=True).encode("utf-8") + + @property + def hash_payload(self): + """Returns the bytes to be signed.""" + return canonical_json_bytes(self.to_signing_dict()) + + @property + def tx_id(self): + """Deterministic identifier for the signed transaction.""" + return canonical_json_hash(self.to_dict()) def sign(self, signing_key: SigningKey): # Validate that the signing key matches the sender diff --git a/tests/test_protocol_hardening.py b/tests/test_protocol_hardening.py new file mode 100644 index 0000000..96bd9d2 --- /dev/null +++ b/tests/test_protocol_hardening.py @@ -0,0 +1,181 @@ +import unittest + +from nacl.encoding import HexEncoder +from nacl.signing import SigningKey + +from minichain import Block, Mempool, P2PNetwork, State, Transaction, calculate_hash +from minichain.serialization import canonical_json_dumps + + +class TestDeterministicConsensus(unittest.TestCase): + def test_canonical_json_is_order_independent(self): + left = {"b": 2, "a": 1, "nested": {"z": 3, "x": 4}} + right = {"nested": {"x": 4, "z": 3}, "a": 1, "b": 2} + + self.assertEqual(canonical_json_dumps(left), canonical_json_dumps(right)) + self.assertEqual(calculate_hash(left), calculate_hash(right)) + + def test_block_hash_matches_compute_hash(self): + block = Block(index=1, previous_hash="abc", transactions=[], timestamp=1234567890) + block.difficulty = 2 + block.nonce = 7 + + self.assertEqual(block.compute_hash(), calculate_hash(block.to_header_dict())) + + +class TestMempoolNonceQueues(unittest.TestCase): + def setUp(self): + self.state = State() + self.sender_sk = SigningKey.generate() + self.sender_pk = self.sender_sk.verify_key.encode(encoder=HexEncoder).decode() + self.receiver_pk = SigningKey.generate().verify_key.encode(encoder=HexEncoder).decode() + self.state.credit_mining_reward(self.sender_pk, 100) + + def _signed_tx(self, nonce, amount=1, timestamp=None) -> Transaction: + tx = Transaction( + sender=self.sender_pk, + receiver=self.receiver_pk, + amount=amount, + nonce=nonce, + timestamp=timestamp, + ) + tx.sign(self.sender_sk) + return tx + + def test_ready_transactions_preserve_sender_nonce_order(self): + mempool = Mempool() + late_tx = self._signed_tx(1, timestamp=2000) + early_tx = self._signed_tx(0, timestamp=1000) + + self.assertTrue(mempool.add_transaction(late_tx)) + self.assertTrue(mempool.add_transaction(early_tx)) + + selected = mempool.get_transactions_for_block(self.state) + + self.assertEqual([tx.nonce for tx in selected], [0, 1]) + self.assertEqual(len(mempool), 2) + mempool.remove_transactions(selected) + self.assertEqual(len(mempool), 0) + + def test_gap_transactions_stay_waiting(self): + mempool = Mempool() + ready_tx = self._signed_tx(0, timestamp=1000) + waiting_tx = self._signed_tx(2, timestamp=3000) + + self.assertTrue(mempool.add_transaction(ready_tx)) + self.assertTrue(mempool.add_transaction(waiting_tx)) + + selected = mempool.get_transactions_for_block(self.state) + + self.assertEqual([tx.nonce for tx in selected], [0]) + mempool.remove_transactions(selected) + self.assertEqual(len(mempool), 1) + + self.state.apply_transaction(ready_tx) + middle_tx = self._signed_tx(1, timestamp=2000) + self.assertTrue(mempool.add_transaction(middle_tx)) + + selected = mempool.get_transactions_for_block(self.state) + self.assertEqual([tx.nonce for tx in selected], [1, 2]) + + def test_remove_transactions_keeps_other_pending(self): + mempool = Mempool() + tx0 = self._signed_tx(0, timestamp=1000) + tx1 = self._signed_tx(1, timestamp=2000) + + self.assertTrue(mempool.add_transaction(tx0)) + self.assertTrue(mempool.add_transaction(tx1)) + mempool.remove_transactions([tx0]) + + self.assertEqual(len(mempool), 1) + self.assertFalse(mempool.add_transaction(self._signed_tx(1, timestamp=3000))) + + self.assertEqual(mempool.get_transactions_for_block(self.state), []) + self.state.apply_transaction(tx0) + selected = mempool.get_transactions_for_block(self.state) + self.assertEqual([tx.nonce for tx in selected], [1]) + self.assertEqual(selected[0].tx_id, tx1.tx_id) + + def test_selection_respects_per_block_cap(self): + mempool = Mempool() + txs = [self._signed_tx(nonce=i, timestamp=1000 + i) for i in range(4)] + for tx in txs: + self.assertTrue(mempool.add_transaction(tx)) + + cap = 2 + selected = mempool.get_transactions_for_block(self.state, max_count=cap) + self.assertEqual(len(selected), cap) + self.assertEqual([tx.nonce for tx in selected], [0, 1]) + self.assertEqual(len(mempool), 4) + + mempool.remove_transactions(selected) + self.assertEqual(len(mempool), 2) + self.state.apply_transaction(txs[0]) + self.state.apply_transaction(txs[1]) + + remaining = mempool.get_transactions_for_block(self.state) + self.assertEqual([tx.nonce for tx in remaining], [2, 3]) + + +class TestP2PValidationAndDedup(unittest.IsolatedAsyncioTestCase): + async def test_invalid_message_schema_is_rejected(self): + network = P2PNetwork() + + invalid_message = {"type": "tx", "data": {"sender": "abc"}} + self.assertFalse(network._validate_message(invalid_message)) + + async def test_block_schema_accepts_current_block_wire_format(self): + sender_sk = SigningKey.generate() + sender_pk = sender_sk.verify_key.encode(encoder=HexEncoder).decode() + receiver_pk = SigningKey.generate().verify_key.encode(encoder=HexEncoder).decode() + + tx = Transaction(sender_pk, receiver_pk, 1, 0, timestamp=123) + tx.sign(sender_sk) + + block = Block(index=1, previous_hash="0" * 64, transactions=[tx], timestamp=456, difficulty=2) + block.nonce = 9 + block.hash = block.compute_hash() + + network = P2PNetwork() + message = {"type": "block", "data": block.to_dict()} + + self.assertTrue(network._validate_message(message)) + + async def test_duplicate_tx_and_block_detection(self): + network = P2PNetwork() + + tx_message = { + "type": "tx", + "data": { + "sender": "a" * 64, + "receiver": "b" * 64, + "amount": 1, + "nonce": 0, + "data": None, + "timestamp": 123, + "signature": "c" * 128, + }, + } + block_message = { + "type": "block", + "data": { + "index": 1, + "previous_hash": "0" * 64, + "transactions": [tx_message["data"]], + "timestamp": 123, + "difficulty": 2, + "nonce": 1, + "hash": "f" * 64, + }, + } + + self.assertFalse(network._is_duplicate("tx", tx_message["data"])) + network._mark_seen("tx", tx_message["data"]) + tx_equivalent = dict(tx_message["data"]) + self.assertTrue(network._is_duplicate("tx", tx_equivalent)) + + self.assertFalse(network._is_duplicate("block", block_message["data"])) + network._mark_seen("block", block_message["data"]) + block_equivalent = dict(block_message["data"]) + block_equivalent["transactions"] = [dict(tx_message["data"])] + self.assertTrue(network._is_duplicate("block", block_equivalent))