diff --git a/main.py b/main.py index 17d3fbd..0b21881 100644 --- a/main.py +++ b/main.py @@ -33,6 +33,8 @@ logger = logging.getLogger(__name__) BURN_ADDRESS = "0" * 40 +TRUSTED_PEERS = set() +LOCALHOST_PEERS = {"127.0.0.1", "::1", "localhost", "0:0:0:0:0:0:0:1"} # ────────────────────────────────────────────── @@ -90,6 +92,11 @@ def mine_and_process_block(chain, mempool, miner_pk): return mined_block else: logger.error("❌ Block rejected by chain") + restored = 0 + for tx in pending_txs: + if mempool.add_transaction(tx): + restored += 1 + logger.info("Mempool: Restored %d/%d txs after rejection", restored, len(pending_txs)) return None @@ -103,15 +110,31 @@ def make_network_handler(chain, mempool): async def handler(data): msg_type = data.get("type") payload = data.get("data") + peer_addr = data.get("_peer_addr", "unknown") if msg_type == "sync": + peer_host = peer_addr.rsplit(":", 1)[0] if ":" in peer_addr else peer_addr + peer_host = peer_host.strip("[]") + is_trusted = peer_addr in TRUSTED_PEERS or peer_host in TRUSTED_PEERS + is_localhost = peer_host in LOCALHOST_PEERS + if chain.state.accounts and not (is_trusted or is_localhost): + logger.warning("🔒 Rejected sync from untrusted peer %s", peer_addr) + return + # Merge remote state into local state (for accounts we don't have yet) - remote_accounts = payload.get("accounts", {}) + remote_accounts = payload.get("accounts") if isinstance(payload, dict) else None + if not isinstance(remote_accounts, dict): + logger.warning("🔒 Rejected sync from %s with invalid accounts payload", peer_addr) + return + for addr, acc in remote_accounts.items(): + if not isinstance(acc, dict): + logger.warning("🔒 Skipping malformed account %r from %s", addr, peer_addr) + continue if addr not in chain.state.accounts: chain.state.accounts[addr] = acc logger.info("🔄 Synced account %s... (balance=%d)", addr[:12], acc.get("balance", 0)) - logger.info("🔄 State sync complete — %d accounts", len(chain.state.accounts)) + logger.info("🔄 Accepted state sync from %s — %d accounts", peer_addr, len(chain.state.accounts)) elif msg_type == "tx": tx = Transaction(**payload) @@ -156,15 +179,15 @@ async def handler(data): ╔════════════════════════════════════════════════╗ ║ MiniChain Commands ║ ╠════════════════════════════════════════════════╣ -║ balance — show all balances ║ -║ send — send coins ║ -║ mine — mine a block ║ -║ peers — show connected peers ║ -║ connect — connect to a peer ║ -║ address — show your public key ║ -║ chain — show chain summary ║ -║ help — show this help ║ -║ quit — shut down ║ +║ balance - show all balances ║ +║ send - send coins ║ +║ mine - mine a block ║ +║ peers - show connected peers ║ +║ connect : - connect to a peer ║ +║ address - show your public key ║ +║ chain - show chain summary ║ +║ help - show this help ║ +║ quit - shut down ║ ╚════════════════════════════════════════════════╝ """ @@ -244,7 +267,11 @@ async def cli_loop(sk, pk, chain, mempool, network): except ValueError: print(" Invalid format. Use host:port") continue - await network.connect_to_peer(host, port) + success = await network.connect_to_peer(host, port) + if success: + print(f" Connected to {host}:{port}") + else: + print(f" Failed to connect to {host}:{port}") # ── address ── elif cmd == "address": @@ -311,14 +338,9 @@ async def on_peer_connected(writer): await writer.drain() logger.info("🔄 Sent state sync to new peer") - network._on_peer_connected = on_peer_connected + network.set_on_peer_connected(on_peer_connected) - await network.start(port=port) - - # Fund this node's wallet so it can transact in the demo - if fund > 0: - chain.state.credit_mining_reward(pk, reward=fund) - logger.info("💰 Funded %s... with %d coins", pk[:12], fund) + await network.start(port=port, host=host) # Connect to a seed peer if requested if connect_to: @@ -328,6 +350,14 @@ async def on_peer_connected(writer): except ValueError: logger.error("Invalid --connect format. Use host:port") + # Fund this node's wallet so it can transact in the demo + if fund > 0: + chain.state.credit_mining_reward(pk, reward=fund) + logger.info("💰 Funded %s... with %d coins", pk[:12], fund) + + # Nonce counter kept as a mutable list so the CLI closure can mutate it + nonce_counter = [0] + try: await cli_loop(sk, pk, chain, mempool, network) finally: @@ -344,6 +374,7 @@ async def on_peer_connected(writer): def main(): parser = argparse.ArgumentParser(description="MiniChain Node — Testnet Demo") + parser.add_argument("--host", type=str, default="127.0.0.1", help="Host/IP to bind the P2P server (default: 127.0.0.1)") parser.add_argument("--port", type=int, default=9000, help="TCP port to listen on (default: 9000)") parser.add_argument("--connect", type=str, default=None, help="Peer address to connect to (host:port)") parser.add_argument("--fund", type=int, default=100, help="Initial coins to fund this wallet (default: 100)") diff --git a/minichain/p2p.py b/minichain/p2p.py index c11d897..ee52d7d 100644 --- a/minichain/p2p.py +++ b/minichain/p2p.py @@ -47,15 +47,18 @@ async def start(self, port: int = 9000): """Start listening for incoming peer connections on the given port.""" self._port = port self._server = await asyncio.start_server( - self._handle_incoming, "0.0.0.0", port + self._handle_incoming, host, port ) - logger.info("Network: Listening on 0.0.0.0:%d", port) + logger.info("Network: Listening on %s:%d", host, port) async def stop(self): """Gracefully shut down the server and disconnect all peers.""" logger.info("Network: Shutting down") for task in self._listen_tasks: task.cancel() + if self._listen_tasks: + await asyncio.gather(*self._listen_tasks, return_exceptions=True) + self._listen_tasks.clear() for _, writer in self._peers: try: writer.close() @@ -262,6 +265,8 @@ async def _listen_to_peer( except (json.JSONDecodeError, UnicodeDecodeError): logger.warning("Network: Malformed message from %s", addr) continue + if isinstance(data, dict): + data["_peer_addr"] = addr if not self._validate_message(data): logger.warning("Network: Invalid message schema from %s", addr) @@ -305,7 +310,13 @@ async def _broadcast_raw(self, payload: dict): await writer.drain() except Exception: disconnected.append((reader, writer)) - for pair in disconnected: + for reader, writer in disconnected: + try: + writer.close() + await writer.wait_closed() + except Exception: + pass + pair = (reader, writer) if pair in self._peers: self._peers.remove(pair)