Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions electrum/lnchannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
ShortChannelID, map_htlcs_to_ctx_output_idxs, LNPeerAddr,
fee_for_htlc_output, offered_htlc_trim_threshold_sat,
received_htlc_trim_threshold_sat, make_commitment_output_to_remote_address,
ChannelType)
ChannelType, LNProtocolWarning)
from .lnsweep import create_sweeptxs_for_our_ctx, create_sweeptxs_for_their_ctx
from .lnsweep import create_sweeptx_for_their_revoked_htlc, SweepInfo
from .lnhtlc import HTLCManager
Expand Down Expand Up @@ -981,7 +981,9 @@ def receive_new_commitment(self, sig: bytes, htlc_sigs: Sequence[bytes]) -> None
preimage_hex = pending_local_commitment.serialize_preimage(0)
pre_hash = sha256d(bfh(preimage_hex))
if not ecc.verify_signature(self.config[REMOTE].multisig_key.pubkey, sig, pre_hash):
raise Exception(f'failed verifying signature of our updated commitment transaction: {bh2u(sig)} preimage is {preimage_hex}')
raise LNProtocolWarning(
f'failed verifying signature of our updated commitment transaction: '
f'{bh2u(sig)} preimage is {preimage_hex}, rawtx: {pending_local_commitment.serialize()}')

htlc_sigs_string = b''.join(htlc_sigs)

Expand All @@ -993,7 +995,7 @@ def receive_new_commitment(self, sig: bytes, htlc_sigs: Sequence[bytes]) -> None
subject=LOCAL,
ctn=next_local_ctn)
if len(htlc_to_ctx_output_idx_map) != len(htlc_sigs):
raise Exception(f'htlc sigs failure. recv {len(htlc_sigs)} sigs, expected {len(htlc_to_ctx_output_idx_map)}')
raise LNProtocolWarning(f'htlc sigs failure. recv {len(htlc_sigs)} sigs, expected {len(htlc_to_ctx_output_idx_map)}')
for (direction, htlc), (ctx_output_idx, htlc_relative_idx) in htlc_to_ctx_output_idx_map.items():
htlc_sig = htlc_sigs[htlc_relative_idx]
self._verify_htlc_sig(htlc=htlc,
Expand Down Expand Up @@ -1021,7 +1023,7 @@ def _verify_htlc_sig(self, *, htlc: UpdateAddHtlc, htlc_sig: bytes, htlc_directi
pre_hash = sha256d(bfh(htlc_tx.serialize_preimage(0)))
remote_htlc_pubkey = derive_pubkey(self.config[REMOTE].htlc_basepoint.pubkey, pcp)
if not ecc.verify_signature(remote_htlc_pubkey, htlc_sig, pre_hash):
raise Exception(f'failed verifying HTLC signatures: {htlc} {htlc_direction}')
raise LNProtocolWarning(f'failed verifying HTLC signatures: {htlc} {htlc_direction}, rawtx: {htlc_tx.serialize()}')

def get_remote_htlc_sig_for_htlc(self, *, htlc_relative_idx: int) -> bytes:
data = self.config[LOCAL].current_htlc_signatures
Expand Down
134 changes: 121 additions & 13 deletions electrum/lnpeer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
LightningPeerConnectionClosed, HandshakeFailed,
RemoteMisbehaving, ShortChannelID,
IncompatibleLightningFeatures, derive_payment_secret_from_payment_preimage,
UpfrontShutdownScriptViolation, ChannelType)
ChannelType, LNProtocolWarning)
from .lnutil import FeeUpdate, channel_id_from_funding_tx
from .lntransport import LNTransport, LNTransportBase
from .lnmsg import encode_msg, decode_msg, UnknownOptionalMsgType
Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(
self.reply_channel_range = asyncio.Queue()
# gossip uses a single queue to preserve message order
self.gossip_queue = asyncio.Queue()
self.ordered_message_queues = defaultdict(asyncio.Queue) # for messsage that are ordered
self.ordered_message_queues = defaultdict(asyncio.Queue) # for messages that are ordered
self.temp_id_to_id = {} # to forward error messages
self.funding_created_sent = set() # for channels in PREOPENING
self.funding_signed_sent = set() # for channels in PREOPENING
Expand Down Expand Up @@ -204,7 +204,7 @@ def process_message(self, message):
chan_id = payload.get('channel_id') or payload["temporary_channel_id"]
self.ordered_message_queues[chan_id].put_nowait((message_type, payload))
else:
if message_type != 'error' and 'channel_id' in payload:
if message_type not in ('error', 'warning') and 'channel_id' in payload:
chan = self.get_channel_by_id(payload['channel_id'])
if chan is None:
raise Exception('Got unknown '+ message_type)
Expand All @@ -223,12 +223,96 @@ def process_message(self, message):
if asyncio.iscoroutinefunction(f):
asyncio.ensure_future(self.taskgroup.spawn(execution_result))

def on_warning(self, payload):
# TODO: we could need some reconnection logic here -> delayed reconnect
self.logger.info(f"remote peer sent warning [DO NOT TRUST THIS MESSAGE]: {payload['data'].decode('ascii')}")
channel_id = payload.get("channel_id")
if channel_id == bytes(32):
for cid in self.channels.keys():
self.ordered_message_queues[cid].put_nowait((None, {'warning': payload['data']}))
raise GracefulDisconnect
warned_channel_id = None
if channel_id in self.temp_id_to_id:
warned_channel_id = self.temp_id_to_id[channel_id]
elif channel_id in self.channels:
warned_channel_id = channel_id
if warned_channel_id:
# MAY disconnect.
self.ordered_message_queues[warned_channel_id].put_nowait((None, {'warning': payload['data']}))
raise GracefulDisconnect

def on_error(self, payload):
self.logger.info(f"remote peer sent error [DO NOT TRUST THIS MESSAGE]: {payload['data'].decode('ascii')}")
chan_id = payload.get("channel_id")
if chan_id in self.temp_id_to_id:
chan_id = self.temp_id_to_id[chan_id]
self.ordered_message_queues[chan_id].put_nowait((None, {'error':payload['data']}))
channel_id = payload.get("channel_id")
# if channel_id is all zero: MUST fail all channels with the sending node.
if channel_id == bytes(32):
for cid in self.channels.keys():
self.schedule_force_closing(cid)
self.ordered_message_queues[cid].put_nowait((None, {'error': payload['data']}))
raise GracefulDisconnect
# otherwise: MUST fail the channel referred to by channel_id, if that channel is with the sending node.
erring_channel_id = None
if channel_id in self.temp_id_to_id:
erring_channel_id = self.temp_id_to_id[channel_id]
elif channel_id in self.channels:
erring_channel_id = channel_id
if erring_channel_id:
self.schedule_force_closing(erring_channel_id)
self.ordered_message_queues[erring_channel_id].put_nowait((None, {'error': payload['data']}))
Comment thread
bitromortac marked this conversation as resolved.
# disconnect now as there might be no one waiting on the queue...
# OTOH this means if there are waiters, they might not see the error
raise GracefulDisconnect
Comment thread
bitromortac marked this conversation as resolved.

async def send_warning(self, channel_id: bytes, message: str = None, *, close_connection=True):
"""Sends a warning and disconnects if close_connection.

Note:
* channel_id is the temporary channel id when the channel id is not yet available

A sending node:
MAY set channel_id to all zero if the warning is not related to a specific channel.

when failure was caused by an invalid signature check:
* SHOULD include the raw, hex-encoded transaction in reply to a funding_created,
funding_signed, closing_signed, or commitment_signed message.
"""
assert isinstance(channel_id, bytes)
encoded_data = b'' if not message else message.encode('ascii')
self.send_message('warning', channel_id=channel_id, data=encoded_data, len=len(encoded_data))
if close_connection:
raise GracefulDisconnect

async def send_error(self, channel_id: bytes, message: str = None, *, force_close_channel=False):
"""Sends an error message and force closes the channel.

Note:
* channel_id is the temporary channel id when the channel id is not yet available

A sending node:
* SHOULD send error for protocol violations or internal errors that make channels
unusable or that make further communication unusable.
* SHOULD send error with the unknown channel_id in reply to messages of type
32-255 related to unknown channels.
* MUST fail the channel(s) referred to by the error message.
* MAY set channel_id to all zero to indicate all channels.

when failure was caused by an invalid signature check:
* SHOULD include the raw, hex-encoded transaction in reply to a funding_created,
funding_signed, closing_signed, or commitment_signed message.
"""
assert isinstance(channel_id, bytes)
encoded_data = b'' if not message else message.encode('ascii')
self.send_message('error', channel_id=channel_id, data=encoded_data, len=len(encoded_data))
# MUST fail the channel(s) referred to by the error message:
# we may violate this with force_close_channel
if force_close_channel:
# channel_id of zero means that the error refers to all channels
if channel_id == bytes(32):
for channel_id in self.channels:
self.schedule_force_closing(channel_id)
else:
self.schedule_force_closing(channel_id)
raise GracefulDisconnect

def on_ping(self, payload):
l = payload['num_pong_bytes']
Expand All @@ -241,7 +325,9 @@ async def wait_for_message(self, expected_name, channel_id):
q = self.ordered_message_queues[channel_id]
name, payload = await asyncio.wait_for(q.get(), LN_P2P_NETWORK_TIMEOUT)
if payload.get('error'):
raise Exception('Remote peer reported error [DO NOT TRUST THIS MESSAGE]: ' + repr(payload.get('error')))
raise GracefulDisconnect(f'Waiting for {expected_name} failed due to an error sent by the peer.')
elif payload.get('warning'):
raise GracefulDisconnect(f'Waiting for {expected_name} failed due to a warning sent by the peer.')
if name != expected_name:
raise Exception(f"Received unexpected '{name}'")
return payload
Expand Down Expand Up @@ -774,7 +860,10 @@ async def channel_establishment_flow(
payload = await self.wait_for_message('funding_signed', channel_id)
self.logger.info('received funding_signed')
remote_sig = payload['signature']
chan.receive_new_commitment(remote_sig, [])
try:
chan.receive_new_commitment(remote_sig, [])
except LNProtocolWarning as e:
await self.send_warning(channel_id, message=str(e), close_connection=True)
chan.open_with_first_pcp(remote_per_commitment_point, remote_sig)
chan.set_state(ChannelState.OPENING)
self.lnworker.add_new_channel(chan)
Expand Down Expand Up @@ -933,7 +1022,10 @@ async def on_open_channel(self, payload):
if isinstance(self.transport, LNTransport):
chan.add_or_update_peer_addr(self.transport.peer_addr)
remote_sig = funding_created['signature']
chan.receive_new_commitment(remote_sig, [])
try:
chan.receive_new_commitment(remote_sig, [])
except LNProtocolWarning as e:
await self.send_warning(channel_id, message=str(e), close_connection=True)
sig_64, _ = chan.sign_next_commitment()
self.send_message('funding_signed',
channel_id=channel_id,
Expand All @@ -955,6 +1047,13 @@ async def trigger_force_close(self, channel_id: bytes):
your_last_per_commitment_secret=0,
my_current_per_commitment_point=latest_point)

def schedule_force_closing(self, channel_id: bytes):
channels_with_peer = list(self.channels.keys())
channels_with_peer.extend(self.temp_id_to_id.values())
if channel_id not in channels_with_peer:
raise ValueError(f"channel {channel_id.hex()} does not belong to this peer")
self.lnworker.schedule_force_closing(channel_id)

def on_channel_reestablish(self, chan, msg):
their_next_local_ctn = msg["next_commitment_number"]
their_oldest_unrevoked_remote_ctn = msg["next_revocation_number"]
Expand Down Expand Up @@ -1774,20 +1873,29 @@ async def close_channel(self, chan_id: bytes):
return txid

async def on_shutdown(self, chan: Channel, payload):
# TODO: A receiving node: if it hasn't received a funding_signed (if it is a
# funder) or a funding_created (if it is a fundee):
# SHOULD send an error and fail the channel.
their_scriptpubkey = payload['scriptpubkey']
their_upfront_scriptpubkey = chan.config[REMOTE].upfront_shutdown_script
# BOLT-02 check if they use the upfront shutdown script they advertized
if their_upfront_scriptpubkey:
if self.is_upfront_shutdown_script() and their_upfront_scriptpubkey:
if not (their_scriptpubkey == their_upfront_scriptpubkey):
raise UpfrontShutdownScriptViolation("remote didn't use upfront shutdown script it commited to in channel opening")
await self.send_warning(
chan.channel_id,
"remote didn't use upfront shutdown script it commited to in channel opening",
close_connection=True)
else:
# BOLT-02 restrict the scriptpubkey to some templates:
if self.is_shutdown_anysegwit() and match_script_against_template(their_scriptpubkey, transaction.SCRIPTPUBKEY_TEMPLATE_ANYSEGWIT):
pass
elif match_script_against_template(their_scriptpubkey, transaction.SCRIPTPUBKEY_TEMPLATE_WITNESS_V0):
pass
else:
raise Exception(f'scriptpubkey in received shutdown message does not conform to any template: {their_scriptpubkey.hex()}')
await self.send_warning(
chan.channel_id,
f'scriptpubkey in received shutdown message does not conform to any template: {their_scriptpubkey.hex()}',
close_connection=True)

chan_id = chan.channel_id
if chan_id in self.shutdown_received:
Expand Down
11 changes: 10 additions & 1 deletion electrum/lnutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,6 @@ class UnableToDeriveSecret(LightningError): pass
class HandshakeFailed(LightningError): pass
class ConnStringFormatError(LightningError): pass
class RemoteMisbehaving(LightningError): pass
class UpfrontShutdownScriptViolation(RemoteMisbehaving): pass

class NotFoundChanAnnouncementForUpdate(Exception): pass
class InvalidGossipMsg(Exception):
Expand All @@ -362,6 +361,16 @@ class NoPathFound(PaymentFailure):
def __str__(self):
return _('No path found')


class LNProtocolError(Exception):
"""Raised in peer methods to trigger an error message."""


class LNProtocolWarning(Exception):
"""Raised in peer methods to trigger a warning message."""



# TODO make some of these values configurable?
REDEEM_AFTER_DOUBLE_SPENT_DELAY = 30

Expand Down
4 changes: 4 additions & 0 deletions electrum/lnwire/peer_wire.csv
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ msgtype,error,17
msgdata,error,channel_id,channel_id,
msgdata,error,len,u16,
msgdata,error,data,byte,len
msgtype,warning,1
msgdata,warning,channel_id,channel_id,
msgdata,warning,len,u16,
msgdata,warning,data,byte,len
msgtype,ping,18
msgdata,ping,num_pong_bytes,u16,
msgdata,ping,byteslen,u16,
Expand Down
38 changes: 35 additions & 3 deletions electrum/tests/test_lnpeer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
from electrum.lnaddr import lnencode, LnAddr, lndecode
from electrum.bitcoin import COIN, sha256
from electrum.util import bh2u, create_and_start_event_loop, NetworkRetryManager, bfh, OldTaskGroup
from electrum.lnpeer import Peer, UpfrontShutdownScriptViolation
from electrum.lnpeer import Peer
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
from electrum.lnutil import PaymentFailure, LnFeatures, HTLCOwner
from electrum.lnchannel import ChannelState, PeerState, Channel
from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent
Expand All @@ -38,6 +37,7 @@
from electrum.lnutil import derive_payment_secret_from_payment_preimage
from electrum.lnutil import LOCAL, REMOTE
from electrum.invoices import PR_PAID, PR_UNPAID
from electrum.interface import GracefulDisconnect

from .test_lnchannel import create_test_channels
from .test_bitcoin import needs_test_with_all_chacha20_implementations
Expand Down Expand Up @@ -1096,6 +1096,38 @@ async def f():
with self.assertRaises(concurrent.futures.CancelledError):
run(f())

@needs_test_with_all_chacha20_implementations
def test_warning(self):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)

async def action():
await asyncio.wait_for(p1.initialized, 1)
await asyncio.wait_for(p2.initialized, 1)
await p1.send_warning(alice_channel.channel_id, 'be warned!', close_connection=True)
gath = asyncio.gather(action(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
async def f():
await gath
with self.assertRaises(GracefulDisconnect):
run(f())

@needs_test_with_all_chacha20_implementations
def test_error(self):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)

async def action():
await asyncio.wait_for(p1.initialized, 1)
await asyncio.wait_for(p2.initialized, 1)
await p1.send_error(alice_channel.channel_id, 'some error happened!', force_close_channel=True)
assert alice_channel.is_closed()
gath.cancel()
gath = asyncio.gather(action(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
async def f():
await gath
with self.assertRaises(GracefulDisconnect):
run(f())

@needs_test_with_all_chacha20_implementations
def test_close_upfront_shutdown_script(self):
alice_channel, bob_channel = create_test_channels()
Expand Down Expand Up @@ -1135,7 +1167,7 @@ async def main_loop(peer):
gath = asyncio.gather(*coros)
await gath

with self.assertRaises(UpfrontShutdownScriptViolation):
with self.assertRaises(GracefulDisconnect):
run(test())

# bob sends the same upfront_shutdown_script has he announced
Expand Down