Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions electrum/channel_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
from . import constants
from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
from .logging import Logger
from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id, ShortChannelID
from .lnutil import (LNPeerAddr, format_short_channel_id, ShortChannelID,
validate_features, IncompatibleOrInsaneFeatures)
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
from .lnmsg import decode_msg

Expand All @@ -47,15 +48,6 @@
from .lnchannel import Channel


class UnknownEvenFeatureBits(Exception): pass

def validate_features(features : int):
enabled_features = list_enabled_bits(features)
for fbit in enabled_features:
if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0:
raise UnknownEvenFeatureBits()


FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0

Expand Down Expand Up @@ -102,14 +94,14 @@ class Policy(NamedTuple):
def from_msg(payload: dict) -> 'Policy':
return Policy(
key = payload['short_channel_id'] + payload['start_node'],
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big"),
htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big"),
htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None,
fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big"),
fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big"),
cltv_expiry_delta = payload['cltv_expiry_delta'],
htlc_minimum_msat = payload['htlc_minimum_msat'],
htlc_maximum_msat = payload.get('htlc_maximum_msat', None),
fee_base_msat = payload['fee_base_msat'],
fee_proportional_millionths = payload['fee_proportional_millionths'],
message_flags = int.from_bytes(payload['message_flags'], "big"),
channel_flags = int.from_bytes(payload['channel_flags'], "big"),
timestamp = int.from_bytes(payload['timestamp'], "big")
timestamp = payload['timestamp'],
)

@staticmethod
Expand Down Expand Up @@ -154,7 +146,7 @@ def from_msg(payload) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]:
alias = alias.decode('utf8')
except:
alias = ''
timestamp = int.from_bytes(payload['timestamp'], "big")
timestamp = payload['timestamp']
node_info = NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias)
return node_info, peer_addrs

Expand Down Expand Up @@ -321,11 +313,12 @@ def get_recent_peers(self):
return ret

# note: currently channel announcements are trusted by default (trusted=True);
# they are not verified. Verifying them would make the gossip sync
# they are not SPV-verified. Verifying them would make the gossip sync
# even slower; especially as servers will start throttling us.
# It would probably put significant strain on servers if all clients
# verified the complete gossip.
def add_channel_announcement(self, msg_payloads, *, trusted=True):
# note: signatures have already been verified.
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
added = 0
Expand All @@ -338,8 +331,8 @@ def add_channel_announcement(self, msg_payloads, *, trusted=True):
continue
try:
channel_info = ChannelInfo.from_msg(msg)
except UnknownEvenFeatureBits:
self.logger.info("unknown feature bits")
except IncompatibleOrInsaneFeatures as e:
self.logger.info(f"unknown or insane feature bits: {e!r}")
continue
if trusted:
added += 1
Expand All @@ -353,7 +346,7 @@ def add_channel_announcement(self, msg_payloads, *, trusted=True):
def add_verified_channel_info(self, msg: dict, *, capacity_sat: int = None) -> None:
try:
channel_info = ChannelInfo.from_msg(msg)
except UnknownEvenFeatureBits:
except IncompatibleOrInsaneFeatures:
return
channel_info = channel_info._replace(capacity_sat=capacity_sat)
with self.lock:
Expand Down Expand Up @@ -392,7 +385,7 @@ def add_channel_updates(self, payloads, max_age=None, verify=True) -> Categorize
now = int(time.time())
for payload in payloads:
short_channel_id = ShortChannelID(payload['short_channel_id'])
timestamp = int.from_bytes(payload['timestamp'], "big")
timestamp = payload['timestamp']
if max_age and now - timestamp > max_age:
expired.append(payload)
continue
Expand All @@ -407,7 +400,7 @@ def add_channel_updates(self, payloads, max_age=None, verify=True) -> Categorize
known.append(payload)
# compare updates to existing database entries
for payload in known:
timestamp = int.from_bytes(payload['timestamp'], "big")
timestamp = payload['timestamp']
start_node = payload['start_node']
short_channel_id = ShortChannelID(payload['short_channel_id'])
key = (start_node, short_channel_id)
Expand Down Expand Up @@ -499,13 +492,14 @@ def verify_channel_update(self, payload):
raise Exception(f'failed verifying channel update for {short_channel_id}')

def add_node_announcement(self, msg_payloads):
# note: signatures have already been verified.
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
new_nodes = {}
for msg_payload in msg_payloads:
try:
node_info, node_addresses = NodeInfo.from_msg(msg_payload)
except UnknownEvenFeatureBits:
except IncompatibleOrInsaneFeatures:
continue
node_id = node_info.node_id
# Ignore node if it has no associated channel (DoS protection)
Expand Down Expand Up @@ -599,11 +593,17 @@ def newest_ts_for_node_id(node_id):
self._recent_peers = sorted_node_ids[:self.NUM_MAX_RECENT_PEERS]
c.execute("""SELECT * FROM channel_info""")
for short_channel_id, msg in c:
ci = ChannelInfo.from_raw_msg(msg)
try:
ci = ChannelInfo.from_raw_msg(msg)
except IncompatibleOrInsaneFeatures:
continue
self._channels[ShortChannelID.normalize(short_channel_id)] = ci
c.execute("""SELECT * FROM node_info""")
for node_id, msg in c:
node_info, node_addresses = NodeInfo.from_raw_msg(msg)
try:
node_info, node_addresses = NodeInfo.from_raw_msg(msg)
except IncompatibleOrInsaneFeatures:
continue
# don't load node_addresses because they dont have timestamps
self._nodes[node_id] = node_info
c.execute("""SELECT * FROM policy""")
Expand Down Expand Up @@ -671,7 +671,7 @@ def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *,
return
now = int(time.time())
remote_update_decoded = decode_msg(remote_update_raw)[1]
remote_update_decoded['timestamp'] = now.to_bytes(4, byteorder="big")
remote_update_decoded['timestamp'] = now
remote_update_decoded['start_node'] = node_id
return Policy.from_msg(remote_update_decoded)
elif node_id == chan.get_local_pubkey(): # outgoing direction (from us)
Expand Down
Loading