diff --git a/docs/release-notes/eclair-vnext.md b/docs/release-notes/eclair-vnext.md index 06175f8bb3..94a34ecfbd 100644 --- a/docs/release-notes/eclair-vnext.md +++ b/docs/release-notes/eclair-vnext.md @@ -24,6 +24,14 @@ Eclair now supports the feature `option_onion_messages`. If this feature is enab It can also send onion messages with the `sendonionmessage` API. Messages sent to Eclair will be ignored. +### Support for `option_compression` + +Eclair now supports the `option_compression` feature as specified in https://github.com/lightning/bolts/pull/825. +Eclair will announce what compression algorithms it supports for routing sync, and will only use compression algorithms supported by its peers when forwarding gossip. + +If you were overriding the default `eclair.router.sync.encoding-type` in your `eclair.conf`, you need to update your configuration. +This field has been renamed `eclair.router.sync.preferred-compression-algorithm` and defaults to `zlib`. + ### API changes #### Timestamps diff --git a/eclair-core/src/main/resources/reference.conf b/eclair-core/src/main/resources/reference.conf index fef81d4392..112bfec88b 100644 --- a/eclair-core/src/main/resources/reference.conf +++ b/eclair-core/src/main/resources/reference.conf @@ -55,6 +55,7 @@ eclair { // Do not enable option_anchor_outputs unless you really know what you're doing. option_anchor_outputs = disabled option_anchors_zero_fee_htlc_tx = disabled + option_compression = optional option_shutdown_anysegwit = optional option_onion_messages = disabled trampoline_payment = disabled @@ -218,7 +219,7 @@ eclair { sync { request-node-announcements = true // if true we will ask for node announcements when we receive channel ids that we don't know - encoding-type = zlib // encoding for short_channel_ids and timestamps in query channel sync messages; other possible value is "uncompressed" + preferred-compression-algorithm = zlib // encoding for short_channel_ids and timestamps in query channel sync messages; other possible value is "uncompressed" channel-range-chunk-size = 1500 // max number of short_channel_ids (+ timestamps + checksums) in reply_channel_range *do not change this unless you know what you are doing* channel-query-chunk-size = 100 // max number of short_channel_ids in query_short_channel_ids *do not change this unless you know what you are doing* } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala index 4c392d5642..60054f0672 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala @@ -203,6 +203,11 @@ object Features { val mandatory = 26 } + case object CompressionSupport extends Feature { + val rfcName = "option_compression" + val mandatory = 32 + } + case object OnionMessages extends Feature { val rfcName = "option_onion_messages" val mandatory = 38 @@ -235,6 +240,7 @@ object Features { StaticRemoteKey, AnchorOutputs, AnchorOutputsZeroFeeHtlcTx, + CompressionSupport, ShutdownAnySegwit, OnionMessages, KeySend diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala index 0de805f9c4..56e6a332a2 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala @@ -32,7 +32,7 @@ import fr.acinq.eclair.router.Graph.{HeuristicsConstants, WeightRatios} import fr.acinq.eclair.router.PathFindingExperimentConf import fr.acinq.eclair.router.Router.{MultiPartParams, PathFindingConf, RouterConf, SearchBoundaries} import fr.acinq.eclair.tor.Socks5ProxyParams -import fr.acinq.eclair.wire.protocol.{Color, EncodingType, NodeAddress} +import fr.acinq.eclair.wire.protocol.{Color, CompressionAlgorithm, NodeAddress} import grizzled.slf4j.Logging import scodec.bits.ByteVector @@ -217,6 +217,7 @@ object NodeParams extends Logging { "router.path-finding.ratio-channel-capacity" -> "router.path-finding.default.ratios.channel-capacity", "router.path-finding.hop-cost-base-msat" -> "router.path-finding.default.hop-cost.fee-base-msat", "router.path-finding.hop-cost-millionths" -> "router.path-finding.default.hop-cost.fee-proportional-millionths", + "router.sync.encoding-type" -> "router.sync.preferred-compression-algorithm" ) deprecatedKeyPaths.foreach { case (old, new_) => require(!config.hasPath(old), s"configuration key '$old' has been replaced by '$new_'") @@ -353,7 +354,6 @@ object NodeParams extends Logging { experimentName = name, experimentPercentage = config.getInt("percentage")) - def getPathFindingExperimentConf(config: Config): PathFindingExperimentConf = { val experiments = config.root.asScala.keys.map(name => name -> getPathFindingConf(config.getConfig(name), name)) PathFindingExperimentConf(experiments.toMap) @@ -364,9 +364,9 @@ object NodeParams extends Logging { case "stop" => UnhandledExceptionStrategy.Stop } - val routerSyncEncodingType = config.getString("router.sync.encoding-type") match { - case "uncompressed" => EncodingType.UNCOMPRESSED - case "zlib" => EncodingType.COMPRESSED_ZLIB + val routerSyncPreferredCompression = config.getString("router.sync.preferred-compression-algorithm") match { + case "uncompressed" => CompressionAlgorithm.Uncompressed + case "zlib" => CompressionAlgorithm.ZlibDeflate } NodeParams( @@ -456,7 +456,7 @@ object NodeParams extends Logging { routerBroadcastInterval = FiniteDuration(config.getDuration("router.broadcast-interval").getSeconds, TimeUnit.SECONDS), networkStatsRefreshInterval = FiniteDuration(config.getDuration("router.network-stats-interval").getSeconds, TimeUnit.SECONDS), requestNodeAnnouncements = config.getBoolean("router.sync.request-node-announcements"), - encodingType = routerSyncEncodingType, + preferredCompression = routerSyncPreferredCompression, channelRangeChunkSize = config.getInt("router.sync.channel-range-chunk-size"), channelQueryChunkSize = config.getInt("router.sync.channel-query-chunk-size"), pathFindingExperimentConf = getPathFindingExperimentConf(config.getConfig("router.path-finding.experiments")) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala index ca7235088b..476aaca361 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala @@ -452,7 +452,7 @@ object Peer { case object GetPeerInfo case class PeerInfo(nodeId: PublicKey, state: String, address: Option[InetSocketAddress], channels: Int) - case class PeerRoutingMessage(peerConnection: ActorRef, remoteNodeId: PublicKey, message: RoutingMessage) extends RemoteTypes + case class PeerRoutingMessage(peerConnection: ActorRef, remoteNodeId: PublicKey, remoteInit: protocol.Init, message: RoutingMessage) extends RemoteTypes /** * Dedicated command for outgoing messages for logging purposes. diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala index 86671b9c7c..2a949515ff 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala @@ -101,7 +101,7 @@ class PeerConnection(keyPair: KeyPair, conf: PeerConnection.Conf, switchboard: A d.transport ! TransportHandler.Listener(self) Metrics.PeerConnectionsConnecting.withTag(Tags.ConnectionState, Tags.ConnectionStates.Initializing).increment() log.info(s"using features=$localFeatures") - val localInit = protocol.Init(localFeatures, TlvStream(InitTlv.Networks(chainHash :: Nil))) + val localInit = protocol.Init(localFeatures, TlvStream(InitTlv.Networks(chainHash :: Nil), InitTlv.CompressionAlgorithms(Set(CompressionAlgorithm.Uncompressed, CompressionAlgorithm.ZlibDeflate)))) d.transport ! localInit startSingleTimer(INIT_TIMER, InitTimeout, conf.initTimeout) goto(INITIALIZING) using InitializingData(chainHash, d.pendingAuth, d.remoteNodeId, d.transport, peer, localInit, doSync) @@ -113,7 +113,7 @@ class PeerConnection(keyPair: KeyPair, conf: PeerConnection.Conf, switchboard: A cancelTimer(INIT_TIMER) d.transport ! TransportHandler.ReadAck(remoteInit) - log.info(s"peer is using features=${remoteInit.features}, networks=${remoteInit.networks.mkString(",")}") + log.info(s"peer is using features=${remoteInit.features}, networks=${remoteInit.networks.mkString(",")} compression=${remoteInit.compressionAlgorithms.mkString(",")}") val featureGraphErr_opt = Features.validateFeatureGraph(remoteInit.features) if (remoteInit.networks.nonEmpty && remoteInit.networks.intersect(d.localInit.networks).isEmpty) { @@ -234,7 +234,7 @@ class PeerConnection(keyPair: KeyPair, conf: PeerConnection.Conf, switchboard: A case Event(DelayedRebroadcast(rebroadcast), d: ConnectedData) => - val thisRemote = RemoteGossip(self, d.remoteNodeId) + val thisRemote = RemoteGossip(self, d.remoteNodeId, d.remoteInit) /** * Send and count in a single iteration @@ -285,7 +285,7 @@ class PeerConnection(keyPair: KeyPair, conf: PeerConnection.Conf, switchboard: A d.transport ! TransportHandler.ReadAck(msg) case _ => // Note: we don't ack messages here because we don't want them to be stacked in the router's mailbox - router ! Peer.PeerRoutingMessage(self, d.remoteNodeId, msg) + router ! Peer.PeerRoutingMessage(self, d.remoteNodeId, d.remoteInit, msg) } stay() @@ -347,7 +347,8 @@ class PeerConnection(keyPair: KeyPair, conf: PeerConnection.Conf, switchboard: A case Event(DoSync(replacePrevious), d: ConnectedData) => val canUseChannelRangeQueries = Features.canUseFeature(d.localInit.features, d.remoteInit.features, Features.ChannelRangeQueries) val canUseChannelRangeQueriesEx = Features.canUseFeature(d.localInit.features, d.remoteInit.features, Features.ChannelRangeQueriesExtended) - if (canUseChannelRangeQueries || canUseChannelRangeQueriesEx) { + val hasCompatibleCompression = CompressionAlgorithm.select(d.localInit.compressionAlgorithms, d.remoteInit.compressionAlgorithms).nonEmpty + if ((canUseChannelRangeQueries || canUseChannelRangeQueriesEx) && hasCompatibleCompression) { val flags_opt = if (canUseChannelRangeQueriesEx) Some(QueryChannelRangeTlv.QueryFlags(QueryChannelRangeTlv.QueryFlags.WANT_ALL)) else None log.info(s"sending sync channel range query with flags_opt=$flags_opt replacePrevious=$replacePrevious") router ! SendChannelQuery(d.chainHash, d.remoteNodeId, self, replacePrevious, flags_opt) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/remote/EclairInternalsSerializer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/remote/EclairInternalsSerializer.scala index 5a23fb478e..7bbeaba286 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/remote/EclairInternalsSerializer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/remote/EclairInternalsSerializer.scala @@ -83,17 +83,19 @@ object EclairInternalsSerializer { ("experimentPercentage" | int32)).as[PathFindingConf] val pathFindingExperimentConfCodec: Codec[PathFindingExperimentConf] = ( - ("experiments" | listOfN(int32, pathFindingConfCodec).xmap[Map[String, PathFindingConf]](_.map(e => (e.experimentName -> e)).toMap, _.values.toList)) + "experiments" | listOfN(int32, pathFindingConfCodec).xmap[Map[String, PathFindingConf]](_.map(e => e.experimentName -> e).toMap, _.values.toList) ).as[PathFindingExperimentConf] + private val compressionAlgorithmCodec: Codec[CompressionAlgorithm] = discriminated[CompressionAlgorithm].by(uint8) + .typecase(CompressionAlgorithm.Uncompressed.bitPosition, provide(CompressionAlgorithm.Uncompressed)) + .typecase(CompressionAlgorithm.ZlibDeflate.bitPosition, provide(CompressionAlgorithm.ZlibDeflate)) + val routerConfCodec: Codec[RouterConf] = ( ("channelExcludeDuration" | finiteDurationCodec) :: ("routerBroadcastInterval" | finiteDurationCodec) :: ("networkStatsRefreshInterval" | finiteDurationCodec) :: ("requestNodeAnnouncements" | bool(8)) :: - ("encodingType" | discriminated[EncodingType].by(uint8) - .typecase(0, provide(EncodingType.UNCOMPRESSED)) - .typecase(1, provide(EncodingType.COMPRESSED_ZLIB))) :: + ("preferredCompression" | compressionAlgorithmCodec) :: ("channelRangeChunkSize" | int32) :: ("channelQueryChunkSize" | int32) :: ("pathFindingExperimentConf" | pathFindingExperimentConfCodec)).as[RouterConf] @@ -166,6 +168,7 @@ object EclairInternalsSerializer { def peerRoutingMessageCodec(system: ExtendedActorSystem): Codec[PeerRoutingMessage] = ( ("peerConnection" | actorRefCodec(system)) :: ("remoteNodeId" | publicKey) :: + ("remoteInit" | lengthPrefixedInitCodec) :: ("msg" | lengthPrefixedLightningMessageCodec.downcast[RoutingMessage])).as[PeerRoutingMessage] val singleChannelDiscoveredCodec: Codec[SingleChannelDiscovered] = (lengthPrefixedChannelAnnouncementCodec :: satoshi :: optional(bool(8), lengthPrefixedChannelUpdateCodec) :: optional(bool(8), lengthPrefixedChannelUpdateCodec)).as[SingleChannelDiscovered] diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index 5041b7273c..bc2bd7b401 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -223,13 +223,13 @@ class Router(val nodeParams: NodeParams, watcher: typed.ActorRef[ZmqWatcher.Comm stay() using RouteCalculation.handleRouteRequest(d, nodeParams.routerConf, nodeParams.currentBlockHeight, r) // Warning: order matters here, this must be the first match for HasChainHash messages ! - case Event(PeerRoutingMessage(_, _, routingMessage: HasChainHash), _) if routingMessage.chainHash != nodeParams.chainHash => + case Event(PeerRoutingMessage(_, _, _, routingMessage: HasChainHash), _) if routingMessage.chainHash != nodeParams.chainHash => sender() ! TransportHandler.ReadAck(routingMessage) log.warning("message {} for wrong chain {}, we're on {}", routingMessage, routingMessage.chainHash, nodeParams.chainHash) stay() - case Event(PeerRoutingMessage(peerConnection, remoteNodeId, c: ChannelAnnouncement), d) => - stay() using Validation.handleChannelAnnouncement(d, nodeParams.db.network, watcher, RemoteGossip(peerConnection, remoteNodeId), c) + case Event(PeerRoutingMessage(peerConnection, remoteNodeId, remoteInit, c: ChannelAnnouncement), d) => + stay() using Validation.handleChannelAnnouncement(d, nodeParams.db.network, watcher, RemoteGossip(peerConnection, remoteNodeId, remoteInit), c) case Event(r: ValidateResult, d) => stay() using Validation.handleChannelValidationResponse(d, nodeParams, watcher, r) @@ -240,14 +240,14 @@ class Router(val nodeParams: NodeParams, watcher: typed.ActorRef[ZmqWatcher.Comm case Event(n: NodeAnnouncement, d: Data) => stay() using Validation.handleNodeAnnouncement(d, nodeParams.db.network, Set(LocalGossip), n) - case Event(PeerRoutingMessage(peerConnection, remoteNodeId, n: NodeAnnouncement), d: Data) => - stay() using Validation.handleNodeAnnouncement(d, nodeParams.db.network, Set(RemoteGossip(peerConnection, remoteNodeId)), n) + case Event(PeerRoutingMessage(peerConnection, remoteNodeId, remoteInit, n: NodeAnnouncement), d: Data) => + stay() using Validation.handleNodeAnnouncement(d, nodeParams.db.network, Set(RemoteGossip(peerConnection, remoteNodeId, remoteInit)), n) case Event(u: ChannelUpdate, d: Data) => stay() using Validation.handleChannelUpdate(d, nodeParams.db.network, nodeParams.routerConf, Right(RemoteChannelUpdate(u, Set(LocalGossip)))) - case Event(PeerRoutingMessage(peerConnection, remoteNodeId, u: ChannelUpdate), d) => - stay() using Validation.handleChannelUpdate(d, nodeParams.db.network, nodeParams.routerConf, Right(RemoteChannelUpdate(u, Set(RemoteGossip(peerConnection, remoteNodeId))))) + case Event(PeerRoutingMessage(peerConnection, remoteNodeId, remoteInit, u: ChannelUpdate), d) => + stay() using Validation.handleChannelUpdate(d, nodeParams.db.network, nodeParams.routerConf, Right(RemoteChannelUpdate(u, Set(RemoteGossip(peerConnection, remoteNodeId, remoteInit))))) case Event(lcu: LocalChannelUpdate, d: Data) => stay() using Validation.handleLocalChannelUpdate(d, nodeParams.db.network, nodeParams.routerConf, nodeParams.nodeId, watcher, lcu) @@ -261,19 +261,19 @@ class Router(val nodeParams: NodeParams, watcher: typed.ActorRef[ZmqWatcher.Comm case Event(s: SendChannelQuery, d) => stay() using Sync.handleSendChannelQuery(d, s) - case Event(PeerRoutingMessage(peerConnection, remoteNodeId, q: QueryChannelRange), d) => - Sync.handleQueryChannelRange(d.channels, nodeParams.routerConf, RemoteGossip(peerConnection, remoteNodeId), q) + case Event(PeerRoutingMessage(peerConnection, remoteNodeId, remoteInit, q: QueryChannelRange), d) => + Sync.handleQueryChannelRange(d.channels, nodeParams.routerConf, RemoteGossip(peerConnection, remoteNodeId, remoteInit), q) stay() - case Event(PeerRoutingMessage(peerConnection, remoteNodeId, r: ReplyChannelRange), d) => - stay() using Sync.handleReplyChannelRange(d, nodeParams.routerConf, RemoteGossip(peerConnection, remoteNodeId), r) + case Event(PeerRoutingMessage(peerConnection, remoteNodeId, remoteInit, r: ReplyChannelRange), d) => + stay() using Sync.handleReplyChannelRange(d, nodeParams.routerConf, RemoteGossip(peerConnection, remoteNodeId, remoteInit), r) - case Event(PeerRoutingMessage(peerConnection, remoteNodeId, q: QueryShortChannelIds), d) => - Sync.handleQueryShortChannelIds(d.nodes, d.channels, RemoteGossip(peerConnection, remoteNodeId), q) + case Event(PeerRoutingMessage(peerConnection, remoteNodeId, remoteInit, q: QueryShortChannelIds), d) => + Sync.handleQueryShortChannelIds(d.nodes, d.channels, RemoteGossip(peerConnection, remoteNodeId, remoteInit), q) stay() - case Event(PeerRoutingMessage(peerConnection, remoteNodeId, r: ReplyShortChannelIdsEnd), d) => - stay() using Sync.handleReplyShortChannelIdsEnd(d, RemoteGossip(peerConnection, remoteNodeId), r) + case Event(PeerRoutingMessage(peerConnection, remoteNodeId, remoteInit, r: ReplyShortChannelIdsEnd), d) => + stay() using Sync.handleReplyShortChannelIdsEnd(d, RemoteGossip(peerConnection, remoteNodeId, remoteInit), r) } @@ -322,7 +322,7 @@ object Router { routerBroadcastInterval: FiniteDuration, networkStatsRefreshInterval: FiniteDuration, requestNodeAnnouncements: Boolean, - encodingType: EncodingType, + preferredCompression: CompressionAlgorithm, channelRangeChunkSize: Int, channelQueryChunkSize: Int, pathFindingExperimentConf: PathFindingExperimentConf) @@ -552,7 +552,7 @@ object Router { // @formatter:off sealed trait GossipOrigin /** Gossip that we received from a remote peer. */ - case class RemoteGossip(peerConnection: ActorRef, nodeId: PublicKey) extends GossipOrigin + case class RemoteGossip(peerConnection: ActorRef, nodeId: PublicKey, remoteInit: Init) extends GossipOrigin /** Gossip that was generated by our node. */ case object LocalGossip extends GossipOrigin diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Sync.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Sync.scala index 15a0f25b42..330cc3509f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Sync.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Sync.scala @@ -74,17 +74,22 @@ object Sync { ctx.sender() ! TransportHandler.ReadAck(q) Metrics.QueryChannelRange.Blocks.withoutTags().record(q.numberOfBlocks) log.info("received query_channel_range with firstBlockNum={} numberOfBlocks={} extendedQueryFlags_opt={}", q.firstBlockNum, q.numberOfBlocks, q.tlvStream) - // keep channel ids that are in [firstBlockNum, firstBlockNum + numberOfBlocks] - val shortChannelIds: SortedSet[ShortChannelId] = channels.keySet.filter(keep(q.firstBlockNum, q.numberOfBlocks, _)) - log.info("replying with {} items for range=({}, {})", shortChannelIds.size, q.firstBlockNum, q.numberOfBlocks) - val chunks = split(shortChannelIds, q.firstBlockNum, q.numberOfBlocks, routerConf.channelRangeChunkSize) - Metrics.QueryChannelRange.Replies.withoutTags().record(chunks.size) - chunks.zipWithIndex.foreach { case (chunk, i) => - val syncComplete = i == chunks.size - 1 - val reply = buildReplyChannelRange(chunk, syncComplete, q.chainHash, routerConf.encodingType, q.queryFlags_opt, channels) - origin.peerConnection ! reply - Metrics.ReplyChannelRange.Blocks.withTag(Tags.Direction, Tags.Directions.Outgoing).record(reply.numberOfBlocks) - Metrics.ReplyChannelRange.ShortChannelIds.withTag(Tags.Direction, Tags.Directions.Outgoing).record(reply.shortChannelIds.array.size) + CompressionAlgorithm.select(routerConf.preferredCompression, CompressionAlgorithm.defaultSupported, origin.remoteInit.compressionAlgorithms) match { + case Some(preferredCompression) => + // keep channel ids that are in [firstBlockNum, firstBlockNum + numberOfBlocks] + val shortChannelIds: SortedSet[ShortChannelId] = channels.keySet.filter(keep(q.firstBlockNum, q.numberOfBlocks, _)) + log.info("replying with {} items for range=({}, {})", shortChannelIds.size, q.firstBlockNum, q.numberOfBlocks) + val chunks = split(shortChannelIds, q.firstBlockNum, q.numberOfBlocks, routerConf.channelRangeChunkSize) + Metrics.QueryChannelRange.Replies.withoutTags().record(chunks.size) + chunks.zipWithIndex.foreach { case (chunk, i) => + val syncComplete = i == chunks.size - 1 + val reply = buildReplyChannelRange(chunk, syncComplete, q.chainHash, preferredCompression, q.queryFlags_opt, channels) + origin.peerConnection ! reply + Metrics.ReplyChannelRange.Blocks.withTag(Tags.Direction, Tags.Directions.Outgoing).record(reply.numberOfBlocks) + Metrics.ReplyChannelRange.ShortChannelIds.withTag(Tags.Direction, Tags.Directions.Outgoing).record(reply.shortChannelIds.array.size) + } + case None => + log.info("peer doesn't support any of our compression algorithms, ignoring query_channel_range") } } @@ -133,7 +138,7 @@ object Sync { def buildQuery(chunk: List[ShortChannelIdAndFlag]): QueryShortChannelIds = { // always encode empty lists as UNCOMPRESSED - val encoding = if (chunk.isEmpty) EncodingType.UNCOMPRESSED else r.shortChannelIds.encoding + val encoding = if (chunk.isEmpty) CompressionAlgorithm.Uncompressed else r.shortChannelIds.encoding val flags: TlvStream[QueryShortChannelIdsTlv] = if (r.timestamps_opt.isDefined || r.checksums_opt.isDefined) { TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(encoding, chunk.map(_.flag))) } else { @@ -481,8 +486,8 @@ object Sync { * @param channels channels map * @return a ReplyChannelRange object */ - def buildReplyChannelRange(chunk: ShortChannelIdsChunk, syncComplete: Boolean, chainHash: ByteVector32, defaultEncoding: EncodingType, queryFlags_opt: Option[QueryChannelRangeTlv.QueryFlags], channels: SortedMap[ShortChannelId, PublicChannel]): ReplyChannelRange = { - val encoding = if (chunk.shortChannelIds.isEmpty) EncodingType.UNCOMPRESSED else defaultEncoding + def buildReplyChannelRange(chunk: ShortChannelIdsChunk, syncComplete: Boolean, chainHash: ByteVector32, defaultEncoding: CompressionAlgorithm, queryFlags_opt: Option[QueryChannelRangeTlv.QueryFlags], channels: SortedMap[ShortChannelId, PublicChannel]): ReplyChannelRange = { + val encoding = if (chunk.shortChannelIds.isEmpty) CompressionAlgorithm.Uncompressed else defaultEncoding val (timestamps, checksums) = queryFlags_opt match { case Some(extension) if extension.wantChecksums | extension.wantTimestamps => // we always compute timestamps and checksums even if we don't need both, overhead is negligible diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala index a78c9e9254..94201c45af 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala @@ -35,7 +35,7 @@ import fr.acinq.eclair.{Logs, MilliSatoshiLong, NodeParams, ShortChannelId, TxCo object Validation { private def sendDecision(origins: Set[GossipOrigin], decision: GossipDecision)(implicit sender: ActorRef): Unit = { - origins.collect { case RemoteGossip(peerConnection, _) => sendDecision(peerConnection, decision) } + origins.collect { case RemoteGossip(peerConnection, _, _) => sendDecision(peerConnection, decision) } } private def sendDecision(peerConnection: ActorRef, decision: GossipDecision)(implicit sender: ActorRef): Unit = { @@ -198,7 +198,7 @@ object Validation { val remoteOrigins = origins flatMap { case r: RemoteGossip if wasStashed => Some(r.peerConnection) - case RemoteGossip(peerConnection, _) => + case RemoteGossip(peerConnection, _, _) => peerConnection ! TransportHandler.ReadAck(n) log.debug("received node announcement for nodeId={}", n.nodeId) Some(peerConnection) @@ -253,7 +253,7 @@ object Validation { case Left(lcu) => (lcu.channelUpdate, Set(LocalGossip)) case Right(rcu) => rcu.origins.collect { - case RemoteGossip(peerConnection, _) if !wasStashed => // stashed changes have already been acknowledged + case RemoteGossip(peerConnection, _, _) if !wasStashed => // stashed changes have already been acknowledged log.debug("received channel update for shortChannelId={}", rcu.channelUpdate.shortChannelId) peerConnection ! TransportHandler.ReadAck(rcu.channelUpdate) } @@ -374,18 +374,24 @@ object Validation { db.removeFromPruned(u.shortChannelId) // peerConnection_opt will contain a valid peerConnection only when we're handling an update that we received from a peer, not // when we're sending updates to ourselves - origins head match { - case RemoteGossip(peerConnection, remoteNodeId) => - val query = QueryShortChannelIds(u.chainHash, EncodedShortChannelIds(routerConf.encodingType, List(u.shortChannelId)), TlvStream.empty) - d.sync.get(remoteNodeId) match { - case Some(sync) if sync.started => - // we already have a pending request to that node, let's add this channel to the list and we'll get it later - // TODO: we only request channels with old style channel_query - d.copy(sync = d.sync + (remoteNodeId -> sync.copy(remainingQueries = sync.remainingQueries :+ query, totalQueries = sync.totalQueries + 1))) - case _ => - // otherwise we send the query right away - peerConnection ! query - d.copy(sync = d.sync + (remoteNodeId -> Syncing(remainingQueries = Nil, totalQueries = 1))) + origins.head match { + case RemoteGossip(peerConnection, remoteNodeId, remoteInit) => + CompressionAlgorithm.select(routerConf.preferredCompression, CompressionAlgorithm.defaultSupported, remoteInit.compressionAlgorithms) match { + case Some(preferredCompression) => + val query = QueryShortChannelIds(u.chainHash, EncodedShortChannelIds(preferredCompression, List(u.shortChannelId)), TlvStream.empty) + d.sync.get(remoteNodeId) match { + case Some(sync) if sync.started => + // we already have a pending request to that node, let's add this channel to the list and we'll get it later + // TODO: we only request channels with old style channel_query + d.copy(sync = d.sync + (remoteNodeId -> sync.copy(remainingQueries = sync.remainingQueries :+ query, totalQueries = sync.totalQueries + 1))) + case _ => + // otherwise we send the query right away + peerConnection ! query + d.copy(sync = d.sync + (remoteNodeId -> Syncing(remainingQueries = Nil, totalQueries = 1))) + } + case None => + log.info(s"peer doesn't support any of our compression algorithms, we can't query channel updates for ${u.shortChannelId}") + d } case _ => // we don't know which node this update came from (maybe it was stashed and the channel got pruned in the meantime or some other corner case). diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/CommonCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/CommonCodecs.scala index 973367bb92..1a3ad0efa8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/CommonCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/CommonCodecs.scala @@ -52,6 +52,19 @@ object CommonCodecs { /** byte-aligned boolean codec */ val bool8: Codec[Boolean] = bool(8) + /** byte-aligned codec for right to left bit vector */ + val reversedBitVector: Codec[ReversedBitVector] = bytes.xmap( + b => ReversedBitVector(b.bits.toIndexedSeq.reverse.zipWithIndex.collect { case (true, i) => i }.toSet), + { + case v if v.activated.isEmpty => ByteVector.empty + case v => + // When converting from BitVector to ByteVector, scodec pads right instead of left, so we make sure we pad to bytes *before* setting feature bits. + var buf = BitVector.fill(v.activated.max + 1)(high = false).bytes.bits + v.activated.foreach { i => buf = buf.set(i) } + buf.reverse.bytes + } + ) + // this codec can be safely used for values < 2^63 and will fail otherwise // (for something smarter see https://github.com/yzernik/bitcoin-scodec/blob/master/src/main/scala/io/github/yzernik/bitcoinscodec/structures/UInt64.scala) val uint64overflow: Codec[Long] = int64.narrow(l => if (l >= 0) Attempt.Successful(l) else Attempt.failure(Err(s"overflow for value $l")), l => l) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecs.scala index 1ffe992e0d..e3204c477c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecs.scala @@ -270,13 +270,13 @@ object LightningMessageCodecs { val encodedShortChannelIdsCodec: Codec[EncodedShortChannelIds] = discriminated[EncodedShortChannelIds].by(byte) - .\(0) { + .\(CompressionAlgorithm.Uncompressed.bitPosition.toByte) { case a@EncodedShortChannelIds(_, Nil) => a // empty list is always encoded with encoding type 'uncompressed' for compatibility with other implementations - case a@EncodedShortChannelIds(EncodingType.UNCOMPRESSED, _) => a - }((provide[EncodingType](EncodingType.UNCOMPRESSED) :: list(shortchannelid)).as[EncodedShortChannelIds]) - .\(1) { - case a@EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, _) => a - }((provide[EncodingType](EncodingType.COMPRESSED_ZLIB) :: zlib(list(shortchannelid))).as[EncodedShortChannelIds]) + case a@EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, _) => a + }((provide[CompressionAlgorithm](CompressionAlgorithm.Uncompressed) :: list(shortchannelid)).as[EncodedShortChannelIds]) + .\(CompressionAlgorithm.ZlibDeflate.bitPosition.toByte) { + case a@EncodedShortChannelIds(CompressionAlgorithm.ZlibDeflate, _) => a + }((provide[CompressionAlgorithm](CompressionAlgorithm.ZlibDeflate) :: zlib(list(shortchannelid))).as[EncodedShortChannelIds]) val queryShortChannelIdsCodec: Codec[QueryShortChannelIds] = ( ("chainHash" | bytes32) :: diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala index 3974851076..3b414a120c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala @@ -49,6 +49,7 @@ sealed trait HtlcSettlementMessage extends UpdateMessage { def id: Long } // <- case class Init(features: Features, tlvStream: TlvStream[InitTlv] = TlvStream.empty) extends SetupMessage { val networks = tlvStream.get[InitTlv.Networks].map(_.chainHashes).getOrElse(Nil) + val compressionAlgorithms = tlvStream.get[InitTlv.CompressionAlgorithms].map(_.supported).getOrElse(CompressionAlgorithm.defaultSupported) } case class Warning(channelId: ByteVector32, data: ByteVector, tlvStream: TlvStream[WarningTlv] = TlvStream.empty) extends SetupMessage with HasChannelId { @@ -276,15 +277,31 @@ object ChannelUpdate { } } -// @formatter:off -sealed trait EncodingType -object EncodingType { - case object UNCOMPRESSED extends EncodingType - case object COMPRESSED_ZLIB extends EncodingType +sealed trait CompressionAlgorithm { + def bitPosition: Int +} + +object CompressionAlgorithm { + // @formatter:off + case object Uncompressed extends CompressionAlgorithm { override val bitPosition: Int = 0 } + case object ZlibDeflate extends CompressionAlgorithm { override val bitPosition: Int = 1 } + // @formatter:on + + // When not provided, we assume support for uncompressed and zlib (which was the case before option_compression was introduced). + val defaultSupported: Set[CompressionAlgorithm] = Set(Uncompressed, ZlibDeflate) + + def select(localSupport: Set[CompressionAlgorithm], remoteSupport: Set[CompressionAlgorithm]): Option[CompressionAlgorithm] = { + localSupport.intersect(remoteSupport).headOption + } + + def select(preferred: CompressionAlgorithm, localSupport: Set[CompressionAlgorithm], remoteSupport: Set[CompressionAlgorithm]): Option[CompressionAlgorithm] = { + val candidates = localSupport.intersect(remoteSupport) + if (candidates.contains(preferred)) Some(preferred) else candidates.headOption + } + } -// @formatter:on -case class EncodedShortChannelIds(encoding: EncodingType, array: List[ShortChannelId]) { +case class EncodedShortChannelIds(encoding: CompressionAlgorithm, array: List[ShortChannelId]) { /** custom toString because it can get huge in logs */ override def toString: String = s"EncodedShortChannelIds($encoding,${array.headOption.getOrElse("")}->${array.lastOption.getOrElse("")} size=${array.size})" } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/ReversedBitVector.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/ReversedBitVector.scala new file mode 100644 index 0000000000..2404570c58 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/ReversedBitVector.scala @@ -0,0 +1,32 @@ +/* + * Copyright 2021 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.wire.protocol + +/** + * Created by t-bast on 15/09/2021. + */ + +/** + * Similar to [[scodec.bits.BitVector]], but bits are numbered right to left instead of left to right, which is what we + * do in most places in lightning. + */ +case class ReversedBitVector(activated: Set[Int]) { + + /** Returns true if the bit is set at the given position. */ + def get(position: Int): Boolean = activated.contains(position) + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RoutingTlv.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RoutingTlv.scala index c567aa86d6..1284e0e64d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RoutingTlv.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RoutingTlv.scala @@ -104,7 +104,7 @@ object ReplyChannelRangeTlv { * * @param encoding same convention as for short channel ids */ - case class EncodedTimestamps(encoding: EncodingType, timestamps: List[Timestamps]) extends ReplyChannelRangeTlv { + case class EncodedTimestamps(encoding: CompressionAlgorithm, timestamps: List[Timestamps]) extends ReplyChannelRangeTlv { /** custom toString because it can get huge in logs */ override def toString: String = s"EncodedTimestamps($encoding, size=${timestamps.size})" } @@ -130,8 +130,8 @@ object ReplyChannelRangeTlv { val encodedTimestampsCodec: Codec[EncodedTimestamps] = variableSizeBytesLong(varintoverflow, discriminated[EncodedTimestamps].by(byte) - .\(0) { case a@EncodedTimestamps(EncodingType.UNCOMPRESSED, _) => a }((provide[EncodingType](EncodingType.UNCOMPRESSED) :: list(timestampsCodec)).as[EncodedTimestamps]) - .\(1) { case a@EncodedTimestamps(EncodingType.COMPRESSED_ZLIB, _) => a }((provide[EncodingType](EncodingType.COMPRESSED_ZLIB) :: zlib(list(timestampsCodec))).as[EncodedTimestamps]) + .\(CompressionAlgorithm.Uncompressed.bitPosition.toByte) { case a@EncodedTimestamps(CompressionAlgorithm.Uncompressed, _) => a }((provide[CompressionAlgorithm](CompressionAlgorithm.Uncompressed) :: list(timestampsCodec)).as[EncodedTimestamps]) + .\(CompressionAlgorithm.ZlibDeflate.bitPosition.toByte) { case a@EncodedTimestamps(CompressionAlgorithm.ZlibDeflate, _) => a }((provide[CompressionAlgorithm](CompressionAlgorithm.ZlibDeflate) :: zlib(list(timestampsCodec))).as[EncodedTimestamps]) ) val checksumsCodec: Codec[Checksums] = ( @@ -158,7 +158,7 @@ object QueryShortChannelIdsTlv { * @param encoding 0 means uncompressed, 1 means compressed with zlib * @param array array of query flags, each flags specifies the info we want for a given channel */ - case class EncodedQueryFlags(encoding: EncodingType, array: List[Long]) extends QueryShortChannelIdsTlv { + case class EncodedQueryFlags(encoding: CompressionAlgorithm, array: List[Long]) extends QueryShortChannelIdsTlv { /** custom toString because it can get huge in logs */ override def toString: String = s"EncodedQueryFlags($encoding, size=${array.size})" } @@ -183,9 +183,8 @@ object QueryShortChannelIdsTlv { val encodedQueryFlagsCodec: Codec[EncodedQueryFlags] = discriminated[EncodedQueryFlags].by(byte) - .\(0) { case a@EncodedQueryFlags(EncodingType.UNCOMPRESSED, _) => a }((provide[EncodingType](EncodingType.UNCOMPRESSED) :: list(varintoverflow)).as[EncodedQueryFlags]) - .\(1) { case a@EncodedQueryFlags(EncodingType.COMPRESSED_ZLIB, _) => a }((provide[EncodingType](EncodingType.COMPRESSED_ZLIB) :: zlib(list(varintoverflow))).as[EncodedQueryFlags]) - + .\(CompressionAlgorithm.Uncompressed.bitPosition.toByte) { case a@EncodedQueryFlags(CompressionAlgorithm.Uncompressed, _) => a }((provide[CompressionAlgorithm](CompressionAlgorithm.Uncompressed) :: list(varintoverflow)).as[EncodedQueryFlags]) + .\(CompressionAlgorithm.ZlibDeflate.bitPosition.toByte) { case a@EncodedQueryFlags(CompressionAlgorithm.ZlibDeflate, _) => a }((provide[CompressionAlgorithm](CompressionAlgorithm.ZlibDeflate) :: zlib(list(varintoverflow))).as[EncodedQueryFlags]) val codec: Codec[TlvStream[QueryShortChannelIdsTlv]] = TlvCodecs.tlvStream(discriminated.by(varint) .typecase(UInt64(1), variableSizeBytesLong(varintoverflow, encodedQueryFlagsCodec)) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/SetupAndControlTlv.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/SetupAndControlTlv.scala index ebcb8a9252..f585d94de1 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/SetupAndControlTlv.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/SetupAndControlTlv.scala @@ -35,6 +35,9 @@ object InitTlv { /** The chains the node is interested in. */ case class Networks(chainHashes: List[ByteVector32]) extends InitTlv + /** Compression algorithms supported by the node. */ + case class CompressionAlgorithms(supported: Set[CompressionAlgorithm]) extends InitTlv + } object InitTlvCodecs { @@ -42,9 +45,20 @@ object InitTlvCodecs { import InitTlv._ private val networks: Codec[Networks] = variableSizeBytesLong(varintoverflow, list(bytes32)).as[Networks] + private val compressionAlgorithms: Codec[CompressionAlgorithms] = variableSizeBytesLong(varintoverflow, reversedBitVector).xmap( + bits => { + val supportedAlgorithms = Set( + if (bits.get(CompressionAlgorithm.Uncompressed.bitPosition)) Some(CompressionAlgorithm.Uncompressed) else Option.empty[CompressionAlgorithm], + if (bits.get(CompressionAlgorithm.ZlibDeflate.bitPosition)) Some(CompressionAlgorithm.ZlibDeflate) else Option.empty[CompressionAlgorithm] + ).flatten + CompressionAlgorithms(supportedAlgorithms) + }, + algorithms => ReversedBitVector(algorithms.supported.map(_.bitPosition).toSet) + ) val initTlvCodec = tlvStream(discriminated[InitTlv].by(varint) .typecase(UInt64(1), networks) + .typecase(UInt64(3), compressionAlgorithms) ) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala index f53df88b96..ed789b4117 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala @@ -216,8 +216,8 @@ class FeaturesSpec extends AnyFunSuite { hex"" -> Features.empty, hex"0100" -> Features(VariableLengthOnion -> Mandatory), hex"028a8a" -> Features(OptionDataLossProtect -> Optional, InitialRoutingSync -> Optional, ChannelRangeQueries -> Optional, VariableLengthOnion -> Optional, ChannelRangeQueriesExtended -> Optional, PaymentSecret -> Optional, BasicMultiPartPayment -> Optional), - hex"09004200" -> Features(Map[Feature, FeatureSupport](VariableLengthOnion -> Optional, PaymentSecret -> Mandatory, ShutdownAnySegwit -> Optional), Set(UnknownFeature(24))), - hex"52000000" -> Features(Map.empty[Feature, FeatureSupport], Set(UnknownFeature(25), UnknownFeature(28), UnknownFeature(30))) + hex"48004200" -> Features(Map[Feature, FeatureSupport](VariableLengthOnion -> Optional, PaymentSecret -> Mandatory, ShutdownAnySegwit -> Optional), Set(UnknownFeature(30))), + hex"c0000000" -> Features(Map.empty[Feature, FeatureSupport], Set(UnknownFeature(30), UnknownFeature(31))) ) for ((bin, features) <- testCases) { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala index 832959837b..d47de51ff7 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala @@ -28,7 +28,7 @@ import fr.acinq.eclair.payment.relay.Relayer.{RelayFees, RelayParams} import fr.acinq.eclair.router.Graph.WeightRatios import fr.acinq.eclair.router.PathFindingExperimentConf import fr.acinq.eclair.router.Router.{MultiPartParams, PathFindingConf, RouterConf, SearchBoundaries} -import fr.acinq.eclair.wire.protocol.{Color, EncodingType, NodeAddress, OnionRoutingPacket} +import fr.acinq.eclair.wire.protocol.{Color, CompressionAlgorithm, NodeAddress, OnionRoutingPacket} import org.scalatest.Tag import scodec.bits.ByteVector @@ -170,7 +170,7 @@ object TestConstants { routerBroadcastInterval = 5 seconds, networkStatsRefreshInterval = 1 hour, requestNodeAnnouncements = true, - encodingType = EncodingType.COMPRESSED_ZLIB, + preferredCompression = CompressionAlgorithm.ZlibDeflate, channelRangeChunkSize = 20, channelQueryChunkSize = 5, pathFindingExperimentConf = PathFindingExperimentConf(Map("alice-test-experiment" -> PathFindingConf( @@ -297,7 +297,7 @@ object TestConstants { routerBroadcastInterval = 5 seconds, networkStatsRefreshInterval = 1 hour, requestNodeAnnouncements = true, - encodingType = EncodingType.UNCOMPRESSED, + preferredCompression = CompressionAlgorithm.Uncompressed, channelRangeChunkSize = 20, channelQueryChunkSize = 5, pathFindingExperimentConf = PathFindingExperimentConf(Map("bob-test-experiment" -> PathFindingConf( diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala index fca9b79aba..a1a0bea8a5 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala @@ -42,8 +42,8 @@ import fr.acinq.eclair.payment.send.PaymentInitiator.{SendPaymentToNode, SendTra import fr.acinq.eclair.router.Graph.WeightRatios import fr.acinq.eclair.router.Router.{GossipDecision, PublicChannel} import fr.acinq.eclair.router.{Announcements, AnnouncementsBatchValidationSpec, Router} -import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, IncorrectOrUnknownPaymentDetails, NodeAnnouncement} -import fr.acinq.eclair.{CltvExpiryDelta, Kit, MilliSatoshiLong, ShortChannelId, TimestampMilli, randomBytes32} +import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, IncorrectOrUnknownPaymentDetails, Init, NodeAnnouncement} +import fr.acinq.eclair.{CltvExpiryDelta, Features, Kit, MilliSatoshiLong, ShortChannelId, TimestampMilli, randomBytes32} import org.json4s.JsonAST.{JString, JValue} import scodec.bits.ByteVector @@ -662,7 +662,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { // then we make the announcements val announcements = channels.map(c => AnnouncementsBatchValidationSpec.makeChannelAnnouncement(c, bitcoinClient)) announcements.foreach { ann => - nodes("A").router ! PeerRoutingMessage(sender.ref, remoteNodeId, ann) + nodes("A").router ! PeerRoutingMessage(sender.ref, remoteNodeId, Init(Features.empty), ann) sender.expectMsg(TransportHandler.ReadAck(ann)) sender.expectMsg(GossipDecision.Accepted(ann)) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala index 39f4f6f419..1870ec038a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala @@ -76,6 +76,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi transport.expectMsgType[TransportHandler.Listener] val localInit = transport.expectMsgType[protocol.Init] assert(localInit.networks === List(Block.RegtestGenesisBlock.hash)) + assert(localInit.compressionAlgorithms === CompressionAlgorithm.defaultSupported) transport.send(peerConnection, remoteInit) transport.expectMsgType[TransportHandler.ReadAck] if (doSync) { @@ -194,6 +195,25 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi connect(nodeParams, remoteNodeId, switchboard, router, connection, transport, peerConnection, peer, remoteInit, doSync = true) } + test("don't sync when compression algorithms don't match") { f => + import f._ + + val probe = TestProbe() + probe.send(peerConnection, PeerConnection.PendingAuth(connection.ref, Some(remoteNodeId), address, origin_opt = None, transport_opt = Some(transport.ref))) + transport.send(peerConnection, TransportHandler.HandshakeCompleted(remoteNodeId)) + switchboard.expectMsg(PeerConnection.Authenticated(peerConnection, remoteNodeId)) + probe.send(peerConnection, PeerConnection.InitializeConnection(peer.ref, nodeParams.chainHash, nodeParams.features, doSync = true)) + transport.expectMsgType[TransportHandler.Listener] + val localInit = transport.expectMsgType[protocol.Init] + val remoteInit = protocol.Init(Bob.nodeParams.features, TlvStream(InitTlv.CompressionAlgorithms(Set.empty))) + transport.send(peerConnection, remoteInit) + transport.expectMsgType[TransportHandler.ReadAck] + // We don't send channel queries because they don't support any type of compression + router.expectNoMessage(1 second) + peer.expectMsg(PeerConnection.ConnectionReady(peerConnection, remoteNodeId, address, outgoing = true, localInit, remoteInit)) + assert(peerConnection.stateName === PeerConnection.CONNECTED) + } + test("reply to ping") { f => import f._ connect(nodeParams, remoteNodeId, switchboard, router, connection, transport, peerConnection, peer) @@ -250,7 +270,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi test("filter gossip message (no filtering)") { f => import f._ val probe = TestProbe() - val gossipOrigin = Set[GossipOrigin](RemoteGossip(TestProbe().ref, randomKey().publicKey)) + val gossipOrigin = Set[GossipOrigin](RemoteGossip(TestProbe().ref, randomKey().publicKey, protocol.Init(Bob.nodeParams.features))) connect(nodeParams, remoteNodeId, switchboard, router, connection, transport, peerConnection, peer) val rebroadcast = Rebroadcast(channels.map(_ -> gossipOrigin).toMap, updates.map(_ -> gossipOrigin).toMap, nodes.map(_ -> gossipOrigin).toMap) probe.send(peerConnection, rebroadcast) @@ -260,8 +280,8 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi test("filter gossip message (filtered by origin)") { f => import f._ connect(nodeParams, remoteNodeId, switchboard, router, connection, transport, peerConnection, peer) - val gossipOrigin = Set[GossipOrigin](RemoteGossip(TestProbe().ref, randomKey().publicKey)) - val bobOrigin = RemoteGossip(peerConnection, remoteNodeId) + val gossipOrigin = Set[GossipOrigin](RemoteGossip(TestProbe().ref, randomKey().publicKey, protocol.Init(Bob.nodeParams.features))) + val bobOrigin = RemoteGossip(peerConnection, remoteNodeId, protocol.Init(Bob.nodeParams.features)) val rebroadcast = Rebroadcast( channels.map(_ -> gossipOrigin).toMap + (channels(5) -> Set(bobOrigin)), updates.map(_ -> gossipOrigin).toMap + (updates(6) -> (gossipOrigin + bobOrigin)) + (updates(10) -> Set(bobOrigin)), @@ -279,7 +299,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi test("filter gossip message (filtered by timestamp)") { f => import f._ connect(nodeParams, remoteNodeId, switchboard, router, connection, transport, peerConnection, peer) - val gossipOrigin = Set[GossipOrigin](RemoteGossip(TestProbe().ref, randomKey().publicKey)) + val gossipOrigin = Set[GossipOrigin](RemoteGossip(TestProbe().ref, randomKey().publicKey, protocol.Init(Bob.nodeParams.features))) val rebroadcast = Rebroadcast(channels.map(_ -> gossipOrigin).toMap, updates.map(_ -> gossipOrigin).toMap, nodes.map(_ -> gossipOrigin).toMap) val timestamps = updates.map(_.timestamp).sorted.slice(10, 30) val filter = protocol.GossipTimestampFilter(Alice.nodeParams.chainHash, timestamps.head, (timestamps.last - timestamps.head).toSeconds) @@ -297,7 +317,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi import f._ val probe = TestProbe() connect(nodeParams, remoteNodeId, switchboard, router, connection, transport, peerConnection, peer) - val gossipOrigin = Set[GossipOrigin](RemoteGossip(TestProbe().ref, randomKey().publicKey)) + val gossipOrigin = Set[GossipOrigin](RemoteGossip(TestProbe().ref, randomKey().publicKey, protocol.Init(Bob.nodeParams.features))) val rebroadcast = Rebroadcast( channels.map(_ -> gossipOrigin).toMap + (channels(5) -> Set(LocalGossip)), updates.map(_ -> gossipOrigin).toMap + (updates(6) -> (gossipOrigin + LocalGossip)) + (updates(10) -> Set(LocalGossip)), @@ -314,17 +334,18 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi test("react to peer's bad behavior") { f => import f._ val probe = TestProbe() + val remoteInit = protocol.Init(Bob.nodeParams.features) connect(nodeParams, remoteNodeId, switchboard, router, connection, transport, peerConnection, peer) val query = QueryShortChannelIds( Alice.nodeParams.chainHash, - EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(42000))), + EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, List(ShortChannelId(42000))), TlvStream.empty) // make sure that routing messages go through for (ann <- channels ++ updates) { transport.send(peerConnection, ann) - router.expectMsg(Peer.PeerRoutingMessage(peerConnection, remoteNodeId, ann)) + router.expectMsg(Peer.PeerRoutingMessage(peerConnection, remoteNodeId, remoteInit, ann)) } transport.expectNoMessage(1 second) // peer hasn't acknowledged the messages @@ -345,7 +366,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi router.expectNoMessage(1 second) // other routing messages go through transport.send(peerConnection, query) - router.expectMsg(Peer.PeerRoutingMessage(peerConnection, remoteNodeId, query)) + router.expectMsg(Peer.PeerRoutingMessage(peerConnection, remoteNodeId, remoteInit, query)) // after a while the ban is lifted probe.send(peerConnection, PeerConnection.ResumeAnnouncements) @@ -353,7 +374,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi // and announcements are processed again for (ann <- channels ++ updates) { transport.send(peerConnection, ann) - router.expectMsg(Peer.PeerRoutingMessage(peerConnection, remoteNodeId, ann)) + router.expectMsg(Peer.PeerRoutingMessage(peerConnection, remoteNodeId, remoteInit, ann)) } transport.expectNoMessage(1 second) // peer hasn't acknowledged the messages diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index 89b905e58f..2d7fd4937c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -675,9 +675,9 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val channelUpdate_hb = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_h, b, channelId_bh, CltvExpiryDelta(9), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 8, htlcMaximumMsat = 500000000 msat) assert(Router.getDesc(channelUpdate_bh, chan_bh) === ChannelDesc(channelId_bh, priv_b.publicKey, priv_h.publicKey)) val peerConnection = TestProbe() - router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_bh) - router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, channelUpdate_bh) - router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, channelUpdate_hb) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, chan_bh) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, channelUpdate_bh) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, channelUpdate_hb) assert(watcher.expectMsgType[ValidateRequest].ann === chan_bh) watcher.send(router, ValidateResult(chan_bh, Right((Transaction(version = 0, txIn = Nil, txOut = TxOut(1000000 sat, write(pay2wsh(Scripts.multiSig2of2(funding_b, funding_h)))) :: Nil, lockTime = 0), UtxoStatus.Unspent)))) watcher.expectMsgType[WatchExternalChannelSpent] diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentRequestSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentRequestSpec.scala index a35656e7b8..b3fb1d39ea 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentRequestSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentRequestSpec.scala @@ -421,9 +421,9 @@ class PaymentRequestSpec extends AnyFunSuite { PaymentRequestFeatures(bin" 0000110000101000100000000") -> Result(allowMultiPart = false, requirePaymentSecret = true, areSupported = true), PaymentRequestFeatures(bin" 0000100000101000100000000") -> Result(allowMultiPart = false, requirePaymentSecret = true, areSupported = true), PaymentRequestFeatures(bin" 0010000000101000100000000") -> Result(allowMultiPart = false, requirePaymentSecret = true, areSupported = true), + PaymentRequestFeatures(bin" 000100000000000100000100000000") -> Result(allowMultiPart = false, requirePaymentSecret = true, areSupported = true), // those are useful for nonreg testing of the areSupported method (which needs to be updated with every new supported mandatory bit) PaymentRequestFeatures(bin" 000001000000000100000100000000") -> Result(allowMultiPart = false, requirePaymentSecret = true, areSupported = false), - PaymentRequestFeatures(bin" 000100000000000100000100000000") -> Result(allowMultiPart = false, requirePaymentSecret = true, areSupported = true), PaymentRequestFeatures(bin"00000010000000000000100000100000000") -> Result(allowMultiPart = false, requirePaymentSecret = true, areSupported = false), PaymentRequestFeatures(bin"00001000000000000000100000100000000") -> Result(allowMultiPart = false, requirePaymentSecret = true, areSupported = false) ) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala index b2fccb69be..faab7a7dea 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala @@ -51,6 +51,7 @@ abstract class BaseRouterSpec extends TestKitBaseClass with FixtureAnyFunSuiteLi case class FixtureParam(nodeParams: NodeParams, router: ActorRef, watcher: TestProbe) val remoteNodeId = PrivateKey(ByteVector32(ByteVector.fill(32)(1))).publicKey + val remoteInit = Init(TestConstants.Bob.nodeParams.features) val publicChannelCapacity = 1000000 sat val htlcMaximum = 500000000 msat @@ -125,30 +126,30 @@ abstract class BaseRouterSpec extends TestKitBaseClass with FixtureAnyFunSuiteLi .modify(_.routerConf.routerBroadcastInterval).setTo(1 day) // "disable" auto rebroadcast val router = system.actorOf(Router.props(nodeParams, watcher.ref)) // we announce channels - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_ab)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_bc)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_cd)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_ef)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_gh)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, chan_ab)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, chan_bc)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, chan_cd)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, chan_ef)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, chan_gh)) // then nodes - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_b)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_c)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_d)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_e)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_f)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_g)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_h)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, node_b)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, node_c)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, node_d)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, node_e)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, node_f)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, node_g)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, node_h)) // then channel updates - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_ab)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_ba)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_bc)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_cb)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_cd)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_dc)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_ef)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_fe)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_gh)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_hg)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, update_ab)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, update_ba)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, update_bc)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, update_cb)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, update_cd)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, update_dc)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, update_ef)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, update_fe)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, update_gh)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, update_hg)) // then private channels sender.send(router, LocalChannelUpdate(sender.ref, randomBytes32(), channelId_ag, g, None, update_ag, CommitmentsSpec.makeCommitments(30000000 msat, 8000000 msat, a, g, announceChannel = false))) // watcher receives the get tx requests diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala index 9bd6348636..84ea99a129 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala @@ -22,7 +22,7 @@ import fr.acinq.eclair.router.Sync._ import fr.acinq.eclair.wire.protocol.QueryChannelRangeTlv.QueryFlags import fr.acinq.eclair.wire.protocol.QueryShortChannelIdsTlv.QueryFlagType._ import fr.acinq.eclair.wire.protocol.ReplyChannelRangeTlv._ -import fr.acinq.eclair.wire.protocol.{EncodedShortChannelIds, EncodingType, ReplyChannelRange} +import fr.acinq.eclair.wire.protocol.{CompressionAlgorithm, EncodedShortChannelIds, ReplyChannelRange} import fr.acinq.eclair.{MilliSatoshiLong, ShortChannelId, TimestampSecond, TimestampSecondLong, randomKey} import org.scalatest.funsuite.AnyFunSuite import scodec.bits.ByteVector @@ -358,20 +358,20 @@ class ChannelRangeQueriesSpec extends AnyFunSuite { test("do not encode empty lists as COMPRESSED_ZLIB") { { - val reply = buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), syncComplete = true, Block.RegtestGenesisBlock.hash, EncodingType.COMPRESSED_ZLIB, Some(QueryFlags(QueryFlags.WANT_ALL)), SortedMap()) - assert(reply == ReplyChannelRange(Block.RegtestGenesisBlock.hash, 0, 42L, 1, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, Nil), Some(EncodedTimestamps(EncodingType.UNCOMPRESSED, Nil)), Some(EncodedChecksums(Nil)))) + val reply = buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), syncComplete = true, Block.RegtestGenesisBlock.hash, CompressionAlgorithm.ZlibDeflate, Some(QueryFlags(QueryFlags.WANT_ALL)), SortedMap()) + assert(reply == ReplyChannelRange(Block.RegtestGenesisBlock.hash, 0, 42L, 1, EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, Nil), Some(EncodedTimestamps(CompressionAlgorithm.Uncompressed, Nil)), Some(EncodedChecksums(Nil)))) } { - val reply = buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), syncComplete = false, Block.RegtestGenesisBlock.hash, EncodingType.COMPRESSED_ZLIB, Some(QueryFlags(QueryFlags.WANT_TIMESTAMPS)), SortedMap()) - assert(reply == ReplyChannelRange(Block.RegtestGenesisBlock.hash, 0, 42L, 0, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, Nil), Some(EncodedTimestamps(EncodingType.UNCOMPRESSED, Nil)), None)) + val reply = buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), syncComplete = false, Block.RegtestGenesisBlock.hash, CompressionAlgorithm.ZlibDeflate, Some(QueryFlags(QueryFlags.WANT_TIMESTAMPS)), SortedMap()) + assert(reply == ReplyChannelRange(Block.RegtestGenesisBlock.hash, 0, 42L, 0, EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, Nil), Some(EncodedTimestamps(CompressionAlgorithm.Uncompressed, Nil)), None)) } { - val reply = buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), syncComplete = false, Block.RegtestGenesisBlock.hash, EncodingType.COMPRESSED_ZLIB, Some(QueryFlags(QueryFlags.WANT_CHECKSUMS)), SortedMap()) - assert(reply == ReplyChannelRange(Block.RegtestGenesisBlock.hash, 0, 42L, 0, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, Nil), None, Some(EncodedChecksums(Nil)))) + val reply = buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), syncComplete = false, Block.RegtestGenesisBlock.hash, CompressionAlgorithm.ZlibDeflate, Some(QueryFlags(QueryFlags.WANT_CHECKSUMS)), SortedMap()) + assert(reply == ReplyChannelRange(Block.RegtestGenesisBlock.hash, 0, 42L, 0, EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, Nil), None, Some(EncodedChecksums(Nil)))) } { - val reply = buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), syncComplete = true, Block.RegtestGenesisBlock.hash, EncodingType.COMPRESSED_ZLIB, None, SortedMap()) - assert(reply == ReplyChannelRange(Block.RegtestGenesisBlock.hash, 0, 42L, 1, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, Nil), None, None)) + val reply = buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), syncComplete = true, Block.RegtestGenesisBlock.hash, CompressionAlgorithm.ZlibDeflate, None, SortedMap()) + assert(reply == ReplyChannelRange(Block.RegtestGenesisBlock.hash, 0, 42L, 1, EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, Nil), None, None)) } } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala index 5c4b52b496..802fb80589 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala @@ -56,7 +56,7 @@ class RouterSpec extends BaseRouterSpec { val chan_ac = channelAnnouncement(ShortChannelId(420000, 5, 0), priv_a, priv_c, priv_funding_a, priv_funding_c) val update_ac = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, c, chan_ac.shortChannelId, CltvExpiryDelta(7), 0 msat, 766000 msat, 10, htlcMaximum) val node_c = makeNodeAnnouncement(priv_c, "node-C", Color(123, 100, -40), Nil, TestConstants.Bob.nodeParams.features, timestamp = TimestampSecond.now() + 1) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_ac)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, chan_ac)) peerConnection.expectNoMessage(100 millis) // we don't immediately acknowledge the announcement (back pressure) assert(watcher.expectMsgType[ValidateRequest].ann === chan_ac) watcher.send(router, ValidateResult(chan_ac, Right(Transaction(version = 0, txIn = Nil, txOut = TxOut(1000000 sat, write(pay2wsh(Scripts.multiSig2of2(funding_a, funding_c)))) :: Nil, lockTime = 0), UtxoStatus.Unspent))) @@ -64,10 +64,10 @@ class RouterSpec extends BaseRouterSpec { peerConnection.expectMsg(GossipDecision.Accepted(chan_ac)) assert(peerConnection.sender() == router) assert(watcher.expectMsgType[WatchExternalChannelSpent].shortChannelId === chan_ac.shortChannelId) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_ac)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, update_ac)) peerConnection.expectMsg(TransportHandler.ReadAck(update_ac)) peerConnection.expectMsg(GossipDecision.Accepted(update_ac)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_c)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, node_c)) peerConnection.expectMsg(TransportHandler.ReadAck(node_c)) peerConnection.expectMsg(GossipDecision.Accepted(node_c)) eventListener.expectMsg(ChannelsDiscovered(SingleChannelDiscovered(chan_ac, 1000000 sat, None, None) :: Nil)) @@ -86,12 +86,12 @@ class RouterSpec extends BaseRouterSpec { val chan_uc = channelAnnouncement(ShortChannelId(420000, 100, 0), priv_u, priv_c, priv_funding_u, priv_funding_c) val update_uc = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_u, c, chan_uc.shortChannelId, CltvExpiryDelta(7), 0 msat, 766000 msat, 10, htlcMaximum) val node_u = makeNodeAnnouncement(priv_u, "node-U", Color(-120, -20, 60), Nil, Features.empty) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_uc)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, chan_uc)) peerConnection.expectNoMessage(200 millis) // we don't immediately acknowledge the announcement (back pressure) assert(watcher.expectMsgType[ValidateRequest].ann === chan_uc) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_uc)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, update_uc)) peerConnection.expectMsg(TransportHandler.ReadAck(update_uc)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_u)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, node_u)) peerConnection.expectMsg(TransportHandler.ReadAck(node_u)) watcher.send(router, ValidateResult(chan_uc, Right(Transaction(version = 0, txIn = Nil, txOut = TxOut(2000000 sat, write(pay2wsh(Scripts.multiSig2of2(priv_funding_u.publicKey, funding_c)))) :: Nil, lockTime = 0), UtxoStatus.Unspent))) peerConnection.expectMsg(TransportHandler.ReadAck(chan_uc)) @@ -111,13 +111,13 @@ class RouterSpec extends BaseRouterSpec { { // duplicates - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_b)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, node_b)) peerConnection.expectMsg(TransportHandler.ReadAck(node_b)) peerConnection.expectMsg(GossipDecision.Duplicate(node_b)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_ab)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, chan_ab)) peerConnection.expectMsg(TransportHandler.ReadAck(chan_ab)) peerConnection.expectMsg(GossipDecision.Duplicate(chan_ab)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_ab)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, update_ab)) peerConnection.expectMsg(TransportHandler.ReadAck(update_ab)) peerConnection.expectMsg(GossipDecision.Duplicate(update_ab)) peerConnection.expectNoMessage(100 millis) @@ -130,13 +130,13 @@ class RouterSpec extends BaseRouterSpec { val invalid_node_b = node_b.copy(timestamp = node_b.timestamp + 10) val invalid_chan_ac = channelAnnouncement(ShortChannelId(420000, 101, 1), priv_a, priv_c, priv_funding_a, priv_funding_c).copy(nodeId1 = randomKey().publicKey) val invalid_update_ab = update_ab.copy(cltvExpiryDelta = CltvExpiryDelta(21), timestamp = update_ab.timestamp + 1) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, invalid_node_b)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, invalid_node_b)) peerConnection.expectMsg(TransportHandler.ReadAck(invalid_node_b)) peerConnection.expectMsg(GossipDecision.InvalidSignature(invalid_node_b)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, invalid_chan_ac)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, invalid_chan_ac)) peerConnection.expectMsg(TransportHandler.ReadAck(invalid_chan_ac)) peerConnection.expectMsg(GossipDecision.InvalidSignature(invalid_chan_ac)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, invalid_update_ab)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, invalid_update_ab)) peerConnection.expectMsg(TransportHandler.ReadAck(invalid_update_ab)) peerConnection.expectMsg(GossipDecision.InvalidSignature(invalid_update_ab)) peerConnection.expectNoMessage(100 millis) @@ -150,7 +150,7 @@ class RouterSpec extends BaseRouterSpec { val priv_funding_v = randomKey() val chan_vc = channelAnnouncement(ShortChannelId(420000, 102, 0), priv_v, priv_c, priv_funding_v, priv_funding_c) nodeParams.db.network.addToPruned(chan_vc.shortChannelId :: Nil) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_vc)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, chan_vc)) peerConnection.expectMsg(TransportHandler.ReadAck(chan_vc)) peerConnection.expectMsg(GossipDecision.ChannelPruned(chan_vc)) peerConnection.expectNoMessage(100 millis) @@ -161,7 +161,7 @@ class RouterSpec extends BaseRouterSpec { { // stale channel update val update_ab = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, priv_b.publicKey, chan_ab.shortChannelId, CltvExpiryDelta(7), 0 msat, 766000 msat, 10, htlcMaximum, timestamp = TimestampSecond.now() - 15.days) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_ab)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, update_ab)) peerConnection.expectMsg(TransportHandler.ReadAck(update_ab)) peerConnection.expectMsg(GossipDecision.Stale(update_ab)) peerConnection.expectNoMessage(100 millis) @@ -174,10 +174,10 @@ class RouterSpec extends BaseRouterSpec { val priv_y = randomKey() val update_ay = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, priv_y.publicKey, ShortChannelId(4646464), CltvExpiryDelta(7), 0 msat, 766000 msat, 10, htlcMaximum) val node_y = makeNodeAnnouncement(priv_y, "node-Y", Color(123, 100, -40), Nil, TestConstants.Bob.nodeParams.features) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_ay)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, update_ay)) peerConnection.expectMsg(TransportHandler.ReadAck(update_ay)) peerConnection.expectMsg(GossipDecision.NoRelatedChannel(update_ay)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_y)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, node_y)) peerConnection.expectMsg(TransportHandler.ReadAck(node_y)) peerConnection.expectMsg(GossipDecision.NoKnownChannel(node_y)) peerConnection.expectNoMessage(100 millis) @@ -192,11 +192,11 @@ class RouterSpec extends BaseRouterSpec { val chan_ay = channelAnnouncement(ShortChannelId(42002), priv_a, priv_y, priv_funding_a, priv_funding_y) val update_ay = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, priv_y.publicKey, chan_ay.shortChannelId, CltvExpiryDelta(7), 0 msat, 766000 msat, 10, htlcMaximum) val node_y = makeNodeAnnouncement(priv_y, "node-Y", Color(123, 100, -40), Nil, TestConstants.Bob.nodeParams.features) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_ay)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, chan_ay)) assert(watcher.expectMsgType[ValidateRequest].ann === chan_ay) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_ay)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, update_ay)) peerConnection.expectMsg(TransportHandler.ReadAck(update_ay)) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, node_y)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, node_y)) peerConnection.expectMsg(TransportHandler.ReadAck(node_y)) watcher.send(router, ValidateResult(chan_ay, Right(Transaction(version = 0, txIn = Nil, txOut = TxOut(1000000 sat, write(pay2wsh(Scripts.multiSig2of2(funding_a, randomKey().publicKey)))) :: Nil, lockTime = 0), UtxoStatus.Unspent))) peerConnection.expectMsg(TransportHandler.ReadAck(chan_ay)) @@ -212,7 +212,7 @@ class RouterSpec extends BaseRouterSpec { // validation failure val priv_x = randomKey() val chan_ax = channelAnnouncement(ShortChannelId(42001), priv_a, priv_x, priv_funding_a, randomKey()) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_ax)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, chan_ax)) assert(watcher.expectMsgType[ValidateRequest].ann === chan_ax) watcher.send(router, ValidateResult(chan_ax, Left(new RuntimeException("funding tx not found")))) peerConnection.expectMsg(TransportHandler.ReadAck(chan_ax)) @@ -227,7 +227,7 @@ class RouterSpec extends BaseRouterSpec { val priv_z = randomKey() val priv_funding_z = randomKey() val chan_az = channelAnnouncement(ShortChannelId(42003), priv_a, priv_z, priv_funding_a, priv_funding_z) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_az)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, chan_az)) assert(watcher.expectMsgType[ValidateRequest].ann === chan_az) watcher.send(router, ValidateResult(chan_az, Right(Transaction(version = 0, txIn = Nil, txOut = TxOut(1000000 sat, write(pay2wsh(Scripts.multiSig2of2(funding_a, priv_funding_z.publicKey)))) :: Nil, lockTime = 0), UtxoStatus.Spent(spendingTxConfirmed = false)))) peerConnection.expectMsg(TransportHandler.ReadAck(chan_az)) @@ -242,7 +242,7 @@ class RouterSpec extends BaseRouterSpec { val priv_z = randomKey() val priv_funding_z = randomKey() val chan_az = channelAnnouncement(ShortChannelId(42003), priv_a, priv_z, priv_funding_a, priv_funding_z) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, chan_az)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, chan_az)) assert(watcher.expectMsgType[ValidateRequest].ann === chan_az) watcher.send(router, ValidateResult(chan_az, Right(Transaction(version = 0, txIn = Nil, txOut = TxOut(1000000 sat, write(pay2wsh(Scripts.multiSig2of2(funding_a, priv_funding_z.publicKey)))) :: Nil, lockTime = 0), UtxoStatus.Spent(spendingTxConfirmed = true)))) peerConnection.expectMsg(TransportHandler.ReadAck(chan_az)) @@ -287,7 +287,7 @@ class RouterSpec extends BaseRouterSpec { val channelId_ac = ShortChannelId(420000, 105, 0) val chan_ac = channelAnnouncement(channelId_ac, priv_a, priv_c, priv_funding_a, priv_funding_c) val buggy_chan_ac = chan_ac.copy(nodeSignature1 = chan_ac.nodeSignature2) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, buggy_chan_ac)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, buggy_chan_ac)) peerConnection.expectMsg(TransportHandler.ReadAck(buggy_chan_ac)) peerConnection.expectMsg(GossipDecision.InvalidSignature(buggy_chan_ac)) } @@ -296,7 +296,7 @@ class RouterSpec extends BaseRouterSpec { import fixture._ val peerConnection = TestProbe() val buggy_ann_b = node_b.copy(signature = node_c.signature, timestamp = node_b.timestamp + 1) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, buggy_ann_b)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, buggy_ann_b)) peerConnection.expectMsg(TransportHandler.ReadAck(buggy_ann_b)) peerConnection.expectMsg(GossipDecision.InvalidSignature(buggy_ann_b)) } @@ -305,7 +305,7 @@ class RouterSpec extends BaseRouterSpec { import fixture._ val peerConnection = TestProbe() val buggy_channelUpdate_ab = update_ab.copy(signature = node_b.signature, timestamp = update_ab.timestamp + 1) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, buggy_channelUpdate_ab)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, buggy_channelUpdate_ab)) peerConnection.expectMsg(TransportHandler.ReadAck(buggy_channelUpdate_ab)) peerConnection.expectMsg(GossipDecision.InvalidSignature(buggy_channelUpdate_ab)) } @@ -373,7 +373,7 @@ class RouterSpec extends BaseRouterSpec { assert(res.routes.head.hops.last.nextNodeId === d) val channelUpdate_cd1 = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_c, d, channelId_cd, CltvExpiryDelta(3), 0 msat, 153000 msat, 4, htlcMaximum, enable = false) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, channelUpdate_cd1)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, channelUpdate_cd1)) peerConnection.expectMsg(TransportHandler.ReadAck(channelUpdate_cd1)) sender.send(router, RouteRequest(a, d, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, routeParams = DEFAULT_ROUTE_PARAMS)) sender.expectMsg(Failure(RouteNotFound)) @@ -454,7 +454,7 @@ class RouterSpec extends BaseRouterSpec { import fixture._ // We need a channel update from our private remote peer, otherwise we can't create invoice routing information. val peerConnection = TestProbe() - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, g, update_ga)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, g, remoteInit, update_ga)) val sender = TestProbe() sender.send(router, GetLocalChannels) val localChannels = sender.expectMsgType[Seq[LocalChannel]] @@ -554,9 +554,9 @@ class RouterSpec extends BaseRouterSpec { val staleUpdate = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, c, channelId, CltvExpiryDelta(7), 0 msat, 766000 msat, 10, 5 msat, timestamp = oldTimestamp) val peerConnection = TestProbe() peerConnection.ignoreMsg { case _: TransportHandler.ReadAck => true } - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, announcement)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, announcement)) watcher.expectMsgType[ValidateRequest] - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, staleUpdate)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, staleUpdate)) watcher.send(router, ValidateResult(announcement, Right((Transaction(version = 0, txIn = Nil, txOut = TxOut(1000000 sat, write(pay2wsh(Scripts.multiSig2of2(funding_a, funding_c)))) :: Nil, lockTime = 0), UtxoStatus.Unspent)))) peerConnection.expectMsg(GossipDecision.Accepted(announcement)) peerConnection.expectMsg(GossipDecision.Stale(staleUpdate)) @@ -570,7 +570,7 @@ class RouterSpec extends BaseRouterSpec { val recentUpdate = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, c, channelId, CltvExpiryDelta(7), 0 msat, 766000 msat, 10, htlcMaximum, timestamp = TimestampSecond.now()) // we want to make sure that transport receives the query - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, recentUpdate)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, recentUpdate)) peerConnection.expectMsg(GossipDecision.RelatedChannelPruned(recentUpdate)) val query = peerConnection.expectMsgType[QueryShortChannelIds] assert(query.shortChannelIds.array == List(channelId)) @@ -690,7 +690,7 @@ class RouterSpec extends BaseRouterSpec { // new announcements val update_ab_2 = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, b, channelId_ab, CltvExpiryDelta(7), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 10, htlcMaximumMsat = htlcMaximum, timestamp = update_ab.timestamp + 1) val peerConnection = TestProbe() - router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, update_ab_2) + router ! PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, update_ab_2) sender.fishForMessage() { case cu: ChannelUpdatesReceived => cu == ChannelUpdatesReceived(List(update_ab_2)) case _ => false diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala index f3fccd4a46..ebe3d7985e 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala @@ -35,7 +35,7 @@ import fr.acinq.eclair.wire.protocol._ import org.scalatest.ParallelTestExecution import org.scalatest.funsuite.AnyFunSuiteLike -import scala.collection.immutable.{TreeMap, SortedSet} +import scala.collection.immutable.{SortedSet, TreeMap} import scala.collection.mutable import scala.concurrent.duration._ @@ -86,7 +86,7 @@ class RoutingSyncSpec extends TestKitBaseClass with AnyFunSuiteLike with Paralle sender.send(src, SendChannelQuery(src.underlyingActor.nodeParams.chainHash, tgtId, pipe.ref, replacePrevious = true, extendedQueryFlags_opt)) // src sends a query_channel_range to bob val qcr = pipe.expectMsgType[QueryChannelRange] - pipe.send(tgt, PeerRoutingMessage(pipe.ref, srcId, qcr)) + pipe.send(tgt, PeerRoutingMessage(pipe.ref, srcId, defaultInit, qcr)) // this allows us to know when the last reply_channel_range has been set pipe.send(tgt, Router.GetRouterData) // tgt answers with reply_channel_ranges @@ -96,7 +96,7 @@ class RoutingSyncSpec extends TestKitBaseClass with AnyFunSuiteLike with Paralle rcrs.dropRight(1).foreach(rcr => assert(rcr.syncComplete == 0)) assert(rcrs.last.syncComplete == 1) pipe.expectMsgType[Data] - rcrs.foreach(rcr => pipe.send(src, PeerRoutingMessage(pipe.ref, tgtId, rcr))) + rcrs.foreach(rcr => pipe.send(src, PeerRoutingMessage(pipe.ref, tgtId, defaultInit, rcr))) // then src will now query announcements var queries = Vector.empty[QueryShortChannelIds] var channels = Vector.empty[ChannelAnnouncement] @@ -105,7 +105,7 @@ class RoutingSyncSpec extends TestKitBaseClass with AnyFunSuiteLike with Paralle while (src.stateData.sync.nonEmpty) { // for each chunk, src sends a query_short_channel_id val query = pipe.expectMsgType[QueryShortChannelIds] - pipe.send(tgt, PeerRoutingMessage(pipe.ref, srcId, query)) + pipe.send(tgt, PeerRoutingMessage(pipe.ref, srcId, defaultInit, query)) queries = queries :+ query val announcements = pipe.receiveWhile() { case c: ChannelAnnouncement => @@ -119,10 +119,10 @@ class RoutingSyncSpec extends TestKitBaseClass with AnyFunSuiteLike with Paralle n } // tgt replies with announcements - announcements.foreach(ann => pipe.send(src, PeerRoutingMessage(pipe.ref, tgtId, ann))) + announcements.foreach(ann => pipe.send(src, PeerRoutingMessage(pipe.ref, tgtId, defaultInit, ann))) // and tgt ends this chunk with a reply_short_channel_id_end val rscie = pipe.expectMsgType[ReplyShortChannelIdsEnd] - pipe.send(src, PeerRoutingMessage(pipe.ref, tgtId, rscie)) + pipe.send(src, PeerRoutingMessage(pipe.ref, tgtId, defaultInit, rscie)) } SyncResult(rcrs, queries, channels, updates, nodes) } @@ -147,11 +147,11 @@ class RoutingSyncSpec extends TestKitBaseClass with AnyFunSuiteLike with Paralle // add some channels and updates to bob and resync fakeRoutingInfo.take(10).values.foreach { case (pc, na1, na2) => - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.ann)) - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_1_opt.get)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, pc.ann)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, pc.update_1_opt.get)) // we don't send channel_update #2 - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na1)) - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na2)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, na1)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, na2)) } awaitCond(bob.stateData.channels.size === 10 && countUpdates(bob.stateData.channels) === 10) assert(BasicSyncResult(ranges = 1, queries = 2, channels = 10, updates = 10, nodes = 10 * 2) === sync(alice, bob, extendedQueryFlags_opt).counts) @@ -160,7 +160,7 @@ class RoutingSyncSpec extends TestKitBaseClass with AnyFunSuiteLike with Paralle // add some updates to bob and resync fakeRoutingInfo.take(10).values.foreach { case (pc, _, _) => - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_2_opt.get)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, pc.update_2_opt.get)) } awaitCond(bob.stateData.channels.size === 10 && countUpdates(bob.stateData.channels) === 10 * 2) assert(BasicSyncResult(ranges = 1, queries = 2, channels = 10, updates = 10 * 2, nodes = 10 * 2) === sync(alice, bob, extendedQueryFlags_opt).counts) @@ -169,11 +169,11 @@ class RoutingSyncSpec extends TestKitBaseClass with AnyFunSuiteLike with Paralle // add everything (duplicates will be ignored) fakeRoutingInfo.values.foreach { case (pc, na1, na2) => - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.ann)) - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_1_opt.get)) - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_2_opt.get)) - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na1)) - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na2)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, pc.ann)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, pc.update_1_opt.get)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, pc.update_2_opt.get)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, na1)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, na2)) } awaitCond(bob.stateData.channels.size === fakeRoutingInfo.size && countUpdates(bob.stateData.channels) === 2 * fakeRoutingInfo.size, max = 60 seconds) assert(BasicSyncResult(ranges = 3, queries = 13, channels = fakeRoutingInfo.size, updates = 2 * fakeRoutingInfo.size, nodes = 2 * fakeRoutingInfo.size) === sync(alice, bob, extendedQueryFlags_opt).counts) @@ -195,11 +195,11 @@ class RoutingSyncSpec extends TestKitBaseClass with AnyFunSuiteLike with Paralle // add some channels and updates to bob and resync fakeRoutingInfo.take(10).values.foreach { case (pc, na1, na2) => - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.ann)) - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_1_opt.get)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, pc.ann)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, pc.update_1_opt.get)) // we don't send channel_update #2 - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na1)) - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na2)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, na1)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, na2)) } awaitCond(bob.stateData.channels.size === 10 && countUpdates(bob.stateData.channels) === 10) assert(BasicSyncResult(ranges = 1, queries = 2, channels = 10, updates = 10, nodes = if (requestNodeAnnouncements) 10 * 2 else 0) === sync(alice, bob, extendedQueryFlags_opt).counts) @@ -209,7 +209,7 @@ class RoutingSyncSpec extends TestKitBaseClass with AnyFunSuiteLike with Paralle // add some updates to bob and resync fakeRoutingInfo.take(10).values.foreach { case (pc, _, _) => - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_2_opt.get)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, pc.update_2_opt.get)) } awaitCond(bob.stateData.channels.size === 10 && countUpdates(bob.stateData.channels) === 10 * 2) assert(BasicSyncResult(ranges = 1, queries = 2, channels = 0, updates = 10, nodes = if (requestNodeAnnouncements) 10 * 2 else 0) === sync(alice, bob, extendedQueryFlags_opt).counts) @@ -218,11 +218,11 @@ class RoutingSyncSpec extends TestKitBaseClass with AnyFunSuiteLike with Paralle // add everything (duplicates will be ignored) fakeRoutingInfo.values.foreach { case (pc, na1, na2) => - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.ann)) - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_1_opt.get)) - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_2_opt.get)) - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na1)) - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na2)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, pc.ann)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, pc.update_1_opt.get)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, pc.update_2_opt.get)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, na1)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, na2)) } awaitCond(bob.stateData.channels.size === fakeRoutingInfo.size && countUpdates(bob.stateData.channels) === 2 * fakeRoutingInfo.size, max = 60 seconds) assert(BasicSyncResult(ranges = 3, queries = 11, channels = fakeRoutingInfo.size - 10, updates = 2 * (fakeRoutingInfo.size - 10), nodes = if (requestNodeAnnouncements) 2 * (fakeRoutingInfo.size - 10) else 0) === sync(alice, bob, extendedQueryFlags_opt).counts) @@ -235,7 +235,7 @@ class RoutingSyncSpec extends TestKitBaseClass with AnyFunSuiteLike with Paralle } val bumpedUpdates = (List(0, 3, 7).map(touchUpdate(_, side = true)) ++ List(1, 3, 9).map(touchUpdate(_, side = false))).toSet - bumpedUpdates.foreach(c => sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, c))) + bumpedUpdates.foreach(c => sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, defaultInit, c))) assert(BasicSyncResult(ranges = 3, queries = 1, channels = 0, updates = bumpedUpdates.size, nodes = if (requestNodeAnnouncements) 5 * 2 else 0) === sync(alice, bob, extendedQueryFlags_opt).counts) awaitCond(alice.stateData.channels === bob.stateData.channels, max = 60 seconds) if (requestNodeAnnouncements) awaitCond(alice.stateData.nodes === bob.stateData.nodes) @@ -270,10 +270,10 @@ class RoutingSyncSpec extends TestKitBaseClass with AnyFunSuiteLike with Paralle sender.expectNoMessage(100 millis) // it's a duplicate and should be ignored assert(router.stateData.sync.get(remoteNodeId) === Some(Syncing(Nil, 0))) - val block1 = ReplyChannelRange(chainHash, firstBlockNum, numberOfBlocks, 1, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, fakeRoutingInfo.take(params.routerConf.channelQueryChunkSize).keys.toList), None, None) + val block1 = ReplyChannelRange(chainHash, firstBlockNum, numberOfBlocks, 1, EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, fakeRoutingInfo.take(params.routerConf.channelQueryChunkSize).keys.toList), None, None) // send first block - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, block1)) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, defaultInit, block1)) // router should ask for our first block of ids assert(peerConnection.expectMsgType[QueryShortChannelIds] === QueryShortChannelIds(chainHash, block1.shortChannelIds, TlvStream.empty)) @@ -300,17 +300,35 @@ class RoutingSyncSpec extends TestKitBaseClass with AnyFunSuiteLike with Paralle assert(!router.stateData.sync.contains(remoteNodeId)) // we didn't send a corresponding query_channel_range, but peer sends us a reply_channel_range - val unsolicitedBlocks = ReplyChannelRange(params.chainHash, 10, 5, 0, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, fakeRoutingInfo.take(5).keys.toList), None, None) - peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, unsolicitedBlocks)) + val unsolicitedBlocks = ReplyChannelRange(params.chainHash, 10, 5, 0, EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, fakeRoutingInfo.take(5).keys.toList), None, None) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, defaultInit, unsolicitedBlocks)) // it will be simply ignored peerConnection.expectNoMessage(100 millis) assert(!router.stateData.sync.contains(remoteNodeId)) } + test("reject sync with no matching compression algorithm") { + val params = TestConstants.Alice.nodeParams + val router = TestFSMRef(new Router(params, TestProbe().ref)) + val peerConnection = TestProbe() + peerConnection.ignoreMsg { case _: TransportHandler.ReadAck => true } + val sender = TestProbe() + sender.ignoreMsg { case _: TransportHandler.ReadAck => true } + + // our peer doesn't support any of our compression algorithms, but sends us a query_channel_range + val remoteNodeId = TestConstants.Bob.nodeParams.nodeId + val remoteInit = Init(Features(Features.CompressionSupport -> FeatureSupport.Optional), TlvStream(InitTlv.CompressionAlgorithms(Set.empty))) + val queryChannelRange = QueryChannelRange(TestConstants.Bob.nodeParams.chainHash, 0, 0xffffffff) + peerConnection.send(router, PeerRoutingMessage(peerConnection.ref, remoteNodeId, remoteInit, queryChannelRange)) + + // it should be simply ignored + peerConnection.expectNoMessage(100 millis) + } + test("sync progress") { - def req = QueryShortChannelIds(Block.RegtestGenesisBlock.hash, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(42))), TlvStream.empty) + def req = QueryShortChannelIds(Block.RegtestGenesisBlock.hash, EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, List(ShortChannelId(42))), TlvStream.empty) val nodeIdA = randomKey().publicKey val nodeIdB = randomKey().publicKey @@ -339,6 +357,8 @@ object RoutingSyncSpec { val unused: PrivateKey = randomKey() + val defaultInit = Init(Features.empty) + def makeFakeRoutingInfo(pub2priv: mutable.Map[PublicKey, PrivateKey])(shortChannelId: ShortChannelId): (PublicChannel, NodeAnnouncement, NodeAnnouncement) = { val timestamp = TimestampSecond.now() val (priv1, priv2) = { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/CommonCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/CommonCodecsSpec.scala index 05fa6eb687..ec3c0e2097 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/CommonCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/CommonCodecsSpec.scala @@ -25,7 +25,7 @@ import fr.acinq.eclair.wire.protocol.CommonCodecs._ import fr.acinq.eclair.{UInt64, randomBytes32} import org.scalatest.funsuite.AnyFunSuite import scodec.DecodeResult -import scodec.bits.{BitVector, HexStringSyntax} +import scodec.bits.{BinStringSyntax, BitVector, HexStringSyntax} import scodec.codecs.uint32 import java.net.{Inet4Address, Inet6Address, InetAddress} @@ -137,6 +137,23 @@ class CommonCodecsSpec extends AnyFunSuite { } } + test("encode/decode reversed bit vector") { + case class TestCase(encoded: BitVector, decoded: ReversedBitVector, reEncoded: BitVector) + val testCases = Seq( + TestCase(bin"", ReversedBitVector(Set.empty), bin""), + TestCase(bin"00000000", ReversedBitVector(Set.empty), bin""), + TestCase(bin"0000000000000000", ReversedBitVector(Set.empty), bin""), + TestCase(bin"0000000001010001", ReversedBitVector(Set(0, 4, 6)), bin"01010001"), + TestCase(bin"1001000001110001", ReversedBitVector(Set(0, 4, 5, 6, 12, 15)), bin"1001000001110001"), + TestCase(bin"0101000000000000", ReversedBitVector(Set(12, 14)), bin"0101000000000000"), + ) + + for (testCase <- testCases) { + assert(reversedBitVector.decode(testCase.encoded).require.value === testCase.decoded) + assert(reversedBitVector.encode(testCase.decoded).require === testCase.reEncoded) + } + } + test("encode/decode with rgb codec") { val color = Color(47.toByte, 255.toByte, 142.toByte) val bin = rgb.encode(color).require diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/ExtendedQueriesCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/ExtendedQueriesCodecsSpec.scala index bec1abf9bd..91bf5af5c0 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/ExtendedQueriesCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/ExtendedQueriesCodecsSpec.scala @@ -29,7 +29,7 @@ class ExtendedQueriesCodecsSpec extends AnyFunSuite { test("encode a list of short channel ids") { { // encode/decode with encoding 'uncompressed' - val ids = EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))) + val ids = EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))) val encoded = encodedShortChannelIdsCodec.encode(ids).require val decoded = encodedShortChannelIdsCodec.decode(encoded).require.value assert(decoded === ids) @@ -37,7 +37,7 @@ class ExtendedQueriesCodecsSpec extends AnyFunSuite { { // encode/decode with encoding 'zlib' - val ids = EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))) + val ids = EncodedShortChannelIds(CompressionAlgorithm.ZlibDeflate, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))) val encoded = encodedShortChannelIdsCodec.encode(ids).require val decoded = encodedShortChannelIdsCodec.decode(encoded).require.value assert(decoded === ids) @@ -45,7 +45,7 @@ class ExtendedQueriesCodecsSpec extends AnyFunSuite { { // encode/decode empty list with encoding 'uncompressed' - val ids = EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List.empty) + val ids = EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, List.empty) val encoded = encodedShortChannelIdsCodec.encode(ids).require assert(encoded.bytes === hex"00") val decoded = encodedShortChannelIdsCodec.decode(encoded).require.value @@ -54,18 +54,18 @@ class ExtendedQueriesCodecsSpec extends AnyFunSuite { { // encode/decode empty list with encoding 'zlib' - val ids = EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List.empty) + val ids = EncodedShortChannelIds(CompressionAlgorithm.ZlibDeflate, List.empty) val encoded = encodedShortChannelIdsCodec.encode(ids).require assert(encoded.bytes === hex"00") // NB: empty list is always encoded with encoding type 'uncompressed' val decoded = encodedShortChannelIdsCodec.decode(encoded).require.value - assert(decoded === EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List.empty)) + assert(decoded === EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, List.empty)) } } test("encode query_short_channel_ids (no optional data)") { val query_short_channel_id = QueryShortChannelIds( Block.RegtestGenesisBlock.blockId, - EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream.empty) val encoded = queryShortChannelIdsCodec.encode(query_short_channel_id).require @@ -76,8 +76,8 @@ class ExtendedQueriesCodecsSpec extends AnyFunSuite { test("encode query_short_channel_ids (with optional data)") { val query_short_channel_id = QueryShortChannelIds( Block.RegtestGenesisBlock.blockId, - EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), - TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.UNCOMPRESSED, List(1.toByte, 2.toByte, 3.toByte, 4.toByte, 5.toByte)))) + EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(CompressionAlgorithm.Uncompressed, List(1.toByte, 2.toByte, 3.toByte, 4.toByte, 5.toByte)))) val encoded = queryShortChannelIdsCodec.encode(query_short_channel_id).require val decoded = queryShortChannelIdsCodec.decode(encoded).require.value @@ -87,9 +87,9 @@ class ExtendedQueriesCodecsSpec extends AnyFunSuite { test("encode query_short_channel_ids (with optional data including unknown data)") { val query_short_channel_id = QueryShortChannelIds( Block.RegtestGenesisBlock.blockId, - EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream( - QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.UNCOMPRESSED, List(1.toByte, 2.toByte, 3.toByte, 4.toByte, 5.toByte)) :: Nil, + QueryShortChannelIdsTlv.EncodedQueryFlags(CompressionAlgorithm.Uncompressed, List(1.toByte, 2.toByte, 3.toByte, 4.toByte, 5.toByte)) :: Nil, GenericTlv(UInt64(43), ByteVector.fromValidHex("deadbeef")) :: Nil ) ) @@ -104,7 +104,7 @@ class ExtendedQueriesCodecsSpec extends AnyFunSuite { Block.RegtestGenesisBlock.blockId, 1, 100, 1.toByte, - EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), None, None) val encoded = replyChannelRangeCodec.encode(replyChannelRange).require @@ -117,8 +117,8 @@ class ExtendedQueriesCodecsSpec extends AnyFunSuite { Block.RegtestGenesisBlock.blockId, 1, 100, 1.toByte, - EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), - Some(EncodedTimestamps(EncodingType.COMPRESSED_ZLIB, List(Timestamps(1 unixsec, 1 unixsec), Timestamps(2 unixsec, 2 unixsec), Timestamps(3 unixsec, 3 unixsec)))), + EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + Some(EncodedTimestamps(CompressionAlgorithm.ZlibDeflate, List(Timestamps(1 unixsec, 1 unixsec), Timestamps(2 unixsec, 2 unixsec), Timestamps(3 unixsec, 3 unixsec)))), None) val encoded = replyChannelRangeCodec.encode(replyChannelRange).require @@ -131,10 +131,10 @@ class ExtendedQueriesCodecsSpec extends AnyFunSuite { Block.RegtestGenesisBlock.blockId, 1, 100, 1.toByte, - EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream( List( - EncodedTimestamps(EncodingType.COMPRESSED_ZLIB, List(Timestamps(1 unixsec, 1 unixsec), Timestamps(2 unixsec, 2 unixsec), Timestamps(3 unixsec, 3 unixsec))), + EncodedTimestamps(CompressionAlgorithm.ZlibDeflate, List(Timestamps(1 unixsec, 1 unixsec), Timestamps(2 unixsec, 2 unixsec), Timestamps(3 unixsec, 3 unixsec))), EncodedChecksums(List(Checksums(1, 1), Checksums(2, 2), Checksums(3, 3))) ), GenericTlv(UInt64(7), ByteVector.fromValidHex("deadbeef")) :: Nil diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecsSpec.scala index 8fcbca9e59..dfe10912cf 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecsSpec.scala @@ -49,24 +49,30 @@ class LightningMessageCodecsSpec extends AnyFunSuite { def publicKey(fill: Byte) = PrivateKey(ByteVector.fill(32)(fill)).publicKey test("encode/decode init message") { - case class TestCase(encoded: ByteVector, rawFeatures: ByteVector, networks: List[ByteVector32], valid: Boolean, reEncoded: Option[ByteVector] = None) + case class TestCase(encoded: ByteVector, rawFeatures: ByteVector, networks: List[ByteVector32], compressionAlgorithms: Set[CompressionAlgorithm], valid: Boolean, reEncoded: Option[ByteVector] = None) val chainHash1 = ByteVector32(hex"0101010101010101010101010101010101010101010101010101010101010101") val chainHash2 = ByteVector32(hex"0202020202020202020202020202020202020202020202020202020202020202") val testCases = Seq( - TestCase(hex"0000 0000", hex"", Nil, valid = true), // no features - TestCase(hex"0000 0002088a", hex"088a", Nil, valid = true), // no global features - TestCase(hex"00020200 0000", hex"0200", Nil, valid = true, Some(hex"0000 00020200")), // no local features - TestCase(hex"00020200 0002088a", hex"0a8a", Nil, valid = true, Some(hex"0000 00020a8a")), // local and global - no conflict - same size - TestCase(hex"00020200 0003020002", hex"020202", Nil, valid = true, Some(hex"0000 0003020202")), // local and global - no conflict - different sizes - TestCase(hex"00020a02 0002088a", hex"0a8a", Nil, valid = true, Some(hex"0000 00020a8a")), // local and global - conflict - same size - TestCase(hex"00022200 000302aaa2", hex"02aaa2", Nil, valid = true, Some(hex"0000 000302aaa2")), // local and global - conflict - different sizes - TestCase(hex"0000 0002088a 03012a05022aa2", hex"088a", Nil, valid = true), // unknown odd records - TestCase(hex"0000 0002088a 03012a04022aa2", hex"088a", Nil, valid = false), // unknown even records - TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101", hex"088a", Nil, valid = false), // invalid tlv stream - TestCase(hex"0000 0002088a 01200101010101010101010101010101010101010101010101010101010101010101", hex"088a", List(chainHash1), valid = true), // single network - TestCase(hex"0000 0002088a 014001010101010101010101010101010101010101010101010101010101010101010202020202020202020202020202020202020202020202020202020202020202", hex"088a", List(chainHash1, chainHash2), valid = true), // multiple networks - TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101010103012a", hex"088a", List(chainHash1), valid = true), // network and unknown odd records - TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101010102012a", hex"088a", Nil, valid = false) // network and unknown even records + TestCase(hex"0000 0000", hex"", Nil, CompressionAlgorithm.defaultSupported, valid = true), // no features + TestCase(hex"0000 0002088a", hex"088a", Nil, CompressionAlgorithm.defaultSupported, valid = true), // no global features + TestCase(hex"00020200 0000", hex"0200", Nil, CompressionAlgorithm.defaultSupported, valid = true, Some(hex"0000 00020200")), // no local features + TestCase(hex"00020200 0002088a", hex"0a8a", Nil, CompressionAlgorithm.defaultSupported, valid = true, Some(hex"0000 00020a8a")), // local and global - no conflict - same size + TestCase(hex"00020200 0003020002", hex"020202", Nil, CompressionAlgorithm.defaultSupported, valid = true, Some(hex"0000 0003020202")), // local and global - no conflict - different sizes + TestCase(hex"00020a02 0002088a", hex"0a8a", Nil, CompressionAlgorithm.defaultSupported, valid = true, Some(hex"0000 00020a8a")), // local and global - conflict - same size + TestCase(hex"00022200 000302aaa2", hex"02aaa2", Nil, CompressionAlgorithm.defaultSupported, valid = true, Some(hex"0000 000302aaa2")), // local and global - conflict - different sizes + TestCase(hex"0000 0002088a 05022aa207012a", hex"088a", Nil, CompressionAlgorithm.defaultSupported, valid = true), // unknown odd records + TestCase(hex"0000 0002088a 04022aa207012a", hex"088a", Nil, CompressionAlgorithm.defaultSupported, valid = false), // unknown even records + TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101", hex"088a", Nil, CompressionAlgorithm.defaultSupported, valid = false), // invalid tlv stream + TestCase(hex"0000 0002088a 01200101010101010101010101010101010101010101010101010101010101010101", hex"088a", List(chainHash1), CompressionAlgorithm.defaultSupported, valid = true), // single network + TestCase(hex"0000 0002088a 014001010101010101010101010101010101010101010101010101010101010101010202020202020202020202020202020202020202020202020202020202020202", hex"088a", List(chainHash1, chainHash2), CompressionAlgorithm.defaultSupported, valid = true), // multiple networks + TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101010107012a", hex"088a", List(chainHash1), CompressionAlgorithm.defaultSupported, valid = true), // network and unknown odd records + TestCase(hex"0000 0002088a 0300", hex"088a", Nil, Set.empty, valid = true), // no compression support + TestCase(hex"0000 0002088a 030100", hex"088a", Nil, Set.empty, valid = true, reEncoded = Some(hex"0000 0002088a 0300")), // no compression support + TestCase(hex"0000 0002088a 030101", hex"088a", Nil, Set(CompressionAlgorithm.Uncompressed), valid = true), // no zlib compression support + TestCase(hex"0000 0002088a 030102", hex"088a", Nil, Set(CompressionAlgorithm.ZlibDeflate), valid = true), // only zlib compression support + TestCase(hex"0000 0002088a 030103", hex"088a", Nil, Set(CompressionAlgorithm.Uncompressed, CompressionAlgorithm.ZlibDeflate), valid = true), // zlib compression and uncompressed support + TestCase(hex"0000 0002088a 0302ff0a", hex"088a", Nil, Set(CompressionAlgorithm.ZlibDeflate), valid = true, reEncoded = Some(hex"0000 0002088a 030102")), // zlib and unknown compression support + TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101010102012a", hex"088a", Nil, CompressionAlgorithm.defaultSupported, valid = false) // network and unknown even records ) for (testCase <- testCases) { @@ -74,6 +80,7 @@ class LightningMessageCodecsSpec extends AnyFunSuite { val init = initCodec.decode(testCase.encoded.bits).require.value assert(init.features.toByteVector === testCase.rawFeatures) assert(init.networks === testCase.networks) + assert(init.compressionAlgorithms === testCase.compressionAlgorithms) val encoded = initCodec.encode(init).require assert(encoded.bytes === testCase.reEncoded.getOrElse(testCase.encoded)) assert(initCodec.decode(encoded).require.value === init) @@ -268,16 +275,16 @@ class LightningMessageCodecsSpec extends AnyFunSuite { val channel_update = ChannelUpdate(randomBytes64(), Block.RegtestGenesisBlock.hash, ShortChannelId(1), 2 unixsec, ChannelUpdate.ChannelFlags.DUMMY, CltvExpiryDelta(3), 4 msat, 5 msat, 6, None) val announcement_signatures = AnnouncementSignatures(randomBytes32(), ShortChannelId(42), randomBytes64(), randomBytes64()) val gossip_timestamp_filter = GossipTimestampFilter(Block.RegtestGenesisBlock.blockId, 100000 unixsec, 1500) - val query_short_channel_id = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream.empty) + val query_short_channel_id = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream.empty) val unknownTlv = GenericTlv(UInt64(5), ByteVector.fromValidHex("deadbeef")) val query_channel_range = QueryChannelRange(Block.RegtestGenesisBlock.blockId, 100000, 1500, TlvStream(QueryChannelRangeTlv.QueryFlags(QueryChannelRangeTlv.QueryFlags.WANT_ALL) :: Nil, unknownTlv :: Nil)) val reply_channel_range = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 100000, 1500, 1, - EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream( - EncodedTimestamps(EncodingType.UNCOMPRESSED, List(Timestamps(1 unixsec, 1 unixsec), Timestamps(2 unixsec, 2 unixsec), Timestamps(3 unixsec, 3 unixsec))) :: EncodedChecksums(List(Checksums(1, 1), Checksums(2, 2), Checksums(3, 3))) :: Nil, + EncodedTimestamps(CompressionAlgorithm.Uncompressed, List(Timestamps(1 unixsec, 1 unixsec), Timestamps(2 unixsec, 2 unixsec), Timestamps(3 unixsec, 3 unixsec))) :: EncodedChecksums(List(Checksums(1, 1), Checksums(2, 2), Checksums(3, 3))) :: Nil, unknownTlv :: Nil) ) val ping = Ping(100, bin(10, 1)) @@ -314,13 +321,13 @@ class LightningMessageCodecsSpec extends AnyFunSuite { test("non-reg encoding type") { val refs = Map( hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001900000000000000008e0000000000003c69000000000045a6c4" - -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream.empty), + -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream.empty), hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001601789c636000833e08659309a65c971d0100126e02e3" - -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream.empty), + -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(CompressionAlgorithm.ZlibDeflate, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream.empty), hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001900000000000000008e0000000000003c69000000000045a6c4010400010204" - -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.UNCOMPRESSED, List(1, 2, 4)))), + -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(CompressionAlgorithm.Uncompressed, List(1, 2, 4)))), hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001601789c636000833e08659309a65c971d0100126e02e3010c01789c6364620100000e0008" - -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.COMPRESSED_ZLIB, List(1, 2, 4)))) + -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(CompressionAlgorithm.ZlibDeflate, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(CompressionAlgorithm.ZlibDeflate, List(1, 2, 4)))) ) refs.forall { @@ -338,21 +345,21 @@ class LightningMessageCodecsSpec extends AnyFunSuite { 100, TlvStream(QueryChannelRangeTlv.QueryFlags(QueryChannelRangeTlv.QueryFlags.WANT_ALL))) val reply_channel_range = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 756230, 1500, 1, - EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), None, None) + EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), None, None) val reply_channel_range_zlib = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 1600, 110, 1, - EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(265462))), None, None) + EncodedShortChannelIds(CompressionAlgorithm.ZlibDeflate, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(265462))), None, None) val reply_channel_range_timestamps_checksums = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 122334, 1500, 1, - EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(12355), ShortChannelId(489686), ShortChannelId(4645313))), - Some(EncodedTimestamps(EncodingType.UNCOMPRESSED, List(Timestamps(164545 unixsec, 948165 unixsec), Timestamps(489645 unixsec, 4786864 unixsec), Timestamps(46456 unixsec, 9788415 unixsec)))), + EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, List(ShortChannelId(12355), ShortChannelId(489686), ShortChannelId(4645313))), + Some(EncodedTimestamps(CompressionAlgorithm.Uncompressed, List(Timestamps(164545 unixsec, 948165 unixsec), Timestamps(489645 unixsec, 4786864 unixsec), Timestamps(46456 unixsec, 9788415 unixsec)))), Some(EncodedChecksums(List(Checksums(1111, 2222), Checksums(3333, 4444), Checksums(5555, 6666))))) val reply_channel_range_timestamps_checksums_zlib = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 122334, 1500, 1, - EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(12355), ShortChannelId(489686), ShortChannelId(4645313))), - Some(EncodedTimestamps(EncodingType.COMPRESSED_ZLIB, List(Timestamps(164545 unixsec, 948165 unixsec), Timestamps(489645 unixsec, 4786864 unixsec), Timestamps(46456 unixsec, 9788415 unixsec)))), + EncodedShortChannelIds(CompressionAlgorithm.ZlibDeflate, List(ShortChannelId(12355), ShortChannelId(489686), ShortChannelId(4645313))), + Some(EncodedTimestamps(CompressionAlgorithm.ZlibDeflate, List(Timestamps(164545 unixsec, 948165 unixsec), Timestamps(489645 unixsec, 4786864 unixsec), Timestamps(46456 unixsec, 9788415 unixsec)))), Some(EncodedChecksums(List(Checksums(1111, 2222), Checksums(3333, 4444), Checksums(5555, 6666))))) - val query_short_channel_id = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream.empty) - val query_short_channel_id_zlib = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(4564), ShortChannelId(178622), ShortChannelId(4564676))), TlvStream.empty) - val query_short_channel_id_flags = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(12232), ShortChannelId(15556), ShortChannelId(4564676))), TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.COMPRESSED_ZLIB, List(1, 2, 4)))) - val query_short_channel_id_flags_zlib = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(14200), ShortChannelId(46645), ShortChannelId(4564676))), TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.COMPRESSED_ZLIB, List(1, 2, 4)))) + val query_short_channel_id = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream.empty) + val query_short_channel_id_zlib = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(CompressionAlgorithm.ZlibDeflate, List(ShortChannelId(4564), ShortChannelId(178622), ShortChannelId(4564676))), TlvStream.empty) + val query_short_channel_id_flags = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(CompressionAlgorithm.Uncompressed, List(ShortChannelId(12232), ShortChannelId(15556), ShortChannelId(4564676))), TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(CompressionAlgorithm.ZlibDeflate, List(1, 2, 4)))) + val query_short_channel_id_flags_zlib = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(CompressionAlgorithm.ZlibDeflate, List(ShortChannelId(14200), ShortChannelId(46645), ShortChannelId(4564676))), TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(CompressionAlgorithm.ZlibDeflate, List(1, 2, 4)))) val refs = Map( query_channel_range -> hex"01070f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206000186a0000005dc", diff --git a/eclair-front/src/main/scala/fr/acinq/eclair/router/FrontRouter.scala b/eclair-front/src/main/scala/fr/acinq/eclair/router/FrontRouter.scala index a3a7ea1d5b..817f52b7ed 100644 --- a/eclair-front/src/main/scala/fr/acinq/eclair/router/FrontRouter.scala +++ b/eclair-front/src/main/scala/fr/acinq/eclair/router/FrontRouter.scala @@ -66,16 +66,16 @@ class FrontRouter(routerConf: RouterConf, remoteRouter: ActorRef, initialized: O remoteRouter forward s stay() - case Event(PeerRoutingMessage(peerConnection, remoteNodeId, q: QueryChannelRange), d) => - Sync.handleQueryChannelRange(d.channels, routerConf, RemoteGossip(peerConnection, remoteNodeId), q) + case Event(PeerRoutingMessage(peerConnection, remoteNodeId, remoteInit, q: QueryChannelRange), d) => + Sync.handleQueryChannelRange(d.channels, routerConf, RemoteGossip(peerConnection, remoteNodeId, remoteInit), q) stay() - case Event(PeerRoutingMessage(peerConnection, remoteNodeId, q: QueryShortChannelIds), d) => - Sync.handleQueryShortChannelIds(d.nodes, d.channels, RemoteGossip(peerConnection, remoteNodeId), q) + case Event(PeerRoutingMessage(peerConnection, remoteNodeId, remoteInit, q: QueryShortChannelIds), d) => + Sync.handleQueryShortChannelIds(d.nodes, d.channels, RemoteGossip(peerConnection, remoteNodeId, remoteInit), q) stay() - case Event(PeerRoutingMessage(peerConnection, remoteNodeId, ann: AnnouncementMessage), d) => - val origin = RemoteGossip(peerConnection, remoteNodeId) + case Event(PeerRoutingMessage(peerConnection, remoteNodeId, remoteInit, ann: AnnouncementMessage), d) => + val origin = RemoteGossip(peerConnection, remoteNodeId, remoteInit) val d1 = d.processing.get(ann) match { case Some(origins) if origins.contains(origin) => log.warning("acking duplicate msg={}", ann) @@ -129,7 +129,7 @@ class FrontRouter(routerConf: RouterConf, remoteRouter: ActorRef, initialized: O case _ => Metrics.gossipForwarded(ann).increment() log.debug("sending announcement class={} to master router", ann.getClass.getSimpleName) - remoteRouter ! PeerRoutingMessage(self, remoteNodeId, ann) // nb: we set ourselves as the origin + remoteRouter ! PeerRoutingMessage(self, remoteNodeId, remoteInit, ann) // nb: we set ourselves as the origin d.copy(processing = d.processing + (ann -> Set(origin))) } } @@ -196,7 +196,7 @@ class FrontRouter(routerConf: RouterConf, remoteRouter: ActorRef, initialized: O override def mdc(currentMessage: Any): MDC = { val category_opt = LogCategory(currentMessage) currentMessage match { - case PeerRoutingMessage(_, remoteNodeId, _) => Logs.mdc(category_opt, remoteNodeId_opt = Some(remoteNodeId)) + case PeerRoutingMessage(_, remoteNodeId, _, _) => Logs.mdc(category_opt, remoteNodeId_opt = Some(remoteNodeId)) case _ => Logs.mdc(category_opt) } } diff --git a/eclair-front/src/test/scala/fr/acinq/eclair/router/FrontRouterSpec.scala b/eclair-front/src/test/scala/fr/acinq/eclair/router/FrontRouterSpec.scala index e2005616cf..a008115a29 100644 --- a/eclair-front/src/test/scala/fr/acinq/eclair/router/FrontRouterSpec.scala +++ b/eclair-front/src/test/scala/fr/acinq/eclair/router/FrontRouterSpec.scala @@ -30,7 +30,7 @@ import fr.acinq.eclair.io.Peer.PeerRoutingMessage import fr.acinq.eclair.router.Announcements.{makeChannelAnnouncement, makeChannelUpdate, makeNodeAnnouncement} import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.transactions.Scripts -import fr.acinq.eclair.wire.protocol.Color +import fr.acinq.eclair.wire.protocol.{Color, Init} import org.scalatest.funsuite.AnyFunSuiteLike import scodec.bits._ @@ -84,22 +84,22 @@ class FrontRouterSpec extends TestKit(ActorSystem("test")) with AnyFunSuiteLike system2.eventStream.subscribe(peerConnection2a.ref, classOf[Rebroadcast]) system3.eventStream.subscribe(peerConnection3a.ref, classOf[Rebroadcast]) - val origin1a = RemoteGossip(peerConnection1a.ref, randomKey().publicKey) - val origin1b = RemoteGossip(peerConnection1b.ref, randomKey().publicKey) - val origin2a = RemoteGossip(peerConnection2a.ref, randomKey().publicKey) + val origin1a = RemoteGossip(peerConnection1a.ref, randomKey().publicKey, remoteInit) + val origin1b = RemoteGossip(peerConnection1b.ref, randomKey().publicKey, remoteInit) + val origin2a = RemoteGossip(peerConnection2a.ref, randomKey().publicKey, remoteInit) - peerConnection1a.send(front1, PeerRoutingMessage(peerConnection1a.ref, origin1a.nodeId, chan_ab)) - pipe1.expectMsg(PeerRoutingMessage(front1, origin1a.nodeId, chan_ab)) - pipe1.send(router, PeerRoutingMessage(pipe1.ref, origin1a.nodeId, chan_ab)) + peerConnection1a.send(front1, PeerRoutingMessage(peerConnection1a.ref, origin1a.nodeId, remoteInit, chan_ab)) + pipe1.expectMsg(PeerRoutingMessage(front1, origin1a.nodeId, remoteInit, chan_ab)) + pipe1.send(router, PeerRoutingMessage(pipe1.ref, origin1a.nodeId, remoteInit, chan_ab)) assert(watcher.expectMsgType[ValidateRequest].ann === chan_ab) - peerConnection1b.send(front1, PeerRoutingMessage(peerConnection1b.ref, origin1b.nodeId, chan_ab)) + peerConnection1b.send(front1, PeerRoutingMessage(peerConnection1b.ref, origin1b.nodeId, remoteInit, chan_ab)) pipe1.expectNoMessage() - peerConnection2a.send(front2, PeerRoutingMessage(peerConnection2a.ref, origin2a.nodeId, chan_ab)) - pipe2.expectMsg(PeerRoutingMessage(front2, origin2a.nodeId, chan_ab)) - pipe2.send(router, PeerRoutingMessage(pipe2.ref, origin2a.nodeId, chan_ab)) + peerConnection2a.send(front2, PeerRoutingMessage(peerConnection2a.ref, origin2a.nodeId, remoteInit, chan_ab)) + pipe2.expectMsg(PeerRoutingMessage(front2, origin2a.nodeId, remoteInit, chan_ab)) + pipe2.send(router, PeerRoutingMessage(pipe2.ref, origin2a.nodeId, remoteInit, chan_ab)) pipe2.expectMsg(TransportHandler.ReadAck(chan_ab)) pipe1.expectNoMessage() @@ -163,22 +163,22 @@ class FrontRouterSpec extends TestKit(ActorSystem("test")) with AnyFunSuiteLike system2.eventStream.subscribe(peerConnection2a.ref, classOf[Rebroadcast]) system3.eventStream.subscribe(peerConnection3a.ref, classOf[Rebroadcast]) - val origin1a = RemoteGossip(peerConnection1a.ref, randomKey().publicKey) - val origin1b = RemoteGossip(peerConnection1b.ref, randomKey().publicKey) - val origin2a = RemoteGossip(peerConnection2a.ref, randomKey().publicKey) - val origin3a = RemoteGossip(peerConnection3a.ref, randomKey().publicKey) + val origin1a = RemoteGossip(peerConnection1a.ref, randomKey().publicKey, remoteInit) + val origin1b = RemoteGossip(peerConnection1b.ref, randomKey().publicKey, remoteInit) + val origin2a = RemoteGossip(peerConnection2a.ref, randomKey().publicKey, remoteInit) + val origin3a = RemoteGossip(peerConnection3a.ref, randomKey().publicKey, remoteInit) - peerConnection1a.send(front1, PeerRoutingMessage(peerConnection1a.ref, origin1a.nodeId, chan_ab)) + peerConnection1a.send(front1, PeerRoutingMessage(peerConnection1a.ref, origin1a.nodeId, remoteInit, chan_ab)) assert(watcher.expectMsgType[ValidateRequest].ann === chan_ab) - peerConnection1b.send(front1, PeerRoutingMessage(peerConnection1b.ref, origin1b.nodeId, chan_ab)) - peerConnection2a.send(front2, PeerRoutingMessage(peerConnection2a.ref, origin2a.nodeId, chan_ab)) + peerConnection1b.send(front1, PeerRoutingMessage(peerConnection1b.ref, origin1b.nodeId, remoteInit, chan_ab)) + peerConnection2a.send(front2, PeerRoutingMessage(peerConnection2a.ref, origin2a.nodeId, remoteInit, chan_ab)) - peerConnection1a.send(front1, PeerRoutingMessage(peerConnection1a.ref, origin1a.nodeId, ann_c)) + peerConnection1a.send(front1, PeerRoutingMessage(peerConnection1a.ref, origin1a.nodeId, remoteInit, ann_c)) peerConnection1a.expectMsg(TransportHandler.ReadAck(ann_c)) peerConnection1a.expectMsg(GossipDecision.NoKnownChannel(ann_c)) - peerConnection3a.send(front3, PeerRoutingMessage(peerConnection3a.ref, origin3a.nodeId, ann_a)) - peerConnection3a.send(front3, PeerRoutingMessage(peerConnection3a.ref, origin3a.nodeId, channelUpdate_ba)) - peerConnection3a.send(front3, PeerRoutingMessage(peerConnection3a.ref, origin3a.nodeId, channelUpdate_bc)) + peerConnection3a.send(front3, PeerRoutingMessage(peerConnection3a.ref, origin3a.nodeId, remoteInit, ann_a)) + peerConnection3a.send(front3, PeerRoutingMessage(peerConnection3a.ref, origin3a.nodeId, remoteInit, channelUpdate_ba)) + peerConnection3a.send(front3, PeerRoutingMessage(peerConnection3a.ref, origin3a.nodeId, remoteInit, channelUpdate_bc)) peerConnection3a.expectMsg(TransportHandler.ReadAck(channelUpdate_bc)) peerConnection3a.expectMsg(GossipDecision.NoRelatedChannel(channelUpdate_bc)) @@ -199,11 +199,11 @@ class FrontRouterSpec extends TestKit(ActorSystem("test")) with AnyFunSuiteLike peerConnection3a.expectMsg(TransportHandler.ReadAck(ann_a)) peerConnection3a.expectMsg(GossipDecision.Accepted(ann_a)) - peerConnection1b.send(front1, PeerRoutingMessage(peerConnection1b.ref, origin1b.nodeId, channelUpdate_ab)) + peerConnection1b.send(front1, PeerRoutingMessage(peerConnection1b.ref, origin1b.nodeId, remoteInit, channelUpdate_ab)) peerConnection1b.expectMsg(TransportHandler.ReadAck(channelUpdate_ab)) peerConnection1b.expectMsg(GossipDecision.Accepted(channelUpdate_ab)) - peerConnection3a.send(front3, PeerRoutingMessage(peerConnection3a.ref, origin3a.nodeId, ann_b)) + peerConnection3a.send(front3, PeerRoutingMessage(peerConnection3a.ref, origin3a.nodeId, remoteInit, ann_b)) peerConnection3a.expectMsg(TransportHandler.ReadAck(ann_b)) peerConnection3a.expectMsg(GossipDecision.Accepted(ann_b)) @@ -226,10 +226,10 @@ class FrontRouterSpec extends TestKit(ActorSystem("test")) with AnyFunSuiteLike val peerConnection1 = TestProbe() system1.eventStream.subscribe(peerConnection1.ref, classOf[Rebroadcast]) - val origin1 = RemoteGossip(peerConnection1.ref, randomKey().publicKey) + val origin1 = RemoteGossip(peerConnection1.ref, randomKey().publicKey, remoteInit) - peerConnection1.send(front1, PeerRoutingMessage(peerConnection1.ref, origin1.nodeId, chan_ab)) - router.expectMsg(PeerRoutingMessage(front1, origin1.nodeId, chan_ab)) + peerConnection1.send(front1, PeerRoutingMessage(peerConnection1.ref, origin1.nodeId, remoteInit, chan_ab)) + router.expectMsg(PeerRoutingMessage(front1, origin1.nodeId, remoteInit, chan_ab)) router.send(front1, TransportHandler.ReadAck(chan_ab)) peerConnection1.expectNoMessage() router.send(front1, GossipDecision.Accepted(chan_ab)) @@ -237,14 +237,14 @@ class FrontRouterSpec extends TestKit(ActorSystem("test")) with AnyFunSuiteLike peerConnection1.expectMsg(GossipDecision.Accepted(chan_ab)) router.send(front1, ChannelsDiscovered(SingleChannelDiscovered(chan_ab, 0.sat, None, None) :: Nil)) - peerConnection1.send(front1, PeerRoutingMessage(peerConnection1.ref, origin1.nodeId, chan_ab)) + peerConnection1.send(front1, PeerRoutingMessage(peerConnection1.ref, origin1.nodeId, remoteInit, chan_ab)) router.expectNoMessage() // announcement is pending rebroadcast peerConnection1.expectMsg(TransportHandler.ReadAck(chan_ab)) router.send(front1, TickBroadcast) peerConnection1.expectMsg(Rebroadcast(channels = Map(chan_ab -> Set(origin1)), updates = Map.empty, nodes = Map.empty)) - peerConnection1.send(front1, PeerRoutingMessage(peerConnection1.ref, origin1.nodeId, chan_ab)) + peerConnection1.send(front1, PeerRoutingMessage(peerConnection1.ref, origin1.nodeId, remoteInit, chan_ab)) router.expectNoMessage() // announcement is already known peerConnection1.expectMsg(TransportHandler.ReadAck(chan_ab)) } @@ -260,14 +260,14 @@ class FrontRouterSpec extends TestKit(ActorSystem("test")) with AnyFunSuiteLike val peerConnection1 = TestProbe() system1.eventStream.subscribe(peerConnection1.ref, classOf[Rebroadcast]) - val origin1 = RemoteGossip(peerConnection1.ref, randomKey().publicKey) + val origin1 = RemoteGossip(peerConnection1.ref, randomKey().publicKey, remoteInit) // first message arrives and is forwarded to router - peerConnection1.send(front1, PeerRoutingMessage(peerConnection1.ref, origin1.nodeId, chan_ab)) - router.expectMsg(PeerRoutingMessage(front1, origin1.nodeId, chan_ab)) + peerConnection1.send(front1, PeerRoutingMessage(peerConnection1.ref, origin1.nodeId, remoteInit, chan_ab)) + router.expectMsg(PeerRoutingMessage(front1, origin1.nodeId, remoteInit, chan_ab)) peerConnection1.expectNoMessage() // duplicate message is immediately acknowledged - peerConnection1.send(front1, PeerRoutingMessage(peerConnection1.ref, origin1.nodeId, chan_ab)) + peerConnection1.send(front1, PeerRoutingMessage(peerConnection1.ref, origin1.nodeId, remoteInit, chan_ab)) peerConnection1.expectMsg(TransportHandler.ReadAck(chan_ab)) // router acknowledges the first message router.send(front1, TransportHandler.ReadAck(chan_ab)) @@ -290,11 +290,11 @@ class FrontRouterSpec extends TestKit(ActorSystem("test")) with AnyFunSuiteLike val peerConnection1 = TestProbe() system1.eventStream.subscribe(peerConnection1.ref, classOf[Rebroadcast]) - val origin1 = RemoteGossip(peerConnection1.ref, randomKey().publicKey) + val origin1 = RemoteGossip(peerConnection1.ref, randomKey().publicKey, remoteInit) // channel_update arrives and is forwarded to router (there is no associated channel, because it is private) - peerConnection1.send(front1, PeerRoutingMessage(peerConnection1.ref, origin1.nodeId, channelUpdate_ab)) - router.expectMsg(PeerRoutingMessage(front1, origin1.nodeId, channelUpdate_ab)) + peerConnection1.send(front1, PeerRoutingMessage(peerConnection1.ref, origin1.nodeId, remoteInit, channelUpdate_ab)) + router.expectMsg(PeerRoutingMessage(front1, origin1.nodeId, remoteInit, channelUpdate_ab)) peerConnection1.expectNoMessage() // router acknowledges the message router.send(front1, TransportHandler.ReadAck(channelUpdate_ab)) @@ -321,6 +321,8 @@ object FrontRouterSpec { val (priv_funding_a, priv_funding_b, priv_funding_c, priv_funding_d, priv_funding_e, priv_funding_f) = (randomKey(), randomKey(), randomKey(), randomKey(), randomKey(), randomKey()) val (funding_a, funding_b, funding_c, funding_d, funding_e, funding_f) = (priv_funding_a.publicKey, priv_funding_b.publicKey, priv_funding_c.publicKey, priv_funding_d.publicKey, priv_funding_e.publicKey, priv_funding_f.publicKey) + val remoteInit = Init(Features.empty) + val ann_a = makeNodeAnnouncement(priv_a, "node-A", Color(15, 10, -70), Nil, Features(hex"0200")) val ann_b = makeNodeAnnouncement(priv_b, "node-B", Color(50, 99, -80), Nil, Features(hex"")) val ann_c = makeNodeAnnouncement(priv_c, "node-C", Color(123, 100, -40), Nil, Features(hex"0200"))