From af9c48ca36332d970dc6448df9dda4171ef64622 Mon Sep 17 00:00:00 2001 From: Thomas HUET Date: Fri, 14 Oct 2022 15:50:08 +0200 Subject: [PATCH 1/9] Build blinded routes from provided nodes --- .../acinq/eclair/payment/Bolt12Invoice.scala | 9 +- .../payment/receive/MultiPartHandler.scala | 61 +++- .../BlindPaymentIntegrationSpec.scala | 298 ++++++++++++++++++ .../eclair/payment/Bolt12InvoiceSpec.scala | 45 +-- .../eclair/payment/MultiPartHandlerSpec.scala | 61 +++- 5 files changed, 426 insertions(+), 48 deletions(-) create mode 100644 eclair-core/src/test/scala/fr/acinq/eclair/integration/BlindPaymentIntegrationSpec.scala diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt12Invoice.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt12Invoice.scala index 5405c546a4..5905fc2457 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt12Invoice.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt12Invoice.scala @@ -54,6 +54,7 @@ case class Bolt12Invoice(records: TlvStream[InvoiceTlv]) extends Invoice { val chain: ByteVector32 = records.get[Chain].map(_.hash).getOrElse(Block.LivenetGenesisBlock.hash) val offerId: Option[ByteVector32] = records.get[OfferId].map(_.offerId) val blindedPaths: Seq[RouteBlinding.BlindedRoute] = records.get[Paths].get.paths + val blindedPathsInfo: Seq[PaymentInfo] = records.get[PaymentPathsInfo].get.paymentInfo val issuer: Option[String] = records.get[Issuer].map(_.issuer) val quantity: Option[Long] = records.get[Quantity].map(_.quantity) val refundFor: Option[ByteVector32] = records.get[RefundFor].map(_.refundedPaymentHash) @@ -101,6 +102,8 @@ case class Bolt12Invoice(records: TlvStream[InvoiceTlv]) extends Invoice { } +case class PaymentBlindedRoute(route: Sphinx.RouteBlinding.BlindedRoute, paymentInfo: PaymentInfo) + object Bolt12Invoice { val hrp = "lni" val DEFAULT_EXPIRY_SECONDS: Long = 7200 @@ -122,7 +125,7 @@ object Bolt12Invoice { nodeKey: PrivateKey, minFinalCltvExpiryDelta: CltvExpiryDelta, features: Features[InvoiceFeature], - paths: Seq[Sphinx.RouteBlinding.BlindedRoute]): Bolt12Invoice = { + paths: Seq[PaymentBlindedRoute]): Bolt12Invoice = { require(request.amount.nonEmpty || offer.amount.nonEmpty) val amount = request.amount.orElse(offer.amount.map(_ * request.quantity)).get val tlvs: Seq[InvoiceTlv] = Seq( @@ -131,8 +134,8 @@ object Bolt12Invoice { Some(Amount(amount)), Some(Description(offer.description)), if (!features.isEmpty) Some(FeaturesTlv(features.unscoped())) else None, - Some(Paths(paths)), - Some(PaymentPathsInfo(Seq(PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, amount, Features.empty)))), + Some(Paths(paths.map(_.route))), + Some(PaymentPathsInfo(paths.map(_.paymentInfo))), offer.issuer.map(Issuer), Some(NodeId(nodeKey.publicKey)), request.quantity_opt.map(Quantity), diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala index 604039a3e9..2b23e72110 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala @@ -22,6 +22,7 @@ import akka.actor.typed.scaladsl.Behaviors import akka.actor.typed.scaladsl.adapter.ClassicActorContextOps import akka.actor.{ActorContext, ActorRef, PoisonPill, Status} import akka.event.{DiagnosticLoggingAdapter, LoggingAdapter} +import akka.pattern.ask import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto} import fr.acinq.eclair.Logs.LogCategory @@ -31,12 +32,16 @@ import fr.acinq.eclair.db._ import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment._ -import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer} +import fr.acinq.eclair.router.Router +import fr.acinq.eclair.router.Router.ChannelHop +import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer, PaymentInfo} import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, FeatureSupport, Features, InvoiceFeature, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, ShortChannelId, TimestampMilli, randomBytes, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, FeatureSupport, Features, InvoiceFeature, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, TimestampMilli, randomBytes32, randomKey} import scodec.bits.HexStringSyntax +import scala.concurrent.Await +import scala.concurrent.duration.DurationInt import scala.util.{Failure, Success, Try} /** @@ -255,6 +260,8 @@ object MultiPartHandler { case class ReceiveOfferPayment(nodeKey: PrivateKey, offer: Offer, invoiceRequest: InvoiceRequest, + routes: Seq[Seq[PublicKey]], + router: ActorRef, paymentPreimage_opt: Option[ByteVector32] = None, paymentType: String = PaymentType.Blinded) extends ReceivePayment @@ -298,20 +305,46 @@ object MultiPartHandler { nodeParams.db.payments.addIncomingPayment(invoice, paymentPreimage, r.paymentType) invoice case r: ReceiveOfferPayment => - // TODO: get blinded paths from the router instead - val pathId = RouteBlindingEncryptedDataTlv.PathId(randomBytes32()) - val dummyConstraints = RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(nodeParams.currentBlockHeight + 144), 1 msat) - val dummyRelay = RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(0), 0, 0 msat) - val dummyScid = RouteBlindingEncryptedDataTlv.OutgoingChannelId(ShortChannelId.toSelf) - val dummyPath = Seq( - (nodeParams.nodeId, RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream(dummyScid, dummyRelay, dummyConstraints)).require.bytes), - (nodeParams.nodeId, RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream(dummyConstraints, pathId)).require.bytes), - ) - val blindedRoute = Sphinx.RouteBlinding.create(randomKey(), dummyPath.map(_._1), dummyPath.map(_._2)) + val amount = r.invoiceRequest.amount.orElse(r.offer.amount.map(_ * r.invoiceRequest.quantity)).get + val paths = r.routes.map(nodeIds => { + require(nodeIds.nonEmpty, "route can't be empty") + val pathId = randomBytes32() + val finalExpiryDelta = nodeParams.channelConf.minFinalExpiryDelta + 3 + val finalConstraints = RouteBlindingEncryptedDataTlv.PaymentConstraints(finalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight), nodeParams.channelConf.htlcMinimum) + val zeroPaymentInfo = PaymentInfo(0 msat, 0, CltvExpiryDelta(0), nodeParams.channelConf.htlcMinimum, amount, Features.empty) + val finalPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( + finalConstraints, + RouteBlindingEncryptedDataTlv.PathId(pathId) + )).require.bytes + val (paymentInfo, payloads) = if (nodeIds.length > 1) { + val timeout = 30 second + val routeResponse = Await.result(r.router.ask(Router.FinalizeRoute(0 msat, Router.PredefinedNodeRoute(nodeIds)))(timeout).mapTo[Router.RouteResponse], timeout) + val routeToBlind = routeResponse.routes.head + val totalCltvDelta = routeToBlind.hops.map(_.cltvExpiryDelta).fold(finalExpiryDelta)(_ + _) + routeToBlind.hops.foldRight((zeroPaymentInfo, Seq(finalPayload))) { + case (channel: ChannelHop, (payInfo, nextPayloads)) => + val newFeeBase = MilliSatoshi((channel.params.relayFees.feeBase.toLong * 1_000_000 + payInfo.feeBase.toLong * (1_000_000 + channel.params.relayFees.feeProportionalMillionths) + 1_000_000 - 1) / 1_000_000) + val newFeeProp = ((payInfo.feeProportionalMillionths + channel.params.relayFees.feeProportionalMillionths) * 1_000_000 + payInfo.feeProportionalMillionths * channel.params.relayFees.feeProportionalMillionths + 1_000_000 - 1) / 1_000_000 + // Because eclair (and others) lies about max HTLC, we remove 10% as a safety margin. + val channelMaxHtlc = channel.params.htlcMaximum_opt.map(_ * 0.9).getOrElse(amount) + val newPayInfo = PaymentInfo(newFeeBase, newFeeProp, payInfo.cltvExpiryDelta + channel.cltvExpiryDelta, payInfo.minHtlc.max(channel.params.htlcMinimum), payInfo.maxHtlc.min(channelMaxHtlc), payInfo.allowedFeatures) + val payload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( + RouteBlindingEncryptedDataTlv.OutgoingChannelId(channel.shortChannelId), + RouteBlindingEncryptedDataTlv.PaymentRelay(channel.cltvExpiryDelta, channel.params.relayFees.feeProportionalMillionths, channel.params.relayFees.feeBase), + RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(nodeParams.currentBlockHeight) + totalCltvDelta, channel.params.htlcMinimum) + )).require.bytes + (newPayInfo, payload +: nextPayloads) + } + } else { + (zeroPaymentInfo, Seq(finalPayload)) + } + val blindedRoute = Sphinx.RouteBlinding.create(randomKey(), nodeIds, payloads) + (blindedRoute, paymentInfo, pathId) + }) val invoiceFeatures = featuresTrampolineOpt.remove(Features.RouteBlinding).add(Features.RouteBlinding, FeatureSupport.Mandatory) - val invoice = Bolt12Invoice(r.offer, r.invoiceRequest, paymentPreimage, r.nodeKey, nodeParams.channelConf.minFinalExpiryDelta, invoiceFeatures, Seq(blindedRoute.route)) + val invoice = Bolt12Invoice(r.offer, r.invoiceRequest, paymentPreimage, r.nodeKey, nodeParams.channelConf.minFinalExpiryDelta, invoiceFeatures, paths.map { case (blindedRoute, paymentInfo, _) => PaymentBlindedRoute(blindedRoute.route, paymentInfo) }) context.log.debug("generated invoice={} for offerId={}", invoice.toString, r.offer.offerId) - nodeParams.db.payments.addIncomingBlindedPayment(invoice, paymentPreimage, Map(blindedRoute.lastBlinding -> pathId.data), r.paymentType) + nodeParams.db.payments.addIncomingBlindedPayment(invoice, paymentPreimage, paths.map { case (blindedRoute, _, pathId) => (blindedRoute.lastBlinding -> pathId.bytes) }.toMap, r.paymentType) invoice } } match { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/BlindPaymentIntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/BlindPaymentIntegrationSpec.scala new file mode 100644 index 0000000000..5b1a3c73bf --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/BlindPaymentIntegrationSpec.scala @@ -0,0 +1,298 @@ +/* + * Copyright 2022 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.integration + +import akka.actor.typed.scaladsl.adapter.actorRefAdapter +import akka.testkit.TestProbe +import com.typesafe.config.ConfigFactory +import fr.acinq.bitcoin.scalacompat.{Crypto, SatoshiLong} +import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher +import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher.{Watch, WatchFundingConfirmed} +import fr.acinq.eclair.channel._ +import fr.acinq.eclair.payment._ +import fr.acinq.eclair.payment.receive.MultiPartHandler.ReceiveOfferPayment +import fr.acinq.eclair.payment.relay.Relayer +import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentToNode +import fr.acinq.eclair.router.Router +import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer} +import fr.acinq.eclair.{Kit, MilliSatoshiLong, randomKey} + +import java.util.UUID +import scala.concurrent.duration._ +import scala.jdk.CollectionConverters._ + +class BlindPaymentIntegrationSpec extends IntegrationSpec { + + test("start eclair nodes") { + instantiateEclairNode("A", ConfigFactory.parseMap(Map("eclair.node-alias" -> "A", "eclair.channel.expiry-delta-blocks" -> 130, "eclair.server.port" -> 29730, "eclair.api.port" -> 28080, "eclair.features.option_route_blinding" -> "optional").asJava).withFallback(withDefaultCommitment).withFallback(commonConfig)) + instantiateEclairNode("B", ConfigFactory.parseMap(Map("eclair.node-alias" -> "B", "eclair.channel.expiry-delta-blocks" -> 131, "eclair.server.port" -> 29731, "eclair.api.port" -> 28081, "eclair.features.option_route_blinding" -> "optional").asJava).withFallback(withDefaultCommitment).withFallback(commonConfig)) + instantiateEclairNode("C", ConfigFactory.parseMap(Map("eclair.node-alias" -> "C", "eclair.channel.expiry-delta-blocks" -> 132, "eclair.server.port" -> 29732, "eclair.api.port" -> 28082, "eclair.features.option_route_blinding" -> "optional").asJava).withFallback(withDefaultCommitment).withFallback(commonConfig)) + instantiateEclairNode("D", ConfigFactory.parseMap(Map("eclair.node-alias" -> "D", "eclair.channel.expiry-delta-blocks" -> 133, "eclair.server.port" -> 29733, "eclair.api.port" -> 28083, "eclair.features.option_route_blinding" -> "optional").asJava).withFallback(withDefaultCommitment).withFallback(commonConfig)) + instantiateEclairNode("E", ConfigFactory.parseMap(Map("eclair.node-alias" -> "E", "eclair.channel.expiry-delta-blocks" -> 134, "eclair.server.port" -> 29734, "eclair.api.port" -> 28084, "eclair.features.option_route_blinding" -> "optional").asJava).withFallback(withDefaultCommitment).withFallback(commonConfig)) + instantiateEclairNode("F", ConfigFactory.parseMap(Map("eclair.node-alias" -> "F", "eclair.channel.expiry-delta-blocks" -> 135, "eclair.server.port" -> 29735, "eclair.api.port" -> 28085, "eclair.features.option_route_blinding" -> "optional").asJava).withFallback(withDefaultCommitment).withFallback(commonConfig)) + instantiateEclairNode("G", ConfigFactory.parseMap(Map("eclair.node-alias" -> "G", "eclair.channel.expiry-delta-blocks" -> 136, "eclair.server.port" -> 29736, "eclair.api.port" -> 28086, "eclair.features.option_route_blinding" -> "disabled").asJava).withFallback(withDefaultCommitment).withFallback(commonConfig)) // G does not support blinded routes. + } + + test("connect nodes") { + // ,--G--, + // / \ + // A---B ------- C ==== D + // \ / \ + // '--E--' \ + // \_____ F + // + // All channels have fees 1 sat + 200 millionths + + val sender = TestProbe() + val eventListener = TestProbe() + nodes.values.foreach(_.system.eventStream.subscribe(eventListener.ref, classOf[ChannelStateChanged])) + + connect(nodes("A"), nodes("B"), 11000000 sat, 0 msat) + connect(nodes("B"), nodes("C"), 2000000 sat, 0 msat) + connect(nodes("C"), nodes("D"), 5000000 sat, 0 msat) + connect(nodes("C"), nodes("D"), 5000000 sat, 0 msat) + connect(nodes("B"), nodes("E"), 10000000 sat, 0 msat) + connect(nodes("E"), nodes("C"), 10000000 sat, 0 msat) + connect(nodes("B"), nodes("G"), 16000000 sat, 0 msat) + connect(nodes("G"), nodes("C"), 16000000 sat, 0 msat) + connect(nodes("C"), nodes("F"), 2000000 sat, 0 msat) + connect(nodes("E"), nodes("F"), 2000000 sat, 0 msat) + + val numberOfChannels = 10 + val channelEndpointsCount = 2 * numberOfChannels + + // we make sure all channels have set up their WatchConfirmed for the funding tx + awaitCond({ + val watches = nodes.values.foldLeft(Set.empty[Watch[_]]) { + case (watches, setup) => + setup.watcher ! ZmqWatcher.ListWatches(sender.ref) + watches ++ sender.expectMsgType[Set[Watch[_]]] + } + watches.count(_.isInstanceOf[WatchFundingConfirmed]) == channelEndpointsCount + }, max = 20 seconds, interval = 1 second) + + // confirming the funding tx + generateBlocks(2) + + within(60 seconds) { + var count = 0 + while (count < channelEndpointsCount) { + if (eventListener.expectMsgType[ChannelStateChanged](60 seconds).currentState == NORMAL) count = count + 1 + } + } + } + + def awaitAnnouncements(subset: Map[String, Kit], nodes: Int, privateChannels: Int, publicChannels: Int, privateUpdates: Int, publicUpdates: Int): Unit = { + val sender = TestProbe() + subset.foreach { + case (node, setup) => + withClue(node) { + awaitAssert({ + sender.send(setup.router, Router.GetRouterData) + val data = sender.expectMsgType[Router.Data] + assert(data.nodes.size == nodes) + assert(data.privateChannels.size == privateChannels) + assert(data.channels.size == publicChannels) + assert(data.privateChannels.values.flatMap(pc => pc.update_1_opt.toSeq ++ pc.update_2_opt.toSeq).size == privateUpdates) + assert(data.channels.values.flatMap(pc => pc.update_1_opt.toSeq ++ pc.update_2_opt.toSeq).size == publicUpdates) + }, max = 10 seconds, interval = 1 second) + } + } + } + + test("wait for network announcements") { + // generating more blocks so that all funding txes are buried under at least 6 blocks + generateBlocks(4) + awaitAnnouncements(nodes.view.filterKeys(key => List("A", "B", "C", "D", "E", "G").contains(key)).toMap, nodes = 7, privateChannels = 0, publicChannels = 10, privateUpdates = 0, publicUpdates = 20) + } + + test("wait for channels balance") { + // Channels balance should now be available in the router + val sender = TestProbe() + val nodeId = nodes("C").nodeParams.nodeId + sender.send(nodes("C").router, Router.GetRoutingState) + val routingState = sender.expectMsgType[Router.RoutingState] + val publicChannels = routingState.channels.filter(pc => Set(pc.ann.nodeId1, pc.ann.nodeId2).contains(nodeId)) + assert(publicChannels.nonEmpty) + publicChannels.foreach(pc => assert(pc.meta_opt.exists(m => m.balance1 > 0.msat || m.balance2 > 0.msat), pc)) + } + + test("send an HTLC A->(D), minimal blinded route") { + val (sender, eventListener) = (TestProbe(), TestProbe()) + nodes("D").system.eventStream.subscribe(eventListener.ref, classOf[PaymentMetadataReceived]) + + val recipientKey = randomKey() + val payerKey = randomKey() + + // first we retrieve an invoice from D + val amount = 42000000 msat + val chain = nodes("D").nodeParams.chainHash + val offer = Offer(Some(amount), "test offer", recipientKey.publicKey, nodes("D").nodeParams.features.invoiceFeatures(), chain) + val invoiceRequest = InvoiceRequest(offer, amount, 1, nodes("A").nodeParams.features.invoiceFeatures(), payerKey, chain) + + sender.send(nodes("D").paymentHandler, ReceiveOfferPayment(recipientKey, offer, invoiceRequest, Seq(Seq(nodes("D").nodeParams.nodeId)), nodes("D").router)) + val invoice = sender.expectMsgType[Bolt12Invoice] + + // then we make the actual payment + sender.send(nodes("A").paymentInitiator, SendPaymentToNode(amount, invoice, routeParams = integrationTestRouteParams, maxAttempts = 1)) + val paymentId = sender.expectMsgType[UUID] + val ps = sender.expectMsgType[PaymentSent] + assert(ps.id == paymentId) + assert(Crypto.sha256(ps.paymentPreimage) == invoice.paymentHash) + } + + test("send an HTLC A->(D->D->D), blinded route with dummy hops") { + val (sender, eventListener) = (TestProbe(), TestProbe()) + nodes("D").system.eventStream.subscribe(eventListener.ref, classOf[PaymentMetadataReceived]) + + val recipientKey = randomKey() + val payerKey = randomKey() + + // first we retrieve an invoice from D + val amount = 5600000 msat + val chain = nodes("D").nodeParams.chainHash + val offer = Offer(Some(amount), "test offer", recipientKey.publicKey, nodes("D").nodeParams.features.invoiceFeatures(), chain) + val invoiceRequest = InvoiceRequest(offer, amount, 1, nodes("A").nodeParams.features.invoiceFeatures(), payerKey, chain) + + sender.send(nodes("D").paymentHandler, ReceiveOfferPayment(recipientKey, offer, invoiceRequest, Seq(Seq(nodes("D").nodeParams.nodeId, nodes("D").nodeParams.nodeId, nodes("D").nodeParams.nodeId)), nodes("D").router)) + val invoice = sender.expectMsgType[Bolt12Invoice] + + // then we make the actual payment + sender.send(nodes("A").paymentInitiator, SendPaymentToNode(amount, invoice, routeParams = integrationTestRouteParams, maxAttempts = 1)) + val paymentId = sender.expectMsgType[UUID] + val ps = sender.expectMsgType[PaymentSent] + assert(ps.id == paymentId) + assert(Crypto.sha256(ps.paymentPreimage) == invoice.paymentHash) + } + + test("send an HTLC A->(E->C->D), blinded route with intermediate hops") { + val (sender, eventListener) = (TestProbe(), TestProbe()) + nodes("D").system.eventStream.subscribe(eventListener.ref, classOf[PaymentMetadataReceived]) + + val recipientKey = randomKey() + val payerKey = randomKey() + + // first we retrieve an invoice from D + val amount = 7500000 msat + val chain = nodes("D").nodeParams.chainHash + val offer = Offer(Some(amount), "test offer", recipientKey.publicKey, nodes("D").nodeParams.features.invoiceFeatures(), chain) + val invoiceRequest = InvoiceRequest(offer, amount, 1, nodes("A").nodeParams.features.invoiceFeatures(), payerKey, chain) + + sender.send(nodes("D").paymentHandler, ReceiveOfferPayment(recipientKey, offer, invoiceRequest, Seq(Seq(nodes("E").nodeParams.nodeId, nodes("C").nodeParams.nodeId, nodes("D").nodeParams.nodeId)), nodes("D").router)) + val invoice = sender.expectMsgType[Bolt12Invoice] + + // then we make the actual payment + sender.send(nodes("A").paymentInitiator, SendPaymentToNode(amount, invoice, routeParams = integrationTestRouteParams, maxAttempts = 1)) + val paymentId = sender.expectMsgType[UUID] + val ps = sender.expectMsgType[PaymentSent] + assert(ps.id == paymentId) + assert(Crypto.sha256(ps.paymentPreimage) == invoice.paymentHash) + } + + test("send an HTLC A->(G->C->D), blinded route with node not supporting blinded routes") { + val (sender, eventListener) = (TestProbe(), TestProbe()) + nodes("D").system.eventStream.subscribe(eventListener.ref, classOf[PaymentMetadataReceived]) + + val recipientKey = randomKey() + val payerKey = randomKey() + + // first we retrieve an invoice from D + val amount = 7500000 msat + val chain = nodes("D").nodeParams.chainHash + val offer = Offer(Some(amount), "test offer", recipientKey.publicKey, nodes("D").nodeParams.features.invoiceFeatures(), chain) + val invoiceRequest = InvoiceRequest(offer, amount, 1, nodes("A").nodeParams.features.invoiceFeatures(), payerKey, chain) + + sender.send(nodes("D").paymentHandler, ReceiveOfferPayment(recipientKey, offer, invoiceRequest, Seq(Seq(nodes("G").nodeParams.nodeId, nodes("C").nodeParams.nodeId, nodes("D").nodeParams.nodeId)), nodes("D").router)) + val invoice = sender.expectMsgType[Bolt12Invoice] + + // then we make the actual payment + sender.send(nodes("A").paymentInitiator, SendPaymentToNode(amount, invoice, routeParams = integrationTestRouteParams, maxAttempts = 1)) + sender.expectMsgType[UUID] + sender.expectMsgType[PaymentFailed] + } + + test("send an HTLC A->(A->B->C->D), blinded route with introduction node being the sender") { + val (sender, eventListener) = (TestProbe(), TestProbe()) + nodes("D").system.eventStream.subscribe(eventListener.ref, classOf[PaymentMetadataReceived]) + + val recipientKey = randomKey() + val payerKey = randomKey() + + // first we retrieve an invoice from D + val amount = 3200000 msat + val chain = nodes("D").nodeParams.chainHash + val offer = Offer(Some(amount), "test offer", recipientKey.publicKey, nodes("D").nodeParams.features.invoiceFeatures(), chain) + val invoiceRequest = InvoiceRequest(offer, amount, 1, nodes("A").nodeParams.features.invoiceFeatures(), payerKey, chain) + + sender.send(nodes("D").paymentHandler, ReceiveOfferPayment(recipientKey, offer, invoiceRequest, Seq(Seq(nodes("A").nodeParams.nodeId, nodes("B").nodeParams.nodeId, nodes("C").nodeParams.nodeId, nodes("D").nodeParams.nodeId)), nodes("D").router)) + val invoice = sender.expectMsgType[Bolt12Invoice] + + // then we make the actual payment + sender.send(nodes("A").paymentInitiator, SendPaymentToNode(amount, invoice, routeParams = integrationTestRouteParams, maxAttempts = 1)) + val paymentId = sender.expectMsgType[UUID] + val ps = sender.expectMsgType[PaymentSent] + assert(ps.id == paymentId) + assert(Crypto.sha256(ps.paymentPreimage) == invoice.paymentHash) + } + + test("send to multiple blinded routes: "){ + val (sender, eventListener) = (TestProbe(), TestProbe()) + nodes("F").system.eventStream.subscribe(eventListener.ref, classOf[PaymentMetadataReceived]) + + val recipientKey = randomKey() + val payerKey = randomKey() + + // first we retrieve an invoice from D + val amount = 3_000_000_000L msat + val chain = nodes("F").nodeParams.chainHash + val offer = Offer(Some(amount), "test offer", recipientKey.publicKey, nodes("F").nodeParams.features.invoiceFeatures(), chain) + val invoiceRequest = InvoiceRequest(offer, amount, 1, nodes("A").nodeParams.features.invoiceFeatures(), payerKey, chain) + + sender.send(nodes("F").paymentHandler, ReceiveOfferPayment(recipientKey, offer, invoiceRequest, Seq(Seq(nodes("E").nodeParams.nodeId, nodes("F").nodeParams.nodeId), Seq(nodes("C").nodeParams.nodeId, nodes("F").nodeParams.nodeId)), nodes("F").router)) + val invoice = sender.expectMsgType[Bolt12Invoice] + + // then we make the actual payment + sender.send(nodes("A").paymentInitiator, SendPaymentToNode(amount, invoice, routeParams = integrationTestRouteParams, maxAttempts = 1)) + val paymentId = sender.expectMsgType[UUID] + val ps = sender.expectMsgType[PaymentSent] + assert(ps.id == paymentId) + assert(Crypto.sha256(ps.paymentPreimage) == invoice.paymentHash) + } + + // To show channel balances before adding new tests + def debugChannelBalances(): Unit = { + val names = nodes.map { case (name, kit) => (kit.nodeParams.nodeId, name) } + val sender = TestProbe() + sender.send(nodes("A").relayer, Relayer.GetOutgoingChannels()) + sender.expectMsgType[Relayer.OutgoingChannels].channels.map(_.toChannelBalance).foreach { balance => + println(s"A ${balance.canSend} ===== ${balance.canReceive} ${names(balance.remoteNodeId)}") + } + sender.send(nodes("B").relayer, Relayer.GetOutgoingChannels()) + sender.expectMsgType[Relayer.OutgoingChannels].channels.map(_.toChannelBalance).foreach { balance => + println(s"B ${balance.canSend} ===== ${balance.canReceive} ${names(balance.remoteNodeId)}") + } + sender.send(nodes("C").relayer, Relayer.GetOutgoingChannels()) + sender.expectMsgType[Relayer.OutgoingChannels].channels.map(_.toChannelBalance).foreach{ balance => + println(s"C ${balance.canSend} ===== ${balance.canReceive} ${names(balance.remoteNodeId)}") + } + sender.send(nodes("E").relayer, Relayer.GetOutgoingChannels()) + sender.expectMsgType[Relayer.OutgoingChannels].channels.map(_.toChannelBalance).foreach{ balance => + println(s"E ${balance.canSend} ===== ${balance.canReceive} ${names(balance.remoteNodeId)}") + } + names.foreach(println) + } +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala index 77a7728a6b..b7c811d4b7 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala @@ -50,16 +50,16 @@ class Bolt12InvoiceSpec extends AnyFunSuite { signedInvoice } - def createDirectPath(sessionKey: PrivateKey, nodeId: PublicKey, pathId: ByteVector): BlindedRoute = { + def createPaymentBlindedRoute(nodeId: PublicKey, sessionKey: PrivateKey = randomKey(), pathId: ByteVector = randomBytes32()): PaymentBlindedRoute = { val selfPayload = blindedRouteDataCodec.encode(TlvStream(Seq(RouteBlindingEncryptedDataTlv.PathId(pathId), PaymentConstraints(CltvExpiry(1234567), 0 msat), AllowedFeatures(Features.empty)))).require.bytes - Sphinx.RouteBlinding.create(sessionKey, Seq(nodeId), Seq(selfPayload)).route + PaymentBlindedRoute(Sphinx.RouteBlinding.create(sessionKey, Seq(nodeId), Seq(selfPayload)).route, PaymentInfo(1 msat, 2, CltvExpiryDelta(3), 4 msat, 5 msat, Features.empty)) } test("check invoice signature") { val (nodeKey, payerKey, chain) = (randomKey(), randomKey(), randomBytes32()) val offer = Offer(Some(10000 msat), "test offer", nodeKey.publicKey, Features.empty, chain) val request = InvoiceRequest(offer, 11000 msat, 1, Features.empty, payerKey, chain) - val invoice = Bolt12Invoice(offer, request, randomBytes32(), nodeKey, CltvExpiryDelta(20), Features.empty, Seq(createDirectPath(randomKey(), nodeKey.publicKey, randomBytes32()))) + val invoice = Bolt12Invoice(offer, request, randomBytes32(), nodeKey, CltvExpiryDelta(20), Features.empty, Seq(createPaymentBlindedRoute(nodeKey.publicKey))) assert(invoice.isValidFor(offer, request)) assert(invoice.checkSignature()) assert(!invoice.checkRefundSignature()) @@ -84,7 +84,7 @@ class Bolt12InvoiceSpec extends AnyFunSuite { val (nodeKey, payerKey, chain) = (randomKey(), randomKey(), randomBytes32()) val offer = Offer(Some(10000 msat), "test offer", nodeKey.publicKey, Features.empty, chain) val request = InvoiceRequest(offer, 11000 msat, 1, Features.empty, payerKey, chain) - val invoice = Bolt12Invoice(offer, request, randomBytes32(), nodeKey, CltvExpiryDelta(20), Features.empty, Seq(createDirectPath(randomKey(), nodeKey.publicKey, randomBytes32()))) + val invoice = Bolt12Invoice(offer, request, randomBytes32(), nodeKey, CltvExpiryDelta(20), Features.empty, Seq(createPaymentBlindedRoute(nodeKey.publicKey))) assert(invoice.isValidFor(offer, request)) assert(!invoice.isValidFor(Offer(None, "test offer", randomKey().publicKey, Features.empty, chain), request)) // amount must match the offer @@ -112,7 +112,7 @@ class Bolt12InvoiceSpec extends AnyFunSuite { val offer = Offer(Some(15000 msat), "test offer", nodeKey.publicKey, Features(VariableLengthOnion -> Mandatory), chain) val request = InvoiceRequest(offer, 15000 msat, 1, Features(VariableLengthOnion -> Mandatory), payerKey, chain) assert(request.quantity_opt.isEmpty) // when paying for a single item, the quantity field must not be present - val invoice = Bolt12Invoice(offer, request, randomBytes32(), nodeKey, CltvExpiryDelta(20), Features(VariableLengthOnion -> Mandatory, BasicMultiPartPayment -> Optional), Seq(createDirectPath(randomKey(), nodeKey.publicKey, randomBytes32()))) + val invoice = Bolt12Invoice(offer, request, randomBytes32(), nodeKey, CltvExpiryDelta(20), Features(VariableLengthOnion -> Mandatory, BasicMultiPartPayment -> Optional), Seq(createPaymentBlindedRoute(nodeKey.publicKey))) assert(invoice.isValidFor(offer, request)) val withInvalidFeatures = signInvoice(Bolt12Invoice(TlvStream(invoice.records.records.map { case FeaturesTlv(_) => FeaturesTlv(Features(VariableLengthOnion -> Mandatory, BasicMultiPartPayment -> Mandatory)) case x => x }.toSeq)), nodeKey) assert(!withInvalidFeatures.isValidFor(offer, request)) @@ -139,7 +139,7 @@ class Bolt12InvoiceSpec extends AnyFunSuite { val signature = signSchnorr(InvoiceRequest.signatureTag, rootHash(TlvStream(tlvs), invoiceRequestTlvCodec), payerKey) InvoiceRequest(TlvStream(tlvs :+ Signature(signature))) } - val withPayerDetails = Bolt12Invoice(offer, requestWithPayerDetails, randomBytes32(), nodeKey, CltvExpiryDelta(20), Features.empty, Seq(createDirectPath(randomKey(), nodeKey.publicKey, randomBytes32()))) + val withPayerDetails = Bolt12Invoice(offer, requestWithPayerDetails, randomBytes32(), nodeKey, CltvExpiryDelta(20), Features.empty, Seq(createPaymentBlindedRoute(nodeKey.publicKey))) assert(withPayerDetails.isValidFor(offer, requestWithPayerDetails)) assert(!withPayerDetails.isValidFor(offer, request)) val withOtherPayerInfo = signInvoice(Bolt12Invoice(TlvStream(withPayerDetails.records.records.map { case PayerInfo(_) => PayerInfo(hex"deadbeef") case x => x }.toSeq)), nodeKey) @@ -154,7 +154,7 @@ class Bolt12InvoiceSpec extends AnyFunSuite { val (nodeKey, payerKey, chain) = (randomKey(), randomKey(), randomBytes32()) val offer = Offer(Some(5000 msat), "test offer", nodeKey.publicKey, Features.empty, chain) val request = InvoiceRequest(offer, 5000 msat, 1, Features.empty, payerKey, chain) - val invoice = Bolt12Invoice(offer, request, randomBytes32(), nodeKey, CltvExpiryDelta(20), Features.empty, Seq(createDirectPath(randomKey(), nodeKey.publicKey, randomBytes32()))) + val invoice = Bolt12Invoice(offer, request, randomBytes32(), nodeKey, CltvExpiryDelta(20), Features.empty, Seq(createPaymentBlindedRoute(nodeKey.publicKey))) assert(!invoice.isExpired()) assert(invoice.isValidFor(offer, request)) val expiredInvoice1 = signInvoice(Bolt12Invoice(TlvStream(invoice.records.records.map { case CreatedAt(_) => CreatedAt(0 unixsec) case x => x })), nodeKey) @@ -171,14 +171,15 @@ class Bolt12InvoiceSpec extends AnyFunSuite { val (chain1, chain2) = (randomBytes32(), randomBytes32()) val offerBtc = Offer(Some(amount), "bitcoin offer", nodeKey.publicKey, Features.empty, Block.LivenetGenesisBlock.hash) val requestBtc = InvoiceRequest(offerBtc, amount, 1, Features.empty, payerKey, Block.LivenetGenesisBlock.hash) + val paymentBlindedRoute = createPaymentBlindedRoute(nodeKey.publicKey) val invoiceImplicitBtc = { val tlvs: Seq[InvoiceTlv] = Seq( CreatedAt(TimestampSecond.now()), PaymentHash(Crypto.sha256(randomBytes32())), OfferId(offerBtc.offerId), NodeId(nodeKey.publicKey), - Paths(Seq(createDirectPath(randomKey(), nodeKey.publicKey, randomBytes32()))), - PaymentPathsInfo(Seq(PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, amount, Features.empty))), + Paths(Seq(paymentBlindedRoute.route)), + PaymentPathsInfo(Seq(paymentBlindedRoute.paymentInfo)), Amount(amount), Description(offerBtc.description), PayerKey(payerKey.publicKey) @@ -194,8 +195,8 @@ class Bolt12InvoiceSpec extends AnyFunSuite { PaymentHash(Crypto.sha256(randomBytes32())), OfferId(offerBtc.offerId), NodeId(nodeKey.publicKey), - Paths(Seq(createDirectPath(randomKey(), nodeKey.publicKey, randomBytes32()))), - PaymentPathsInfo(Seq(PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, amount, Features.empty))), + Paths(Seq(paymentBlindedRoute.route)), + PaymentPathsInfo(Seq(paymentBlindedRoute.paymentInfo)), Amount(amount), Description(offerBtc.description), PayerKey(payerKey.publicKey) @@ -211,8 +212,8 @@ class Bolt12InvoiceSpec extends AnyFunSuite { PaymentHash(Crypto.sha256(randomBytes32())), OfferId(offerBtc.offerId), NodeId(nodeKey.publicKey), - Paths(Seq(createDirectPath(randomKey(), nodeKey.publicKey, randomBytes32()))), - PaymentPathsInfo(Seq(PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, amount, Features.empty))), + Paths(Seq(paymentBlindedRoute.route)), + PaymentPathsInfo(Seq(paymentBlindedRoute.paymentInfo)), Amount(amount), Description(offerBtc.description), PayerKey(payerKey.publicKey) @@ -230,8 +231,8 @@ class Bolt12InvoiceSpec extends AnyFunSuite { PaymentHash(Crypto.sha256(randomBytes32())), OfferId(offerOtherChains.offerId), NodeId(nodeKey.publicKey), - Paths(Seq(createDirectPath(randomKey(), nodeKey.publicKey, randomBytes32()))), - PaymentPathsInfo(Seq(PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, amount, Features.empty))), + Paths(Seq(paymentBlindedRoute.route)), + PaymentPathsInfo(Seq(paymentBlindedRoute.paymentInfo)), Amount(amount), Description(offerOtherChains.description), PayerKey(payerKey.publicKey) @@ -247,8 +248,8 @@ class Bolt12InvoiceSpec extends AnyFunSuite { PaymentHash(Crypto.sha256(randomBytes32())), OfferId(offerOtherChains.offerId), NodeId(nodeKey.publicKey), - Paths(Seq(createDirectPath(randomKey(), nodeKey.publicKey, randomBytes32()))), - PaymentPathsInfo(Seq(PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, amount, Features.empty))), + Paths(Seq(paymentBlindedRoute.route)), + PaymentPathsInfo(Seq(paymentBlindedRoute.paymentInfo)), Amount(amount), Description(offerOtherChains.description), PayerKey(payerKey.publicKey) @@ -267,7 +268,7 @@ class Bolt12InvoiceSpec extends AnyFunSuite { Amount(765432 msat), Description("minimal invoice"), NodeId(nodeKey.publicKey), - Paths(Seq(createDirectPath(randomKey(), randomKey().publicKey, randomBytes32()))), + Paths(Seq(createPaymentBlindedRoute(randomKey().publicKey).route)), PaymentPathsInfo(Seq(PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, 765432 msat, Features.empty))), CreatedAt(TimestampSecond(123456789L)), PaymentHash(randomBytes32()), @@ -296,7 +297,7 @@ class Bolt12InvoiceSpec extends AnyFunSuite { val features = Features[Feature](Features.VariableLengthOnion -> FeatureSupport.Mandatory) val issuer = "alice" val nodeKey = PrivateKey(hex"998cf8ecab46f949bb960813b79d3317cabf4193452a211795cd8af1b9a25d90") - val path = createDirectPath(PrivateKey(hex"f0442c17bdd2cefe4a4ede210f163b068bb3fea6113ffacea4f322de7aa9737b"), nodeKey.publicKey, hex"76030536ba732cdc4e7bb0a883750bab2e88cb3dddd042b1952c44b4849c86bb") + val path = createPaymentBlindedRoute(nodeKey.publicKey, PrivateKey(hex"f0442c17bdd2cefe4a4ede210f163b068bb3fea6113ffacea4f322de7aa9737b"), hex"76030536ba732cdc4e7bb0a883750bab2e88cb3dddd042b1952c44b4849c86bb").route val payInfo = PaymentInfo(2345 msat, 765, CltvExpiryDelta(324), 1000 msat, amount, Features.empty) val quantity = 57 val payerKey = ByteVector32.fromValidHex("8faadd71b1f78b16265e5b061b9d2b88891012dc7ad38626eeaaa2a271615a65") @@ -369,7 +370,7 @@ class Bolt12InvoiceSpec extends AnyFunSuite { assert(request.toString == encodedRequest) assert(InvoiceRequest.decode(encodedRequest).get == request) assert(request.isValidFor(offer)) - val invoice = Bolt12Invoice(offer, request, preimage, nodeKey, CltvExpiryDelta(22), Features.empty, Seq(createDirectPath(randomKey(), nodeKey.publicKey, hex""))) + val invoice = Bolt12Invoice(offer, request, preimage, nodeKey, CltvExpiryDelta(22), Features.empty, Seq(createPaymentBlindedRoute(nodeKey.publicKey))) assert(Bolt12Invoice.fromString(invoice.toString).get.records == invoice.records) assert(invoice.isValidFor(offer, request)) // Invoice generation is not reproducible as the timestamp and blinding point will change but all other fields should be the same. @@ -398,7 +399,7 @@ class Bolt12InvoiceSpec extends AnyFunSuite { assert(request.toString == encodedRequest) assert(InvoiceRequest.decode(encodedRequest).get == request) assert(request.isValidFor(offer)) - val invoice = Bolt12Invoice(offer, request, preimage, nodeKey, CltvExpiryDelta(22), Features.empty, Seq(createDirectPath(randomKey(), nodeKey.publicKey, hex"747e01a7152169b058a1fbc0024c254077db7e399308483e0c30e2352ba1d6cc"))) + val invoice = Bolt12Invoice(offer, request, preimage, nodeKey, CltvExpiryDelta(22), Features.empty, Seq(createPaymentBlindedRoute(nodeKey.publicKey))) assert(Bolt12Invoice.fromString(invoice.toString).get.records == invoice.records) assert(invoice.isValidFor(offer, request)) // Invoice generation is not reproducible as the timestamp and blinding point will change but all other fields should be the same. @@ -434,7 +435,7 @@ class Bolt12InvoiceSpec extends AnyFunSuite { assert(request.toString == encodedRequest) assert(InvoiceRequest.decode(encodedRequest).get == request) assert(request.isValidFor(offer)) - val invoice = Bolt12Invoice(offer, request, preimage, nodeKey, CltvExpiryDelta(34), Features.empty, Seq(createDirectPath(randomKey(), nodeKey.publicKey, hex"9134d86e269a13203bd85bb3fd05bf396b72fcb9fd5206e3a392f6a0ab94011d"))) + val invoice = Bolt12Invoice(offer, request, preimage, nodeKey, CltvExpiryDelta(34), Features.empty, Seq(createPaymentBlindedRoute(nodeKey.publicKey))) assert(Bolt12Invoice.fromString(invoice.toString).get.records == invoice.records) assert(invoice.isValidFor(offer, request)) // Invoice generation is not reproducible as the timestamp and blinding point will change but all other fields should be the same. diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala index 933a9198b4..23273af9ad 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala @@ -30,7 +30,9 @@ import fr.acinq.eclair.payment.PaymentReceived.PartialPayment import fr.acinq.eclair.payment.receive.MultiPartHandler._ import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM.HtlcPart import fr.acinq.eclair.payment.receive.{MultiPartPaymentFSM, PaymentHandler} -import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer} +import fr.acinq.eclair.router.Router +import fr.acinq.eclair.router.Router.RouteResponse +import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer, PaymentInfo} import fr.acinq.eclair.wire.protocol.OnionPaymentPayloadTlv.{AmountToForward, BlindingPoint, EncryptedRecipientData, OutgoingCltv} import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{PathId, PaymentConstraints} @@ -160,7 +162,10 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val privKey = randomKey() val offer = Offer(Some(amountMsat), "a blinded coffee please", privKey.publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, amountMsat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(privKey, offer, invoiceReq)) + val router = TestProbe() + val nodeId = randomKey().publicKey + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(privKey, offer, invoiceReq, Seq(Seq(nodeId)), router.ref)) + router.expectNoMessage() val invoice = sender.expectMsgType[Bolt12Invoice] assert(invoice.features.hasFeature(RouteBlinding, Some(Mandatory))) val pendingPayment = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.asInstanceOf[IncomingBlindedPayment] @@ -274,7 +279,19 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val privKey = randomKey() val offer = Offer(Some(25_000 msat), "a blinded coffee please", privKey.publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, 25_000 msat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(privKey, offer, invoiceReq)) + val router = TestProbe() + val (a, b, c, d) = (randomKey().publicKey, randomKey().publicKey, randomKey().publicKey, randomKey().publicKey) + val hop_ab = Router.ChannelHop(ShortChannelId(1), a, b, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(a, b, ShortChannelId(1), 1000 msat, 699, CltvExpiryDelta(123)))) + val hop_bc = Router.ChannelHop(ShortChannelId(2), b, c, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(b, c, ShortChannelId(2), 800 msat, 455, CltvExpiryDelta(78)))) + val hop_dc = Router.ChannelHop(ShortChannelId(3), d, c, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(c, d, ShortChannelId(3), 0 msat, 1700, CltvExpiryDelta(89)))) + val hop_cc = Router.ChannelHop(ShortChannelId.toSelf, c, c, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(c, c, ShortChannelId.toSelf, 0 msat, 0, CltvExpiryDelta(0)))) + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(privKey, offer, invoiceReq, Seq(Seq(a, b, c, c), Seq(d, c, c, c), Seq(c)), router.ref)) + val finalizeRoute1 = router.expectMsgType[Router.FinalizeRoute] + assert(finalizeRoute1.route == Router.PredefinedNodeRoute(Seq(a, b, c, c))) + router.send(router.lastSender, RouteResponse(Seq(Router.Route(finalizeRoute1.amount, Seq(hop_ab, hop_bc, hop_cc))))) + val finalizeRoute2 = router.expectMsgType[Router.FinalizeRoute] + assert(finalizeRoute2.route == Router.PredefinedNodeRoute(Seq(d, c, c, c))) + router.send(router.lastSender, RouteResponse(Seq(Router.Route(finalizeRoute2.amount, Seq(hop_dc, hop_cc, hop_cc))))) val invoice = sender.expectMsgType[Bolt12Invoice] assert(invoice.amount == 25_000.msat) assert(invoice.nodeId == privKey.publicKey) @@ -282,6 +299,16 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(invoice.features.hasFeature(RouteBlinding, Some(Mandatory))) assert(invoice.description == Left("a blinded coffee please")) assert(invoice.offerId.contains(offer.offerId)) + assert(invoice.extraEdges.length == 3) + assert(invoice.blindedPaths(0).blindedNodeIds.length == 4) + assert(invoice.blindedPaths(0).introductionNodeId == a) + assert(invoice.blindedPathsInfo(0) == PaymentInfo(1801 msat, 1155, CltvExpiryDelta(201), 0 msat, 25_000 msat, Features.empty)) + assert(invoice.blindedPaths(1).blindedNodeIds.length == 4) + assert(invoice.blindedPaths(1).introductionNodeId == d) + assert(invoice.blindedPathsInfo(1) == PaymentInfo(0 msat, 1700, CltvExpiryDelta(89), 0 msat, 25_000 msat, Features.empty)) + assert(invoice.blindedPaths(2).blindedNodeIds.length == 1) + assert(invoice.blindedPaths(2).introductionNodeId == c) + assert(invoice.blindedPathsInfo(2) == PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, 25_000 msat, Features.empty)) val pendingPayment = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.asInstanceOf[IncomingBlindedPayment] assert(pendingPayment.invoice == invoice) @@ -444,7 +471,10 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val offer = Offer(None, "a blinded coffee please", randomKey().publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, 5000 msat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq)) + val router = TestProbe() + val nodeId = randomKey().publicKey + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq, Seq(Seq(nodeId)), router.ref)) + router.expectNoMessage() val invoice = sender.expectMsgType[Bolt12Invoice] assert(invoice.features.hasFeature(RouteBlinding, Some(Mandatory))) @@ -460,7 +490,9 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val offer = Offer(None, "a blinded coffee please", randomKey().publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, 5000 msat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq)) + val router = TestProbe() + val nodeId = randomKey().publicKey + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq, Seq(Seq(nodeId)), router.ref)) val invoice = sender.expectMsgType[Bolt12Invoice] assert(invoice.features.hasFeature(RouteBlinding, Some(Mandatory))) val pathIds = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.asInstanceOf[IncomingBlindedPayment].pathIds @@ -477,7 +509,9 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val offer = Offer(None, "a blinded coffee please", randomKey().publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, 5000 msat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq)) + val router = TestProbe() + val nodeId = randomKey().publicKey + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq, Seq(Seq(nodeId)), router.ref)) val invoice = sender.expectMsgType[Bolt12Invoice] assert(invoice.features.hasFeature(RouteBlinding, Some(Mandatory))) val pathIds = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.asInstanceOf[IncomingBlindedPayment].pathIds @@ -496,7 +530,10 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val offer = Offer(None, "a blinded coffee please", randomKey().publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, 5000 msat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq)) + val router = TestProbe() + val nodeId = randomKey().publicKey + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq, Seq(Seq(nodeId)), router.ref)) + router.expectNoMessage() val invoice = sender.expectMsgType[Bolt12Invoice] assert(invoice.features.hasFeature(RouteBlinding, Some(Mandatory))) val pathIds = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.asInstanceOf[IncomingBlindedPayment].pathIds @@ -514,7 +551,10 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val offer = Offer(None, "a blinded coffee please", randomKey().publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, 5000 msat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq)) + val router = TestProbe() + val nodeId = randomKey().publicKey + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq, Seq(Seq(nodeId)), router.ref)) + router.expectNoMessage() val invoice = sender.expectMsgType[Bolt12Invoice] assert(invoice.features.hasFeature(RouteBlinding, Some(Mandatory))) val pathIds = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.asInstanceOf[IncomingBlindedPayment].pathIds @@ -532,7 +572,10 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val offer = Offer(None, "a blinded coffee please", randomKey().publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, 5000 msat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq)) + val router = TestProbe() + val nodeId = randomKey().publicKey + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq, Seq(Seq(nodeId)), router.ref)) + router.expectNoMessage() val invoice = sender.expectMsgType[Bolt12Invoice] assert(invoice.features.hasFeature(RouteBlinding, Some(Mandatory))) val pathIds = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.asInstanceOf[IncomingBlindedPayment].pathIds From 97220e599e778cd1ead44a155f3572547107ab0f Mon Sep 17 00:00:00 2001 From: Thomas HUET Date: Mon, 21 Nov 2022 16:59:28 +0100 Subject: [PATCH 2/9] pipeTo --- .../payment/receive/MultiPartHandler.scala | 104 +++++++++--------- 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala index 2b23e72110..6ad163e8fe 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala @@ -22,7 +22,8 @@ import akka.actor.typed.scaladsl.Behaviors import akka.actor.typed.scaladsl.adapter.ClassicActorContextOps import akka.actor.{ActorContext, ActorRef, PoisonPill, Status} import akka.event.{DiagnosticLoggingAdapter, LoggingAdapter} -import akka.pattern.ask +import akka.pattern.{ask, pipe} +import akka.util.Timeout import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto} import fr.acinq.eclair.Logs.LogCategory @@ -40,8 +41,8 @@ import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, FeatureSupport, Features, InvoiceFeature, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, TimestampMilli, randomBytes32, randomKey} import scodec.bits.HexStringSyntax -import scala.concurrent.Await import scala.concurrent.duration.DurationInt +import scala.concurrent.{ExecutionContextExecutor, Future} import scala.util.{Failure, Success, Try} /** @@ -276,7 +277,6 @@ object MultiPartHandler { Behaviors.setup { context => Behaviors.receiveMessage { case CreateInvoice(replyTo, receivePayment) => - Try { val paymentPreimage = receivePayment.paymentPreimage_opt.getOrElse(randomBytes32()) val paymentHash = Crypto.sha256(paymentPreimage) val featuresTrampolineOpt = if (nodeParams.enableTrampolinePayment) { @@ -286,47 +286,52 @@ object MultiPartHandler { } receivePayment match { case r: ReceiveStandardPayment => - val expirySeconds = r.expirySeconds_opt.getOrElse(nodeParams.invoiceExpiry.toSeconds) - val paymentMetadata = hex"2a" - val invoice = Bolt11Invoice( - nodeParams.chainHash, - r.amount_opt, - paymentHash, - nodeParams.privateKey, - r.description, - nodeParams.channelConf.minFinalExpiryDelta, - r.fallbackAddress_opt, - expirySeconds = Some(expirySeconds), - extraHops = r.extraHops, - paymentMetadata = Some(paymentMetadata), - features = featuresTrampolineOpt.remove(Features.RouteBlinding) - ) - context.log.debug("generated invoice={} from amount={}", invoice.toString, r.amount_opt) - nodeParams.db.payments.addIncomingPayment(invoice, paymentPreimage, r.paymentType) - invoice + Try { + val expirySeconds = r.expirySeconds_opt.getOrElse(nodeParams.invoiceExpiry.toSeconds) + val paymentMetadata = hex"2a" + val invoice = Bolt11Invoice( + nodeParams.chainHash, + r.amount_opt, + paymentHash, + nodeParams.privateKey, + r.description, + nodeParams.channelConf.minFinalExpiryDelta, + r.fallbackAddress_opt, + expirySeconds = Some(expirySeconds), + extraHops = r.extraHops, + paymentMetadata = Some(paymentMetadata), + features = featuresTrampolineOpt.remove(Features.RouteBlinding) + ) + context.log.debug("generated invoice={} from amount={}", invoice.toString, r.amount_opt) + nodeParams.db.payments.addIncomingPayment(invoice, paymentPreimage, r.paymentType) + invoice + } match { + case Success(invoice) => replyTo ! invoice + case Failure(exception) => replyTo ! Status.Failure(exception) + } case r: ReceiveOfferPayment => - val amount = r.invoiceRequest.amount.orElse(r.offer.amount.map(_ * r.invoiceRequest.quantity)).get - val paths = r.routes.map(nodeIds => { - require(nodeIds.nonEmpty, "route can't be empty") - val pathId = randomBytes32() - val finalExpiryDelta = nodeParams.channelConf.minFinalExpiryDelta + 3 - val finalConstraints = RouteBlindingEncryptedDataTlv.PaymentConstraints(finalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight), nodeParams.channelConf.htlcMinimum) - val zeroPaymentInfo = PaymentInfo(0 msat, 0, CltvExpiryDelta(0), nodeParams.channelConf.htlcMinimum, amount, Features.empty) - val finalPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( - finalConstraints, - RouteBlindingEncryptedDataTlv.PathId(pathId) - )).require.bytes - val (paymentInfo, payloads) = if (nodeIds.length > 1) { - val timeout = 30 second - val routeResponse = Await.result(r.router.ask(Router.FinalizeRoute(0 msat, Router.PredefinedNodeRoute(nodeIds)))(timeout).mapTo[Router.RouteResponse], timeout) + val amount = r.invoiceRequest.amount.orElse(r.offer.amount.map(_ * r.invoiceRequest.quantity)) + implicit val ec: ExecutionContextExecutor = context.executionContext + Future.sequence(r.routes.map(nodeIds => { + require(nodeIds.length > 1, "route must have at least one hop") + implicit val timeout: Timeout = 10.seconds + r.router.ask(Router.FinalizeRoute(0 msat, Router.PredefinedNodeRoute(nodeIds))).mapTo[Router.RouteResponse].map(routeResponse => { + val pathId = randomBytes32() + val finalExpiryDelta = nodeParams.channelConf.minFinalExpiryDelta + 3 + val finalConstraints = RouteBlindingEncryptedDataTlv.PaymentConstraints(finalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight), nodeParams.channelConf.htlcMinimum) + val zeroPaymentInfo = PaymentInfo(0 msat, 0, CltvExpiryDelta(0), nodeParams.channelConf.htlcMinimum, amount.get, Features.empty) + val finalPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( + finalConstraints, + RouteBlindingEncryptedDataTlv.PathId(pathId) + )).require.bytes val routeToBlind = routeResponse.routes.head val totalCltvDelta = routeToBlind.hops.map(_.cltvExpiryDelta).fold(finalExpiryDelta)(_ + _) - routeToBlind.hops.foldRight((zeroPaymentInfo, Seq(finalPayload))) { + val (paymentInfo, payloads) = routeToBlind.hops.foldRight((zeroPaymentInfo, Seq(finalPayload))) { case (channel: ChannelHop, (payInfo, nextPayloads)) => val newFeeBase = MilliSatoshi((channel.params.relayFees.feeBase.toLong * 1_000_000 + payInfo.feeBase.toLong * (1_000_000 + channel.params.relayFees.feeProportionalMillionths) + 1_000_000 - 1) / 1_000_000) val newFeeProp = ((payInfo.feeProportionalMillionths + channel.params.relayFees.feeProportionalMillionths) * 1_000_000 + payInfo.feeProportionalMillionths * channel.params.relayFees.feeProportionalMillionths + 1_000_000 - 1) / 1_000_000 // Because eclair (and others) lies about max HTLC, we remove 10% as a safety margin. - val channelMaxHtlc = channel.params.htlcMaximum_opt.map(_ * 0.9).getOrElse(amount) + val channelMaxHtlc = channel.params.htlcMaximum_opt.map(_ * 0.9).getOrElse(amount.get) val newPayInfo = PaymentInfo(newFeeBase, newFeeProp, payInfo.cltvExpiryDelta + channel.cltvExpiryDelta, payInfo.minHtlc.max(channel.params.htlcMinimum), payInfo.maxHtlc.min(channelMaxHtlc), payInfo.allowedFeatures) val payload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( RouteBlindingEncryptedDataTlv.OutgoingChannelId(channel.shortChannelId), @@ -335,22 +340,17 @@ object MultiPartHandler { )).require.bytes (newPayInfo, payload +: nextPayloads) } - } else { - (zeroPaymentInfo, Seq(finalPayload)) - } - val blindedRoute = Sphinx.RouteBlinding.create(randomKey(), nodeIds, payloads) - (blindedRoute, paymentInfo, pathId) - }) - val invoiceFeatures = featuresTrampolineOpt.remove(Features.RouteBlinding).add(Features.RouteBlinding, FeatureSupport.Mandatory) - val invoice = Bolt12Invoice(r.offer, r.invoiceRequest, paymentPreimage, r.nodeKey, nodeParams.channelConf.minFinalExpiryDelta, invoiceFeatures, paths.map { case (blindedRoute, paymentInfo, _) => PaymentBlindedRoute(blindedRoute.route, paymentInfo) }) - context.log.debug("generated invoice={} for offerId={}", invoice.toString, r.offer.offerId) - nodeParams.db.payments.addIncomingBlindedPayment(invoice, paymentPreimage, paths.map { case (blindedRoute, _, pathId) => (blindedRoute.lastBlinding -> pathId.bytes) }.toMap, r.paymentType) - invoice + val blindedRoute = Sphinx.RouteBlinding.create(randomKey(), nodeIds, payloads) + (blindedRoute, paymentInfo, pathId) + }) + })).map(paths => { + val invoiceFeatures = featuresTrampolineOpt.remove(Features.RouteBlinding).add(Features.RouteBlinding, FeatureSupport.Mandatory) + val invoice = Bolt12Invoice(r.offer, r.invoiceRequest, paymentPreimage, r.nodeKey, nodeParams.channelConf.minFinalExpiryDelta, invoiceFeatures, paths.map { case (blindedRoute, paymentInfo, _) => PaymentBlindedRoute(blindedRoute.route, paymentInfo) }) + context.log.debug("generated invoice={} for offerId={}", invoice.toString, r.offer.offerId) + nodeParams.db.payments.addIncomingBlindedPayment(invoice, paymentPreimage, paths.map { case (blindedRoute, _, pathId) => (blindedRoute.lastBlinding -> pathId.bytes) }.toMap, r.paymentType) + invoice + }).recover(exception => Status.Failure(exception)).pipeTo(replyTo) } - } match { - case Success(invoice) => replyTo ! invoice - case Failure(exception) => replyTo ! Status.Failure(exception) - } Behaviors.stopped } } From 9bf3b9ff4e7bfa3b3b48a9f3a657f40fd872bd4d Mon Sep 17 00:00:00 2001 From: Thomas HUET Date: Mon, 21 Nov 2022 17:02:19 +0100 Subject: [PATCH 3/9] Remove blinded payment tests --- .../BlindPaymentIntegrationSpec.scala | 298 ------------------ 1 file changed, 298 deletions(-) delete mode 100644 eclair-core/src/test/scala/fr/acinq/eclair/integration/BlindPaymentIntegrationSpec.scala diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/BlindPaymentIntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/BlindPaymentIntegrationSpec.scala deleted file mode 100644 index 5b1a3c73bf..0000000000 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/BlindPaymentIntegrationSpec.scala +++ /dev/null @@ -1,298 +0,0 @@ -/* - * Copyright 2022 ACINQ SAS - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package fr.acinq.eclair.integration - -import akka.actor.typed.scaladsl.adapter.actorRefAdapter -import akka.testkit.TestProbe -import com.typesafe.config.ConfigFactory -import fr.acinq.bitcoin.scalacompat.{Crypto, SatoshiLong} -import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher -import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher.{Watch, WatchFundingConfirmed} -import fr.acinq.eclair.channel._ -import fr.acinq.eclair.payment._ -import fr.acinq.eclair.payment.receive.MultiPartHandler.ReceiveOfferPayment -import fr.acinq.eclair.payment.relay.Relayer -import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentToNode -import fr.acinq.eclair.router.Router -import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer} -import fr.acinq.eclair.{Kit, MilliSatoshiLong, randomKey} - -import java.util.UUID -import scala.concurrent.duration._ -import scala.jdk.CollectionConverters._ - -class BlindPaymentIntegrationSpec extends IntegrationSpec { - - test("start eclair nodes") { - instantiateEclairNode("A", ConfigFactory.parseMap(Map("eclair.node-alias" -> "A", "eclair.channel.expiry-delta-blocks" -> 130, "eclair.server.port" -> 29730, "eclair.api.port" -> 28080, "eclair.features.option_route_blinding" -> "optional").asJava).withFallback(withDefaultCommitment).withFallback(commonConfig)) - instantiateEclairNode("B", ConfigFactory.parseMap(Map("eclair.node-alias" -> "B", "eclair.channel.expiry-delta-blocks" -> 131, "eclair.server.port" -> 29731, "eclair.api.port" -> 28081, "eclair.features.option_route_blinding" -> "optional").asJava).withFallback(withDefaultCommitment).withFallback(commonConfig)) - instantiateEclairNode("C", ConfigFactory.parseMap(Map("eclair.node-alias" -> "C", "eclair.channel.expiry-delta-blocks" -> 132, "eclair.server.port" -> 29732, "eclair.api.port" -> 28082, "eclair.features.option_route_blinding" -> "optional").asJava).withFallback(withDefaultCommitment).withFallback(commonConfig)) - instantiateEclairNode("D", ConfigFactory.parseMap(Map("eclair.node-alias" -> "D", "eclair.channel.expiry-delta-blocks" -> 133, "eclair.server.port" -> 29733, "eclair.api.port" -> 28083, "eclair.features.option_route_blinding" -> "optional").asJava).withFallback(withDefaultCommitment).withFallback(commonConfig)) - instantiateEclairNode("E", ConfigFactory.parseMap(Map("eclair.node-alias" -> "E", "eclair.channel.expiry-delta-blocks" -> 134, "eclair.server.port" -> 29734, "eclair.api.port" -> 28084, "eclair.features.option_route_blinding" -> "optional").asJava).withFallback(withDefaultCommitment).withFallback(commonConfig)) - instantiateEclairNode("F", ConfigFactory.parseMap(Map("eclair.node-alias" -> "F", "eclair.channel.expiry-delta-blocks" -> 135, "eclair.server.port" -> 29735, "eclair.api.port" -> 28085, "eclair.features.option_route_blinding" -> "optional").asJava).withFallback(withDefaultCommitment).withFallback(commonConfig)) - instantiateEclairNode("G", ConfigFactory.parseMap(Map("eclair.node-alias" -> "G", "eclair.channel.expiry-delta-blocks" -> 136, "eclair.server.port" -> 29736, "eclair.api.port" -> 28086, "eclair.features.option_route_blinding" -> "disabled").asJava).withFallback(withDefaultCommitment).withFallback(commonConfig)) // G does not support blinded routes. - } - - test("connect nodes") { - // ,--G--, - // / \ - // A---B ------- C ==== D - // \ / \ - // '--E--' \ - // \_____ F - // - // All channels have fees 1 sat + 200 millionths - - val sender = TestProbe() - val eventListener = TestProbe() - nodes.values.foreach(_.system.eventStream.subscribe(eventListener.ref, classOf[ChannelStateChanged])) - - connect(nodes("A"), nodes("B"), 11000000 sat, 0 msat) - connect(nodes("B"), nodes("C"), 2000000 sat, 0 msat) - connect(nodes("C"), nodes("D"), 5000000 sat, 0 msat) - connect(nodes("C"), nodes("D"), 5000000 sat, 0 msat) - connect(nodes("B"), nodes("E"), 10000000 sat, 0 msat) - connect(nodes("E"), nodes("C"), 10000000 sat, 0 msat) - connect(nodes("B"), nodes("G"), 16000000 sat, 0 msat) - connect(nodes("G"), nodes("C"), 16000000 sat, 0 msat) - connect(nodes("C"), nodes("F"), 2000000 sat, 0 msat) - connect(nodes("E"), nodes("F"), 2000000 sat, 0 msat) - - val numberOfChannels = 10 - val channelEndpointsCount = 2 * numberOfChannels - - // we make sure all channels have set up their WatchConfirmed for the funding tx - awaitCond({ - val watches = nodes.values.foldLeft(Set.empty[Watch[_]]) { - case (watches, setup) => - setup.watcher ! ZmqWatcher.ListWatches(sender.ref) - watches ++ sender.expectMsgType[Set[Watch[_]]] - } - watches.count(_.isInstanceOf[WatchFundingConfirmed]) == channelEndpointsCount - }, max = 20 seconds, interval = 1 second) - - // confirming the funding tx - generateBlocks(2) - - within(60 seconds) { - var count = 0 - while (count < channelEndpointsCount) { - if (eventListener.expectMsgType[ChannelStateChanged](60 seconds).currentState == NORMAL) count = count + 1 - } - } - } - - def awaitAnnouncements(subset: Map[String, Kit], nodes: Int, privateChannels: Int, publicChannels: Int, privateUpdates: Int, publicUpdates: Int): Unit = { - val sender = TestProbe() - subset.foreach { - case (node, setup) => - withClue(node) { - awaitAssert({ - sender.send(setup.router, Router.GetRouterData) - val data = sender.expectMsgType[Router.Data] - assert(data.nodes.size == nodes) - assert(data.privateChannels.size == privateChannels) - assert(data.channels.size == publicChannels) - assert(data.privateChannels.values.flatMap(pc => pc.update_1_opt.toSeq ++ pc.update_2_opt.toSeq).size == privateUpdates) - assert(data.channels.values.flatMap(pc => pc.update_1_opt.toSeq ++ pc.update_2_opt.toSeq).size == publicUpdates) - }, max = 10 seconds, interval = 1 second) - } - } - } - - test("wait for network announcements") { - // generating more blocks so that all funding txes are buried under at least 6 blocks - generateBlocks(4) - awaitAnnouncements(nodes.view.filterKeys(key => List("A", "B", "C", "D", "E", "G").contains(key)).toMap, nodes = 7, privateChannels = 0, publicChannels = 10, privateUpdates = 0, publicUpdates = 20) - } - - test("wait for channels balance") { - // Channels balance should now be available in the router - val sender = TestProbe() - val nodeId = nodes("C").nodeParams.nodeId - sender.send(nodes("C").router, Router.GetRoutingState) - val routingState = sender.expectMsgType[Router.RoutingState] - val publicChannels = routingState.channels.filter(pc => Set(pc.ann.nodeId1, pc.ann.nodeId2).contains(nodeId)) - assert(publicChannels.nonEmpty) - publicChannels.foreach(pc => assert(pc.meta_opt.exists(m => m.balance1 > 0.msat || m.balance2 > 0.msat), pc)) - } - - test("send an HTLC A->(D), minimal blinded route") { - val (sender, eventListener) = (TestProbe(), TestProbe()) - nodes("D").system.eventStream.subscribe(eventListener.ref, classOf[PaymentMetadataReceived]) - - val recipientKey = randomKey() - val payerKey = randomKey() - - // first we retrieve an invoice from D - val amount = 42000000 msat - val chain = nodes("D").nodeParams.chainHash - val offer = Offer(Some(amount), "test offer", recipientKey.publicKey, nodes("D").nodeParams.features.invoiceFeatures(), chain) - val invoiceRequest = InvoiceRequest(offer, amount, 1, nodes("A").nodeParams.features.invoiceFeatures(), payerKey, chain) - - sender.send(nodes("D").paymentHandler, ReceiveOfferPayment(recipientKey, offer, invoiceRequest, Seq(Seq(nodes("D").nodeParams.nodeId)), nodes("D").router)) - val invoice = sender.expectMsgType[Bolt12Invoice] - - // then we make the actual payment - sender.send(nodes("A").paymentInitiator, SendPaymentToNode(amount, invoice, routeParams = integrationTestRouteParams, maxAttempts = 1)) - val paymentId = sender.expectMsgType[UUID] - val ps = sender.expectMsgType[PaymentSent] - assert(ps.id == paymentId) - assert(Crypto.sha256(ps.paymentPreimage) == invoice.paymentHash) - } - - test("send an HTLC A->(D->D->D), blinded route with dummy hops") { - val (sender, eventListener) = (TestProbe(), TestProbe()) - nodes("D").system.eventStream.subscribe(eventListener.ref, classOf[PaymentMetadataReceived]) - - val recipientKey = randomKey() - val payerKey = randomKey() - - // first we retrieve an invoice from D - val amount = 5600000 msat - val chain = nodes("D").nodeParams.chainHash - val offer = Offer(Some(amount), "test offer", recipientKey.publicKey, nodes("D").nodeParams.features.invoiceFeatures(), chain) - val invoiceRequest = InvoiceRequest(offer, amount, 1, nodes("A").nodeParams.features.invoiceFeatures(), payerKey, chain) - - sender.send(nodes("D").paymentHandler, ReceiveOfferPayment(recipientKey, offer, invoiceRequest, Seq(Seq(nodes("D").nodeParams.nodeId, nodes("D").nodeParams.nodeId, nodes("D").nodeParams.nodeId)), nodes("D").router)) - val invoice = sender.expectMsgType[Bolt12Invoice] - - // then we make the actual payment - sender.send(nodes("A").paymentInitiator, SendPaymentToNode(amount, invoice, routeParams = integrationTestRouteParams, maxAttempts = 1)) - val paymentId = sender.expectMsgType[UUID] - val ps = sender.expectMsgType[PaymentSent] - assert(ps.id == paymentId) - assert(Crypto.sha256(ps.paymentPreimage) == invoice.paymentHash) - } - - test("send an HTLC A->(E->C->D), blinded route with intermediate hops") { - val (sender, eventListener) = (TestProbe(), TestProbe()) - nodes("D").system.eventStream.subscribe(eventListener.ref, classOf[PaymentMetadataReceived]) - - val recipientKey = randomKey() - val payerKey = randomKey() - - // first we retrieve an invoice from D - val amount = 7500000 msat - val chain = nodes("D").nodeParams.chainHash - val offer = Offer(Some(amount), "test offer", recipientKey.publicKey, nodes("D").nodeParams.features.invoiceFeatures(), chain) - val invoiceRequest = InvoiceRequest(offer, amount, 1, nodes("A").nodeParams.features.invoiceFeatures(), payerKey, chain) - - sender.send(nodes("D").paymentHandler, ReceiveOfferPayment(recipientKey, offer, invoiceRequest, Seq(Seq(nodes("E").nodeParams.nodeId, nodes("C").nodeParams.nodeId, nodes("D").nodeParams.nodeId)), nodes("D").router)) - val invoice = sender.expectMsgType[Bolt12Invoice] - - // then we make the actual payment - sender.send(nodes("A").paymentInitiator, SendPaymentToNode(amount, invoice, routeParams = integrationTestRouteParams, maxAttempts = 1)) - val paymentId = sender.expectMsgType[UUID] - val ps = sender.expectMsgType[PaymentSent] - assert(ps.id == paymentId) - assert(Crypto.sha256(ps.paymentPreimage) == invoice.paymentHash) - } - - test("send an HTLC A->(G->C->D), blinded route with node not supporting blinded routes") { - val (sender, eventListener) = (TestProbe(), TestProbe()) - nodes("D").system.eventStream.subscribe(eventListener.ref, classOf[PaymentMetadataReceived]) - - val recipientKey = randomKey() - val payerKey = randomKey() - - // first we retrieve an invoice from D - val amount = 7500000 msat - val chain = nodes("D").nodeParams.chainHash - val offer = Offer(Some(amount), "test offer", recipientKey.publicKey, nodes("D").nodeParams.features.invoiceFeatures(), chain) - val invoiceRequest = InvoiceRequest(offer, amount, 1, nodes("A").nodeParams.features.invoiceFeatures(), payerKey, chain) - - sender.send(nodes("D").paymentHandler, ReceiveOfferPayment(recipientKey, offer, invoiceRequest, Seq(Seq(nodes("G").nodeParams.nodeId, nodes("C").nodeParams.nodeId, nodes("D").nodeParams.nodeId)), nodes("D").router)) - val invoice = sender.expectMsgType[Bolt12Invoice] - - // then we make the actual payment - sender.send(nodes("A").paymentInitiator, SendPaymentToNode(amount, invoice, routeParams = integrationTestRouteParams, maxAttempts = 1)) - sender.expectMsgType[UUID] - sender.expectMsgType[PaymentFailed] - } - - test("send an HTLC A->(A->B->C->D), blinded route with introduction node being the sender") { - val (sender, eventListener) = (TestProbe(), TestProbe()) - nodes("D").system.eventStream.subscribe(eventListener.ref, classOf[PaymentMetadataReceived]) - - val recipientKey = randomKey() - val payerKey = randomKey() - - // first we retrieve an invoice from D - val amount = 3200000 msat - val chain = nodes("D").nodeParams.chainHash - val offer = Offer(Some(amount), "test offer", recipientKey.publicKey, nodes("D").nodeParams.features.invoiceFeatures(), chain) - val invoiceRequest = InvoiceRequest(offer, amount, 1, nodes("A").nodeParams.features.invoiceFeatures(), payerKey, chain) - - sender.send(nodes("D").paymentHandler, ReceiveOfferPayment(recipientKey, offer, invoiceRequest, Seq(Seq(nodes("A").nodeParams.nodeId, nodes("B").nodeParams.nodeId, nodes("C").nodeParams.nodeId, nodes("D").nodeParams.nodeId)), nodes("D").router)) - val invoice = sender.expectMsgType[Bolt12Invoice] - - // then we make the actual payment - sender.send(nodes("A").paymentInitiator, SendPaymentToNode(amount, invoice, routeParams = integrationTestRouteParams, maxAttempts = 1)) - val paymentId = sender.expectMsgType[UUID] - val ps = sender.expectMsgType[PaymentSent] - assert(ps.id == paymentId) - assert(Crypto.sha256(ps.paymentPreimage) == invoice.paymentHash) - } - - test("send to multiple blinded routes: "){ - val (sender, eventListener) = (TestProbe(), TestProbe()) - nodes("F").system.eventStream.subscribe(eventListener.ref, classOf[PaymentMetadataReceived]) - - val recipientKey = randomKey() - val payerKey = randomKey() - - // first we retrieve an invoice from D - val amount = 3_000_000_000L msat - val chain = nodes("F").nodeParams.chainHash - val offer = Offer(Some(amount), "test offer", recipientKey.publicKey, nodes("F").nodeParams.features.invoiceFeatures(), chain) - val invoiceRequest = InvoiceRequest(offer, amount, 1, nodes("A").nodeParams.features.invoiceFeatures(), payerKey, chain) - - sender.send(nodes("F").paymentHandler, ReceiveOfferPayment(recipientKey, offer, invoiceRequest, Seq(Seq(nodes("E").nodeParams.nodeId, nodes("F").nodeParams.nodeId), Seq(nodes("C").nodeParams.nodeId, nodes("F").nodeParams.nodeId)), nodes("F").router)) - val invoice = sender.expectMsgType[Bolt12Invoice] - - // then we make the actual payment - sender.send(nodes("A").paymentInitiator, SendPaymentToNode(amount, invoice, routeParams = integrationTestRouteParams, maxAttempts = 1)) - val paymentId = sender.expectMsgType[UUID] - val ps = sender.expectMsgType[PaymentSent] - assert(ps.id == paymentId) - assert(Crypto.sha256(ps.paymentPreimage) == invoice.paymentHash) - } - - // To show channel balances before adding new tests - def debugChannelBalances(): Unit = { - val names = nodes.map { case (name, kit) => (kit.nodeParams.nodeId, name) } - val sender = TestProbe() - sender.send(nodes("A").relayer, Relayer.GetOutgoingChannels()) - sender.expectMsgType[Relayer.OutgoingChannels].channels.map(_.toChannelBalance).foreach { balance => - println(s"A ${balance.canSend} ===== ${balance.canReceive} ${names(balance.remoteNodeId)}") - } - sender.send(nodes("B").relayer, Relayer.GetOutgoingChannels()) - sender.expectMsgType[Relayer.OutgoingChannels].channels.map(_.toChannelBalance).foreach { balance => - println(s"B ${balance.canSend} ===== ${balance.canReceive} ${names(balance.remoteNodeId)}") - } - sender.send(nodes("C").relayer, Relayer.GetOutgoingChannels()) - sender.expectMsgType[Relayer.OutgoingChannels].channels.map(_.toChannelBalance).foreach{ balance => - println(s"C ${balance.canSend} ===== ${balance.canReceive} ${names(balance.remoteNodeId)}") - } - sender.send(nodes("E").relayer, Relayer.GetOutgoingChannels()) - sender.expectMsgType[Relayer.OutgoingChannels].channels.map(_.toChannelBalance).foreach{ balance => - println(s"E ${balance.canSend} ===== ${balance.canReceive} ${names(balance.remoteNodeId)}") - } - names.foreach(println) - } -} From d675c6b39ff74d67893b80182cf01380f4481d34 Mon Sep 17 00:00:00 2001 From: Thomas HUET Date: Tue, 22 Nov 2022 16:31:42 +0100 Subject: [PATCH 4/9] refactor --- .../payment/receive/MultiPartHandler.scala | 94 +++++++++++-------- .../eclair/router/RouteCalculation.scala | 9 +- .../eclair/wire/protocol/OfferTypes.scala | 6 +- .../eclair/payment/MultiPartHandlerSpec.scala | 27 +++++- 4 files changed, 94 insertions(+), 42 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala index 6ad163e8fe..d956de991e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala @@ -29,6 +29,7 @@ import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto} import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, RES_SUCCESS} import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRouteDetails import fr.acinq.eclair.db._ import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} @@ -38,8 +39,8 @@ import fr.acinq.eclair.router.Router.ChannelHop import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer, PaymentInfo} import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, FeatureSupport, Features, InvoiceFeature, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, TimestampMilli, randomBytes32, randomKey} -import scodec.bits.HexStringSyntax +import fr.acinq.eclair.{CltvExpiryDelta, FeatureSupport, Features, InvoiceFeature, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, TimestampMilli, randomBytes32, randomKey} +import scodec.bits.{ByteVector, HexStringSyntax} import scala.concurrent.duration.DurationInt import scala.concurrent.{ExecutionContextExecutor, Future} @@ -273,15 +274,48 @@ object MultiPartHandler { case class CreateInvoice(replyTo: ActorRef, receivePayment: ReceivePayment) extends Command // @formatter:on + def blindedRouteFromHops(nodeParams: NodeParams, hops: Seq[ChannelHop], nodeIds: Seq[PublicKey], pathId: ByteVector): BlindedRouteDetails = { + val finalExpiryDelta = nodeParams.channelConf.minFinalExpiryDelta + 500 // We let the sender add up to 500 blocks to the CLTV. + val finalConstraints = RouteBlindingEncryptedDataTlv.PaymentConstraints(finalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight), nodeParams.channelConf.htlcMinimum) + val finalPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( + finalConstraints, + RouteBlindingEncryptedDataTlv.PathId(pathId) + )).require.bytes + val maxCltvExpiry = hops.map(_.cltvExpiryDelta).fold(finalExpiryDelta)(_ + _).toCltvExpiry(nodeParams.currentBlockHeight) // Same CLTV for all nodes so they can't use it to guess their position. + val payloads = hops.foldRight(Seq(finalPayload)) { + case (channel: ChannelHop, nextPayloads) => + // Because eclair (and others) lies about max HTLC, we remove 10% as a safety margin. + val payload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( + RouteBlindingEncryptedDataTlv.OutgoingChannelId(channel.shortChannelId), + RouteBlindingEncryptedDataTlv.PaymentRelay(channel.cltvExpiryDelta, channel.params.relayFees.feeProportionalMillionths, channel.params.relayFees.feeBase), + RouteBlindingEncryptedDataTlv.PaymentConstraints(maxCltvExpiry, channel.params.htlcMinimum) + )).require.bytes + payload +: nextPayloads + } + Sphinx.RouteBlinding.create(randomKey(), nodeIds, payloads) + } + + def aggregatePayInfo(route: Router.Route): PaymentInfo = { + val zeroPaymentInfo = PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, route.amount, Features.empty) + route.hops.foldRight(zeroPaymentInfo) { + case (channel: ChannelHop, payInfo) => + val newFeeBase = MilliSatoshi((channel.params.relayFees.feeBase.toLong * 1_000_000 + payInfo.feeBase.toLong * (1_000_000 + channel.params.relayFees.feeProportionalMillionths) + 1_000_000 - 1) / 1_000_000) + val newFeeProp = ((payInfo.feeProportionalMillionths + channel.params.relayFees.feeProportionalMillionths) * 1_000_000 + payInfo.feeProportionalMillionths * channel.params.relayFees.feeProportionalMillionths + 1_000_000 - 1) / 1_000_000 + // Because eclair (and others) lies about max HTLC, we remove 10% as a safety margin. + val channelMaxHtlc = channel.params.htlcMaximum_opt.map(_ * 0.9).getOrElse(route.amount) + PaymentInfo(newFeeBase, newFeeProp, payInfo.cltvExpiryDelta + channel.cltvExpiryDelta, payInfo.minHtlc.max(channel.params.htlcMinimum), payInfo.maxHtlc.min(channelMaxHtlc), payInfo.allowedFeatures) + } + } + def apply(nodeParams: NodeParams): Behavior[Command] = { Behaviors.setup { context => Behaviors.receiveMessage { case CreateInvoice(replyTo, receivePayment) => - val paymentPreimage = receivePayment.paymentPreimage_opt.getOrElse(randomBytes32()) - val paymentHash = Crypto.sha256(paymentPreimage) - val featuresTrampolineOpt = if (nodeParams.enableTrampolinePayment) { - nodeParams.features.invoiceFeatures().add(Features.TrampolinePaymentPrototype, FeatureSupport.Optional) - } else { + val paymentPreimage = receivePayment.paymentPreimage_opt.getOrElse(randomBytes32()) + val paymentHash = Crypto.sha256(paymentPreimage) + val featuresTrampolineOpt = if (nodeParams.enableTrampolinePayment) { + nodeParams.features.invoiceFeatures().add(Features.TrampolinePaymentPrototype, FeatureSupport.Optional) + } else { nodeParams.features.invoiceFeatures() } receivePayment match { @@ -312,41 +346,25 @@ object MultiPartHandler { case r: ReceiveOfferPayment => val amount = r.invoiceRequest.amount.orElse(r.offer.amount.map(_ * r.invoiceRequest.quantity)) implicit val ec: ExecutionContextExecutor = context.executionContext + val log = context.log Future.sequence(r.routes.map(nodeIds => { - require(nodeIds.length > 1, "route must have at least one hop") - implicit val timeout: Timeout = 10.seconds - r.router.ask(Router.FinalizeRoute(0 msat, Router.PredefinedNodeRoute(nodeIds))).mapTo[Router.RouteResponse].map(routeResponse => { - val pathId = randomBytes32() - val finalExpiryDelta = nodeParams.channelConf.minFinalExpiryDelta + 3 - val finalConstraints = RouteBlindingEncryptedDataTlv.PaymentConstraints(finalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight), nodeParams.channelConf.htlcMinimum) - val zeroPaymentInfo = PaymentInfo(0 msat, 0, CltvExpiryDelta(0), nodeParams.channelConf.htlcMinimum, amount.get, Features.empty) - val finalPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( - finalConstraints, - RouteBlindingEncryptedDataTlv.PathId(pathId) - )).require.bytes - val routeToBlind = routeResponse.routes.head - val totalCltvDelta = routeToBlind.hops.map(_.cltvExpiryDelta).fold(finalExpiryDelta)(_ + _) - val (paymentInfo, payloads) = routeToBlind.hops.foldRight((zeroPaymentInfo, Seq(finalPayload))) { - case (channel: ChannelHop, (payInfo, nextPayloads)) => - val newFeeBase = MilliSatoshi((channel.params.relayFees.feeBase.toLong * 1_000_000 + payInfo.feeBase.toLong * (1_000_000 + channel.params.relayFees.feeProportionalMillionths) + 1_000_000 - 1) / 1_000_000) - val newFeeProp = ((payInfo.feeProportionalMillionths + channel.params.relayFees.feeProportionalMillionths) * 1_000_000 + payInfo.feeProportionalMillionths * channel.params.relayFees.feeProportionalMillionths + 1_000_000 - 1) / 1_000_000 - // Because eclair (and others) lies about max HTLC, we remove 10% as a safety margin. - val channelMaxHtlc = channel.params.htlcMaximum_opt.map(_ * 0.9).getOrElse(amount.get) - val newPayInfo = PaymentInfo(newFeeBase, newFeeProp, payInfo.cltvExpiryDelta + channel.cltvExpiryDelta, payInfo.minHtlc.max(channel.params.htlcMinimum), payInfo.maxHtlc.min(channelMaxHtlc), payInfo.allowedFeatures) - val payload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( - RouteBlindingEncryptedDataTlv.OutgoingChannelId(channel.shortChannelId), - RouteBlindingEncryptedDataTlv.PaymentRelay(channel.cltvExpiryDelta, channel.params.relayFees.feeProportionalMillionths, channel.params.relayFees.feeBase), - RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(nodeParams.currentBlockHeight) + totalCltvDelta, channel.params.htlcMinimum) - )).require.bytes - (newPayInfo, payload +: nextPayloads) - } - val blindedRoute = Sphinx.RouteBlinding.create(randomKey(), nodeIds, payloads) - (blindedRoute, paymentInfo, pathId) - }) + val pathId = randomBytes32() + if (nodeIds.length == 1) { + Future.successful( + (blindedRouteFromHops(nodeParams, Nil, nodeIds, pathId), + PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, amount.get, Features.empty), + pathId)) + } else { + implicit val timeout: Timeout = 10.seconds + r.router.ask(Router.FinalizeRoute(amount.get, Router.PredefinedNodeRoute(nodeIds))).mapTo[Router.RouteResponse].map(routeResponse => { + val route = routeResponse.routes.head + (blindedRouteFromHops(nodeParams, route.hops, nodeIds, pathId), aggregatePayInfo(route), pathId) + }) + } })).map(paths => { val invoiceFeatures = featuresTrampolineOpt.remove(Features.RouteBlinding).add(Features.RouteBlinding, FeatureSupport.Mandatory) val invoice = Bolt12Invoice(r.offer, r.invoiceRequest, paymentPreimage, r.nodeKey, nodeParams.channelConf.minFinalExpiryDelta, invoiceFeatures, paths.map { case (blindedRoute, paymentInfo, _) => PaymentBlindedRoute(blindedRoute.route, paymentInfo) }) - context.log.debug("generated invoice={} for offerId={}", invoice.toString, r.offer.offerId) + log.debug("generated invoice={} for offerId={}", invoice.toString, r.offer.offerId) nodeParams.db.payments.addIncomingBlindedPayment(invoice, paymentPreimage, paths.map { case (blindedRoute, _, pathId) => (blindedRoute.lastBlinding -> pathId.bytes) }.toMap, r.paymentType) invoice }).recover(exception => Status.Failure(exception)).pipeTo(replyTo) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala index b15763634b..c592af06ac 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala @@ -22,6 +22,7 @@ import com.softwaremill.quicklens.ModifyPimp import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair._ +import fr.acinq.eclair.payment.Invoice.BasicEdge import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph.graphEdgeToHop import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} import fr.acinq.eclair.router.Graph.{InfiniteLoop, NegativeProbability, RichWeight} @@ -48,7 +49,13 @@ object RouteCalculation { fr.route match { case PredefinedNodeRoute(hops) => // split into sublists [(a,b),(b,c), ...] then get the edges between each of those pairs - hops.sliding(2).map { case List(v1, v2) => g.getEdgesBetween(v1, v2) }.toList match { + hops.sliding(2).map { + case List(v1, v2) => if (v1 == localNodeId && v2 == localNodeId) { + Seq(GraphEdge(BasicEdge(localNodeId, localNodeId, ShortChannelId.toSelf, 0 msat, 0, CltvExpiryDelta(0)))) + } else { + g.getEdgesBetween(v1, v2) + } + }.toList match { case edges if edges.nonEmpty && edges.forall(_.nonEmpty) => // select the largest edge (using balance when available, otherwise capacity). val selectedEdges = edges.map(es => es.maxBy(e => e.balance_opt.getOrElse(e.capacity.toMilliSatoshi))) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala index 61c3c047eb..abb284cb18 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala @@ -22,7 +22,7 @@ import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, ByteVector64, Crypto, import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.TlvCodecs.genericTlv -import fr.acinq.eclair.{CltvExpiryDelta, Feature, Features, InvoiceFeature, MilliSatoshi, TimestampSecond, UInt64} +import fr.acinq.eclair.{CltvExpiryDelta, Feature, Features, InvoiceFeature, MilliSatoshi, TimestampSecond, UInt64, nodeFee} import fr.acinq.secp256k1.Secp256k1JvmKt import scodec.Codec import scodec.bits.ByteVector @@ -65,7 +65,9 @@ object OfferTypes { cltvExpiryDelta: CltvExpiryDelta, minHtlc: MilliSatoshi, maxHtlc: MilliSatoshi, - allowedFeatures: Features[Feature]) + allowedFeatures: Features[Feature]) { + def fee(amount: MilliSatoshi): MilliSatoshi = nodeFee(feeBase, feeProportionalMillionths, amount) + } case class PaymentPathsInfo(paymentInfo: Seq[PaymentInfo]) extends InvoiceTlv diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala index 23273af9ad..146d3a09c7 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala @@ -26,6 +26,7 @@ import fr.acinq.eclair.TestConstants.Alice import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Register} import fr.acinq.eclair.db.{IncomingBlindedPayment, IncomingPaymentStatus} import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop +import fr.acinq.eclair.payment.Invoice.BasicEdge import fr.acinq.eclair.payment.PaymentReceived.PartialPayment import fr.acinq.eclair.payment.receive.MultiPartHandler._ import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM.HtlcPart @@ -273,6 +274,30 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike } } + test("Aggregate route fees") { f => + val rand = new scala.util.Random + for (_ <- 0 to 100) { + val routeLength = rand.nextInt(10) + 1 + val hops = + for (_ <- 1 to routeLength; + scid = ShortChannelId.generateLocalAlias(); + nid = randomKey().publicKey; + params = Router.ChannelRelayParams.FromHint(BasicEdge(nid, nid, scid, MilliSatoshi(rand.nextLong(10_000)), rand.nextInt(5000), CltvExpiryDelta(0)))) + yield Router.ChannelHop(scid, nid, nid, params) + val route = Router.Route(0 msat, hops) + val aggregate = CreateInvoiceActor.aggregatePayInfo(route) + for (_ <- 0 to 100) { + val amount = MilliSatoshi(rand.nextLong(10_000_000_000L)) + val fee1 = aggregate.fee(amount) + val fee2 = route.copy(amount = amount).fee(true) + // The aggregated fees are always enough + assert(fee1 >= fee2, s"amount=$amount, route=${route.hops.map(_.params.relayFees)}, aggregate=$aggregate") + // and we don't overpay too much. + assert(fee1 - fee2 < 1000.msat.max(amount * 1e-5), s"amount=$amount, route=${route.hops.map(_.params.relayFees)}, aggregate=$aggregate") + } + } + } + test("Invoice generation with route blinding support") { f => import f._ @@ -299,7 +324,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(invoice.features.hasFeature(RouteBlinding, Some(Mandatory))) assert(invoice.description == Left("a blinded coffee please")) assert(invoice.offerId.contains(offer.offerId)) - assert(invoice.extraEdges.length == 3) + assert(invoice.blindedPaths.length == 3) assert(invoice.blindedPaths(0).blindedNodeIds.length == 4) assert(invoice.blindedPaths(0).introductionNodeId == a) assert(invoice.blindedPathsInfo(0) == PaymentInfo(1801 msat, 1155, CltvExpiryDelta(201), 0 msat, 25_000 msat, Features.empty)) From aa40ddb41a030d0e78432c3d989c7d5ef9c681c3 Mon Sep 17 00:00:00 2001 From: Thomas HUET Date: Wed, 23 Nov 2022 15:49:12 +0100 Subject: [PATCH 5/9] more tests --- .../eclair/payment/MultiPartHandlerSpec.scala | 71 +++++++++++++++++-- 1 file changed, 67 insertions(+), 4 deletions(-) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala index 146d3a09c7..d7e3847efe 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala @@ -16,6 +16,7 @@ package fr.acinq.eclair.payment +import akka.actor.Status import akka.actor.Status.Failure import akka.testkit.{TestActorRef, TestProbe} import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey @@ -32,13 +33,13 @@ import fr.acinq.eclair.payment.receive.MultiPartHandler._ import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM.HtlcPart import fr.acinq.eclair.payment.receive.{MultiPartPaymentFSM, PaymentHandler} import fr.acinq.eclair.router.Router -import fr.acinq.eclair.router.Router.RouteResponse +import fr.acinq.eclair.router.Router.{ChannelRelayParams, RouteResponse} import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer, PaymentInfo} import fr.acinq.eclair.wire.protocol.OnionPaymentPayloadTlv.{AmountToForward, BlindingPoint, EncryptedRecipientData, OutgoingCltv} import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload -import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{PathId, PaymentConstraints} +import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{OutgoingChannelId, PathId, PaymentConstraints, PaymentRelay} import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, ShortChannelId, TestConstants, TestKitBaseClass, TimestampMilliLong, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, ShortChannelId, TestConstants, TestKitBaseClass, TimestampMilli, TimestampMilliLong, randomBytes32, randomKey} import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike import scodec.bits.{ByteVector, HexStringSyntax} @@ -274,7 +275,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike } } - test("Aggregate route fees") { f => + test("Aggregate route fees") { _ => val rand = new scala.util.Random for (_ <- 0 to 100) { val routeLength = rand.nextInt(10) + 1 @@ -298,6 +299,50 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike } } + test("Generate blinded route from zero hop"){f => + import f._ + + val a = randomKey() + val pathId = randomBytes32() + val route = CreateInvoiceActor.blindedRouteFromHops(nodeParams, Nil, Seq(a.publicKey), pathId) + assert(route.route.introductionNodeId == a.publicKey) + assert(route.route.encryptedPayloads.length == 1) + assert(route.route.blindingKey == route.lastBlinding) + val Right(decoded) = RouteBlindingEncryptedDataCodecs.decode(a, route.route.blindingKey, route.route.encryptedPayloads.head) + assert(BlindedRouteData.validPaymentRecipientData(decoded.tlvs).isRight) + assert(decoded.tlvs.get[PathId].get.data == pathId.bytes) + } + + test("Generate blinded route from hops"){f => + import f._ + + val (a, b, c) = (randomKey(), randomKey(), randomKey()) + val pathId = randomBytes32() + val (channelId1, channelId2) = (ShortChannelId.generateLocalAlias(), ShortChannelId.generateLocalAlias()) + val hops = Seq( + Router.ChannelHop(channelId1, a.publicKey, b.publicKey, ChannelRelayParams.FromHint(Invoice.BasicEdge(a.publicKey, b.publicKey, channelId1, 10 msat, 300, CltvExpiryDelta(200)))), + Router.ChannelHop(channelId2, b.publicKey, c.publicKey, ChannelRelayParams.FromHint(Invoice.BasicEdge(b.publicKey, c.publicKey, channelId2, 20 msat, 150, CltvExpiryDelta(600)))), + ) + val route = CreateInvoiceActor.blindedRouteFromHops(nodeParams, hops, Seq(a.publicKey, b.publicKey, c.publicKey), pathId) + assert(route.route.introductionNodeId == a.publicKey) + assert(route.route.encryptedPayloads.length == 3) + val Right(decoded1) = RouteBlindingEncryptedDataCodecs.decode(a, route.route.blindingKey, route.route.encryptedPayloads(0)) + assert(BlindedRouteData.validatePaymentRelayData(decoded1.tlvs).isRight) + assert(decoded1.tlvs.get[OutgoingChannelId].get.shortChannelId == channelId1) + assert(decoded1.tlvs.get[PaymentRelay].get.feeBase == 10.msat) + assert(decoded1.tlvs.get[PaymentRelay].get.feeProportionalMillionths == 300) + assert(decoded1.tlvs.get[PaymentRelay].get.cltvExpiryDelta == CltvExpiryDelta(200)) + val Right(decoded2) = RouteBlindingEncryptedDataCodecs.decode(b, decoded1.nextBlinding, route.route.encryptedPayloads(1)) + assert(BlindedRouteData.validatePaymentRelayData(decoded2.tlvs).isRight) + assert(decoded2.tlvs.get[OutgoingChannelId].get.shortChannelId == channelId2) + assert(decoded2.tlvs.get[PaymentRelay].get.feeBase == 20.msat) + assert(decoded2.tlvs.get[PaymentRelay].get.feeProportionalMillionths == 150) + assert(decoded2.tlvs.get[PaymentRelay].get.cltvExpiryDelta == CltvExpiryDelta(600)) + val Right(decoded3) = RouteBlindingEncryptedDataCodecs.decode(c, decoded2.nextBlinding, route.route.encryptedPayloads(2)) + assert(BlindedRouteData.validPaymentRecipientData(decoded3.tlvs).isRight) + assert(decoded3.tlvs.get[PathId].get.data == pathId.bytes) + } + test("Invoice generation with route blinding support") { f => import f._ @@ -342,6 +387,24 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike pendingPayment.pathIds.values.foreach(pathId => assert(pathId.length == 32)) } + test("Invoice generation with route blinding - incorrect route") { f => + import f._ + + val privKey = randomKey() + val offer = Offer(Some(25_000 msat), "a blinded coffee please", privKey.publicKey, Features.empty, Block.RegtestGenesisBlock.hash) + val invoiceReq = InvoiceRequest(offer, 25_000 msat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) + val router = TestProbe() + val (a, b, c) = (randomKey().publicKey, randomKey().publicKey, randomKey().publicKey) + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(privKey, offer, invoiceReq, Seq(Seq(a, b, c)), router.ref)) + val finalizeRoute1 = router.expectMsgType[Router.FinalizeRoute] + assert(finalizeRoute1.route == Router.PredefinedNodeRoute(Seq(a, b, c))) + router.send(router.lastSender, Status.Failure(new IllegalArgumentException("Not all the nodes in the supplied route are connected with public channels"))) + sender.expectMsgType[Status.Failure] + + val pendingPayments = nodeParams.db.payments.listIncomingPayments(TimestampMilli.min, TimestampMilli.max, None) + assert(pendingPayments.isEmpty) + } + test("Generated invoice contains the provided extra hops") { f => import f._ From 1543309988f3dca646fa57081c9f6376fdbc8d9e Mon Sep 17 00:00:00 2001 From: Thomas HUET Date: Wed, 23 Nov 2022 15:55:17 +0100 Subject: [PATCH 6/9] require --- .../fr/acinq/eclair/payment/receive/MultiPartHandler.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala index d956de991e..c9f8a9db1d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala @@ -265,7 +265,9 @@ object MultiPartHandler { routes: Seq[Seq[PublicKey]], router: ActorRef, paymentPreimage_opt: Option[ByteVector32] = None, - paymentType: String = PaymentType.Blinded) extends ReceivePayment + paymentType: String = PaymentType.Blinded) extends ReceivePayment { + require(routes.forall(_.nonEmpty), "each route must have at least one node") + } object CreateInvoiceActor { From 03574850e08a8a1cc5b6db8922c230184a7acb48 Mon Sep 17 00:00:00 2001 From: t-bast Date: Thu, 24 Nov 2022 14:08:36 +0100 Subject: [PATCH 7/9] Fix PR comments --- .../payment/receive/MultiPartHandler.scala | 180 +++++++++++------- .../eclair/router/RouteCalculation.scala | 10 +- .../eclair/payment/MultiPartHandlerSpec.scala | 143 +++++++------- 3 files changed, 181 insertions(+), 152 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala index c9f8a9db1d..2050fb423d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala @@ -35,11 +35,11 @@ import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment._ import fr.acinq.eclair.router.Router -import fr.acinq.eclair.router.Router.ChannelHop +import fr.acinq.eclair.router.Router.{ChannelHop, ChannelRelayParams} import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer, PaymentInfo} import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiryDelta, FeatureSupport, Features, InvoiceFeature, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, TimestampMilli, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, FeatureSupport, Features, InvoiceFeature, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, ShortChannelId, TimestampMilli, randomBytes32, randomKey} import scodec.bits.{ByteVector, HexStringSyntax} import scala.concurrent.duration.DurationInt @@ -251,22 +251,42 @@ object MultiPartHandler { paymentPreimage_opt: Option[ByteVector32] = None, paymentType: String = PaymentType.Standard) extends ReceivePayment + /** + * A dummy blinded hop that will be added at the end of a blinded route. + * The fees and expiry delta should match those of real channels, otherwise it will be obvious that dummy hops are used. + */ + case class DummyBlindedHop(feeBase: MilliSatoshi, feeProportionalMillionths: Long, cltvExpiryDelta: CltvExpiryDelta) + + /** + * A route that will be blinded and included in a Bolt 12 invoice. + * + * @param nodes a valid route ending at our nodeId. + * @param maxFinalExpiryDelta maximum expiry delta that senders can use: the route expiry will be computed based on this value. + * @param dummyHops (optional) dummy hops to add to the blinded route. + */ + case class ReceivingRoute(nodes: Seq[PublicKey], maxFinalExpiryDelta: CltvExpiryDelta, dummyHops: Seq[DummyBlindedHop] = Nil) + /** * Use this message to create a Bolt 12 invoice to receive a payment for a given offer. * * @param nodeKey the key that will be used to sign the invoice, which may be different from our public nodeId. * @param offer the offer this invoice corresponds to. * @param invoiceRequest the request this invoice responds to. + * @param routes routes that must be blinded and provided in the invoice. + * @param router router actor. * @param paymentPreimage_opt payment preimage. */ case class ReceiveOfferPayment(nodeKey: PrivateKey, offer: Offer, invoiceRequest: InvoiceRequest, - routes: Seq[Seq[PublicKey]], + routes: Seq[ReceivingRoute], router: ActorRef, paymentPreimage_opt: Option[ByteVector32] = None, paymentType: String = PaymentType.Blinded) extends ReceivePayment { - require(routes.forall(_.nonEmpty), "each route must have at least one node") + require(routes.forall(_.nodes.nonEmpty), "each route must have at least one node") + require(offer.amount.nonEmpty || invoiceRequest.amount.nonEmpty, "an amount must be specified in the offer or in the invoice request") + + val amount = invoiceRequest.amount.orElse(offer.amount.map(_ * invoiceRequest.quantity)).get } object CreateInvoiceActor { @@ -276,35 +296,45 @@ object MultiPartHandler { case class CreateInvoice(replyTo: ActorRef, receivePayment: ReceivePayment) extends Command // @formatter:on - def blindedRouteFromHops(nodeParams: NodeParams, hops: Seq[ChannelHop], nodeIds: Seq[PublicKey], pathId: ByteVector): BlindedRouteDetails = { - val finalExpiryDelta = nodeParams.channelConf.minFinalExpiryDelta + 500 // We let the sender add up to 500 blocks to the CLTV. - val finalConstraints = RouteBlindingEncryptedDataTlv.PaymentConstraints(finalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight), nodeParams.channelConf.htlcMinimum) + def blindedRouteFromHops(hops: Seq[ChannelHop], pathId: ByteVector, minAmount: MilliSatoshi, routeFinalExpiry: CltvExpiry): BlindedRouteDetails = { + require(hops.nonEmpty, "route must contain at least one hop") + // We use the same constraints for all nodes so they can't use it to guess their position. + val routeExpiry = hops.foldLeft(routeFinalExpiry) { case (expiry, hop) => expiry + hop.cltvExpiryDelta } + val routeMinAmount = hops.foldLeft(minAmount) { case (amount, hop) => amount.max(hop.params.htlcMinimum) } val finalPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( - finalConstraints, + RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, routeMinAmount), RouteBlindingEncryptedDataTlv.PathId(pathId) )).require.bytes - val maxCltvExpiry = hops.map(_.cltvExpiryDelta).fold(finalExpiryDelta)(_ + _).toCltvExpiry(nodeParams.currentBlockHeight) // Same CLTV for all nodes so they can't use it to guess their position. val payloads = hops.foldRight(Seq(finalPayload)) { - case (channel: ChannelHop, nextPayloads) => - // Because eclair (and others) lies about max HTLC, we remove 10% as a safety margin. + case (channel, payloads) => val payload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( RouteBlindingEncryptedDataTlv.OutgoingChannelId(channel.shortChannelId), RouteBlindingEncryptedDataTlv.PaymentRelay(channel.cltvExpiryDelta, channel.params.relayFees.feeProportionalMillionths, channel.params.relayFees.feeBase), - RouteBlindingEncryptedDataTlv.PaymentConstraints(maxCltvExpiry, channel.params.htlcMinimum) + RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, routeMinAmount), )).require.bytes - payload +: nextPayloads + payload +: payloads } + val nodeIds = hops.map(_.nodeId) :+ hops.last.nextNodeId Sphinx.RouteBlinding.create(randomKey(), nodeIds, payloads) } - def aggregatePayInfo(route: Router.Route): PaymentInfo = { - val zeroPaymentInfo = PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, route.amount, Features.empty) - route.hops.foldRight(zeroPaymentInfo) { - case (channel: ChannelHop, payInfo) => + def blindedRouteWithoutHops(nodeId: PublicKey, pathId: ByteVector, minAmount: MilliSatoshi, routeExpiry: CltvExpiry): BlindedRouteDetails = { + val finalPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( + RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, minAmount), + RouteBlindingEncryptedDataTlv.PathId(pathId) + )).require.bytes + Sphinx.RouteBlinding.create(randomKey(), Seq(nodeId), Seq(finalPayload)) + } + + def aggregatePayInfo(amount: MilliSatoshi, hops: Seq[ChannelHop]): PaymentInfo = { + val zeroPaymentInfo = PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, amount, Features.empty) + hops.foldRight(zeroPaymentInfo) { + case (channel, payInfo) => val newFeeBase = MilliSatoshi((channel.params.relayFees.feeBase.toLong * 1_000_000 + payInfo.feeBase.toLong * (1_000_000 + channel.params.relayFees.feeProportionalMillionths) + 1_000_000 - 1) / 1_000_000) val newFeeProp = ((payInfo.feeProportionalMillionths + channel.params.relayFees.feeProportionalMillionths) * 1_000_000 + payInfo.feeProportionalMillionths * channel.params.relayFees.feeProportionalMillionths + 1_000_000 - 1) / 1_000_000 - // Because eclair (and others) lies about max HTLC, we remove 10% as a safety margin. - val channelMaxHtlc = channel.params.htlcMaximum_opt.map(_ * 0.9).getOrElse(route.amount) + // Most nodes on the network set `htlc_maximum_msat` to the channel capacity. We cannot expect the route to be + // able to relay that amount, so we remove 10% as a safety margin. + val channelMaxHtlc = channel.params.htlcMaximum_opt.map(_ * 0.9).getOrElse(amount) PaymentInfo(newFeeBase, newFeeProp, payInfo.cltvExpiryDelta + channel.cltvExpiryDelta, payInfo.minHtlc.max(channel.params.htlcMinimum), payInfo.maxHtlc.min(channelMaxHtlc), payInfo.allowedFeatures) } } @@ -318,59 +348,69 @@ object MultiPartHandler { val featuresTrampolineOpt = if (nodeParams.enableTrampolinePayment) { nodeParams.features.invoiceFeatures().add(Features.TrampolinePaymentPrototype, FeatureSupport.Optional) } else { - nodeParams.features.invoiceFeatures() - } - receivePayment match { - case r: ReceiveStandardPayment => - Try { - val expirySeconds = r.expirySeconds_opt.getOrElse(nodeParams.invoiceExpiry.toSeconds) - val paymentMetadata = hex"2a" - val invoice = Bolt11Invoice( - nodeParams.chainHash, - r.amount_opt, - paymentHash, - nodeParams.privateKey, - r.description, - nodeParams.channelConf.minFinalExpiryDelta, - r.fallbackAddress_opt, - expirySeconds = Some(expirySeconds), - extraHops = r.extraHops, - paymentMetadata = Some(paymentMetadata), - features = featuresTrampolineOpt.remove(Features.RouteBlinding) - ) - context.log.debug("generated invoice={} from amount={}", invoice.toString, r.amount_opt) - nodeParams.db.payments.addIncomingPayment(invoice, paymentPreimage, r.paymentType) - invoice - } match { - case Success(invoice) => replyTo ! invoice - case Failure(exception) => replyTo ! Status.Failure(exception) - } - case r: ReceiveOfferPayment => - val amount = r.invoiceRequest.amount.orElse(r.offer.amount.map(_ * r.invoiceRequest.quantity)) - implicit val ec: ExecutionContextExecutor = context.executionContext - val log = context.log - Future.sequence(r.routes.map(nodeIds => { - val pathId = randomBytes32() - if (nodeIds.length == 1) { - Future.successful( - (blindedRouteFromHops(nodeParams, Nil, nodeIds, pathId), - PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, amount.get, Features.empty), - pathId)) + nodeParams.features.invoiceFeatures() + } + receivePayment match { + case r: ReceiveStandardPayment => + Try { + val expirySeconds = r.expirySeconds_opt.getOrElse(nodeParams.invoiceExpiry.toSeconds) + val paymentMetadata = hex"2a" + val invoice = Bolt11Invoice( + nodeParams.chainHash, + r.amount_opt, + paymentHash, + nodeParams.privateKey, + r.description, + nodeParams.channelConf.minFinalExpiryDelta, + r.fallbackAddress_opt, + expirySeconds = Some(expirySeconds), + extraHops = r.extraHops, + paymentMetadata = Some(paymentMetadata), + features = featuresTrampolineOpt.remove(Features.RouteBlinding) + ) + context.log.debug("generated invoice={} from amount={}", invoice.toString, r.amount_opt) + nodeParams.db.payments.addIncomingPayment(invoice, paymentPreimage, r.paymentType) + invoice + } match { + case Success(invoice) => replyTo ! invoice + case Failure(exception) => replyTo ! Status.Failure(exception) + } + case r: ReceiveOfferPayment if r.routes.exists(!_.nodes.lastOption.contains(nodeParams.nodeId)) => + replyTo ! Status.Failure(new IllegalArgumentException("receiving routes must end at our node")) + case r: ReceiveOfferPayment => + implicit val ec: ExecutionContextExecutor = context.executionContext + val log = context.log + Future.sequence(r.routes.map(route => { + val pathId = randomBytes32() + val dummyHops = route.dummyHops.map(h => { + val edge = Invoice.BasicEdge(nodeParams.nodeId, nodeParams.nodeId, ShortChannelId.toSelf, h.feeBase, h.feeProportionalMillionths, h.cltvExpiryDelta) + ChannelHop(edge.shortChannelId, edge.sourceNodeId, edge.targetNodeId, ChannelRelayParams.FromHint(edge)) + }) + if (route.nodes.length == 1) { + val blindedRoute = if (dummyHops.isEmpty) { + blindedRouteWithoutHops(route.nodes.last, pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)) } else { - implicit val timeout: Timeout = 10.seconds - r.router.ask(Router.FinalizeRoute(amount.get, Router.PredefinedNodeRoute(nodeIds))).mapTo[Router.RouteResponse].map(routeResponse => { - val route = routeResponse.routes.head - (blindedRouteFromHops(nodeParams, route.hops, nodeIds, pathId), aggregatePayInfo(route), pathId) - }) + blindedRouteFromHops(dummyHops, pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)) } - })).map(paths => { - val invoiceFeatures = featuresTrampolineOpt.remove(Features.RouteBlinding).add(Features.RouteBlinding, FeatureSupport.Mandatory) - val invoice = Bolt12Invoice(r.offer, r.invoiceRequest, paymentPreimage, r.nodeKey, nodeParams.channelConf.minFinalExpiryDelta, invoiceFeatures, paths.map { case (blindedRoute, paymentInfo, _) => PaymentBlindedRoute(blindedRoute.route, paymentInfo) }) - log.debug("generated invoice={} for offerId={}", invoice.toString, r.offer.offerId) - nodeParams.db.payments.addIncomingBlindedPayment(invoice, paymentPreimage, paths.map { case (blindedRoute, _, pathId) => (blindedRoute.lastBlinding -> pathId.bytes) }.toMap, r.paymentType) - invoice - }).recover(exception => Status.Failure(exception)).pipeTo(replyTo) - } + val paymentInfo = aggregatePayInfo(r.amount, dummyHops) + Future.successful((blindedRoute, paymentInfo, pathId)) + } else { + implicit val timeout: Timeout = 10.seconds + r.router.ask(Router.FinalizeRoute(r.amount, Router.PredefinedNodeRoute(route.nodes))).mapTo[Router.RouteResponse].map(routeResponse => { + val clearRoute = routeResponse.routes.head + val blindedRoute = blindedRouteFromHops(clearRoute.hops ++ dummyHops, pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)) + val paymentInfo = aggregatePayInfo(r.amount, clearRoute.hops ++ dummyHops) + (blindedRoute, paymentInfo, pathId) + }) + } + })).map(paths => { + val invoiceFeatures = featuresTrampolineOpt.remove(Features.RouteBlinding).add(Features.RouteBlinding, FeatureSupport.Mandatory) + val invoice = Bolt12Invoice(r.offer, r.invoiceRequest, paymentPreimage, r.nodeKey, nodeParams.channelConf.minFinalExpiryDelta, invoiceFeatures, paths.map { case (blindedRoute, paymentInfo, _) => PaymentBlindedRoute(blindedRoute.route, paymentInfo) }) + log.debug("generated invoice={} for offerId={}", invoice.toString, r.offer.offerId) + nodeParams.db.payments.addIncomingBlindedPayment(invoice, paymentPreimage, paths.map { case (blindedRoute, _, pathId) => blindedRoute.lastBlinding -> pathId.bytes }.toMap, r.paymentType) + invoice + }).recover(exception => Status.Failure(exception)).pipeTo(replyTo) + } Behaviors.stopped } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala index c592af06ac..9231eba761 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala @@ -49,17 +49,11 @@ object RouteCalculation { fr.route match { case PredefinedNodeRoute(hops) => // split into sublists [(a,b),(b,c), ...] then get the edges between each of those pairs - hops.sliding(2).map { - case List(v1, v2) => if (v1 == localNodeId && v2 == localNodeId) { - Seq(GraphEdge(BasicEdge(localNodeId, localNodeId, ShortChannelId.toSelf, 0 msat, 0, CltvExpiryDelta(0)))) - } else { - g.getEdgesBetween(v1, v2) - } - }.toList match { + hops.sliding(2).map { case List(v1, v2) => g.getEdgesBetween(v1, v2) }.toList match { case edges if edges.nonEmpty && edges.forall(_.nonEmpty) => // select the largest edge (using balance when available, otherwise capacity). val selectedEdges = edges.map(es => es.maxBy(e => e.balance_opt.getOrElse(e.capacity.toMilliSatoshi))) - val hops = selectedEdges.map(d => ChannelHop(d.desc.shortChannelId, d.desc.a, d.desc.b, d.params)) + val hops = selectedEdges.map(e => ChannelHop(e.desc.shortChannelId, e.desc.a, e.desc.b, e.params)) ctx.sender() ! RouteResponse(Route(fr.amount, hops) :: Nil) case _ => // some nodes in the supplied route aren't connected in our graph diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala index d7e3847efe..208a52d04d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala @@ -40,12 +40,13 @@ import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{OutgoingChannelId, PathId, PaymentConstraints, PaymentRelay} import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, ShortChannelId, TestConstants, TestKitBaseClass, TimestampMilli, TimestampMilliLong, randomBytes32, randomKey} -import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike +import org.scalatest.{Outcome, Tag} import scodec.bits.{ByteVector, HexStringSyntax} import scala.collection.immutable.Queue import scala.concurrent.duration._ +import scala.util.Random /** * Created by PM on 24/03/2017. @@ -82,6 +83,8 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike lazy val handlerWithMpp = TestActorRef[PaymentHandler](PaymentHandler.props(nodeParams.copy(features = featuresWithMpp), register.ref)) lazy val handlerWithKeySend = TestActorRef[PaymentHandler](PaymentHandler.props(nodeParams.copy(features = featuresWithKeySend), register.ref)) lazy val handlerWithRouteBlinding = TestActorRef[PaymentHandler](PaymentHandler.props(nodeParams.copy(features = featuresWithRouteBlinding), register.ref)) + + def createEmptyReceivingRoute(): Seq[ReceivingRoute] = Seq(ReceivingRoute(Seq(nodeParams.nodeId), CltvExpiryDelta(144))) } override def withFixture(test: OneArgTest): Outcome = { @@ -165,9 +168,8 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val offer = Offer(Some(amountMsat), "a blinded coffee please", privKey.publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, amountMsat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) val router = TestProbe() - val nodeId = randomKey().publicKey - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(privKey, offer, invoiceReq, Seq(Seq(nodeId)), router.ref)) - router.expectNoMessage() + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(privKey, offer, invoiceReq, createEmptyReceivingRoute(), router.ref)) + router.expectNoMessage(50 millis) val invoice = sender.expectMsgType[Bolt12Invoice] assert(invoice.features.hasFeature(RouteBlinding, Some(Mandatory))) val pendingPayment = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.asInstanceOf[IncomingBlindedPayment] @@ -275,47 +277,44 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike } } - test("Aggregate route fees") { _ => - val rand = new scala.util.Random + test("Aggregate route fees", Tag("fuzzy")) { _ => + val rand = new Random() + val nodeId = randomKey().publicKey for (_ <- 0 to 100) { val routeLength = rand.nextInt(10) + 1 - val hops = - for (_ <- 1 to routeLength; - scid = ShortChannelId.generateLocalAlias(); - nid = randomKey().publicKey; - params = Router.ChannelRelayParams.FromHint(BasicEdge(nid, nid, scid, MilliSatoshi(rand.nextLong(10_000)), rand.nextInt(5000), CltvExpiryDelta(0)))) - yield Router.ChannelHop(scid, nid, nid, params) - val route = Router.Route(0 msat, hops) - val aggregate = CreateInvoiceActor.aggregatePayInfo(route) + val hops = (1 to routeLength).map(i => { + val scid = ShortChannelId(i) + val feeBase = rand.nextInt(10_000).msat + val feeProp = rand.nextInt(5000) + val cltvExpiryDelta = CltvExpiryDelta(rand.nextInt(500)) + val params = Router.ChannelRelayParams.FromHint(BasicEdge(nodeId, nodeId, scid, feeBase, feeProp, cltvExpiryDelta)) + Router.ChannelHop(scid, nodeId, nodeId, params) + }) for (_ <- 0 to 100) { - val amount = MilliSatoshi(rand.nextLong(10_000_000_000L)) - val fee1 = aggregate.fee(amount) - val fee2 = route.copy(amount = amount).fee(true) - // The aggregated fees are always enough - assert(fee1 >= fee2, s"amount=$amount, route=${route.hops.map(_.params.relayFees)}, aggregate=$aggregate") - // and we don't overpay too much. - assert(fee1 - fee2 < 1000.msat.max(amount * 1e-5), s"amount=$amount, route=${route.hops.map(_.params.relayFees)}, aggregate=$aggregate") + val amount = rand.nextLong(10_000_000_000L).msat + val payInfo = CreateInvoiceActor.aggregatePayInfo(amount, hops) + // We verify that the aggregated fee slightly exceeds the actual fee (because of proportional fees rounding). + val aggregatedFee = payInfo.fee(amount) + val actualFee = Router.Route(amount, hops).fee(includeLocalChannelCost = true) + assert(aggregatedFee >= actualFee, s"amount=$amount, hops=${hops.map(_.params.relayFees)}, aggregatedFee=$aggregatedFee, actualFee=$actualFee") + assert(aggregatedFee - actualFee < 1000.msat.max(amount * 1e-5), s"amount=$amount, hops=${hops.map(_.params.relayFees)}, aggregatedFee=$aggregatedFee, actualFee=$actualFee") } } } - test("Generate blinded route from zero hop"){f => - import f._ - + test("Generate blinded route from zero hop") { () => val a = randomKey() val pathId = randomBytes32() - val route = CreateInvoiceActor.blindedRouteFromHops(nodeParams, Nil, Seq(a.publicKey), pathId) + val route = CreateInvoiceActor.blindedRouteWithoutHops(a.publicKey, pathId, 1 msat, CltvExpiry(500)) assert(route.route.introductionNodeId == a.publicKey) assert(route.route.encryptedPayloads.length == 1) assert(route.route.blindingKey == route.lastBlinding) - val Right(decoded) = RouteBlindingEncryptedDataCodecs.decode(a, route.route.blindingKey, route.route.encryptedPayloads.head) + val Right(decoded) = RouteBlindingEncryptedDataCodecs.decode(a, route.route.blindingKey, route.route.encryptedPayloads.head) assert(BlindedRouteData.validPaymentRecipientData(decoded.tlvs).isRight) assert(decoded.tlvs.get[PathId].get.data == pathId.bytes) } - test("Generate blinded route from hops"){f => - import f._ - + test("Generate blinded route from hops") { () => val (a, b, c) = (randomKey(), randomKey(), randomKey()) val pathId = randomBytes32() val (channelId1, channelId2) = (ShortChannelId.generateLocalAlias(), ShortChannelId.generateLocalAlias()) @@ -323,22 +322,22 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike Router.ChannelHop(channelId1, a.publicKey, b.publicKey, ChannelRelayParams.FromHint(Invoice.BasicEdge(a.publicKey, b.publicKey, channelId1, 10 msat, 300, CltvExpiryDelta(200)))), Router.ChannelHop(channelId2, b.publicKey, c.publicKey, ChannelRelayParams.FromHint(Invoice.BasicEdge(b.publicKey, c.publicKey, channelId2, 20 msat, 150, CltvExpiryDelta(600)))), ) - val route = CreateInvoiceActor.blindedRouteFromHops(nodeParams, hops, Seq(a.publicKey, b.publicKey, c.publicKey), pathId) + val route = CreateInvoiceActor.blindedRouteFromHops(hops, pathId, 1 msat, CltvExpiry(500)) assert(route.route.introductionNodeId == a.publicKey) assert(route.route.encryptedPayloads.length == 3) - val Right(decoded1) = RouteBlindingEncryptedDataCodecs.decode(a, route.route.blindingKey, route.route.encryptedPayloads(0)) + val Right(decoded1) = RouteBlindingEncryptedDataCodecs.decode(a, route.route.blindingKey, route.route.encryptedPayloads(0)) assert(BlindedRouteData.validatePaymentRelayData(decoded1.tlvs).isRight) assert(decoded1.tlvs.get[OutgoingChannelId].get.shortChannelId == channelId1) assert(decoded1.tlvs.get[PaymentRelay].get.feeBase == 10.msat) assert(decoded1.tlvs.get[PaymentRelay].get.feeProportionalMillionths == 300) assert(decoded1.tlvs.get[PaymentRelay].get.cltvExpiryDelta == CltvExpiryDelta(200)) - val Right(decoded2) = RouteBlindingEncryptedDataCodecs.decode(b, decoded1.nextBlinding, route.route.encryptedPayloads(1)) + val Right(decoded2) = RouteBlindingEncryptedDataCodecs.decode(b, decoded1.nextBlinding, route.route.encryptedPayloads(1)) assert(BlindedRouteData.validatePaymentRelayData(decoded2.tlvs).isRight) assert(decoded2.tlvs.get[OutgoingChannelId].get.shortChannelId == channelId2) assert(decoded2.tlvs.get[PaymentRelay].get.feeBase == 20.msat) assert(decoded2.tlvs.get[PaymentRelay].get.feeProportionalMillionths == 150) assert(decoded2.tlvs.get[PaymentRelay].get.cltvExpiryDelta == CltvExpiryDelta(600)) - val Right(decoded3) = RouteBlindingEncryptedDataCodecs.decode(c, decoded2.nextBlinding, route.route.encryptedPayloads(2)) + val Right(decoded3) = RouteBlindingEncryptedDataCodecs.decode(c, decoded2.nextBlinding, route.route.encryptedPayloads(2)) assert(BlindedRouteData.validPaymentRecipientData(decoded3.tlvs).isRight) assert(decoded3.tlvs.get[PathId].get.data == pathId.bytes) } @@ -350,18 +349,22 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val offer = Offer(Some(25_000 msat), "a blinded coffee please", privKey.publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, 25_000 msat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) val router = TestProbe() - val (a, b, c, d) = (randomKey().publicKey, randomKey().publicKey, randomKey().publicKey, randomKey().publicKey) - val hop_ab = Router.ChannelHop(ShortChannelId(1), a, b, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(a, b, ShortChannelId(1), 1000 msat, 699, CltvExpiryDelta(123)))) - val hop_bc = Router.ChannelHop(ShortChannelId(2), b, c, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(b, c, ShortChannelId(2), 800 msat, 455, CltvExpiryDelta(78)))) - val hop_dc = Router.ChannelHop(ShortChannelId(3), d, c, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(c, d, ShortChannelId(3), 0 msat, 1700, CltvExpiryDelta(89)))) - val hop_cc = Router.ChannelHop(ShortChannelId.toSelf, c, c, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(c, c, ShortChannelId.toSelf, 0 msat, 0, CltvExpiryDelta(0)))) - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(privKey, offer, invoiceReq, Seq(Seq(a, b, c, c), Seq(d, c, c, c), Seq(c)), router.ref)) + val (a, b, c, d) = (randomKey().publicKey, randomKey().publicKey, randomKey().publicKey, nodeParams.nodeId) + val hop_ab = Router.ChannelHop(ShortChannelId(1), a, b, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(a, b, ShortChannelId(1), 1000 msat, 0, CltvExpiryDelta(100)))) + val hop_bd = Router.ChannelHop(ShortChannelId(2), b, d, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(b, d, ShortChannelId(2), 800 msat, 0, CltvExpiryDelta(50)))) + val hop_cd = Router.ChannelHop(ShortChannelId(3), c, d, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(c, d, ShortChannelId(3), 0 msat, 0, CltvExpiryDelta(75)))) + val receivingRoutes = Seq( + ReceivingRoute(Seq(a, b, d), CltvExpiryDelta(100), Seq(DummyBlindedHop(150 msat, 0, CltvExpiryDelta(25)))), + ReceivingRoute(Seq(c, d), CltvExpiryDelta(50), Seq(DummyBlindedHop(250 msat, 0, CltvExpiryDelta(10)), DummyBlindedHop(150 msat, 0, CltvExpiryDelta(80)))), + ReceivingRoute(Seq(d), CltvExpiryDelta(250)), + ) + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(privKey, offer, invoiceReq, receivingRoutes, router.ref)) val finalizeRoute1 = router.expectMsgType[Router.FinalizeRoute] - assert(finalizeRoute1.route == Router.PredefinedNodeRoute(Seq(a, b, c, c))) - router.send(router.lastSender, RouteResponse(Seq(Router.Route(finalizeRoute1.amount, Seq(hop_ab, hop_bc, hop_cc))))) + assert(finalizeRoute1.route == Router.PredefinedNodeRoute(Seq(a, b, d))) + router.send(router.lastSender, RouteResponse(Seq(Router.Route(finalizeRoute1.amount, Seq(hop_ab, hop_bd))))) val finalizeRoute2 = router.expectMsgType[Router.FinalizeRoute] - assert(finalizeRoute2.route == Router.PredefinedNodeRoute(Seq(d, c, c, c))) - router.send(router.lastSender, RouteResponse(Seq(Router.Route(finalizeRoute2.amount, Seq(hop_dc, hop_cc, hop_cc))))) + assert(finalizeRoute2.route == Router.PredefinedNodeRoute(Seq(c, d))) + router.send(router.lastSender, RouteResponse(Seq(Router.Route(finalizeRoute2.amount, Seq(hop_cd))))) val invoice = sender.expectMsgType[Bolt12Invoice] assert(invoice.amount == 25_000.msat) assert(invoice.nodeId == privKey.publicKey) @@ -372,12 +375,12 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(invoice.blindedPaths.length == 3) assert(invoice.blindedPaths(0).blindedNodeIds.length == 4) assert(invoice.blindedPaths(0).introductionNodeId == a) - assert(invoice.blindedPathsInfo(0) == PaymentInfo(1801 msat, 1155, CltvExpiryDelta(201), 0 msat, 25_000 msat, Features.empty)) + assert(invoice.blindedPathsInfo(0) == PaymentInfo(1950 msat, 0, CltvExpiryDelta(175), 0 msat, 25_000 msat, Features.empty)) assert(invoice.blindedPaths(1).blindedNodeIds.length == 4) - assert(invoice.blindedPaths(1).introductionNodeId == d) - assert(invoice.blindedPathsInfo(1) == PaymentInfo(0 msat, 1700, CltvExpiryDelta(89), 0 msat, 25_000 msat, Features.empty)) + assert(invoice.blindedPaths(1).introductionNodeId == c) + assert(invoice.blindedPathsInfo(1) == PaymentInfo(400 msat, 0, CltvExpiryDelta(165), 0 msat, 25_000 msat, Features.empty)) assert(invoice.blindedPaths(2).blindedNodeIds.length == 1) - assert(invoice.blindedPaths(2).introductionNodeId == c) + assert(invoice.blindedPaths(2).introductionNodeId == d) assert(invoice.blindedPathsInfo(2) == PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, 25_000 msat, Features.empty)) val pendingPayment = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.asInstanceOf[IncomingBlindedPayment] @@ -387,18 +390,26 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike pendingPayment.pathIds.values.foreach(pathId => assert(pathId.length == 32)) } - test("Invoice generation with route blinding - incorrect route") { f => + test("Invoice generation with route blinding should fail when router returns an error") { f => import f._ val privKey = randomKey() val offer = Offer(Some(25_000 msat), "a blinded coffee please", privKey.publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, 25_000 msat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) val router = TestProbe() - val (a, b, c) = (randomKey().publicKey, randomKey().publicKey, randomKey().publicKey) - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(privKey, offer, invoiceReq, Seq(Seq(a, b, c)), router.ref)) + val (a, b, c) = (randomKey().publicKey, randomKey().publicKey, nodeParams.nodeId) + val hop_ac = Router.ChannelHop(ShortChannelId(1), a, c, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(a, c, ShortChannelId(1), 100 msat, 0, CltvExpiryDelta(50)))) + val receivingRoutes = Seq( + ReceivingRoute(Seq(a, c), CltvExpiryDelta(100)), + ReceivingRoute(Seq(b, c), CltvExpiryDelta(100)), + ) + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(privKey, offer, invoiceReq, receivingRoutes, router.ref)) val finalizeRoute1 = router.expectMsgType[Router.FinalizeRoute] - assert(finalizeRoute1.route == Router.PredefinedNodeRoute(Seq(a, b, c))) - router.send(router.lastSender, Status.Failure(new IllegalArgumentException("Not all the nodes in the supplied route are connected with public channels"))) + assert(finalizeRoute1.route == Router.PredefinedNodeRoute(Seq(a, c))) + router.send(router.lastSender, RouteResponse(Seq(Router.Route(finalizeRoute1.amount, Seq(hop_ac))))) + val finalizeRoute2 = router.expectMsgType[Router.FinalizeRoute] + assert(finalizeRoute2.route == Router.PredefinedNodeRoute(Seq(b, c))) + router.send(router.lastSender, Status.Failure(new IllegalArgumentException("invalid route"))) sender.expectMsgType[Status.Failure] val pendingPayments = nodeParams.db.payments.listIncomingPayments(TimestampMilli.min, TimestampMilli.max, None) @@ -559,10 +570,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val offer = Offer(None, "a blinded coffee please", randomKey().publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, 5000 msat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) - val router = TestProbe() - val nodeId = randomKey().publicKey - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq, Seq(Seq(nodeId)), router.ref)) - router.expectNoMessage() + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq, createEmptyReceivingRoute(), TestProbe().ref)) val invoice = sender.expectMsgType[Bolt12Invoice] assert(invoice.features.hasFeature(RouteBlinding, Some(Mandatory))) @@ -578,9 +586,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val offer = Offer(None, "a blinded coffee please", randomKey().publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, 5000 msat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) - val router = TestProbe() - val nodeId = randomKey().publicKey - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq, Seq(Seq(nodeId)), router.ref)) + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq, createEmptyReceivingRoute(), TestProbe().ref)) val invoice = sender.expectMsgType[Bolt12Invoice] assert(invoice.features.hasFeature(RouteBlinding, Some(Mandatory))) val pathIds = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.asInstanceOf[IncomingBlindedPayment].pathIds @@ -597,9 +603,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val offer = Offer(None, "a blinded coffee please", randomKey().publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, 5000 msat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) - val router = TestProbe() - val nodeId = randomKey().publicKey - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq, Seq(Seq(nodeId)), router.ref)) + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq, createEmptyReceivingRoute(), TestProbe().ref)) val invoice = sender.expectMsgType[Bolt12Invoice] assert(invoice.features.hasFeature(RouteBlinding, Some(Mandatory))) val pathIds = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.asInstanceOf[IncomingBlindedPayment].pathIds @@ -618,10 +622,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val offer = Offer(None, "a blinded coffee please", randomKey().publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, 5000 msat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) - val router = TestProbe() - val nodeId = randomKey().publicKey - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq, Seq(Seq(nodeId)), router.ref)) - router.expectNoMessage() + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq, createEmptyReceivingRoute(), TestProbe().ref)) val invoice = sender.expectMsgType[Bolt12Invoice] assert(invoice.features.hasFeature(RouteBlinding, Some(Mandatory))) val pathIds = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.asInstanceOf[IncomingBlindedPayment].pathIds @@ -639,10 +640,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val offer = Offer(None, "a blinded coffee please", randomKey().publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, 5000 msat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) - val router = TestProbe() - val nodeId = randomKey().publicKey - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq, Seq(Seq(nodeId)), router.ref)) - router.expectNoMessage() + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq, createEmptyReceivingRoute(), TestProbe().ref)) val invoice = sender.expectMsgType[Bolt12Invoice] assert(invoice.features.hasFeature(RouteBlinding, Some(Mandatory))) val pathIds = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.asInstanceOf[IncomingBlindedPayment].pathIds @@ -660,10 +658,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val offer = Offer(None, "a blinded coffee please", randomKey().publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, 5000 msat, 1, featuresWithRouteBlinding.invoiceFeatures(), randomKey(), Block.RegtestGenesisBlock.hash) - val router = TestProbe() - val nodeId = randomKey().publicKey - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq, Seq(Seq(nodeId)), router.ref)) - router.expectNoMessage() + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(randomKey(), offer, invoiceReq, createEmptyReceivingRoute(), TestProbe().ref)) val invoice = sender.expectMsgType[Bolt12Invoice] assert(invoice.features.hasFeature(RouteBlinding, Some(Mandatory))) val pathIds = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.asInstanceOf[IncomingBlindedPayment].pathIds From ad8d90c56ebb94e1347c3a9100c5a0b3f1f29c78 Mon Sep 17 00:00:00 2001 From: t-bast Date: Thu, 24 Nov 2022 14:30:49 +0100 Subject: [PATCH 8/9] Move helper functions This commit contains no logical changes, it just moves some code to places where it makes more sense. --- .../payment/receive/MultiPartHandler.scala | 62 +++------------- .../eclair/wire/protocol/OfferTypes.scala | 19 ++++- .../eclair/wire/protocol/RouteBlinding.scala | 29 +++++++- .../eclair/payment/MultiPartHandlerSpec.scala | 73 +------------------ .../wire/protocol/RouteBlindingSpec.scala | 73 ++++++++++++++++++- 5 files changed, 129 insertions(+), 127 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala index 2050fb423d..2d6a7040a2 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala @@ -28,19 +28,18 @@ import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto} import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, RES_SUCCESS} -import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRouteDetails import fr.acinq.eclair.db._ import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment._ import fr.acinq.eclair.router.Router import fr.acinq.eclair.router.Router.{ChannelHop, ChannelRelayParams} -import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer, PaymentInfo} +import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer} import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload +import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.{createBlindedRouteFromHops, createBlindedRouteWithoutHops} import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, FeatureSupport, Features, InvoiceFeature, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, ShortChannelId, TimestampMilli, randomBytes32, randomKey} -import scodec.bits.{ByteVector, HexStringSyntax} +import fr.acinq.eclair.{CltvExpiryDelta, FeatureSupport, Features, InvoiceFeature, Logs, MilliSatoshi, NodeParams, ShortChannelId, TimestampMilli, randomBytes32} +import scodec.bits.HexStringSyntax import scala.concurrent.duration.DurationInt import scala.concurrent.{ExecutionContextExecutor, Future} @@ -296,49 +295,6 @@ object MultiPartHandler { case class CreateInvoice(replyTo: ActorRef, receivePayment: ReceivePayment) extends Command // @formatter:on - def blindedRouteFromHops(hops: Seq[ChannelHop], pathId: ByteVector, minAmount: MilliSatoshi, routeFinalExpiry: CltvExpiry): BlindedRouteDetails = { - require(hops.nonEmpty, "route must contain at least one hop") - // We use the same constraints for all nodes so they can't use it to guess their position. - val routeExpiry = hops.foldLeft(routeFinalExpiry) { case (expiry, hop) => expiry + hop.cltvExpiryDelta } - val routeMinAmount = hops.foldLeft(minAmount) { case (amount, hop) => amount.max(hop.params.htlcMinimum) } - val finalPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( - RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, routeMinAmount), - RouteBlindingEncryptedDataTlv.PathId(pathId) - )).require.bytes - val payloads = hops.foldRight(Seq(finalPayload)) { - case (channel, payloads) => - val payload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( - RouteBlindingEncryptedDataTlv.OutgoingChannelId(channel.shortChannelId), - RouteBlindingEncryptedDataTlv.PaymentRelay(channel.cltvExpiryDelta, channel.params.relayFees.feeProportionalMillionths, channel.params.relayFees.feeBase), - RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, routeMinAmount), - )).require.bytes - payload +: payloads - } - val nodeIds = hops.map(_.nodeId) :+ hops.last.nextNodeId - Sphinx.RouteBlinding.create(randomKey(), nodeIds, payloads) - } - - def blindedRouteWithoutHops(nodeId: PublicKey, pathId: ByteVector, minAmount: MilliSatoshi, routeExpiry: CltvExpiry): BlindedRouteDetails = { - val finalPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( - RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, minAmount), - RouteBlindingEncryptedDataTlv.PathId(pathId) - )).require.bytes - Sphinx.RouteBlinding.create(randomKey(), Seq(nodeId), Seq(finalPayload)) - } - - def aggregatePayInfo(amount: MilliSatoshi, hops: Seq[ChannelHop]): PaymentInfo = { - val zeroPaymentInfo = PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, amount, Features.empty) - hops.foldRight(zeroPaymentInfo) { - case (channel, payInfo) => - val newFeeBase = MilliSatoshi((channel.params.relayFees.feeBase.toLong * 1_000_000 + payInfo.feeBase.toLong * (1_000_000 + channel.params.relayFees.feeProportionalMillionths) + 1_000_000 - 1) / 1_000_000) - val newFeeProp = ((payInfo.feeProportionalMillionths + channel.params.relayFees.feeProportionalMillionths) * 1_000_000 + payInfo.feeProportionalMillionths * channel.params.relayFees.feeProportionalMillionths + 1_000_000 - 1) / 1_000_000 - // Most nodes on the network set `htlc_maximum_msat` to the channel capacity. We cannot expect the route to be - // able to relay that amount, so we remove 10% as a safety margin. - val channelMaxHtlc = channel.params.htlcMaximum_opt.map(_ * 0.9).getOrElse(amount) - PaymentInfo(newFeeBase, newFeeProp, payInfo.cltvExpiryDelta + channel.cltvExpiryDelta, payInfo.minHtlc.max(channel.params.htlcMinimum), payInfo.maxHtlc.min(channelMaxHtlc), payInfo.allowedFeatures) - } - } - def apply(nodeParams: NodeParams): Behavior[Command] = { Behaviors.setup { context => Behaviors.receiveMessage { @@ -388,18 +344,18 @@ object MultiPartHandler { }) if (route.nodes.length == 1) { val blindedRoute = if (dummyHops.isEmpty) { - blindedRouteWithoutHops(route.nodes.last, pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)) + createBlindedRouteWithoutHops(route.nodes.last, pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)) } else { - blindedRouteFromHops(dummyHops, pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)) + createBlindedRouteFromHops(dummyHops, pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)) } - val paymentInfo = aggregatePayInfo(r.amount, dummyHops) + val paymentInfo = OfferTypes.PaymentInfo(r.amount, dummyHops) Future.successful((blindedRoute, paymentInfo, pathId)) } else { implicit val timeout: Timeout = 10.seconds r.router.ask(Router.FinalizeRoute(r.amount, Router.PredefinedNodeRoute(route.nodes))).mapTo[Router.RouteResponse].map(routeResponse => { val clearRoute = routeResponse.routes.head - val blindedRoute = blindedRouteFromHops(clearRoute.hops ++ dummyHops, pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)) - val paymentInfo = aggregatePayInfo(r.amount, clearRoute.hops ++ dummyHops) + val blindedRoute = createBlindedRouteFromHops(clearRoute.hops ++ dummyHops, pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)) + val paymentInfo = OfferTypes.PaymentInfo(r.amount, clearRoute.hops ++ dummyHops) (blindedRoute, paymentInfo, pathId) }) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala index abb284cb18..1724810769 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala @@ -20,9 +20,10 @@ import fr.acinq.bitcoin.Bech32 import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, ByteVector64, Crypto, LexicographicalOrdering} import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute +import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.TlvCodecs.genericTlv -import fr.acinq.eclair.{CltvExpiryDelta, Feature, Features, InvoiceFeature, MilliSatoshi, TimestampSecond, UInt64, nodeFee} +import fr.acinq.eclair.{CltvExpiryDelta, Feature, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, TimestampSecond, UInt64, nodeFee} import fr.acinq.secp256k1.Secp256k1JvmKt import scodec.Codec import scodec.bits.ByteVector @@ -69,6 +70,22 @@ object OfferTypes { def fee(amount: MilliSatoshi): MilliSatoshi = nodeFee(feeBase, feeProportionalMillionths, amount) } + object PaymentInfo { + /** Compute aggregated fees and expiry for a blinded route. */ + def apply(amount: MilliSatoshi, hops: Seq[Router.ChannelHop]): PaymentInfo = { + val zeroPaymentInfo = PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, amount, Features.empty) + hops.foldRight(zeroPaymentInfo) { + case (channel, payInfo) => + val newFeeBase = MilliSatoshi((channel.params.relayFees.feeBase.toLong * 1_000_000 + payInfo.feeBase.toLong * (1_000_000 + channel.params.relayFees.feeProportionalMillionths) + 1_000_000 - 1) / 1_000_000) + val newFeeProp = ((payInfo.feeProportionalMillionths + channel.params.relayFees.feeProportionalMillionths) * 1_000_000 + payInfo.feeProportionalMillionths * channel.params.relayFees.feeProportionalMillionths + 1_000_000 - 1) / 1_000_000 + // Most nodes on the network set `htlc_maximum_msat` to the channel capacity. We cannot expect the route to be + // able to relay that amount, so we remove 10% as a safety margin. + val channelMaxHtlc = channel.params.htlcMaximum_opt.map(_ * 0.9).getOrElse(amount) + PaymentInfo(newFeeBase, newFeeProp, payInfo.cltvExpiryDelta + channel.cltvExpiryDelta, payInfo.minHtlc.max(channel.params.htlcMinimum), payInfo.maxHtlc.min(channelMaxHtlc), payInfo.allowedFeatures) + } + } + } + case class PaymentPathsInfo(paymentInfo: Seq[PaymentInfo]) extends InvoiceTlv case class PaymentPathsCapacities(capacities: Seq[MilliSatoshi]) extends InvoiceTlv diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala index cbe392ef63..10f23eaa91 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala @@ -18,10 +18,11 @@ package fr.acinq.eclair.wire.protocol import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.CommonCodecs.{cltvExpiry, cltvExpiryDelta, featuresCodec} import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.TlvCodecs.{fixedLengthTlvField, tlvField, tmillisatoshi, tmillisatoshi32} -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, ShortChannelId, UInt64} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, ShortChannelId, UInt64, randomKey} import scodec.bits.ByteVector import scala.util.{Failure, Success} @@ -139,6 +140,32 @@ object RouteBlindingEncryptedDataCodecs { case class CannotDecodeData(message: String) extends InvalidEncryptedData // @formatter:on + /** Create a blinded route from a non-empty list of channel hops. */ + def createBlindedRouteFromHops(hops: Seq[Router.ChannelHop], pathId: ByteVector, minAmount: MilliSatoshi, routeFinalExpiry: CltvExpiry): Sphinx.RouteBlinding.BlindedRouteDetails = { + require(hops.nonEmpty, "route must contain at least one hop") + // We use the same constraints for all nodes so they can't use it to guess their position. + val routeExpiry = hops.foldLeft(routeFinalExpiry) { case (expiry, hop) => expiry + hop.cltvExpiryDelta } + val routeMinAmount = hops.foldLeft(minAmount) { case (amount, hop) => amount.max(hop.params.htlcMinimum) } + val finalPayload = blindedRouteDataCodec.encode(TlvStream(PaymentConstraints(routeExpiry, routeMinAmount), PathId(pathId))).require.bytes + val payloads = hops.foldRight(Seq(finalPayload)) { + case (channel, payloads) => + val payload = blindedRouteDataCodec.encode(TlvStream( + OutgoingChannelId(channel.shortChannelId), + PaymentRelay(channel.cltvExpiryDelta, channel.params.relayFees.feeProportionalMillionths, channel.params.relayFees.feeBase), + PaymentConstraints(routeExpiry, routeMinAmount), + )).require.bytes + payload +: payloads + } + val nodeIds = hops.map(_.nodeId) :+ hops.last.nextNodeId + Sphinx.RouteBlinding.create(randomKey(), nodeIds, payloads) + } + + /** Create a blinded route where the recipient is also the introduction point (which reveals the recipient's identity). */ + def createBlindedRouteWithoutHops(nodeId: PublicKey, pathId: ByteVector, minAmount: MilliSatoshi, routeExpiry: CltvExpiry): Sphinx.RouteBlinding.BlindedRouteDetails = { + val finalPayload = blindedRouteDataCodec.encode(TlvStream(PaymentConstraints(routeExpiry, minAmount), PathId(pathId))).require.bytes + Sphinx.RouteBlinding.create(randomKey(), Seq(nodeId), Seq(finalPayload)) + } + /** * Decrypt and decode the contents of an encrypted_recipient_data TLV field. * diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala index 208a52d04d..97acf19574 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala @@ -27,26 +27,24 @@ import fr.acinq.eclair.TestConstants.Alice import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Register} import fr.acinq.eclair.db.{IncomingBlindedPayment, IncomingPaymentStatus} import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop -import fr.acinq.eclair.payment.Invoice.BasicEdge import fr.acinq.eclair.payment.PaymentReceived.PartialPayment import fr.acinq.eclair.payment.receive.MultiPartHandler._ import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM.HtlcPart import fr.acinq.eclair.payment.receive.{MultiPartPaymentFSM, PaymentHandler} import fr.acinq.eclair.router.Router -import fr.acinq.eclair.router.Router.{ChannelRelayParams, RouteResponse} +import fr.acinq.eclair.router.Router.RouteResponse import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer, PaymentInfo} import fr.acinq.eclair.wire.protocol.OnionPaymentPayloadTlv.{AmountToForward, BlindingPoint, EncryptedRecipientData, OutgoingCltv} import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload -import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{OutgoingChannelId, PathId, PaymentConstraints, PaymentRelay} +import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{PathId, PaymentConstraints} import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, ShortChannelId, TestConstants, TestKitBaseClass, TimestampMilli, TimestampMilliLong, randomBytes32, randomKey} +import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike -import org.scalatest.{Outcome, Tag} import scodec.bits.{ByteVector, HexStringSyntax} import scala.collection.immutable.Queue import scala.concurrent.duration._ -import scala.util.Random /** * Created by PM on 24/03/2017. @@ -277,71 +275,6 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike } } - test("Aggregate route fees", Tag("fuzzy")) { _ => - val rand = new Random() - val nodeId = randomKey().publicKey - for (_ <- 0 to 100) { - val routeLength = rand.nextInt(10) + 1 - val hops = (1 to routeLength).map(i => { - val scid = ShortChannelId(i) - val feeBase = rand.nextInt(10_000).msat - val feeProp = rand.nextInt(5000) - val cltvExpiryDelta = CltvExpiryDelta(rand.nextInt(500)) - val params = Router.ChannelRelayParams.FromHint(BasicEdge(nodeId, nodeId, scid, feeBase, feeProp, cltvExpiryDelta)) - Router.ChannelHop(scid, nodeId, nodeId, params) - }) - for (_ <- 0 to 100) { - val amount = rand.nextLong(10_000_000_000L).msat - val payInfo = CreateInvoiceActor.aggregatePayInfo(amount, hops) - // We verify that the aggregated fee slightly exceeds the actual fee (because of proportional fees rounding). - val aggregatedFee = payInfo.fee(amount) - val actualFee = Router.Route(amount, hops).fee(includeLocalChannelCost = true) - assert(aggregatedFee >= actualFee, s"amount=$amount, hops=${hops.map(_.params.relayFees)}, aggregatedFee=$aggregatedFee, actualFee=$actualFee") - assert(aggregatedFee - actualFee < 1000.msat.max(amount * 1e-5), s"amount=$amount, hops=${hops.map(_.params.relayFees)}, aggregatedFee=$aggregatedFee, actualFee=$actualFee") - } - } - } - - test("Generate blinded route from zero hop") { () => - val a = randomKey() - val pathId = randomBytes32() - val route = CreateInvoiceActor.blindedRouteWithoutHops(a.publicKey, pathId, 1 msat, CltvExpiry(500)) - assert(route.route.introductionNodeId == a.publicKey) - assert(route.route.encryptedPayloads.length == 1) - assert(route.route.blindingKey == route.lastBlinding) - val Right(decoded) = RouteBlindingEncryptedDataCodecs.decode(a, route.route.blindingKey, route.route.encryptedPayloads.head) - assert(BlindedRouteData.validPaymentRecipientData(decoded.tlvs).isRight) - assert(decoded.tlvs.get[PathId].get.data == pathId.bytes) - } - - test("Generate blinded route from hops") { () => - val (a, b, c) = (randomKey(), randomKey(), randomKey()) - val pathId = randomBytes32() - val (channelId1, channelId2) = (ShortChannelId.generateLocalAlias(), ShortChannelId.generateLocalAlias()) - val hops = Seq( - Router.ChannelHop(channelId1, a.publicKey, b.publicKey, ChannelRelayParams.FromHint(Invoice.BasicEdge(a.publicKey, b.publicKey, channelId1, 10 msat, 300, CltvExpiryDelta(200)))), - Router.ChannelHop(channelId2, b.publicKey, c.publicKey, ChannelRelayParams.FromHint(Invoice.BasicEdge(b.publicKey, c.publicKey, channelId2, 20 msat, 150, CltvExpiryDelta(600)))), - ) - val route = CreateInvoiceActor.blindedRouteFromHops(hops, pathId, 1 msat, CltvExpiry(500)) - assert(route.route.introductionNodeId == a.publicKey) - assert(route.route.encryptedPayloads.length == 3) - val Right(decoded1) = RouteBlindingEncryptedDataCodecs.decode(a, route.route.blindingKey, route.route.encryptedPayloads(0)) - assert(BlindedRouteData.validatePaymentRelayData(decoded1.tlvs).isRight) - assert(decoded1.tlvs.get[OutgoingChannelId].get.shortChannelId == channelId1) - assert(decoded1.tlvs.get[PaymentRelay].get.feeBase == 10.msat) - assert(decoded1.tlvs.get[PaymentRelay].get.feeProportionalMillionths == 300) - assert(decoded1.tlvs.get[PaymentRelay].get.cltvExpiryDelta == CltvExpiryDelta(200)) - val Right(decoded2) = RouteBlindingEncryptedDataCodecs.decode(b, decoded1.nextBlinding, route.route.encryptedPayloads(1)) - assert(BlindedRouteData.validatePaymentRelayData(decoded2.tlvs).isRight) - assert(decoded2.tlvs.get[OutgoingChannelId].get.shortChannelId == channelId2) - assert(decoded2.tlvs.get[PaymentRelay].get.feeBase == 20.msat) - assert(decoded2.tlvs.get[PaymentRelay].get.feeProportionalMillionths == 150) - assert(decoded2.tlvs.get[PaymentRelay].get.cltvExpiryDelta == CltvExpiryDelta(600)) - val Right(decoded3) = RouteBlindingEncryptedDataCodecs.decode(c, decoded2.nextBlinding, route.route.encryptedPayloads(2)) - assert(BlindedRouteData.validPaymentRecipientData(decoded3.tlvs).isRight) - assert(decoded3.tlvs.get[PathId].get.data == pathId.bytes) - } - test("Invoice generation with route blinding support") { f => import f._ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/RouteBlindingSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/RouteBlindingSpec.scala index 8659e97fe5..9dd9743869 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/RouteBlindingSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/RouteBlindingSpec.scala @@ -3,10 +3,13 @@ package fr.acinq.eclair.wire.protocol import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRouteDetails +import fr.acinq.eclair.payment.Invoice +import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, MissingRequiredTlv} -import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.{RouteBlindingDecryptedData, blindedRouteDataCodec} +import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.{RouteBlindingDecryptedData, blindedRouteDataCodec, createBlindedRouteFromHops, createBlindedRouteWithoutHops} import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, FeatureSupport, Features, MilliSatoshiLong, ShortChannelId, UInt64, UnknownFeature, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, FeatureSupport, Features, MilliSatoshiLong, ShortChannelId, UInt64, UnknownFeature, randomBytes32, randomKey} +import org.scalatest.Tag import org.scalatest.funsuite.AnyFunSuiteLike import scodec.bits.{ByteVector, HexStringSyntax} @@ -167,4 +170,70 @@ class RouteBlindingSpec extends AnyFunSuiteLike { } } + test("create blinded route without hops") { + val a = randomKey() + val pathId = randomBytes32() + val route = createBlindedRouteWithoutHops(a.publicKey, pathId, 1 msat, CltvExpiry(500)) + assert(route.route.introductionNodeId == a.publicKey) + assert(route.route.encryptedPayloads.length == 1) + assert(route.route.blindingKey == route.lastBlinding) + val Right(decoded) = RouteBlindingEncryptedDataCodecs.decode(a, route.route.blindingKey, route.route.encryptedPayloads.head) + assert(BlindedRouteData.validPaymentRecipientData(decoded.tlvs).isRight) + assert(decoded.tlvs.get[PathId].get.data == pathId.bytes) + } + + test("create blinded route from channel hops") { + val (a, b, c) = (randomKey(), randomKey(), randomKey()) + val pathId = randomBytes32() + val (channelId1, channelId2) = (ShortChannelId(1), ShortChannelId(2)) + val hops = Seq( + Router.ChannelHop(channelId1, a.publicKey, b.publicKey, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(a.publicKey, b.publicKey, channelId1, 10 msat, 300, CltvExpiryDelta(200)))), + Router.ChannelHop(channelId2, b.publicKey, c.publicKey, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(b.publicKey, c.publicKey, channelId2, 20 msat, 150, CltvExpiryDelta(600)))), + ) + val route = createBlindedRouteFromHops(hops, pathId, 1 msat, CltvExpiry(500)) + assert(route.route.introductionNodeId == a.publicKey) + assert(route.route.encryptedPayloads.length == 3) + val Right(decoded1) = RouteBlindingEncryptedDataCodecs.decode(a, route.route.blindingKey, route.route.encryptedPayloads(0)) + assert(BlindedRouteData.validatePaymentRelayData(decoded1.tlvs).isRight) + assert(decoded1.tlvs.get[OutgoingChannelId].get.shortChannelId == channelId1) + assert(decoded1.tlvs.get[PaymentRelay].get.feeBase == 10.msat) + assert(decoded1.tlvs.get[PaymentRelay].get.feeProportionalMillionths == 300) + assert(decoded1.tlvs.get[PaymentRelay].get.cltvExpiryDelta == CltvExpiryDelta(200)) + val Right(decoded2) = RouteBlindingEncryptedDataCodecs.decode(b, decoded1.nextBlinding, route.route.encryptedPayloads(1)) + assert(BlindedRouteData.validatePaymentRelayData(decoded2.tlvs).isRight) + assert(decoded2.tlvs.get[OutgoingChannelId].get.shortChannelId == channelId2) + assert(decoded2.tlvs.get[PaymentRelay].get.feeBase == 20.msat) + assert(decoded2.tlvs.get[PaymentRelay].get.feeProportionalMillionths == 150) + assert(decoded2.tlvs.get[PaymentRelay].get.cltvExpiryDelta == CltvExpiryDelta(600)) + val Right(decoded3) = RouteBlindingEncryptedDataCodecs.decode(c, decoded2.nextBlinding, route.route.encryptedPayloads(2)) + assert(BlindedRouteData.validPaymentRecipientData(decoded3.tlvs).isRight) + assert(decoded3.tlvs.get[PathId].get.data == pathId.bytes) + } + + test("create blinded route payment info", Tag("fuzzy")) { + val rand = new scala.util.Random() + val nodeId = randomKey().publicKey + for (_ <- 0 to 100) { + val routeLength = rand.nextInt(10) + 1 + val hops = (1 to routeLength).map(i => { + val scid = ShortChannelId(i) + val feeBase = rand.nextInt(10_000).msat + val feeProp = rand.nextInt(5000) + val cltvExpiryDelta = CltvExpiryDelta(rand.nextInt(500)) + val params = Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(nodeId, nodeId, scid, feeBase, feeProp, cltvExpiryDelta)) + Router.ChannelHop(scid, nodeId, nodeId, params) + }) + for (_ <- 0 to 100) { + val amount = rand.nextLong(10_000_000_000L).msat + val payInfo = OfferTypes.PaymentInfo(amount, hops) + assert(payInfo.cltvExpiryDelta == CltvExpiryDelta(hops.map(_.cltvExpiryDelta.toInt).sum)) + // We verify that the aggregated fee slightly exceeds the actual fee (because of proportional fees rounding). + val aggregatedFee = payInfo.fee(amount) + val actualFee = Router.Route(amount, hops).fee(includeLocalChannelCost = true) + assert(aggregatedFee >= actualFee, s"amount=$amount, hops=${hops.map(_.params.relayFees)}, aggregatedFee=$aggregatedFee, actualFee=$actualFee") + assert(aggregatedFee - actualFee < 1000.msat.max(amount * 1e-5), s"amount=$amount, hops=${hops.map(_.params.relayFees)}, aggregatedFee=$aggregatedFee, actualFee=$actualFee") + } + } + } + } From a7f812bdf13927d2642e3912cd88390a6f1e6227 Mon Sep 17 00:00:00 2001 From: t-bast Date: Thu, 24 Nov 2022 16:17:25 +0100 Subject: [PATCH 9/9] Move route blinding construction to router --- .../payment/receive/MultiPartHandler.scala | 6 +- .../eclair/router/BlindedRouteCreation.scala | 75 +++++++++++++++ .../eclair/router/RouteCalculation.scala | 1 - .../eclair/wire/protocol/OfferTypes.scala | 19 +--- .../eclair/wire/protocol/RouteBlinding.scala | 29 +----- .../router/BlindedRouteCreationSpec.scala | 96 +++++++++++++++++++ .../wire/protocol/RouteBlindingSpec.scala | 73 +------------- 7 files changed, 178 insertions(+), 121 deletions(-) create mode 100644 eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala create mode 100644 eclair-core/src/test/scala/fr/acinq/eclair/router/BlindedRouteCreationSpec.scala diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala index 2d6a7040a2..6242af7080 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala @@ -32,11 +32,11 @@ import fr.acinq.eclair.db._ import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment._ +import fr.acinq.eclair.router.BlindedRouteCreation.{aggregatePaymentInfo, createBlindedRouteFromHops, createBlindedRouteWithoutHops} import fr.acinq.eclair.router.Router import fr.acinq.eclair.router.Router.{ChannelHop, ChannelRelayParams} import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer} import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload -import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.{createBlindedRouteFromHops, createBlindedRouteWithoutHops} import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{CltvExpiryDelta, FeatureSupport, Features, InvoiceFeature, Logs, MilliSatoshi, NodeParams, ShortChannelId, TimestampMilli, randomBytes32} import scodec.bits.HexStringSyntax @@ -348,14 +348,14 @@ object MultiPartHandler { } else { createBlindedRouteFromHops(dummyHops, pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)) } - val paymentInfo = OfferTypes.PaymentInfo(r.amount, dummyHops) + val paymentInfo = aggregatePaymentInfo(r.amount, dummyHops) Future.successful((blindedRoute, paymentInfo, pathId)) } else { implicit val timeout: Timeout = 10.seconds r.router.ask(Router.FinalizeRoute(r.amount, Router.PredefinedNodeRoute(route.nodes))).mapTo[Router.RouteResponse].map(routeResponse => { val clearRoute = routeResponse.routes.head val blindedRoute = createBlindedRouteFromHops(clearRoute.hops ++ dummyHops, pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)) - val paymentInfo = OfferTypes.PaymentInfo(r.amount, clearRoute.hops ++ dummyHops) + val paymentInfo = aggregatePaymentInfo(r.amount, clearRoute.hops ++ dummyHops) (blindedRoute, paymentInfo, pathId) }) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala new file mode 100644 index 0000000000..503ac7ffc2 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala @@ -0,0 +1,75 @@ +/* + * Copyright 2022 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.router + +import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey +import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.router.Router.ChannelHop +import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo +import fr.acinq.eclair.wire.protocol.{RouteBlindingEncryptedDataCodecs, RouteBlindingEncryptedDataTlv, TlvStream} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, randomKey} +import scodec.bits.ByteVector + +object BlindedRouteCreation { + + /** Compute aggregated fees and expiry for a given route. */ + def aggregatePaymentInfo(amount: MilliSatoshi, hops: Seq[ChannelHop]): PaymentInfo = { + val zeroPaymentInfo = PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, amount, Features.empty) + hops.foldRight(zeroPaymentInfo) { + case (channel, payInfo) => + val newFeeBase = MilliSatoshi((channel.params.relayFees.feeBase.toLong * 1_000_000 + payInfo.feeBase.toLong * (1_000_000 + channel.params.relayFees.feeProportionalMillionths) + 1_000_000 - 1) / 1_000_000) + val newFeeProp = ((payInfo.feeProportionalMillionths + channel.params.relayFees.feeProportionalMillionths) * 1_000_000 + payInfo.feeProportionalMillionths * channel.params.relayFees.feeProportionalMillionths + 1_000_000 - 1) / 1_000_000 + // Most nodes on the network set `htlc_maximum_msat` to the channel capacity. We cannot expect the route to be + // able to relay that amount, so we remove 10% as a safety margin. + val channelMaxHtlc = channel.params.htlcMaximum_opt.map(_ * 0.9).getOrElse(amount) + PaymentInfo(newFeeBase, newFeeProp, payInfo.cltvExpiryDelta + channel.cltvExpiryDelta, payInfo.minHtlc.max(channel.params.htlcMinimum), payInfo.maxHtlc.min(channelMaxHtlc), payInfo.allowedFeatures) + } + } + + /** Create a blinded route from a non-empty list of channel hops. */ + def createBlindedRouteFromHops(hops: Seq[Router.ChannelHop], pathId: ByteVector, minAmount: MilliSatoshi, routeFinalExpiry: CltvExpiry): Sphinx.RouteBlinding.BlindedRouteDetails = { + require(hops.nonEmpty, "route must contain at least one hop") + // We use the same constraints for all nodes so they can't use it to guess their position. + val routeExpiry = hops.foldLeft(routeFinalExpiry) { case (expiry, hop) => expiry + hop.cltvExpiryDelta } + val routeMinAmount = hops.foldLeft(minAmount) { case (amount, hop) => amount.max(hop.params.htlcMinimum) } + val finalPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( + RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, routeMinAmount), + RouteBlindingEncryptedDataTlv.PathId(pathId), + )).require.bytes + val payloads = hops.foldRight(Seq(finalPayload)) { + case (channel, payloads) => + val payload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( + RouteBlindingEncryptedDataTlv.OutgoingChannelId(channel.shortChannelId), + RouteBlindingEncryptedDataTlv.PaymentRelay(channel.cltvExpiryDelta, channel.params.relayFees.feeProportionalMillionths, channel.params.relayFees.feeBase), + RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, routeMinAmount), + )).require.bytes + payload +: payloads + } + val nodeIds = hops.map(_.nodeId) :+ hops.last.nextNodeId + Sphinx.RouteBlinding.create(randomKey(), nodeIds, payloads) + } + + /** Create a blinded route where the recipient is also the introduction point (which reveals the recipient's identity). */ + def createBlindedRouteWithoutHops(nodeId: PublicKey, pathId: ByteVector, minAmount: MilliSatoshi, routeExpiry: CltvExpiry): Sphinx.RouteBlinding.BlindedRouteDetails = { + val finalPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( + RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, minAmount), + RouteBlindingEncryptedDataTlv.PathId(pathId), + )).require.bytes + Sphinx.RouteBlinding.create(randomKey(), Seq(nodeId), Seq(finalPayload)) + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala index 9231eba761..3b5ac6b9a4 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala @@ -22,7 +22,6 @@ import com.softwaremill.quicklens.ModifyPimp import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair._ -import fr.acinq.eclair.payment.Invoice.BasicEdge import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph.graphEdgeToHop import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} import fr.acinq.eclair.router.Graph.{InfiniteLoop, NegativeProbability, RichWeight} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala index 1724810769..abb284cb18 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala @@ -20,10 +20,9 @@ import fr.acinq.bitcoin.Bech32 import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, ByteVector64, Crypto, LexicographicalOrdering} import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute -import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.TlvCodecs.genericTlv -import fr.acinq.eclair.{CltvExpiryDelta, Feature, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, TimestampSecond, UInt64, nodeFee} +import fr.acinq.eclair.{CltvExpiryDelta, Feature, Features, InvoiceFeature, MilliSatoshi, TimestampSecond, UInt64, nodeFee} import fr.acinq.secp256k1.Secp256k1JvmKt import scodec.Codec import scodec.bits.ByteVector @@ -70,22 +69,6 @@ object OfferTypes { def fee(amount: MilliSatoshi): MilliSatoshi = nodeFee(feeBase, feeProportionalMillionths, amount) } - object PaymentInfo { - /** Compute aggregated fees and expiry for a blinded route. */ - def apply(amount: MilliSatoshi, hops: Seq[Router.ChannelHop]): PaymentInfo = { - val zeroPaymentInfo = PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, amount, Features.empty) - hops.foldRight(zeroPaymentInfo) { - case (channel, payInfo) => - val newFeeBase = MilliSatoshi((channel.params.relayFees.feeBase.toLong * 1_000_000 + payInfo.feeBase.toLong * (1_000_000 + channel.params.relayFees.feeProportionalMillionths) + 1_000_000 - 1) / 1_000_000) - val newFeeProp = ((payInfo.feeProportionalMillionths + channel.params.relayFees.feeProportionalMillionths) * 1_000_000 + payInfo.feeProportionalMillionths * channel.params.relayFees.feeProportionalMillionths + 1_000_000 - 1) / 1_000_000 - // Most nodes on the network set `htlc_maximum_msat` to the channel capacity. We cannot expect the route to be - // able to relay that amount, so we remove 10% as a safety margin. - val channelMaxHtlc = channel.params.htlcMaximum_opt.map(_ * 0.9).getOrElse(amount) - PaymentInfo(newFeeBase, newFeeProp, payInfo.cltvExpiryDelta + channel.cltvExpiryDelta, payInfo.minHtlc.max(channel.params.htlcMinimum), payInfo.maxHtlc.min(channelMaxHtlc), payInfo.allowedFeatures) - } - } - } - case class PaymentPathsInfo(paymentInfo: Seq[PaymentInfo]) extends InvoiceTlv case class PaymentPathsCapacities(capacities: Seq[MilliSatoshi]) extends InvoiceTlv diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala index 10f23eaa91..cbe392ef63 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala @@ -18,11 +18,10 @@ package fr.acinq.eclair.wire.protocol import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.CommonCodecs.{cltvExpiry, cltvExpiryDelta, featuresCodec} import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.TlvCodecs.{fixedLengthTlvField, tlvField, tmillisatoshi, tmillisatoshi32} -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, ShortChannelId, UInt64, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, ShortChannelId, UInt64} import scodec.bits.ByteVector import scala.util.{Failure, Success} @@ -140,32 +139,6 @@ object RouteBlindingEncryptedDataCodecs { case class CannotDecodeData(message: String) extends InvalidEncryptedData // @formatter:on - /** Create a blinded route from a non-empty list of channel hops. */ - def createBlindedRouteFromHops(hops: Seq[Router.ChannelHop], pathId: ByteVector, minAmount: MilliSatoshi, routeFinalExpiry: CltvExpiry): Sphinx.RouteBlinding.BlindedRouteDetails = { - require(hops.nonEmpty, "route must contain at least one hop") - // We use the same constraints for all nodes so they can't use it to guess their position. - val routeExpiry = hops.foldLeft(routeFinalExpiry) { case (expiry, hop) => expiry + hop.cltvExpiryDelta } - val routeMinAmount = hops.foldLeft(minAmount) { case (amount, hop) => amount.max(hop.params.htlcMinimum) } - val finalPayload = blindedRouteDataCodec.encode(TlvStream(PaymentConstraints(routeExpiry, routeMinAmount), PathId(pathId))).require.bytes - val payloads = hops.foldRight(Seq(finalPayload)) { - case (channel, payloads) => - val payload = blindedRouteDataCodec.encode(TlvStream( - OutgoingChannelId(channel.shortChannelId), - PaymentRelay(channel.cltvExpiryDelta, channel.params.relayFees.feeProportionalMillionths, channel.params.relayFees.feeBase), - PaymentConstraints(routeExpiry, routeMinAmount), - )).require.bytes - payload +: payloads - } - val nodeIds = hops.map(_.nodeId) :+ hops.last.nextNodeId - Sphinx.RouteBlinding.create(randomKey(), nodeIds, payloads) - } - - /** Create a blinded route where the recipient is also the introduction point (which reveals the recipient's identity). */ - def createBlindedRouteWithoutHops(nodeId: PublicKey, pathId: ByteVector, minAmount: MilliSatoshi, routeExpiry: CltvExpiry): Sphinx.RouteBlinding.BlindedRouteDetails = { - val finalPayload = blindedRouteDataCodec.encode(TlvStream(PaymentConstraints(routeExpiry, minAmount), PathId(pathId))).require.bytes - Sphinx.RouteBlinding.create(randomKey(), Seq(nodeId), Seq(finalPayload)) - } - /** * Decrypt and decode the contents of an encrypted_recipient_data TLV field. * diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/BlindedRouteCreationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/BlindedRouteCreationSpec.scala new file mode 100644 index 0000000000..9f4217615e --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/BlindedRouteCreationSpec.scala @@ -0,0 +1,96 @@ +/* + * Copyright 2022 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.router + +import fr.acinq.eclair.router.RouteCalculationSpec.makeUpdateShort +import fr.acinq.eclair.router.Router.{ChannelHop, ChannelRelayParams} +import fr.acinq.eclair.wire.protocol.{BlindedRouteData, RouteBlindingEncryptedDataCodecs, RouteBlindingEncryptedDataTlv} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, MilliSatoshiLong, ShortChannelId, randomBytes32, randomKey} +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.{ParallelTestExecution, Tag} + +class BlindedRouteCreationSpec extends AnyFunSuite with ParallelTestExecution { + + import BlindedRouteCreation._ + + test("create blinded route without hops") { + val a = randomKey() + val pathId = randomBytes32() + val route = createBlindedRouteWithoutHops(a.publicKey, pathId, 1 msat, CltvExpiry(500)) + assert(route.route.introductionNodeId == a.publicKey) + assert(route.route.encryptedPayloads.length == 1) + assert(route.route.blindingKey == route.lastBlinding) + val Right(decoded) = RouteBlindingEncryptedDataCodecs.decode(a, route.route.blindingKey, route.route.encryptedPayloads.head) + assert(BlindedRouteData.validPaymentRecipientData(decoded.tlvs).isRight) + assert(decoded.tlvs.get[RouteBlindingEncryptedDataTlv.PathId].get.data == pathId.bytes) + } + + test("create blinded route from channel hops") { + val (a, b, c) = (randomKey(), randomKey(), randomKey()) + val pathId = randomBytes32() + val (scid1, scid2) = (ShortChannelId(1), ShortChannelId(2)) + val hops = Seq( + ChannelHop(scid1, a.publicKey, b.publicKey, ChannelRelayParams.FromAnnouncement(makeUpdateShort(scid1, a.publicKey, b.publicKey, 10 msat, 300, cltvDelta = CltvExpiryDelta(200)))), + ChannelHop(scid2, b.publicKey, c.publicKey, ChannelRelayParams.FromAnnouncement(makeUpdateShort(scid2, b.publicKey, c.publicKey, 20 msat, 150, cltvDelta = CltvExpiryDelta(600)))), + ) + val route = createBlindedRouteFromHops(hops, pathId, 1 msat, CltvExpiry(500)) + assert(route.route.introductionNodeId == a.publicKey) + assert(route.route.encryptedPayloads.length == 3) + val Right(decoded1) = RouteBlindingEncryptedDataCodecs.decode(a, route.route.blindingKey, route.route.encryptedPayloads(0)) + assert(BlindedRouteData.validatePaymentRelayData(decoded1.tlvs).isRight) + assert(decoded1.tlvs.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId].get.shortChannelId == scid1) + assert(decoded1.tlvs.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get.feeBase == 10.msat) + assert(decoded1.tlvs.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get.feeProportionalMillionths == 300) + assert(decoded1.tlvs.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get.cltvExpiryDelta == CltvExpiryDelta(200)) + val Right(decoded2) = RouteBlindingEncryptedDataCodecs.decode(b, decoded1.nextBlinding, route.route.encryptedPayloads(1)) + assert(BlindedRouteData.validatePaymentRelayData(decoded2.tlvs).isRight) + assert(decoded2.tlvs.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId].get.shortChannelId == scid2) + assert(decoded2.tlvs.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get.feeBase == 20.msat) + assert(decoded2.tlvs.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get.feeProportionalMillionths == 150) + assert(decoded2.tlvs.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get.cltvExpiryDelta == CltvExpiryDelta(600)) + val Right(decoded3) = RouteBlindingEncryptedDataCodecs.decode(c, decoded2.nextBlinding, route.route.encryptedPayloads(2)) + assert(BlindedRouteData.validPaymentRecipientData(decoded3.tlvs).isRight) + assert(decoded3.tlvs.get[RouteBlindingEncryptedDataTlv.PathId].get.data == pathId.bytes) + } + + test("create blinded route payment info", Tag("fuzzy")) { + val rand = new scala.util.Random() + val nodeId = randomKey().publicKey + for (_ <- 0 to 100) { + val routeLength = rand.nextInt(10) + 1 + val hops = (1 to routeLength).map(i => { + val scid = ShortChannelId(i) + val feeBase = rand.nextInt(10_000).msat + val feeProp = rand.nextInt(5000) + val cltvExpiryDelta = CltvExpiryDelta(rand.nextInt(500)) + val params = ChannelRelayParams.FromAnnouncement(makeUpdateShort(scid, nodeId, nodeId, feeBase, feeProp, cltvDelta = cltvExpiryDelta)) + ChannelHop(scid, nodeId, nodeId, params) + }) + for (_ <- 0 to 100) { + val amount = rand.nextLong(10_000_000_000L).msat + val payInfo = aggregatePaymentInfo(amount, hops) + assert(payInfo.cltvExpiryDelta == CltvExpiryDelta(hops.map(_.cltvExpiryDelta.toInt).sum)) + // We verify that the aggregated fee slightly exceeds the actual fee (because of proportional fees rounding). + val aggregatedFee = payInfo.fee(amount) + val actualFee = Router.Route(amount, hops).fee(includeLocalChannelCost = true) + assert(aggregatedFee >= actualFee, s"amount=$amount, hops=${hops.map(_.params.relayFees)}, aggregatedFee=$aggregatedFee, actualFee=$actualFee") + assert(aggregatedFee - actualFee < 1000.msat.max(amount * 1e-5), s"amount=$amount, hops=${hops.map(_.params.relayFees)}, aggregatedFee=$aggregatedFee, actualFee=$actualFee") + } + } + } + +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/RouteBlindingSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/RouteBlindingSpec.scala index 9dd9743869..8659e97fe5 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/RouteBlindingSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/RouteBlindingSpec.scala @@ -3,13 +3,10 @@ package fr.acinq.eclair.wire.protocol import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRouteDetails -import fr.acinq.eclair.payment.Invoice -import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, MissingRequiredTlv} -import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.{RouteBlindingDecryptedData, blindedRouteDataCodec, createBlindedRouteFromHops, createBlindedRouteWithoutHops} +import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.{RouteBlindingDecryptedData, blindedRouteDataCodec} import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, FeatureSupport, Features, MilliSatoshiLong, ShortChannelId, UInt64, UnknownFeature, randomBytes32, randomKey} -import org.scalatest.Tag +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, FeatureSupport, Features, MilliSatoshiLong, ShortChannelId, UInt64, UnknownFeature, randomKey} import org.scalatest.funsuite.AnyFunSuiteLike import scodec.bits.{ByteVector, HexStringSyntax} @@ -170,70 +167,4 @@ class RouteBlindingSpec extends AnyFunSuiteLike { } } - test("create blinded route without hops") { - val a = randomKey() - val pathId = randomBytes32() - val route = createBlindedRouteWithoutHops(a.publicKey, pathId, 1 msat, CltvExpiry(500)) - assert(route.route.introductionNodeId == a.publicKey) - assert(route.route.encryptedPayloads.length == 1) - assert(route.route.blindingKey == route.lastBlinding) - val Right(decoded) = RouteBlindingEncryptedDataCodecs.decode(a, route.route.blindingKey, route.route.encryptedPayloads.head) - assert(BlindedRouteData.validPaymentRecipientData(decoded.tlvs).isRight) - assert(decoded.tlvs.get[PathId].get.data == pathId.bytes) - } - - test("create blinded route from channel hops") { - val (a, b, c) = (randomKey(), randomKey(), randomKey()) - val pathId = randomBytes32() - val (channelId1, channelId2) = (ShortChannelId(1), ShortChannelId(2)) - val hops = Seq( - Router.ChannelHop(channelId1, a.publicKey, b.publicKey, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(a.publicKey, b.publicKey, channelId1, 10 msat, 300, CltvExpiryDelta(200)))), - Router.ChannelHop(channelId2, b.publicKey, c.publicKey, Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(b.publicKey, c.publicKey, channelId2, 20 msat, 150, CltvExpiryDelta(600)))), - ) - val route = createBlindedRouteFromHops(hops, pathId, 1 msat, CltvExpiry(500)) - assert(route.route.introductionNodeId == a.publicKey) - assert(route.route.encryptedPayloads.length == 3) - val Right(decoded1) = RouteBlindingEncryptedDataCodecs.decode(a, route.route.blindingKey, route.route.encryptedPayloads(0)) - assert(BlindedRouteData.validatePaymentRelayData(decoded1.tlvs).isRight) - assert(decoded1.tlvs.get[OutgoingChannelId].get.shortChannelId == channelId1) - assert(decoded1.tlvs.get[PaymentRelay].get.feeBase == 10.msat) - assert(decoded1.tlvs.get[PaymentRelay].get.feeProportionalMillionths == 300) - assert(decoded1.tlvs.get[PaymentRelay].get.cltvExpiryDelta == CltvExpiryDelta(200)) - val Right(decoded2) = RouteBlindingEncryptedDataCodecs.decode(b, decoded1.nextBlinding, route.route.encryptedPayloads(1)) - assert(BlindedRouteData.validatePaymentRelayData(decoded2.tlvs).isRight) - assert(decoded2.tlvs.get[OutgoingChannelId].get.shortChannelId == channelId2) - assert(decoded2.tlvs.get[PaymentRelay].get.feeBase == 20.msat) - assert(decoded2.tlvs.get[PaymentRelay].get.feeProportionalMillionths == 150) - assert(decoded2.tlvs.get[PaymentRelay].get.cltvExpiryDelta == CltvExpiryDelta(600)) - val Right(decoded3) = RouteBlindingEncryptedDataCodecs.decode(c, decoded2.nextBlinding, route.route.encryptedPayloads(2)) - assert(BlindedRouteData.validPaymentRecipientData(decoded3.tlvs).isRight) - assert(decoded3.tlvs.get[PathId].get.data == pathId.bytes) - } - - test("create blinded route payment info", Tag("fuzzy")) { - val rand = new scala.util.Random() - val nodeId = randomKey().publicKey - for (_ <- 0 to 100) { - val routeLength = rand.nextInt(10) + 1 - val hops = (1 to routeLength).map(i => { - val scid = ShortChannelId(i) - val feeBase = rand.nextInt(10_000).msat - val feeProp = rand.nextInt(5000) - val cltvExpiryDelta = CltvExpiryDelta(rand.nextInt(500)) - val params = Router.ChannelRelayParams.FromHint(Invoice.BasicEdge(nodeId, nodeId, scid, feeBase, feeProp, cltvExpiryDelta)) - Router.ChannelHop(scid, nodeId, nodeId, params) - }) - for (_ <- 0 to 100) { - val amount = rand.nextLong(10_000_000_000L).msat - val payInfo = OfferTypes.PaymentInfo(amount, hops) - assert(payInfo.cltvExpiryDelta == CltvExpiryDelta(hops.map(_.cltvExpiryDelta.toInt).sum)) - // We verify that the aggregated fee slightly exceeds the actual fee (because of proportional fees rounding). - val aggregatedFee = payInfo.fee(amount) - val actualFee = Router.Route(amount, hops).fee(includeLocalChannelCost = true) - assert(aggregatedFee >= actualFee, s"amount=$amount, hops=${hops.map(_.params.relayFees)}, aggregatedFee=$aggregatedFee, actualFee=$actualFee") - assert(aggregatedFee - actualFee < 1000.msat.max(amount * 1e-5), s"amount=$amount, hops=${hops.map(_.params.relayFees)}, aggregatedFee=$aggregatedFee, actualFee=$actualFee") - } - } - } - }