diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index e644daded4..d3e57da393 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -647,20 +647,8 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ Kamon.runWithSpan(Kamon.spanBuilder("compute-timestamps-checksums").start(), finishSpan = true) { chunks.foreach { chunk => - val (timestamps, checksums) = routingMessage.queryFlags_opt match { - case Some(extension) if extension.wantChecksums | extension.wantTimestamps => - // we always compute timestamps and checksums even if we don't need both, overhead is negligible - val (timestamps, checksums) = chunk.shortChannelIds.map(getChannelDigestInfo(d.channels)).unzip - val encodedTimestamps = if (extension.wantTimestamps) Some(ReplyChannelRangeTlv.EncodedTimestamps(nodeParams.routerConf.encodingType, timestamps)) else None - val encodedChecksums = if (extension.wantChecksums) Some(ReplyChannelRangeTlv.EncodedChecksums(checksums)) else None - (encodedTimestamps, encodedChecksums) - case _ => (None, None) - } - transport ! ReplyChannelRange(chainHash, chunk.firstBlock, chunk.numBlocks, - complete = 1, - shortChannelIds = EncodedShortChannelIds(nodeParams.routerConf.encodingType, chunk.shortChannelIds), - timestamps = timestamps, - checksums = checksums) + val reply = Router.buildReplyChannelRange(chunk, chainHash, nodeParams.routerConf.encodingType, routingMessage.queryFlags_opt, d.channels) + transport ! reply } } stay @@ -699,17 +687,25 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ (c1, u1) } log.info(s"received reply_channel_range with {} channels, we're missing {} channel announcements and {} updates, format={}", shortChannelIds.array.size, channelCount, updatesCount, shortChannelIds.encoding) - // we update our sync data to this node (there may be multiple channel range responses and we can only query one set of ids at a time) - val replies = shortChannelIdAndFlags - .grouped(nodeParams.routerConf.channelQueryChunkSize) - .map(chunk => QueryShortChannelIds(chainHash, - shortChannelIds = EncodedShortChannelIds(shortChannelIds.encoding, chunk.map(_.shortChannelId)), + + def buildQuery(chunk: List[ShortChannelIdAndFlag]): QueryShortChannelIds = { + // always encode empty lists as UNCOMPRESSED + val encoding = if (chunk.isEmpty) EncodingType.UNCOMPRESSED else shortChannelIds.encoding + QueryShortChannelIds(chainHash, + shortChannelIds = EncodedShortChannelIds(encoding, chunk.map(_.shortChannelId)), if (routingMessage.timestamps_opt.isDefined || routingMessage.checksums_opt.isDefined) - TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(shortChannelIds.encoding, chunk.map(_.flag))) + TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(encoding, chunk.map(_.flag))) else TlvStream.empty - )) + ) + } + + // we update our sync data to this node (there may be multiple channel range responses and we can only query one set of ids at a time) + val replies = shortChannelIdAndFlags + .grouped(nodeParams.routerConf.channelQueryChunkSize) + .map(buildQuery) .toList + val (sync1, replynow_opt) = addToSync(d.sync, remoteNodeId, replies) // we only send a reply right away if there were no pending requests replynow_opt.foreach(transport ! _) @@ -1285,6 +1281,33 @@ object Router { */ def enforceMaximumSize(chunks: List[ShortChannelIdsChunk]) : List[ShortChannelIdsChunk] = chunks.map(_.enforceMaximumSize(MAXIMUM_CHUNK_SIZE)) + /** + * Build a `reply_channel_range` message + * @param chunk chunk of scids + * @param chainHash chain hash + * @param defaultEncoding default encoding + * @param queryFlags_opt query flag set by the requester + * @param channels channels map + * @return a ReplyChannelRange object + */ + def buildReplyChannelRange(chunk: ShortChannelIdsChunk, chainHash: ByteVector32, defaultEncoding: EncodingType, queryFlags_opt: Option[QueryChannelRangeTlv.QueryFlags], channels: SortedMap[ShortChannelId, PublicChannel]): ReplyChannelRange = { + val encoding = if (chunk.shortChannelIds.isEmpty) EncodingType.UNCOMPRESSED else defaultEncoding + val (timestamps, checksums) = queryFlags_opt match { + case Some(extension) if extension.wantChecksums | extension.wantTimestamps => + // we always compute timestamps and checksums even if we don't need both, overhead is negligible + val (timestamps, checksums) = chunk.shortChannelIds.map(getChannelDigestInfo(channels)).unzip + val encodedTimestamps = if (extension.wantTimestamps) Some(ReplyChannelRangeTlv.EncodedTimestamps(encoding, timestamps)) else None + val encodedChecksums = if (extension.wantChecksums) Some(ReplyChannelRangeTlv.EncodedChecksums(checksums)) else None + (encodedTimestamps, encodedChecksums) + case _ => (None, None) + } + ReplyChannelRange(chainHash, chunk.firstBlock, chunk.numBlocks, + complete = 1, + shortChannelIds = EncodedShortChannelIds(encoding, chunk.shortChannelIds), + timestamps = timestamps, + checksums = checksums) + } + def addToSync(syncMap: Map[PublicKey, Sync], remoteNodeId: PublicKey, pending: List[RoutingMessage]): (Map[PublicKey, Sync], Option[RoutingMessage]) = { pending match { case head +: rest => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala index d3a78bd7e1..860883760f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala @@ -16,8 +16,10 @@ package fr.acinq.eclair.router -import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.bitcoin.{Block, ByteVector32} import fr.acinq.eclair.router.Router.ShortChannelIdsChunk +import fr.acinq.eclair.wire.QueryChannelRangeTlv.QueryFlags +import fr.acinq.eclair.wire.{EncodedShortChannelIds, EncodingType, QueryChannelRange, QueryChannelRangeTlv, ReplyChannelRange} import fr.acinq.eclair.wire.ReplyChannelRangeTlv._ import fr.acinq.eclair.{LongToBtcAmount, ShortChannelId, randomKey} import org.scalatest.FunSuite @@ -357,4 +359,23 @@ class ChannelRangeQueriesSpec extends FunSuite { validateChunks(chunks.toList, pruned) } } + + test("do not encode empty lists as COMPRESSED_ZLIB") { + { + val reply = Router.buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), Block.RegtestGenesisBlock.hash, EncodingType.COMPRESSED_ZLIB, Some(QueryFlags(QueryFlags.WANT_ALL)), SortedMap()) + assert(reply == ReplyChannelRange(Block.RegtestGenesisBlock.hash, 0L, 42L, 1.toByte, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, Nil), Some(EncodedTimestamps(EncodingType.UNCOMPRESSED, Nil)), Some(EncodedChecksums(Nil)))) + } + { + val reply = Router.buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), Block.RegtestGenesisBlock.hash, EncodingType.COMPRESSED_ZLIB, Some(QueryFlags(QueryFlags.WANT_TIMESTAMPS)), SortedMap()) + assert(reply == ReplyChannelRange(Block.RegtestGenesisBlock.hash, 0L, 42L, 1.toByte, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, Nil), Some(EncodedTimestamps(EncodingType.UNCOMPRESSED, Nil)), None)) + } + { + val reply = Router.buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), Block.RegtestGenesisBlock.hash, EncodingType.COMPRESSED_ZLIB, Some(QueryFlags(QueryFlags.WANT_CHECKSUMS)), SortedMap()) + assert(reply == ReplyChannelRange(Block.RegtestGenesisBlock.hash, 0L, 42L, 1.toByte, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, Nil), None, Some(EncodedChecksums(Nil)))) + } + { + val reply = Router.buildReplyChannelRange(ShortChannelIdsChunk(0, 42, Nil), Block.RegtestGenesisBlock.hash, EncodingType.COMPRESSED_ZLIB, None, SortedMap()) + assert(reply == ReplyChannelRange(Block.RegtestGenesisBlock.hash, 0L, 42L, 1.toByte, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, Nil), None, None)) + } + } }