diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index 85df4844a2f8..ba8b8750ac82 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -53,7 +53,7 @@ ShortChannelID, map_htlcs_to_ctx_output_idxs, LNPeerAddr, fee_for_htlc_output, offered_htlc_trim_threshold_sat, received_htlc_trim_threshold_sat, make_commitment_output_to_remote_address, - ChannelType) + ChannelType, LNProtocolWarning) from .lnsweep import create_sweeptxs_for_our_ctx, create_sweeptxs_for_their_ctx from .lnsweep import create_sweeptx_for_their_revoked_htlc, SweepInfo from .lnhtlc import HTLCManager @@ -981,7 +981,9 @@ def receive_new_commitment(self, sig: bytes, htlc_sigs: Sequence[bytes]) -> None preimage_hex = pending_local_commitment.serialize_preimage(0) pre_hash = sha256d(bfh(preimage_hex)) if not ecc.verify_signature(self.config[REMOTE].multisig_key.pubkey, sig, pre_hash): - raise Exception(f'failed verifying signature of our updated commitment transaction: {bh2u(sig)} preimage is {preimage_hex}') + raise LNProtocolWarning( + f'failed verifying signature of our updated commitment transaction: ' + f'{bh2u(sig)} preimage is {preimage_hex}, rawtx: {pending_local_commitment.serialize()}') htlc_sigs_string = b''.join(htlc_sigs) @@ -993,7 +995,7 @@ def receive_new_commitment(self, sig: bytes, htlc_sigs: Sequence[bytes]) -> None subject=LOCAL, ctn=next_local_ctn) if len(htlc_to_ctx_output_idx_map) != len(htlc_sigs): - raise Exception(f'htlc sigs failure. recv {len(htlc_sigs)} sigs, expected {len(htlc_to_ctx_output_idx_map)}') + raise LNProtocolWarning(f'htlc sigs failure. recv {len(htlc_sigs)} sigs, expected {len(htlc_to_ctx_output_idx_map)}') for (direction, htlc), (ctx_output_idx, htlc_relative_idx) in htlc_to_ctx_output_idx_map.items(): htlc_sig = htlc_sigs[htlc_relative_idx] self._verify_htlc_sig(htlc=htlc, @@ -1021,7 +1023,7 @@ def _verify_htlc_sig(self, *, htlc: UpdateAddHtlc, htlc_sig: bytes, htlc_directi pre_hash = sha256d(bfh(htlc_tx.serialize_preimage(0))) remote_htlc_pubkey = derive_pubkey(self.config[REMOTE].htlc_basepoint.pubkey, pcp) if not ecc.verify_signature(remote_htlc_pubkey, htlc_sig, pre_hash): - raise Exception(f'failed verifying HTLC signatures: {htlc} {htlc_direction}') + raise LNProtocolWarning(f'failed verifying HTLC signatures: {htlc} {htlc_direction}, rawtx: {htlc_tx.serialize()}') def get_remote_htlc_sig_for_htlc(self, *, htlc_relative_idx: int) -> bytes: data = self.config[LOCAL].current_htlc_signatures diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index a35b72b35d3d..73ded2d1446e 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -42,7 +42,7 @@ LightningPeerConnectionClosed, HandshakeFailed, RemoteMisbehaving, ShortChannelID, IncompatibleLightningFeatures, derive_payment_secret_from_payment_preimage, - UpfrontShutdownScriptViolation, ChannelType) + ChannelType, LNProtocolWarning) from .lnutil import FeeUpdate, channel_id_from_funding_tx from .lntransport import LNTransport, LNTransportBase from .lnmsg import encode_msg, decode_msg, UnknownOptionalMsgType @@ -97,7 +97,7 @@ def __init__( self.reply_channel_range = asyncio.Queue() # gossip uses a single queue to preserve message order self.gossip_queue = asyncio.Queue() - self.ordered_message_queues = defaultdict(asyncio.Queue) # for messsage that are ordered + self.ordered_message_queues = defaultdict(asyncio.Queue) # for messages that are ordered self.temp_id_to_id = {} # to forward error messages self.funding_created_sent = set() # for channels in PREOPENING self.funding_signed_sent = set() # for channels in PREOPENING @@ -204,7 +204,7 @@ def process_message(self, message): chan_id = payload.get('channel_id') or payload["temporary_channel_id"] self.ordered_message_queues[chan_id].put_nowait((message_type, payload)) else: - if message_type != 'error' and 'channel_id' in payload: + if message_type not in ('error', 'warning') and 'channel_id' in payload: chan = self.get_channel_by_id(payload['channel_id']) if chan is None: raise Exception('Got unknown '+ message_type) @@ -223,12 +223,96 @@ def process_message(self, message): if asyncio.iscoroutinefunction(f): asyncio.ensure_future(self.taskgroup.spawn(execution_result)) + def on_warning(self, payload): + # TODO: we could need some reconnection logic here -> delayed reconnect + self.logger.info(f"remote peer sent warning [DO NOT TRUST THIS MESSAGE]: {payload['data'].decode('ascii')}") + channel_id = payload.get("channel_id") + if channel_id == bytes(32): + for cid in self.channels.keys(): + self.ordered_message_queues[cid].put_nowait((None, {'warning': payload['data']})) + raise GracefulDisconnect + warned_channel_id = None + if channel_id in self.temp_id_to_id: + warned_channel_id = self.temp_id_to_id[channel_id] + elif channel_id in self.channels: + warned_channel_id = channel_id + if warned_channel_id: + # MAY disconnect. + self.ordered_message_queues[warned_channel_id].put_nowait((None, {'warning': payload['data']})) + raise GracefulDisconnect + def on_error(self, payload): self.logger.info(f"remote peer sent error [DO NOT TRUST THIS MESSAGE]: {payload['data'].decode('ascii')}") - chan_id = payload.get("channel_id") - if chan_id in self.temp_id_to_id: - chan_id = self.temp_id_to_id[chan_id] - self.ordered_message_queues[chan_id].put_nowait((None, {'error':payload['data']})) + channel_id = payload.get("channel_id") + # if channel_id is all zero: MUST fail all channels with the sending node. + if channel_id == bytes(32): + for cid in self.channels.keys(): + self.schedule_force_closing(cid) + self.ordered_message_queues[cid].put_nowait((None, {'error': payload['data']})) + raise GracefulDisconnect + # otherwise: MUST fail the channel referred to by channel_id, if that channel is with the sending node. + erring_channel_id = None + if channel_id in self.temp_id_to_id: + erring_channel_id = self.temp_id_to_id[channel_id] + elif channel_id in self.channels: + erring_channel_id = channel_id + if erring_channel_id: + self.schedule_force_closing(erring_channel_id) + self.ordered_message_queues[erring_channel_id].put_nowait((None, {'error': payload['data']})) + # disconnect now as there might be no one waiting on the queue... + # OTOH this means if there are waiters, they might not see the error + raise GracefulDisconnect + + async def send_warning(self, channel_id: bytes, message: str = None, *, close_connection=True): + """Sends a warning and disconnects if close_connection. + + Note: + * channel_id is the temporary channel id when the channel id is not yet available + + A sending node: + MAY set channel_id to all zero if the warning is not related to a specific channel. + + when failure was caused by an invalid signature check: + * SHOULD include the raw, hex-encoded transaction in reply to a funding_created, + funding_signed, closing_signed, or commitment_signed message. + """ + assert isinstance(channel_id, bytes) + encoded_data = b'' if not message else message.encode('ascii') + self.send_message('warning', channel_id=channel_id, data=encoded_data, len=len(encoded_data)) + if close_connection: + raise GracefulDisconnect + + async def send_error(self, channel_id: bytes, message: str = None, *, force_close_channel=False): + """Sends an error message and force closes the channel. + + Note: + * channel_id is the temporary channel id when the channel id is not yet available + + A sending node: + * SHOULD send error for protocol violations or internal errors that make channels + unusable or that make further communication unusable. + * SHOULD send error with the unknown channel_id in reply to messages of type + 32-255 related to unknown channels. + * MUST fail the channel(s) referred to by the error message. + * MAY set channel_id to all zero to indicate all channels. + + when failure was caused by an invalid signature check: + * SHOULD include the raw, hex-encoded transaction in reply to a funding_created, + funding_signed, closing_signed, or commitment_signed message. + """ + assert isinstance(channel_id, bytes) + encoded_data = b'' if not message else message.encode('ascii') + self.send_message('error', channel_id=channel_id, data=encoded_data, len=len(encoded_data)) + # MUST fail the channel(s) referred to by the error message: + # we may violate this with force_close_channel + if force_close_channel: + # channel_id of zero means that the error refers to all channels + if channel_id == bytes(32): + for channel_id in self.channels: + self.schedule_force_closing(channel_id) + else: + self.schedule_force_closing(channel_id) + raise GracefulDisconnect def on_ping(self, payload): l = payload['num_pong_bytes'] @@ -241,7 +325,9 @@ async def wait_for_message(self, expected_name, channel_id): q = self.ordered_message_queues[channel_id] name, payload = await asyncio.wait_for(q.get(), LN_P2P_NETWORK_TIMEOUT) if payload.get('error'): - raise Exception('Remote peer reported error [DO NOT TRUST THIS MESSAGE]: ' + repr(payload.get('error'))) + raise GracefulDisconnect(f'Waiting for {expected_name} failed due to an error sent by the peer.') + elif payload.get('warning'): + raise GracefulDisconnect(f'Waiting for {expected_name} failed due to a warning sent by the peer.') if name != expected_name: raise Exception(f"Received unexpected '{name}'") return payload @@ -774,7 +860,10 @@ async def channel_establishment_flow( payload = await self.wait_for_message('funding_signed', channel_id) self.logger.info('received funding_signed') remote_sig = payload['signature'] - chan.receive_new_commitment(remote_sig, []) + try: + chan.receive_new_commitment(remote_sig, []) + except LNProtocolWarning as e: + await self.send_warning(channel_id, message=str(e), close_connection=True) chan.open_with_first_pcp(remote_per_commitment_point, remote_sig) chan.set_state(ChannelState.OPENING) self.lnworker.add_new_channel(chan) @@ -933,7 +1022,10 @@ async def on_open_channel(self, payload): if isinstance(self.transport, LNTransport): chan.add_or_update_peer_addr(self.transport.peer_addr) remote_sig = funding_created['signature'] - chan.receive_new_commitment(remote_sig, []) + try: + chan.receive_new_commitment(remote_sig, []) + except LNProtocolWarning as e: + await self.send_warning(channel_id, message=str(e), close_connection=True) sig_64, _ = chan.sign_next_commitment() self.send_message('funding_signed', channel_id=channel_id, @@ -955,6 +1047,13 @@ async def trigger_force_close(self, channel_id: bytes): your_last_per_commitment_secret=0, my_current_per_commitment_point=latest_point) + def schedule_force_closing(self, channel_id: bytes): + channels_with_peer = list(self.channels.keys()) + channels_with_peer.extend(self.temp_id_to_id.values()) + if channel_id not in channels_with_peer: + raise ValueError(f"channel {channel_id.hex()} does not belong to this peer") + self.lnworker.schedule_force_closing(channel_id) + def on_channel_reestablish(self, chan, msg): their_next_local_ctn = msg["next_commitment_number"] their_oldest_unrevoked_remote_ctn = msg["next_revocation_number"] @@ -1774,12 +1873,18 @@ async def close_channel(self, chan_id: bytes): return txid async def on_shutdown(self, chan: Channel, payload): + # TODO: A receiving node: if it hasn't received a funding_signed (if it is a + # funder) or a funding_created (if it is a fundee): + # SHOULD send an error and fail the channel. their_scriptpubkey = payload['scriptpubkey'] their_upfront_scriptpubkey = chan.config[REMOTE].upfront_shutdown_script # BOLT-02 check if they use the upfront shutdown script they advertized - if their_upfront_scriptpubkey: + if self.is_upfront_shutdown_script() and their_upfront_scriptpubkey: if not (their_scriptpubkey == their_upfront_scriptpubkey): - raise UpfrontShutdownScriptViolation("remote didn't use upfront shutdown script it commited to in channel opening") + await self.send_warning( + chan.channel_id, + "remote didn't use upfront shutdown script it commited to in channel opening", + close_connection=True) else: # BOLT-02 restrict the scriptpubkey to some templates: if self.is_shutdown_anysegwit() and match_script_against_template(their_scriptpubkey, transaction.SCRIPTPUBKEY_TEMPLATE_ANYSEGWIT): @@ -1787,7 +1892,10 @@ async def on_shutdown(self, chan: Channel, payload): elif match_script_against_template(their_scriptpubkey, transaction.SCRIPTPUBKEY_TEMPLATE_WITNESS_V0): pass else: - raise Exception(f'scriptpubkey in received shutdown message does not conform to any template: {their_scriptpubkey.hex()}') + await self.send_warning( + chan.channel_id, + f'scriptpubkey in received shutdown message does not conform to any template: {their_scriptpubkey.hex()}', + close_connection=True) chan_id = chan.channel_id if chan_id in self.shutdown_received: diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 8a4878746dc9..79a1f42f3885 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -351,7 +351,6 @@ class UnableToDeriveSecret(LightningError): pass class HandshakeFailed(LightningError): pass class ConnStringFormatError(LightningError): pass class RemoteMisbehaving(LightningError): pass -class UpfrontShutdownScriptViolation(RemoteMisbehaving): pass class NotFoundChanAnnouncementForUpdate(Exception): pass class InvalidGossipMsg(Exception): @@ -362,6 +361,16 @@ class NoPathFound(PaymentFailure): def __str__(self): return _('No path found') + +class LNProtocolError(Exception): + """Raised in peer methods to trigger an error message.""" + + +class LNProtocolWarning(Exception): + """Raised in peer methods to trigger a warning message.""" + + + # TODO make some of these values configurable? REDEEM_AFTER_DOUBLE_SPENT_DELAY = 30 diff --git a/electrum/lnwire/peer_wire.csv b/electrum/lnwire/peer_wire.csv index 17b8a103d557..b795f42a1369 100644 --- a/electrum/lnwire/peer_wire.csv +++ b/electrum/lnwire/peer_wire.csv @@ -10,6 +10,10 @@ msgtype,error,17 msgdata,error,channel_id,channel_id, msgdata,error,len,u16, msgdata,error,data,byte,len +msgtype,warning,1 +msgdata,warning,channel_id,channel_id, +msgdata,warning,len,u16, +msgdata,warning,data,byte,len msgtype,ping,18 msgdata,ping,num_pong_bytes,u16, msgdata,ping,byteslen,u16, diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 2d5bbb33a94b..89ffc5da9c8a 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -22,9 +22,8 @@ from electrum.lnaddr import lnencode, LnAddr, lndecode from electrum.bitcoin import COIN, sha256 from electrum.util import bh2u, create_and_start_event_loop, NetworkRetryManager, bfh, OldTaskGroup -from electrum.lnpeer import Peer, UpfrontShutdownScriptViolation +from electrum.lnpeer import Peer from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey -from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving from electrum.lnutil import PaymentFailure, LnFeatures, HTLCOwner from electrum.lnchannel import ChannelState, PeerState, Channel from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent @@ -38,6 +37,7 @@ from electrum.lnutil import derive_payment_secret_from_payment_preimage from electrum.lnutil import LOCAL, REMOTE from electrum.invoices import PR_PAID, PR_UNPAID +from electrum.interface import GracefulDisconnect from .test_lnchannel import create_test_channels from .test_bitcoin import needs_test_with_all_chacha20_implementations @@ -1096,6 +1096,38 @@ async def f(): with self.assertRaises(concurrent.futures.CancelledError): run(f()) + @needs_test_with_all_chacha20_implementations + def test_warning(self): + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + + async def action(): + await asyncio.wait_for(p1.initialized, 1) + await asyncio.wait_for(p2.initialized, 1) + await p1.send_warning(alice_channel.channel_id, 'be warned!', close_connection=True) + gath = asyncio.gather(action(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) + async def f(): + await gath + with self.assertRaises(GracefulDisconnect): + run(f()) + + @needs_test_with_all_chacha20_implementations + def test_error(self): + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + + async def action(): + await asyncio.wait_for(p1.initialized, 1) + await asyncio.wait_for(p2.initialized, 1) + await p1.send_error(alice_channel.channel_id, 'some error happened!', force_close_channel=True) + assert alice_channel.is_closed() + gath.cancel() + gath = asyncio.gather(action(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) + async def f(): + await gath + with self.assertRaises(GracefulDisconnect): + run(f()) + @needs_test_with_all_chacha20_implementations def test_close_upfront_shutdown_script(self): alice_channel, bob_channel = create_test_channels() @@ -1135,7 +1167,7 @@ async def main_loop(peer): gath = asyncio.gather(*coros) await gath - with self.assertRaises(UpfrontShutdownScriptViolation): + with self.assertRaises(GracefulDisconnect): run(test()) # bob sends the same upfront_shutdown_script has he announced