From 002a74116c5cf68b8ce5ce435292e78c8f732ed6 Mon Sep 17 00:00:00 2001 From: t-bast Date: Wed, 7 Dec 2022 15:42:59 +0100 Subject: [PATCH 1/2] Send payments to blinded routes Since blinded routes have to be used from start to end and are somewhat similar to Bolt 11 routing hints, we model them as an abstract single hop during path-finding. This makes it trivial to reuse existing algorithms without any modifications. We then add support for paying blinded routes. We introduce a new type of recipient for those payments, that uses blinded hops and creates onion payloads accordingly. There is a subtlety in the case where we're the introduction of the blinded route: when that happens we need to decrypt the first payload to figure out where to send the payment. When we receive a failure from a blinded route, we simply ignore it in retries: we don't know what caused the issue so we assume it's permanent, which makes sense in most cases since we cannot change the relaying parameters (fees and expiry delta are chosen by the recipient). --- .../main/scala/fr/acinq/eclair/Eclair.scala | 5 +- .../scala/fr/acinq/eclair/crypto/Sphinx.scala | 1 + .../scala/fr/acinq/eclair/db/PaymentsDb.scala | 3 +- .../acinq/eclair/json/JsonSerializers.scala | 7 +- .../acinq/eclair/payment/PaymentEvents.scala | 24 +- .../acinq/eclair/payment/PaymentPacket.scala | 44 ++- .../send/MultiPartPaymentLifecycle.scala | 2 +- .../payment/send/PaymentInitiator.scala | 78 ++-- .../payment/send/PaymentLifecycle.scala | 19 +- .../acinq/eclair/payment/send/Recipient.scala | 89 ++++- .../eclair/router/RouteCalculation.scala | 31 +- .../scala/fr/acinq/eclair/router/Router.scala | 31 ++ .../eclair/wire/protocol/OfferCodecs.scala | 5 +- .../eclair/wire/protocol/PaymentOnion.scala | 41 +- .../eclair/wire/protocol/RouteBlinding.scala | 2 +- .../fr/acinq/eclair/channel/FuzzySpec.scala | 2 +- .../ChannelStateTestsHelperMethods.scala | 2 +- .../channel/states/f/ShutdownStateSpec.scala | 4 +- .../eclair/json/JsonSerializersSpec.scala | 4 +- .../eclair/message/OnionMessagesSpec.scala | 4 - .../eclair/payment/Bolt12InvoiceSpec.scala | 13 +- .../MultiPartPaymentLifecycleSpec.scala | 76 +++- .../eclair/payment/PaymentInitiatorSpec.scala | 99 ++++- .../eclair/payment/PaymentLifecycleSpec.scala | 156 ++++++-- .../eclair/payment/PaymentPacketSpec.scala | 352 ++++++++++++++++-- .../payment/PostRestartHtlcCleanerSpec.scala | 8 +- .../eclair/payment/relay/RelayerSpec.scala | 10 +- .../acinq/eclair/router/BaseRouterSpec.scala | 40 ++ .../fr/acinq/eclair/router/RouterSpec.scala | 81 +++- 29 files changed, 1049 insertions(+), 184 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index ed4f63a023..6101ab5693 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -566,7 +566,10 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging { randomKey(), randomKey(), intermediateNodes.map(OnionMessages.IntermediateNode(_)), - destination match { case Left(key) => OnionMessages.Recipient(key, None) case Right(route) => OnionMessages.BlindedPath(route) }, + destination match { + case Left(key) => OnionMessages.Recipient(key, None) + case Right(route) => OnionMessages.BlindedPath(route) + }, replyRoute.map(OnionMessagePayloadTlv.ReplyPath(_) :: Nil).getOrElse(Nil), userCustomTlvs) match { case Success((nextNodeId, message)) => diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala index bb50bdd555..1f9f27f37e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala @@ -367,6 +367,7 @@ object Sphinx extends Logging { val subsequentNodes: Seq[BlindedNode] = blindedNodes.tail val blindedNodeIds: Seq[PublicKey] = blindedNodes.map(_.blindedPublicKey) val encryptedPayloads: Seq[ByteVector] = blindedNodes.map(_.encryptedPayload) + val length: Int = blindedNodes.length - 1 } /** diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala index de4398403e..5f24180c9d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala @@ -19,7 +19,7 @@ package fr.acinq.eclair.db import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.payment._ -import fr.acinq.eclair.router.Router.{ChannelHop, Hop, NodeHop} +import fr.acinq.eclair.router.Router.{BlindedHop, ChannelHop, Hop, NodeHop} import fr.acinq.eclair.{MilliSatoshi, Paginated, ShortChannelId, TimestampMilli} import scodec.bits.ByteVector @@ -226,6 +226,7 @@ object HopSummary { def apply(h: Hop): HopSummary = { val shortChannelId = h match { case ch: ChannelHop => Some(ch.shortChannelId) + case _: BlindedHop => None case _: NodeHop => None } HopSummary(h.nodeId, h.nextNodeId, shortChannelId) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala b/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala index 4fb7ea15bc..84db12f7b5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala @@ -31,7 +31,7 @@ import fr.acinq.eclair.io.Peer import fr.acinq.eclair.message.OnionMessages import fr.acinq.eclair.payment.PaymentFailure.PaymentFailedSummary import fr.acinq.eclair.payment._ -import fr.acinq.eclair.router.Router.{HopRelayParams, NodeHop, Route} +import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.transactions.DirectedHtlc import fr.acinq.eclair.transactions.Transactions._ import fr.acinq.eclair.wire.protocol.MessageOnionCodecs.blindedRouteCodec @@ -296,12 +296,14 @@ object ColorSerializer extends MinimalSerializer({ // @formatter:off private sealed trait HopJson private case class ChannelHopJson(nodeId: PublicKey, nextNodeId: PublicKey, source: HopRelayParams) extends HopJson +private case class BlindedHopJson(nodeId: PublicKey, nextNodeId: PublicKey, paymentInfo: OfferTypes.PaymentInfo) extends HopJson private case class NodeHopJson(nodeId: PublicKey, nextNodeId: PublicKey, fee: MilliSatoshi, cltvExpiryDelta: CltvExpiryDelta) extends HopJson private case class RouteFullJson(amount: MilliSatoshi, hops: Seq[HopJson]) object RouteFullSerializer extends ConvertClassSerializer[Route](route => { val channelHops = route.hops.map(h => ChannelHopJson(h.nodeId, h.nextNodeId, h.params)) val finalHop_opt = route.finalHop_opt.map { case h: NodeHop => NodeHopJson(h.nodeId, h.nextNodeId, h.fee, h.cltvExpiryDelta) + case h: BlindedHop => BlindedHopJson(h.nodeId, h.nextNodeId, h.paymentInfo) } RouteFullJson(route.amount, channelHops ++ finalHop_opt.toSeq) }) @@ -315,6 +317,8 @@ object RouteNodeIdsSerializer extends ConvertClassSerializer[Route](route => { val finalNodeIds = route.finalHop_opt match { case Some(hop: NodeHop) if channelNodeIds.nonEmpty => Seq(hop.nextNodeId) case Some(hop: NodeHop) => Seq(hop.nodeId, hop.nextNodeId) + case Some(hop: BlindedHop) if channelNodeIds.nonEmpty => hop.route.blindedNodeIds.tail + case Some(hop: BlindedHop) => hop.route.introductionNodeId +: hop.route.blindedNodeIds.tail case None => Nil } RouteNodeIdsJson(route.amount, channelNodeIds ++ finalNodeIds) @@ -325,6 +329,7 @@ object RouteShortChannelIdsSerializer extends ConvertClassSerializer[Route](rout val hops = route.hops.map(_.shortChannelId) val finalHop = route.finalHop_opt.map { case _: NodeHop => "trampoline" + case _: BlindedHop => "blinded" } RouteShortChannelIdsJson(route.amount, hops, finalHop) }) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala index 98a579785c..cb4d139c00 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala @@ -24,8 +24,8 @@ import fr.acinq.eclair.payment.send.PaymentError.RetryExhausted import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.{ClearRecipient, Recipient} import fr.acinq.eclair.router.Announcements -import fr.acinq.eclair.router.Router.{ChannelDesc, ChannelHop, Hop, Ignore} -import fr.acinq.eclair.wire.protocol.{ChannelDisabled, ChannelUpdate, Node, TemporaryChannelFailure} +import fr.acinq.eclair.router.Router._ +import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{MilliSatoshi, ShortChannelId, TimestampMilli} import scodec.bits.ByteVector @@ -183,11 +183,15 @@ object PaymentFailure { .isDefined /** Ignore the channel outgoing from the given nodeId in the given route. */ - private def ignoreNodeOutgoingChannel(nodeId: PublicKey, hops: Seq[Hop], ignore: Ignore): Ignore = { + private def ignoreNodeOutgoingEdge(nodeId: PublicKey, hops: Seq[Hop], ignore: Ignore): Ignore = { hops.collectFirst { case hop: ChannelHop if hop.nodeId == nodeId => ChannelDesc(hop.shortChannelId, hop.nodeId, hop.nextNodeId) + case hop: BlindedHop if hop.nodeId == nodeId => ChannelDesc(hop.dummyId, hop.nodeId, hop.nextNodeId) + // The error comes from inside the blinded route: this is a spec violation, errors should always come from the + // introduction node, so we definitely want to ignore this blinded route when this happens. + case hop: BlindedHop if hop.route.blindedNodeIds.contains(nodeId) => ChannelDesc(hop.dummyId, hop.nodeId, hop.nextNodeId) } match { - case Some(faultyChannel) => ignore + faultyChannel + case Some(faultyEdge) => ignore + faultyEdge case None => ignore } } @@ -207,7 +211,7 @@ object PaymentFailure { case _ => false } if (shouldIgnore) { - ignoreNodeOutgoingChannel(nodeId, hops, ignore) + ignoreNodeOutgoingEdge(nodeId, hops, ignore) } else { // We were using an outdated channel update, we should retry with the new one and nobody should be penalized. ignore @@ -217,10 +221,14 @@ object PaymentFailure { ignore + nodeId } case RemoteFailure(_, hops, Sphinx.DecryptedFailurePacket(nodeId, _)) => - ignoreNodeOutgoingChannel(nodeId, hops, ignore) + ignoreNodeOutgoingEdge(nodeId, hops, ignore) case UnreadableRemoteFailure(_, hops) => - // We don't know which node is sending garbage, let's blacklist all nodes except the one we are directly connected to and the final recipient. - val blacklist = hops.map(_.nextNodeId).drop(1).dropRight(1).toSet + // We don't know which node is sending garbage, let's blacklist all nodes except: + // - the one we are directly connected to: it would be too restrictive for retries + // - the final recipient: they have no incentive to send garbage since they want that payment + // - the introduction point of a blinded route: we don't want a node before the blinded path to force us to ignore that blinded path + // - the trampoline node: we don't want a node before the trampoline node to force us to ignore that trampoline node + val blacklist = hops.collect { case hop: ChannelHop => hop }.map(_.nextNodeId).drop(1).dropRight(1).toSet ignore ++ blacklist case LocalFailure(_, hops, _) => hops.headOption match { case Some(hop: ChannelHop) => diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala index a063d74c7b..9410dbce35 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala @@ -23,7 +23,7 @@ import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.eclair.channel.{CMD_ADD_HTLC, CMD_FAIL_HTLC, CannotExtractSharedSecret, Origin} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.send.Recipient -import fr.acinq.eclair.router.Router.Route +import fr.acinq.eclair.router.Router.{BlindedHop, ChannelHop, Route} import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload, PerHopPayload} import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, ShortChannelId, UInt64, randomKey} @@ -182,6 +182,8 @@ object IncomingPaymentPacket { case payload if add.amountMsat < payload.paymentConstraints.minAmount => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) case payload if add.cltvExpiry > payload.paymentConstraints.maxCltvExpiry => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) case payload if !Features.areCompatible(Features.empty, payload.allowedFeatures) => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) + case payload if add.amountMsat < payload.amount => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) + case payload if add.cltvExpiry < payload.expiry => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) case payload => Right(FinalPacket(add, payload)) } } @@ -231,8 +233,10 @@ object OutgoingPaymentPacket { sealed trait OutgoingPaymentError extends Throwable case class CannotCreateOnion(message: String) extends OutgoingPaymentError { override def getMessage: String = message } + case class CannotDecryptBlindedRoute(message: String) extends OutgoingPaymentError { override def getMessage: String = message } case class InvalidRouteRecipient(expected: PublicKey, actual: PublicKey) extends OutgoingPaymentError { override def getMessage: String = s"expected route to $expected, got route to $actual" } case class MissingTrampolineHop(trampolineNodeId: PublicKey) extends OutgoingPaymentError { override def getMessage: String = s"expected route to trampoline node $trampolineNodeId" } + case class MissingBlindedHop(introductionNodeIds: Set[PublicKey]) extends OutgoingPaymentError { override def getMessage: String = s"expected blinded route using one of the following introduction nodes: ${introductionNodeIds.mkString(", ")}" } case object EmptyRoute extends OutgoingPaymentError { override def getMessage: String = "route cannot be empty" } sealed trait Upstream @@ -261,15 +265,41 @@ object OutgoingPaymentPacket { } } + private case class OutgoingPaymentWithChannel(shortChannelId: ShortChannelId, nextBlinding_opt: Option[PublicKey], payment: PaymentPayloads) + + private def getOutgoingChannel(privateKey: PrivateKey, payment: PaymentPayloads, route: Route): Either[OutgoingPaymentError, OutgoingPaymentWithChannel] = { + route.hops.headOption match { + case Some(hop) => Right(OutgoingPaymentWithChannel(hop.shortChannelId, None, payment)) + case None => route.finalHop_opt match { + case Some(hop: BlindedHop) => + // We are the introduction node of the blinded route: we need to decrypt the first payload. + val firstBlinding = hop.route.introductionNode.blindingEphemeralKey + val firstEncryptedPayload = hop.route.introductionNode.encryptedPayload + RouteBlindingEncryptedDataCodecs.decode(privateKey, firstBlinding, firstEncryptedPayload) match { + case Left(e) => Left(CannotDecryptBlindedRoute(e.message)) + case Right(decoded) => + val tlvs = TlvStream[OnionPaymentPayloadTlv](OnionPaymentPayloadTlv.EncryptedRecipientData(firstEncryptedPayload), OnionPaymentPayloadTlv.BlindingPoint(firstBlinding)) + IntermediatePayload.ChannelRelay.Blinded.validate(tlvs, decoded.tlvs, decoded.nextBlinding) match { + case Left(e) => Left(CannotDecryptBlindedRoute(e.failureMessage.message)) + case Right(payload) => + val payment1 = PaymentPayloads(payload.amountToForward(payment.amount), payload.outgoingCltv(payment.expiry), payment.payloads.tail) + Right(OutgoingPaymentWithChannel(payload.outgoingChannelId, Some(decoded.nextBlinding), payment1)) + } + } + case _ => Left(EmptyRoute) + } + } + } + /** Build the command to add an HTLC for the given recipient using the provided route. */ - def buildOutgoingPayment(replyTo: ActorRef, upstream: Upstream, paymentHash: ByteVector32, route: Route, recipient: Recipient): Either[OutgoingPaymentError, OutgoingPaymentPacket] = { - val outgoingChannel = route.hops.head.shortChannelId + def buildOutgoingPayment(replyTo: ActorRef, privateKey: PrivateKey, upstream: Upstream, paymentHash: ByteVector32, route: Route, recipient: Recipient): Either[OutgoingPaymentError, OutgoingPaymentPacket] = { for { - payment <- recipient.buildPayloads(paymentHash, route) - onion <- buildOnion(PaymentOnionCodecs.paymentOnionPayloadLength, payment.payloads, paymentHash) // BOLT 2 requires that associatedData == paymentHash + paymentTmp <- recipient.buildPayloads(paymentHash, route) + outgoing <- getOutgoingChannel(privateKey, paymentTmp, route) + onion <- buildOnion(PaymentOnionCodecs.paymentOnionPayloadLength, outgoing.payment.payloads, paymentHash) // BOLT 2 requires that associatedData == paymentHash } yield { - val cmd = CMD_ADD_HTLC(replyTo, payment.amount, paymentHash, payment.expiry, onion.packet, None, Origin.Hot(replyTo, upstream), commit = true) - OutgoingPaymentPacket(cmd, outgoingChannel, onion.sharedSecrets) + val cmd = CMD_ADD_HTLC(replyTo, outgoing.payment.amount, paymentHash, outgoing.payment.expiry, onion.packet, outgoing.nextBlinding_opt, Origin.Hot(replyTo, upstream), commit = true) + OutgoingPaymentPacket(cmd, outgoing.shortChannelId, onion.sharedSecrets) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala index 68574d0d0f..97fb1d8721 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala @@ -97,7 +97,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, if (cfg.storeInDb && d.pending.isEmpty && d.failures.isEmpty) { // In cases where we fail early (router error during the first attempt), the DB won't have an entry for that // payment, which may be confusing for users. - val dummyPayment = OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, d.request.recipient.totalAmount, d.request.recipient.totalAmount, d.request.recipient.nodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending) + val dummyPayment = OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, cfg.paymentType, d.request.recipient.totalAmount, d.request.recipient.totalAmount, d.request.recipient.nodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending) nodeParams.db.payments.addOutgoingPayment(dummyPayment) nodeParams.db.payments.updateOutgoingPayment(PaymentFailed(id, paymentHash, failure :: Nil)) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala index 2a05451dc3..f0fc579f2d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala @@ -21,6 +21,7 @@ import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto} import fr.acinq.eclair.channel.fsm.Channel import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.db.PaymentType import fr.acinq.eclair.payment.OutgoingPaymentPacket.Upstream import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.send.PaymentError._ @@ -50,22 +51,20 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn } val paymentCfg = SendPaymentConfig(paymentId, paymentId, r.externalId, r.paymentHash, r.invoice.nodeId, Upstream.Local(paymentId), Some(r.invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = true) val finalExpiry = r.finalExpiry(nodeParams) - r.invoice match { - case invoice: Bolt11Invoice => - val recipient = ClearRecipient(invoice, r.recipientAmount, finalExpiry, r.userCustomTlvs) - if (!nodeParams.features.invoiceFeatures().areSupported(recipient.features)) { - sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, UnsupportedFeatures(recipient.features)) :: Nil) - } else if (Features.canUseFeature(nodeParams.features.invoiceFeatures(), recipient.features, Features.BasicMultiPartPayment)) { - val fsm = outgoingPaymentFactory.spawnOutgoingMultiPartPayment(context, paymentCfg) - fsm ! MultiPartPaymentLifecycle.SendMultiPartPayment(self, recipient, r.maxAttempts, r.routeParams) - context become main(pending + (paymentId -> PendingPaymentToNode(sender(), r))) - } else { - val fsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) - fsm ! PaymentLifecycle.SendPaymentToNode(self, recipient, r.maxAttempts, r.routeParams) - context become main(pending + (paymentId -> PendingPaymentToNode(sender(), r))) - } - case _: Bolt12Invoice => - sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, new IllegalArgumentException("payments to Bolt12 invoices are not supported yet")) :: Nil) + val recipient = r.invoice match { + case invoice: Bolt11Invoice => ClearRecipient(invoice, r.recipientAmount, finalExpiry, r.userCustomTlvs) + case invoice: Bolt12Invoice => BlindedRecipient(invoice, r.recipientAmount, finalExpiry, r.userCustomTlvs) + } + if (!nodeParams.features.invoiceFeatures().areSupported(recipient.features)) { + sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, UnsupportedFeatures(recipient.features)) :: Nil) + } else if (Features.canUseFeature(nodeParams.features.invoiceFeatures(), recipient.features, Features.BasicMultiPartPayment)) { + val fsm = outgoingPaymentFactory.spawnOutgoingMultiPartPayment(context, paymentCfg) + fsm ! MultiPartPaymentLifecycle.SendMultiPartPayment(self, recipient, r.maxAttempts, r.routeParams) + context become main(pending + (paymentId -> PendingPaymentToNode(sender(), r))) + } else { + val fsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) + fsm ! PaymentLifecycle.SendPaymentToNode(self, recipient, r.maxAttempts, r.routeParams) + context become main(pending + (paymentId -> PendingPaymentToNode(sender(), r))) } case r: SendSpontaneousPayment => @@ -119,18 +118,16 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, t) :: Nil) } case None => - r.invoice match { - case invoice: Bolt11Invoice => - sender() ! SendPaymentToRouteResponse(paymentId, parentPaymentId, None) - val paymentCfg = SendPaymentConfig(paymentId, parentPaymentId, r.externalId, r.paymentHash, r.recipientNodeId, Upstream.Local(paymentId), Some(r.invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = false) - val finalExpiry = r.finalExpiry(nodeParams) - val recipient = ClearRecipient(invoice, r.recipientAmount, finalExpiry, Nil) - val payFsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) - payFsm ! PaymentLifecycle.SendPaymentToRoute(self, Left(r.route), recipient) - context become main(pending + (paymentId -> PendingPaymentToRoute(sender(), r))) - case _: Bolt12Invoice => - sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, new IllegalArgumentException("payments to Bolt12 invoices are not supported yet")) :: Nil) + sender() ! SendPaymentToRouteResponse(paymentId, parentPaymentId, None) + val paymentCfg = SendPaymentConfig(paymentId, parentPaymentId, r.externalId, r.paymentHash, r.recipientNodeId, Upstream.Local(paymentId), Some(r.invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = false) + val finalExpiry = r.finalExpiry(nodeParams) + val recipient = r.invoice match { + case invoice: Bolt11Invoice => ClearRecipient(invoice, r.recipientAmount, finalExpiry, Nil) + case invoice: Bolt12Invoice => BlindedRecipient(invoice, r.recipientAmount, finalExpiry, Nil) } + val payFsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) + payFsm ! PaymentLifecycle.SendPaymentToRoute(self, Left(r.route), recipient) + context become main(pending + (paymentId -> PendingPaymentToRoute(sender(), r))) case _ => sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, TrampolineMultiNodeNotSupported) :: Nil) } @@ -195,18 +192,15 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn } private def buildTrampolineRecipient(r: SendRequestedPayment, trampolineHop: NodeHop): Try[ClearTrampolineRecipient] = { + // We generate a random secret for the payment to the trampoline node. + val trampolineSecret = r match { + case r: SendPaymentToRoute => r.trampoline_opt.map(_.paymentSecret).getOrElse(randomBytes32()) + case _ => randomBytes32() + } + val finalExpiry = r.finalExpiry(nodeParams) r.invoice match { - case invoice: Bolt11Invoice => - // We generate a random secret for the payment to the trampoline node. - val trampolineSecret = r match { - case r: SendPaymentToRoute => r.trampoline_opt.map(_.paymentSecret).getOrElse(randomBytes32()) - case _ => randomBytes32() - } - val finalExpiry = r.finalExpiry(nodeParams) - val recipient = ClearTrampolineRecipient(invoice, r.recipientAmount, finalExpiry, trampolineHop, trampolineSecret) - Success(recipient) - case _: Bolt12Invoice => - Failure(new IllegalArgumentException("payments to Bolt12 invoices are not supported yet")) + case invoice: Bolt11Invoice => Success(ClearTrampolineRecipient(invoice, r.recipientAmount, finalExpiry, trampolineHop, trampolineSecret)) + case _: Bolt12Invoice => Failure(new IllegalArgumentException("trampoline blinded payments are not supported yet")) } } @@ -403,9 +397,13 @@ object PaymentInitiator { storeInDb: Boolean, // e.g. for trampoline we don't want to store in the DB when we're relaying payments publishEvent: Boolean, recordPathFindingMetrics: Boolean) { - def createPaymentSent(recipient: Recipient, preimage: ByteVector32, parts: Seq[PaymentSent.PartialPayment]) = PaymentSent(parentId, paymentHash, preimage, recipient.totalAmount, recipient.nodeId, parts) + val paymentContext: PaymentContext = PaymentContext(id, parentId, paymentHash) + val paymentType = invoice match { + case Some(_: Bolt12Invoice) => PaymentType.Blinded + case _ => PaymentType.Standard + } - def paymentContext: PaymentContext = PaymentContext(id, parentId, paymentHash) + def createPaymentSent(recipient: Recipient, preimage: ByteVector32, parts: Seq[PaymentSent.PartialPayment]) = PaymentSent(parentId, paymentHash, preimage, recipient.totalAmount, recipient.nodeId, parts) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala index c5e0476e96..066e8a0ab5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala @@ -60,7 +60,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A route => self ! RouteResponse(route :: Nil) ) if (cfg.storeInDb) { - paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, request.amount, request.recipient.totalAmount, request.recipient.nodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending)) + paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, cfg.paymentType, request.amount, request.recipient.totalAmount, request.recipient.nodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending)) } goto(WAITING_FOR_ROUTE) using WaitingForRoute(request, Nil, Ignore.empty) @@ -68,7 +68,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A log.debug("sending {} to {}", request.amount, request.recipient.nodeId) router ! RouteRequest(nodeParams.nodeId, request.recipient, request.routeParams, paymentContext = Some(cfg.paymentContext)) if (cfg.storeInDb) { - paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, request.amount, request.recipient.totalAmount, request.recipient.nodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending)) + paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, cfg.paymentType, request.amount, request.recipient.totalAmount, request.recipient.nodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending)) } goto(WAITING_FOR_ROUTE) using WaitingForRoute(request, Nil, Ignore.empty) } @@ -76,7 +76,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A when(WAITING_FOR_ROUTE) { case Event(RouteResponse(route +: _), WaitingForRoute(request, failures, ignore)) => log.info(s"route found: attempt=${failures.size + 1}/${request.maxAttempts} route=${route.printNodes()} channels=${route.printChannels()}") - OutgoingPaymentPacket.buildOutgoingPayment(self, cfg.upstream, paymentHash, route, request.recipient) match { + OutgoingPaymentPacket.buildOutgoingPayment(self, nodeParams.privateKey, cfg.upstream, paymentHash, route, request.recipient) match { case Right(payment) => register ! Register.ForwardShortId(self.toTyped[Register.ForwardShortIdFailure[CMD_ADD_HTLC]], payment.outgoingChannel, payment.cmd) goto(WAITING_FOR_PAYMENT_COMPLETE) using WaitingForComplete(request, payment.cmd, failures, payment.sharedSecrets, ignore, route) @@ -252,6 +252,19 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A goto(WAITING_FOR_ROUTE) using WaitingForRoute(request, failures :+ failure, ignore + nodeId) } } + case Success(e@Sphinx.DecryptedFailurePacket(nodeId, _: InvalidOnionBlinding)) => + // there was a failure inside the blinded route we used: we cannot know why it failed, so let's ignore it. + log.info(s"received an error coming from nodeId=$nodeId inside the blinded route, retrying with different blinded routes") + val failure = RemoteFailure(request.amount, route.fullRoute, e) + val ignore1 = PaymentFailure.updateIgnored(failure, ignore) + request match { + case _: SendPaymentToRoute => + log.error("unexpected retry during SendPaymentToRoute") + stop(FSM.Normal) + case request: SendPaymentToNode => + router ! RouteRequest(nodeParams.nodeId, recipient, request.routeParams, ignore1, paymentContext = Some(cfg.paymentContext)) + goto(WAITING_FOR_ROUTE) using WaitingForRoute(request, failures :+ failure, ignore1) + } case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) => log.info(s"received an error message from nodeId=$nodeId, trying to use a different channel (failure=$failureMessage)") val failure = RemoteFailure(request.amount, route.fullRoute, e) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala index ed0690c400..3e599bd255 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala @@ -21,9 +21,9 @@ import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.Invoice.ExtraEdge import fr.acinq.eclair.payment.OutgoingPaymentPacket._ -import fr.acinq.eclair.payment.{Bolt11Invoice, OutgoingPaymentPacket} -import fr.acinq.eclair.router.Router.{ChannelHop, NodeHop, Route} -import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload} +import fr.acinq.eclair.payment.{Bolt11Invoice, Bolt12Invoice, OutgoingPaymentPacket} +import fr.acinq.eclair.router.Router._ +import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload, OutgoingBlindedPerHopPayload} import fr.acinq.eclair.wire.protocol.{GenericTlv, OnionRoutingPacket, PaymentOnionCodecs} import fr.acinq.eclair.{CltvExpiry, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, ShortChannelId} import scodec.bits.ByteVector @@ -54,9 +54,9 @@ sealed trait Recipient { object Recipient { /** Iteratively build all the payloads for a payment relayed through channel hops. */ - def buildPayloads(finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, finalPayload: NodePayload, hops: Seq[ChannelHop]): PaymentPayloads = { + def buildPayloads(finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, finalPayloads: Seq[NodePayload], hops: Seq[ChannelHop]): PaymentPayloads = { // We ignore the first hop since the route starts at our node. - hops.tail.foldRight(PaymentPayloads(finalAmount, finalExpiry, Seq(finalPayload))) { + hops.tail.foldRight(PaymentPayloads(finalAmount, finalExpiry, finalPayloads)) { case (hop, current) => val payload = NodePayload(hop.nodeId, IntermediatePayload.ChannelRelay.Standard(hop.shortChannelId, current.amount, current.expiry)) PaymentPayloads(current.amount + hop.fee(current.amount), current.expiry + hop.cltvExpiryDelta, payload +: current.payloads) @@ -80,7 +80,7 @@ case class ClearRecipient(nodeId: PublicKey, case Some(trampolinePacket) => NodePayload(nodeId, FinalPayload.Standard.createTrampolinePayload(route.amount, totalAmount, expiry, paymentSecret, trampolinePacket)) case None => NodePayload(nodeId, FinalPayload.Standard.createPayload(route.amount, totalAmount, expiry, paymentSecret, paymentMetadata_opt, customTlvs)) } - Recipient.buildPayloads(route.amount, expiry, finalPayload, route.hops) + Recipient.buildPayloads(route.amount, expiry, Seq(finalPayload), route.hops) }) } } @@ -111,12 +111,81 @@ case class SpontaneousRecipient(nodeId: PublicKey, override def buildPayloads(paymentHash: ByteVector32, route: Route): Either[OutgoingPaymentError, PaymentPayloads] = { ClearRecipient.validateRoute(nodeId, route).map(_ => { val finalPayload = NodePayload(nodeId, FinalPayload.Standard.createKeySendPayload(route.amount, totalAmount, expiry, preimage, customTlvs)) - Recipient.buildPayloads(totalAmount, expiry, finalPayload, route.hops) + Recipient.buildPayloads(totalAmount, expiry, Seq(finalPayload), route.hops) }) } } -/** A payment recipient that can be reached through a given trampoline node (usually not found in the routing graph). */ +/** A payment recipient that hides its real identity using route blinding. */ +case class BlindedRecipient(nodeId: PublicKey, + features: Features[InvoiceFeature], + totalAmount: MilliSatoshi, + expiry: CltvExpiry, + blindedHops: Seq[BlindedHop], + customTlvs: Seq[GenericTlv] = Nil) extends Recipient { + require(blindedHops.nonEmpty, "blinded routes must be provided") + + override val extraEdges = blindedHops.map { h => + ExtraEdge(h.route.introductionNodeId, nodeId, h.dummyId, h.paymentInfo.feeBase, h.paymentInfo.feeProportionalMillionths, h.paymentInfo.cltvExpiryDelta, h.paymentInfo.minHtlc, Some(h.paymentInfo.maxHtlc)) + } + + private def validateRoute(route: Route): Either[OutgoingPaymentError, BlindedHop] = { + route.finalHop_opt match { + case Some(blindedHop: BlindedHop) => Right(blindedHop) + case _ => Left(MissingBlindedHop(blindedHops.map(_.route.introductionNodeId).toSet)) + } + } + + private def buildBlindedPayloads(amount: MilliSatoshi, blindedHop: BlindedHop): PaymentPayloads = { + val blinding = blindedHop.route.introductionNode.blindingEphemeralKey + val payloads = if (blindedHop.route.subsequentNodes.isEmpty) { + // The recipient is also the introduction node. + Seq(NodePayload(blindedHop.route.introductionNodeId, OutgoingBlindedPerHopPayload.createFinalIntroductionPayload(amount, totalAmount, expiry, blinding, blindedHop.route.introductionNode.encryptedPayload, customTlvs))) + } else { + val introductionPayload = NodePayload(blindedHop.route.introductionNodeId, OutgoingBlindedPerHopPayload.createIntroductionPayload(blindedHop.route.introductionNode.encryptedPayload, blinding)) + val intermediatePayloads = blindedHop.route.subsequentNodes.dropRight(1).map(n => NodePayload(n.blindedPublicKey, OutgoingBlindedPerHopPayload.createIntermediatePayload(n.encryptedPayload))) + val finalPayload = NodePayload(blindedHop.route.blindedNodes.last.blindedPublicKey, OutgoingBlindedPerHopPayload.createFinalPayload(amount, totalAmount, expiry, blindedHop.route.blindedNodes.last.encryptedPayload, customTlvs)) + introductionPayload +: intermediatePayloads :+ finalPayload + } + val introductionAmount = amount + blindedHop.paymentInfo.fee(amount) + val introductionExpiry = expiry + blindedHop.paymentInfo.cltvExpiryDelta + PaymentPayloads(introductionAmount, introductionExpiry, payloads) + } + + override def buildPayloads(paymentHash: ByteVector32, route: Route): Either[OutgoingPaymentError, PaymentPayloads] = { + validateRoute(route).map(blindedHop => { + val blindedPayloads = buildBlindedPayloads(route.amount, blindedHop) + if (route.hops.isEmpty) { + // We are the introduction node of the blinded route. + blindedPayloads + } else { + Recipient.buildPayloads(blindedPayloads.amount, blindedPayloads.expiry, blindedPayloads.payloads, route.hops) + } + }) + } +} + +object BlindedRecipient { + def apply(invoice: Bolt12Invoice, totalAmount: MilliSatoshi, expiry: CltvExpiry, customTlvs: Seq[GenericTlv]): BlindedRecipient = { + val blindedHops = invoice.blindedPaths.zip(invoice.blindedPathsInfo).map { + case (route, info) => + // We don't know the scids of channels inside the blinded route, but it's useful to have an ID to refer to a + // given edge in the graph, so we create a dummy one for the duration of the payment attempt. + val dummyId = ShortChannelId.generateLocalAlias() + BlindedHop(dummyId, route, info) + } + BlindedRecipient(invoice.nodeId, invoice.features, totalAmount, expiry, blindedHops, customTlvs) + } +} + +/** + * A payment recipient that can be reached through a trampoline node (such recipients usually cannot be found in the + * public graph). Splitting a payment across multiple trampoline nodes is not supported yet, but can easily be added + * with a new field containing a bigger recipient total amount. + * + * Note that we don't need to support the case where we'd use multiple trampoline hops in the same route: since we have + * access to the network graph, it's always more efficient to find a channel route to the last trampoline node. + */ case class ClearTrampolineRecipient(invoice: Bolt11Invoice, totalAmount: MilliSatoshi, expiry: CltvExpiry, @@ -137,7 +206,7 @@ case class ClearTrampolineRecipient(invoice: Bolt11Invoice, private def validateRoute(route: Route): Either[OutgoingPaymentError, NodeHop] = { route.finalHop_opt match { case Some(trampolineHop: NodeHop) => Right(trampolineHop) - case None => Left(MissingTrampolineHop(trampolineNodeId)) + case _ => Left(MissingTrampolineHop(trampolineNodeId)) } } @@ -147,7 +216,7 @@ case class ClearTrampolineRecipient(invoice: Bolt11Invoice, trampolineOnion <- createTrampolinePacket(paymentHash, trampolineHop) } yield { val trampolinePayload = NodePayload(trampolineHop.nodeId, FinalPayload.Standard.createTrampolinePayload(route.amount, trampolineAmount, trampolineExpiry, trampolinePaymentSecret, trampolineOnion.packet)) - Recipient.buildPayloads(route.amount, trampolineExpiry, trampolinePayload, route.hops) + Recipient.buildPayloads(route.amount, trampolineExpiry, Seq(trampolinePayload), route.hops) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala index b8de39d65a..1f91d50f34 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala @@ -22,7 +22,7 @@ import com.softwaremill.quicklens.ModifyPimp import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair._ -import fr.acinq.eclair.payment.send.{ClearRecipient, ClearTrampolineRecipient, Recipient, SpontaneousRecipient} +import fr.acinq.eclair.payment.send._ import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph.graphEdgeToHop import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} import fr.acinq.eclair.router.Graph.{InfiniteLoop, NegativeProbability, RichWeight} @@ -134,6 +134,21 @@ object RouteCalculation { val amountToSend = recipient.totalAmount - pendingAmount val maxFee = totalMaxFee - pendingChannelFee (targetNodeId, amountToSend, maxFee, Set.empty) + case recipient: BlindedRecipient => + // Blinded routes all end at a different (blinded) node, so we create graph edges in which they lead to the same node. + val targetNodeId = randomKey().publicKey + val extraEdges = recipient.extraEdges + .map(_.copy(targetNodeId = targetNodeId)) + .filterNot(e => ignoredEdges.exists(_.shortChannelId == e.shortChannelId)) + // For blinded routes, the maximum htlc field is used to indicate the maximum amount that can be sent through the route. + .map(e => GraphEdge(e).copy(balance_opt = e.htlcMaximum_opt)) + .toSet + val amountToSend = recipient.totalAmount - pendingAmount + // When we are the introduction node and includeLocalChannelCost is false, we cannot easily remove the fee for + // the first hop in the blinded route (we would need to decrypt the route and fetch the corresponding channel). + // In that case, we will slightly over-estimate the fee we're paying, but at least we won't exceed our fee budget. + val maxFee = totalMaxFee - pendingChannelFee - r.pendingPayments.map(_.blindedFee).sum + (targetNodeId, amountToSend, maxFee, extraEdges) case recipient: ClearTrampolineRecipient => // Trampoline payments require finding routes to the trampoline node, not the final recipient. // This also ensures that we correctly take the trampoline fee into account only once, even when using MPP to @@ -146,11 +161,17 @@ object RouteCalculation { } private def addFinalHop(recipient: Recipient, routes: Seq[Route]): Seq[Route] = { - routes.map(route => { + routes.flatMap(route => { recipient match { - case _: ClearRecipient => route - case _: SpontaneousRecipient => route - case recipient: ClearTrampolineRecipient => route.copy(finalHop_opt = Some(recipient.trampolineHop)) + case _: ClearRecipient => Some(route) + case _: SpontaneousRecipient => Some(route) + case recipient: ClearTrampolineRecipient => Some(route.copy(finalHop_opt = Some(recipient.trampolineHop))) + case recipient: BlindedRecipient => + route.hops.lastOption.flatMap { + hop => recipient.blindedHops.find(_.dummyId == hop.shortChannelId) + }.map { + blindedHop => Route(route.amount, route.hops.dropRight(1), Some(blindedHop)) + } } }) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index 56d3970ab7..32c43707f3 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -28,6 +28,7 @@ import fr.acinq.eclair._ import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher.{ValidateResult, WatchExternalChannelSpent, WatchExternalChannelSpentTriggered} import fr.acinq.eclair.channel._ +import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.db.NetworkDb import fr.acinq.eclair.io.Peer.PeerRoutingMessage @@ -497,6 +498,24 @@ object Router { sealed trait FinalHop extends Hop + /** + * A directed hop over a blinded route composed of multiple (blinded) channels. + * Since a blinded route has to be used from start to end, we model it as a single virtual hop. + * + * @param dummyId dummy identifier to allow indexing in maps: unlike normal scid aliases, this one doesn't exist + * in our routing tables and should be used carefully. + * @param route blinded route covered by that hop. + * @param paymentInfo payment information about the blinded route. + */ + case class BlindedHop(dummyId: Alias, route: BlindedRoute, paymentInfo: OfferTypes.PaymentInfo) extends FinalHop { + // @formatter:off + override val nodeId = route.introductionNodeId + override val nextNodeId = route.blindedNodes.last.blindedPublicKey + override val cltvExpiryDelta = paymentInfo.cltvExpiryDelta + override def fee(amount: MilliSatoshi): MilliSatoshi = paymentInfo.fee(amount) + // @formatter:on + } + /** * A directed hop between two trampoline nodes. * These nodes need not be connected and we don't need to know a route between them. @@ -568,7 +587,19 @@ object Router { */ val trampolineFee: MilliSatoshi = finalHop_opt.collect { case hop: NodeHop => hop.fee(amount) }.getOrElse(0 msat) + /** + * Fee paid for the blinded route, if any. + * Note that when we are the introduction node for the blinded route, we cannot easily compute the fee without the + * cost for the first local channel. + */ + val blindedFee: MilliSatoshi = finalHop_opt.collect { case hop: BlindedHop => hop.fee(amount) }.getOrElse(0 msat) + /** Fee paid for the channel hops towards the recipient or the source of the final hop, if any. */ + + /** + * Fee paid for the channel hops towards the recipient or the source of the final hop. + * Note that this doesn't include the fees for the final hop, if one exits. + */ def channelFee(includeLocalChannelCost: Boolean): MilliSatoshi = { val hopsToPay = if (includeLocalChannelCost) hops else hops.drop(1) val amountToSend = hopsToPay.foldRight(amount) { case (hop, amount1) => amount1 + hop.fee(amount1) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala index 91314bd229..ee2b1b6207 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala @@ -39,9 +39,8 @@ object OfferCodecs { private val absoluteExpiry: Codec[AbsoluteExpiry] = tlvField(tu64overflow.as[TimestampSecond].as[AbsoluteExpiry]) private val blindedNodeCodec: Codec[BlindedNode] = (("nodeId" | publicKey) :: ("encryptedData" | variableSizeBytes(uint16, bytes))).as[BlindedNode] - - private val pathCodec: Codec[BlindedRoute] = (("firstNodeId" | publicKey) :: ("blinding" | publicKey) :: ("path" | listOfN(uint8, blindedNodeCodec).xmap[Seq[BlindedNode]](_.toSeq, _.toList))).as[BlindedRoute] - + private val blindedNodesCodec: Codec[Seq[BlindedNode]] = listOfN(uint8, blindedNodeCodec).xmap(_.toSeq, _.toList) + private val pathCodec: Codec[BlindedRoute] = (("firstNodeId" | publicKey) :: ("blinding" | publicKey) :: ("path" | blindedNodesCodec)).as[BlindedRoute] private val paths: Codec[Paths] = tlvField(list(pathCodec).xmap[Seq[BlindedRoute]](_.toSeq, _.toList).as[Paths]) private val issuer: Codec[Issuer] = tlvField(utf8.as[Issuer]) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala index eca19fb94a..ab26f44ed8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala @@ -191,13 +191,13 @@ object PaymentOnion { import OnionPaymentPayloadTlv._ /* - * PerHopPayload - * | - * | - * +--------------+---------------+ - * | | - * | | - * IntermediatePayload FinalPayload + * PerHopPayload + * | + * | + * +------------------------------+-----------------------------+ + * | | | + * | | | + * IntermediatePayload FinalPayload OutgoingBlindedPerHopPayload * | | * | | * +---------+---------+ +------+------+ @@ -437,6 +437,33 @@ object PaymentOnion { } } + /** + * An opaque blinded payload (used when sending to a blinded route, never used to decode incoming payloads). + * We cannot use the other payload types because we cannot decrypt the recipient encrypted data, so we don't even + * know if those payloads are valid. + */ + case class OutgoingBlindedPerHopPayload(records: TlvStream[OnionPaymentPayloadTlv]) extends PerHopPayload { + require(records.get[EncryptedRecipientData].nonEmpty, "blinded per-hop payload must contain encrypted data") + } + + object OutgoingBlindedPerHopPayload { + def createIntroductionPayload(encryptedRecipientData: ByteVector, blinding: PublicKey): OutgoingBlindedPerHopPayload = { + OutgoingBlindedPerHopPayload(TlvStream(Seq(EncryptedRecipientData(encryptedRecipientData), BlindingPoint(blinding)))) + } + + def createIntermediatePayload(encryptedRecipientData: ByteVector): OutgoingBlindedPerHopPayload = { + OutgoingBlindedPerHopPayload(TlvStream(Seq(EncryptedRecipientData(encryptedRecipientData)))) + } + + def createFinalPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, encryptedRecipientData: ByteVector, customTlvs: Seq[GenericTlv] = Nil): OutgoingBlindedPerHopPayload = { + OutgoingBlindedPerHopPayload(TlvStream(Seq(AmountToForward(amount), TotalAmount(totalAmount), OutgoingCltv(expiry), EncryptedRecipientData(encryptedRecipientData)), customTlvs)) + } + + def createFinalIntroductionPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, blinding: PublicKey, encryptedRecipientData: ByteVector, customTlvs: Seq[GenericTlv] = Nil): OutgoingBlindedPerHopPayload = { + OutgoingBlindedPerHopPayload(TlvStream(Seq(AmountToForward(amount), TotalAmount(totalAmount), OutgoingCltv(expiry), EncryptedRecipientData(encryptedRecipientData), BlindingPoint(blinding)), customTlvs)) + } + } + } object PaymentOnionCodecs { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala index cbe392ef63..7fa8c9ed0e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala @@ -134,7 +134,7 @@ object RouteBlindingEncryptedDataCodecs { // @formatter:off case class RouteBlindingDecryptedData(tlvs: TlvStream[RouteBlindingEncryptedDataTlv], nextBlinding: PublicKey) - sealed trait InvalidEncryptedData + sealed trait InvalidEncryptedData { def message: String } case class CannotDecryptData(message: String) extends InvalidEncryptedData case class CannotDecodeData(message: String) extends InvalidEncryptedData // @formatter:on diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala index d06c59de45..c805309591 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala @@ -121,7 +121,7 @@ class FuzzySpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Channe // allow overpaying (no more than 2 times the required amount) val amount = requiredAmount + Random.nextInt(requiredAmount.toLong.toInt).msat val expiry = (Channel.MIN_CLTV_EXPIRY_DELTA + 1).toCltvExpiry(currentBlockHeight = BlockHeight(400000)) - val Right(payment) = OutgoingPaymentPacket.buildOutgoingPayment(self, Upstream.Local(UUID.randomUUID()), invoice.paymentHash, makeSingleHopRoute(amount, invoice.nodeId), ClearRecipient(invoice, amount, expiry, Nil)) + val Right(payment) = OutgoingPaymentPacket.buildOutgoingPayment(self, randomKey(), Upstream.Local(UUID.randomUUID()), invoice.paymentHash, makeSingleHopRoute(amount, invoice.nodeId), ClearRecipient(invoice, amount, expiry, Nil)) payment.cmd } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala index 10bcc52e1d..8266c0d2f4 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala @@ -359,7 +359,7 @@ trait ChannelStateTestsBase extends Assertions with Eventually { val paymentHash = Crypto.sha256(paymentPreimage) val expiry = cltvExpiryDelta.toCltvExpiry(currentBlockHeight) val recipient = SpontaneousRecipient(destination, amount, expiry, paymentPreimage) - val Right(payment) = OutgoingPaymentPacket.buildOutgoingPayment(replyTo, upstream, paymentHash, makeSingleHopRoute(amount, destination), recipient) + val Right(payment) = OutgoingPaymentPacket.buildOutgoingPayment(replyTo, randomKey(), upstream, paymentHash, makeSingleHopRoute(amount, destination), recipient) (paymentPreimage, payment.cmd.copy(commit = false)) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala index b473943559..b9f06744b5 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala @@ -59,7 +59,7 @@ class ShutdownStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wit // alice sends an HTLC to bob val h1 = Crypto.sha256(r1) val recipient1 = SpontaneousRecipient(TestConstants.Bob.nodeParams.nodeId, 300_000_000 msat, CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight), r1) - val Right(cmd1) = OutgoingPaymentPacket.buildOutgoingPayment(sender.ref, Upstream.Local(UUID.randomUUID), h1, makeSingleHopRoute(recipient1.totalAmount, recipient1.nodeId), recipient1).map(_.cmd.copy(commit = false)) + val Right(cmd1) = OutgoingPaymentPacket.buildOutgoingPayment(sender.ref, TestConstants.Alice.nodeParams.privateKey, Upstream.Local(UUID.randomUUID), h1, makeSingleHopRoute(recipient1.totalAmount, recipient1.nodeId), recipient1).map(_.cmd.copy(commit = false)) alice ! cmd1 sender.expectMsgType[RES_SUCCESS[CMD_ADD_HTLC]] val htlc1 = alice2bob.expectMsgType[UpdateAddHtlc] @@ -68,7 +68,7 @@ class ShutdownStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wit // alice sends another HTLC to bob val h2 = Crypto.sha256(r2) val recipient2 = SpontaneousRecipient(TestConstants.Bob.nodeParams.nodeId, 200_000_000 msat, CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight), r2) - val Right(cmd2) = OutgoingPaymentPacket.buildOutgoingPayment(sender.ref, Upstream.Local(UUID.randomUUID), h2, makeSingleHopRoute(recipient2.totalAmount, recipient2.nodeId), recipient2).map(_.cmd.copy(commit = false)) + val Right(cmd2) = OutgoingPaymentPacket.buildOutgoingPayment(sender.ref, TestConstants.Alice.nodeParams.privateKey, Upstream.Local(UUID.randomUUID), h2, makeSingleHopRoute(recipient2.totalAmount, recipient2.nodeId), recipient2).map(_.cmd.copy(commit = false)) alice ! cmd2 sender.expectMsgType[RES_SUCCESS[CMD_ADD_HTLC]] val htlc2 = alice2bob.expectMsgType[UpdateAddHtlc] diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/json/JsonSerializersSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/json/JsonSerializersSpec.scala index ed6eb2c56c..39d2ef757d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/json/JsonSerializersSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/json/JsonSerializersSpec.scala @@ -191,9 +191,9 @@ class JsonSerializersSpec extends AnyFunSuite with Matchers { } test("Bolt 12 invoice") { - val ref = "lni1qvsyxjtl6luzd9t3pr62xr7eemp6awnejusgf6gw45q75vcfqqqqqqqyyz9ut9uduhtztjgpxm06394g5qkw7v79g4czw6zxsl3lnrsvljj0qzqrq83yqzscd9h8vmmfvdjjqamfw35zqmtpdeujqenfv4kxgucvqsqsqqgqzzsq83r88cmqur4u9xxgmykeuhmf4stsrsl8h9keu8q9pf5wsq69zaauq2vu8z4tg6zltltxekud2a59r9l8j7uy4ydnfe752tqz7a8g03fcvqgzw2q6uut0ldazm8squqsmq3gls2ut7gzcpfty8zh0d5vmf30dlyssqwh4qdsxe8n3d9d38e8g85r966yjqr7xuljey2yxar05x3mx46qz8qhef7n5hg0fxx7phe7jtag5evuj28rsas80gglufwf3y8qqqqyjjqqqqt7sz3qqqqqqqqqqq05qqqqqqqqqrcjqqqqpgptpd35kxeg7yypugee7xc8qa0pf3jxe9k0976dvzuqu8eaedk0pcpg2dr5qx3gh00pqqyujvgy04twhrv0h3vtzvhjmqcde62ug3ygp9hr66wrzdm425238zc26v5nswjf8d5syymmz9qzx9gqxhc4zq5v4r4x98jgyqd0sk2fae803crnevusngv9wq7jl8cf5e5eny56p9spquypwqgq8kvq5quzqqpqj84t0szcxqqywkwkudz29aaspxgx2n6mw2fh2ckwdnwylkgpcypc66qe7tapqdq39vzrhp7nkwfg9gj2ztk658g0ecgalqdjh4gmuruzqqsv99asjssw802xu2ls93fzzl64c5jehhxvp0m7au9klpg4pnedtwl0zgw648gwtalak3hr5uxxl5qzp7txlud9y7qfk038cq4pusj8aqymsxqgzq07szwgq" + val ref = "lni1qvsxlc5vp2m0rvmjcxn2y34wv0m5lyc7sdj7zksgn35dvxgqqqqqqqqyypmpsc7ww3cxguwl27ela95ykset7t8tlvyfy7a200eujcnhczws6zqyrvhqd5s2p4kkjmnfd4skcgr0venx2ussmvpexwyy4tcadvgg89l9aljus6709kx235hhqrk6n8dey98uyuftzdqrshx2mcmnvj7gxa709vhgcrqr7hcdp7l7x9t2au8dj8tjreqyfvuqyqkp4czrrxpn3hdrdqu8k3teynrl4nf977deq2ja53zkefmsr0nyjgqp3dtytq67durd2jjupmnjdtvvlsuw7lsm6tvcyrrqx7pwaunazjuqn2uk9gvpzj0aeagku2h2wv8vcfmfekflxmfmgu0kqqa94fuhuhyn6jxefk7k5rfgejltw7cwa4fhjl67trweukvsw4wqq7ec790vdjar7qwtsnlfrq84fmq02ah49fvnnsj06j5wzgwqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqm9crdyqqqrcss83y2e9lqnu7tht4ntvp24fksw26hwf5yrg6dyk2jz472efs2rjh4ycsra98j4l2k35fg7qhvapz26js2rh0j5n36pzlt9kaprvl3zd9s29egq33khv3m9gszev88kpfrveu8g5xr8khk6tev8jmpxg3pxfhpcx6f4jtlm4ltwgpwqgqpduzqyec5qspkg0zqq788krw2w2kstvsz3dekms304ykkh395zl6chm34vdu03yuvgwzm0580zu6sp2f07uwa4crgvkgucd8zdpt5vu302nc" val pr = Invoice.fromString(ref).get - JsonSerializers.serialization.write(pr)(JsonSerializers.formats) shouldBe """{"amount":123456,"nodeId":"03c4673e360e0ebc298c8d92d9e5f69ac1701c3e7b96d9e1c050a68e80345177bc","paymentHash":"51951d4c53c904035f0b293dc9df1c0e7967213430ae07a5f3e134cd33325341","description":"invoice with many fields","features":{"activated":{"var_onion_optin":"mandatory","option_route_blinding":"mandatory"},"unknown":[]},"blindedPaths":[{"introductionNodeId":"03c4673e360e0ebc298c8d92d9e5f69ac1701c3e7b96d9e1c050a68e80345177bc","blindedNodeIds":["027281ae716ffb7a2d9e00e021b0451f82b8bf20580a56438aef6d19b4c5edf921"]}],"createdAt":1654654654,"expiresAt":1654658254,"serialized":"lni1qvsyxjtl6luzd9t3pr62xr7eemp6awnejusgf6gw45q75vcfqqqqqqqyyz9ut9uduhtztjgpxm06394g5qkw7v79g4czw6zxsl3lnrsvljj0qzqrq83yqzscd9h8vmmfvdjjqamfw35zqmtpdeujqenfv4kxgucvqsqsqqgqzzsq83r88cmqur4u9xxgmykeuhmf4stsrsl8h9keu8q9pf5wsq69zaauq2vu8z4tg6zltltxekud2a59r9l8j7uy4ydnfe752tqz7a8g03fcvqgzw2q6uut0ldazm8squqsmq3gls2ut7gzcpfty8zh0d5vmf30dlyssqwh4qdsxe8n3d9d38e8g85r966yjqr7xuljey2yxar05x3mx46qz8qhef7n5hg0fxx7phe7jtag5evuj28rsas80gglufwf3y8qqqqyjjqqqqt7sz3qqqqqqqqqqq05qqqqqqqqqrcjqqqqpgptpd35kxeg7yypugee7xc8qa0pf3jxe9k0976dvzuqu8eaedk0pcpg2dr5qx3gh00pqqyujvgy04twhrv0h3vtzvhjmqcde62ug3ygp9hr66wrzdm425238zc26v5nswjf8d5syymmz9qzx9gqxhc4zq5v4r4x98jgyqd0sk2fae803crnevusngv9wq7jl8cf5e5eny56p9spquypwqgq8kvq5quzqqpqj84t0szcxqqywkwkudz29aaspxgx2n6mw2fh2ckwdnwylkgpcypc66qe7tapqdq39vzrhp7nkwfg9gj2ztk658g0ecgalqdjh4gmuruzqqsv99asjssw802xu2ls93fzzl64c5jehhxvp0m7au9klpg4pnedtwl0zgw648gwtalak3hr5uxxl5qzp7txlud9y7qfk038cq4pusj8aqymsxqgzq07szwgq"}""" + JsonSerializers.serialization.write(pr)(JsonSerializers.formats) shouldBe """{"amount":456001234,"nodeId":"03c48ac97e09f3cbbaeb35b02aaa6d072b57726841a34d25952157caca60a1caf5","paymentHash":"2cb0e7b052366787450c33daf6d2f2c3cb6132221326e1c1b49ac97fdd7eb720","description":"minimal offer","features":{"activated":{},"unknown":[]},"blindedPaths":[{"introductionNodeId":"03933884aaf1d6b108397e5efe5c86bcf2d8ca8d2f700eda99db9214fc2712b134","blindedNodeIds":["02c1ae043198338dda368387b457924c7facd25f79b902a5da4456ca7701be6492","03782eef27d14b809ab962a181149fdcf516e2aea730ecc2769cd93f36d3b471f6"]}],"createdAt":1668002363,"expiresAt":1668009563,"serialized":"lni1qvsxlc5vp2m0rvmjcxn2y34wv0m5lyc7sdj7zksgn35dvxgqqqqqqqqyypmpsc7ww3cxguwl27ela95ykset7t8tlvyfy7a200eujcnhczws6zqyrvhqd5s2p4kkjmnfd4skcgr0venx2ussmvpexwyy4tcadvgg89l9aljus6709kx235hhqrk6n8dey98uyuftzdqrshx2mcmnvj7gxa709vhgcrqr7hcdp7l7x9t2au8dj8tjreqyfvuqyqkp4czrrxpn3hdrdqu8k3teynrl4nf977deq2ja53zkefmsr0nyjgqp3dtytq67durd2jjupmnjdtvvlsuw7lsm6tvcyrrqx7pwaunazjuqn2uk9gvpzj0aeagku2h2wv8vcfmfekflxmfmgu0kqqa94fuhuhyn6jxefk7k5rfgejltw7cwa4fhjl67trweukvsw4wqq7ec790vdjar7qwtsnlfrq84fmq02ah49fvnnsj06j5wzgwqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqm9crdyqqqrcss83y2e9lqnu7tht4ntvp24fksw26hwf5yrg6dyk2jz472efs2rjh4ycsra98j4l2k35fg7qhvapz26js2rh0j5n36pzlt9kaprvl3zd9s29egq33khv3m9gszev88kpfrveu8g5xr8khk6tev8jmpxg3pxfhpcx6f4jtlm4ltwgpwqgqpduzqyec5qspkg0zqq788krw2w2kstvsz3dekms304ykkh395zl6chm34vdu03yuvgwzm0580zu6sp2f07uwa4crgvkgucd8zdpt5vu302nc"}""" } test("GlobalBalance serializer") { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala index d19ef2b605..030349f3ee 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala @@ -225,13 +225,9 @@ class OnionMessagesSpec extends AnyFunSuite { val carol = randomKey() val sessionKey = randomKey() val blindingSecret = randomKey() - val pathId = randomBytes(65201) - val Success((_, messageForAlice)) = buildMessage(sessionKey, blindingSecret, IntermediateNode(alice.publicKey) :: IntermediateNode(bob.publicKey) :: Nil, Recipient(carol.publicKey, Some(pathId)), Nil) - println(messageForAlice.onionRoutingPacket.payload.length) - // Checking that the onion is relayed properly process(alice, messageForAlice) match { case SendMessage(nextNodeId, onionForBob) => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala index b7c811d4b7..3b12de42e8 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala @@ -22,13 +22,12 @@ import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, Crypto} import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.Features.{BasicMultiPartPayment, VariableLengthOnion} import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute import fr.acinq.eclair.payment.Bolt12Invoice.{hrp, signatureTag} import fr.acinq.eclair.wire.protocol.OfferCodecs.{invoiceRequestTlvCodec, invoiceTlvCodec} import fr.acinq.eclair.wire.protocol.OfferTypes._ import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec -import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{AllowedFeatures, PaymentConstraints} -import fr.acinq.eclair.wire.protocol.{GenericTlv, OfferTypes, RouteBlindingEncryptedDataTlv, TlvStream} +import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{AllowedFeatures, PathId, PaymentConstraints} +import fr.acinq.eclair.wire.protocol.{GenericTlv, OfferTypes, TlvStream} import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, FeatureSupport, Features, MilliSatoshiLong, TimestampSecond, TimestampSecondLong, UInt64, randomBytes32, randomBytes64, randomKey} import org.scalatest.funsuite.AnyFunSuite import scodec.bits._ @@ -51,7 +50,7 @@ class Bolt12InvoiceSpec extends AnyFunSuite { } def createPaymentBlindedRoute(nodeId: PublicKey, sessionKey: PrivateKey = randomKey(), pathId: ByteVector = randomBytes32()): PaymentBlindedRoute = { - val selfPayload = blindedRouteDataCodec.encode(TlvStream(Seq(RouteBlindingEncryptedDataTlv.PathId(pathId), PaymentConstraints(CltvExpiry(1234567), 0 msat), AllowedFeatures(Features.empty)))).require.bytes + val selfPayload = blindedRouteDataCodec.encode(TlvStream(Seq(PathId(pathId), PaymentConstraints(CltvExpiry(1234567), 0 msat), AllowedFeatures(Features.empty)))).require.bytes PaymentBlindedRoute(Sphinx.RouteBlinding.create(sessionKey, Seq(nodeId), Seq(selfPayload)).route, PaymentInfo(1 msat, 2, CltvExpiryDelta(3), 4 msat, 5 msat, Features.empty)) } @@ -274,18 +273,18 @@ class Bolt12InvoiceSpec extends AnyFunSuite { PaymentHash(randomBytes32()), ) // This minimal invoice is valid. - val signed = signInvoiceTlvs(TlvStream[InvoiceTlv](tlvs), nodeKey) + val signed = signInvoiceTlvs(TlvStream(tlvs), nodeKey) val signedEncoded = Bech32.encodeBytes(hrp, invoiceTlvCodec.encode(signed).require.bytes.toArray, Bech32.Encoding.Beck32WithoutChecksum) assert(Bolt12Invoice.fromString(signedEncoded).isSuccess) // But removing any TLV makes it invalid. for (tlv <- tlvs) { val incomplete = tlvs.filterNot(_ == tlv) - val incompleteSigned = signInvoiceTlvs(TlvStream[InvoiceTlv](incomplete), nodeKey) + val incompleteSigned = signInvoiceTlvs(TlvStream(incomplete), nodeKey) val incompleteSignedEncoded = Bech32.encodeBytes(hrp, invoiceTlvCodec.encode(incompleteSigned).require.bytes.toArray, Bech32.Encoding.Beck32WithoutChecksum) assert(Bolt12Invoice.fromString(incompleteSignedEncoded).isFailure) } // Missing signature is also invalid. - val unsignedEncoded = Bech32.encodeBytes(hrp, invoiceTlvCodec.encode(TlvStream[InvoiceTlv](tlvs)).require.bytes.toArray, Bech32.Encoding.Beck32WithoutChecksum) + val unsignedEncoded = Bech32.encodeBytes(hrp, invoiceTlvCodec.encode(TlvStream(tlvs)).require.bytes.toArray, Bech32.Encoding.Beck32WithoutChecksum) assert(Bolt12Invoice.fromString(unsignedEncoded).isFailure) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala index 1ea2063575..ae9707958f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala @@ -32,7 +32,7 @@ import fr.acinq.eclair.payment.send.PaymentError.RetryExhausted import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToRoute import fr.acinq.eclair.payment.send._ -import fr.acinq.eclair.router.BaseRouterSpec.channelHopFromUpdate +import fr.acinq.eclair.router.BaseRouterSpec.{blindedRouteFromHops, channelHopFromUpdate} import fr.acinq.eclair.router.Graph.WeightRatios import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.router.{Announcements, RouteNotFound} @@ -178,6 +178,33 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS metricsListener.expectNoMessage() } + test("successful first attempt (blinded)") { f => + import f._ + + assert(payFsm.stateName == WAIT_FOR_PAYMENT_REQUEST) + val (_, hop_be, recipient) = blindedRouteFromHops(finalAmount, expiry, Seq(channelHopFromUpdate(b, e, channelUpdate_be)), blindedRouteExpiry, paymentPreimage) + val payment = SendMultiPartPayment(sender.ref, recipient, 1, routeParams) + sender.send(payFsm, payment) + + router.expectMsg(RouteRequest(nodeParams.nodeId, recipient, routeParams.copy(randomize = false), allowMultiPart = true, paymentContext = Some(cfg.paymentContext))) + assert(payFsm.stateName == WAIT_FOR_ROUTES) + + val routes = Seq( + Route(600_000 msat, Seq(hop_ab_1), Some(hop_be)), + Route(400_000 msat, Seq(hop_ab_2), Some(hop_be)), + ) + router.send(payFsm, RouteResponse(routes)) + val childPayments = childPayFsm.expectMsgType[SendPaymentToRoute] :: childPayFsm.expectMsgType[SendPaymentToRoute] :: Nil + assert(childPayments.map(_.route).toSet == routes.map(r => Right(r)).toSet) + childPayments.foreach(childPayment => assert(childPayment.recipient == recipient)) + assert(childPayments.map(_.amount).toSet == Set(400_000 msat, 600_000 msat)) + assert(payFsm.stateName == PAYMENT_IN_PROGRESS) + + val result = fulfillPendingPayments(f, 2, recipient.nodeId, finalAmount) + assert(result.amountWithFees == 1_000_200.msat) + assert(result.nonTrampolineFees == 200.msat) + } + test("successful retry") { f => import f._ @@ -190,7 +217,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.expectNoMessage(100 millis) val childId = payFsm.stateData.asInstanceOf[PaymentProgress].pending.keys.head - childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(failingRoute.amount, failingRoute.hops, Sphinx.DecryptedFailurePacket(b, PermanentChannelFailure))))) + childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(failingRoute.amount, failingRoute.fullRoute, Sphinx.DecryptedFailurePacket(b, PermanentChannelFailure))))) // We retry ignoring the failing channel. router.expectMsg(RouteRequest(nodeParams.nodeId, clearRecipient, routeParams.copy(randomize = true), allowMultiPart = true, ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_be, b, e))), paymentContext = Some(cfg.paymentContext))) router.send(payFsm, RouteResponse(Seq(Route(400_000 msat, hop_ac_1 :: hop_ce :: Nil, None), Route(600_000 msat, hop_ad :: hop_de :: Nil, None)))) @@ -264,7 +291,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.expectNoMessage(100 millis) val (failedId, failedRoute) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toSeq - childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(LocalFailure(failedRoute.amount, failedRoute.hops, RemoteCannotAffordFeesForNewHtlc(randomBytes32(), finalAmount, 15 sat, 0 sat, 15 sat))))) + childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(LocalFailure(failedRoute.amount, failedRoute.fullRoute, RemoteCannotAffordFeesForNewHtlc(randomBytes32(), finalAmount, 15 sat, 0 sat, 15 sat))))) // We retry without the failing channel. router.expectMsg(RouteRequest( @@ -290,7 +317,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.expectNoMessage(100 millis) val (failedId, failedRoute) :: (_, pendingRoute) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toSeq - childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(LocalFailure(failedRoute.amount, failedRoute.hops, ChannelUnavailable(randomBytes32()))))) + childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(LocalFailure(failedRoute.amount, failedRoute.fullRoute, ChannelUnavailable(randomBytes32()))))) // If the router doesn't find routes, we will retry without ignoring the channel: it may work with a different split // of the amount to send. @@ -337,7 +364,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS // B changed his fees and expiry after the invoice was issued. val channelUpdate = channelUpdate_be.copy(feeBaseMsat = 250 msat, feeProportionalMillionths = 150, cltvExpiryDelta = CltvExpiryDelta(24)) val childId = payFsm.stateData.asInstanceOf[PaymentProgress].pending.keys.head - childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(route.amount, route.hops, Sphinx.DecryptedFailurePacket(b, FeeInsufficient(finalAmount, channelUpdate)))))) + childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(route.amount, route.fullRoute, Sphinx.DecryptedFailurePacket(b, FeeInsufficient(finalAmount, channelUpdate)))))) // We update the routing hints accordingly before requesting a new route. val extraEdge1 = extraEdge.copy(feeBase = 250 msat, feeProportionalMillionths = 150, cltvExpiryDelta = CltvExpiryDelta(24)) assert(router.expectMsgType[RouteRequest].target.extraEdges == Seq(extraEdge1)) @@ -361,13 +388,33 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS // NB: we need a channel update with a valid signature, otherwise we'll ignore the node instead of this specific channel. val channelUpdate = Announcements.makeChannelUpdate(channelUpdate_be.chainHash, priv_b, e, channelUpdate_be.shortChannelId, channelUpdate_be.cltvExpiryDelta, channelUpdate_be.htlcMinimumMsat, channelUpdate_be.feeBaseMsat, channelUpdate_be.feeProportionalMillionths, channelUpdate_be.htlcMaximumMsat) val childId = payFsm.stateData.asInstanceOf[PaymentProgress].pending.keys.head - childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(route.amount, route.hops, Sphinx.DecryptedFailurePacket(b, TemporaryChannelFailure(channelUpdate)))))) + childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(route.amount, route.fullRoute, Sphinx.DecryptedFailurePacket(b, TemporaryChannelFailure(channelUpdate)))))) // We update the routing hints accordingly before requesting a new route and ignore the channel. val routeRequest = router.expectMsgType[RouteRequest] assert(routeRequest.target.extraEdges == Seq(extraEdge)) assert(routeRequest.ignore.channels.map(_.shortChannelId) == Set(channelUpdate.shortChannelId)) } + test("retry with ignored blinded route") { f => + import f._ + + val (_, hop_be, recipient) = blindedRouteFromHops(finalAmount, expiry, Seq(channelHopFromUpdate(b, e, channelUpdate_be)), blindedRouteExpiry, paymentPreimage) + val payment = SendMultiPartPayment(sender.ref, recipient, 3, routeParams) + sender.send(payFsm, payment) + assert(router.expectMsgType[RouteRequest].target.extraEdges == recipient.extraEdges) + val route = Route(finalAmount, Seq(hop_ab_1), Some(hop_be)) + router.send(payFsm, RouteResponse(Seq(route))) + childPayFsm.expectMsgType[SendPaymentToRoute] + childPayFsm.expectNoMessage(100 millis) + + // The blinded route fails to relay the payment. + val childId = payFsm.stateData.asInstanceOf[PaymentProgress].pending.keys.head + childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(route.amount, route.fullRoute, Sphinx.DecryptedFailurePacket(b, InvalidOnionBlinding(randomBytes32())))))) + // We retry and ignore that blinded route. + val routeRequest = router.expectMsgType[RouteRequest] + assert(routeRequest.ignore.channels.map(_.shortChannelId) == Set(hop_be.dummyId)) + } + test("update routing hints") { () => val recipient = ClearRecipient(e, Features.empty, finalAmount, expiry, randomBytes32(), Seq( ExtraEdge(a, b, ShortChannelId(1), 10 msat, 0, CltvExpiryDelta(12), 1 msat, None), @@ -479,7 +526,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.expectMsgType[SendPaymentToRoute] val (failedId, failedRoute) = payFsm.stateData.asInstanceOf[PaymentProgress].pending.head - val result = abortAfterFailure(f, PaymentFailed(failedId, paymentHash, Seq(RemoteFailure(failedRoute.amount, failedRoute.hops, Sphinx.DecryptedFailurePacket(e, IncorrectOrUnknownPaymentDetails(600_000 msat, BlockHeight(0))))))) + val result = abortAfterFailure(f, PaymentFailed(failedId, paymentHash, Seq(RemoteFailure(failedRoute.amount, failedRoute.fullRoute, Sphinx.DecryptedFailurePacket(e, IncorrectOrUnknownPaymentDetails(600_000 msat, BlockHeight(0))))))) assert(result.failures.length == 1) val metrics = metricsListener.expectMsgType[PathFindingExperimentMetrics] @@ -500,7 +547,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.expectMsgType[SendPaymentToRoute] val (failedId, failedRoute) = payFsm.stateData.asInstanceOf[PaymentProgress].pending.head - val result = abortAfterFailure(f, PaymentFailed(failedId, paymentHash, Seq(LocalFailure(failedRoute.amount, failedRoute.hops, HtlcsTimedoutDownstream(channelId = ByteVector32.One, htlcs = Set.empty))))) + val result = abortAfterFailure(f, PaymentFailed(failedId, paymentHash, Seq(LocalFailure(failedRoute.amount, failedRoute.fullRoute, HtlcsTimedoutDownstream(channelId = ByteVector32.One, htlcs = Set.empty))))) assert(result.failures.length == 1) } @@ -533,7 +580,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.expectMsgType[SendPaymentToRoute] val (failedId, failedRoute) :: (successId, successRoute) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toSeq - childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(UnreadableRemoteFailure(failedRoute.amount, failedRoute.hops)))) + childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(UnreadableRemoteFailure(failedRoute.amount, failedRoute.fullRoute)))) router.expectMsgType[RouteRequest] val result = fulfillPendingPayments(f, 1, e, finalAmount) @@ -553,11 +600,11 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.expectMsgType[SendPaymentToRoute] val (failedId, failedRoute) :: (successId, successRoute) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toSeq - childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(RemoteFailure(failedRoute.amount, failedRoute.hops, Sphinx.DecryptedFailurePacket(e, PaymentTimeout))))) + childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(RemoteFailure(failedRoute.amount, failedRoute.fullRoute, Sphinx.DecryptedFailurePacket(e, PaymentTimeout))))) awaitCond(payFsm.stateName == PAYMENT_ABORTED) sender.watch(payFsm) - childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(PaymentSent.PartialPayment(successId, successRoute.amount, successRoute.channelFee(false), randomBytes32(), Some(successRoute.hops))))) + childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(PaymentSent.PartialPayment(successId, successRoute.amount, successRoute.channelFee(false), randomBytes32(), Some(successRoute.fullRoute))))) sender.expectMsg(PreimageReceived(paymentHash, paymentPreimage)) val result = sender.expectMsgType[PaymentSent] assert(result.id == cfg.id) @@ -586,12 +633,12 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.expectMsgType[SendPaymentToRoute] val (childId, route) :: (failedId, failedRoute) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toSeq - childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(PaymentSent.PartialPayment(childId, route.amount, route.channelFee(false), randomBytes32(), Some(route.hops))))) + childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(PaymentSent.PartialPayment(childId, route.amount, route.channelFee(false), randomBytes32(), Some(route.fullRoute))))) sender.expectMsg(PreimageReceived(paymentHash, paymentPreimage)) awaitCond(payFsm.stateName == PAYMENT_SUCCEEDED) sender.watch(payFsm) - childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(RemoteFailure(failedRoute.amount, failedRoute.hops, Sphinx.DecryptedFailurePacket(e, PaymentTimeout))))) + childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(RemoteFailure(failedRoute.amount, failedRoute.fullRoute, Sphinx.DecryptedFailurePacket(e, PaymentTimeout))))) val result = sender.expectMsgType[PaymentSent] assert(result.parts.length == 1 && result.parts.head.id == childId) assert(result.amountWithFees < finalAmount) // we got the preimage without paying the full amount @@ -611,7 +658,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS assert(pending.size == childCount) val partialPayments = pending.map { - case (childId, route) => PaymentSent.PartialPayment(childId, route.amount, route.channelFee(false), randomBytes32(), Some(route.hops)) + case (childId, route) => PaymentSent.PartialPayment(childId, route.amount, route.channelFee(false) + route.blindedFee, randomBytes32(), Some(route.fullRoute)) } partialPayments.foreach(pp => childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(pp)))) sender.expectMsg(PreimageReceived(paymentHash, paymentPreimage)) @@ -666,6 +713,7 @@ object MultiPartPaymentLifecycleSpec { val paymentPreimage = randomBytes32() val paymentHash = Crypto.sha256(paymentPreimage) val expiry = CltvExpiry(1105) + val blindedRouteExpiry = CltvExpiry(500_000) val finalAmount = 1_000_000 msat val routeParams = PathFindingConf( randomize = false, diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala index 7627e47c6d..3c08a631ac 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala @@ -32,8 +32,9 @@ import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.SendMultiPartPayme import fr.acinq.eclair.payment.send.PaymentError.UnsupportedFeatures import fr.acinq.eclair.payment.send.PaymentInitiator._ import fr.acinq.eclair.payment.send._ -import fr.acinq.eclair.router.RouteNotFound import fr.acinq.eclair.router.Router._ +import fr.acinq.eclair.router.{BlindedRouteCreation, RouteNotFound} +import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer} import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, InvoiceFeature, MilliSatoshiLong, NodeParams, PaymentFinalExpiryConf, TestConstants, TestKitBaseClass, TimestampSecond, UnknownFeature, randomBytes32, randomKey} import org.scalatest.funsuite.FixtureAnyFunSuiteLike @@ -51,6 +52,7 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike object Tags { val DisableMPP = "mpp_disabled" + val DisableRouteBlinding = "route_blinding_disabled" val RandomizeFinalExpiry = "random_final_expiry" } @@ -58,22 +60,31 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val featuresWithoutMpp: Features[InvoiceFeature] = Features( VariableLengthOnion -> Mandatory, - PaymentSecret -> Mandatory + PaymentSecret -> Mandatory, + RouteBlinding -> Optional, ) val featuresWithMpp: Features[InvoiceFeature] = Features( VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, + RouteBlinding -> Optional, ) val featuresWithTrampoline: Features[InvoiceFeature] = Features( VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, + RouteBlinding -> Optional, TrampolinePaymentPrototype -> Optional, ) + val featuresWithoutRouteBlinding: Features[InvoiceFeature] = Features( + VariableLengthOnion -> Mandatory, + PaymentSecret -> Mandatory, + BasicMultiPartPayment -> Optional, + ) + case class FakePaymentFactory(payFsm: TestProbe, multiPartPayFsm: TestProbe) extends PaymentInitiator.MultiPartPaymentFactory { // @formatter:off override def spawnOutgoingPayment(context: ActorContext, cfg: SendPaymentConfig): ActorRef = { @@ -88,7 +99,13 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike } override def withFixture(test: OneArgTest): Outcome = { - val features = if (test.tags.contains(Tags.DisableMPP)) featuresWithoutMpp else featuresWithMpp + val features = if (test.tags.contains(Tags.DisableMPP)) { + featuresWithoutMpp + } else if (test.tags.contains(Tags.DisableRouteBlinding)) { + featuresWithoutRouteBlinding + } else { + featuresWithMpp + } val paymentFinalExpiry = if (test.tags.contains(Tags.RandomizeFinalExpiry)) { PaymentFinalExpiryConf(CltvExpiryDelta(50), CltvExpiryDelta(200)) } else { @@ -276,6 +293,82 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike sender.expectMsg(NoPendingPayment(Right(invoice.paymentHash))) } + def createBolt12Invoice(features: Features[InvoiceFeature]): Bolt12Invoice = { + val offer = Offer(None, "Bolt12 r0cks", e, features, Block.RegtestGenesisBlock.hash) + val invoiceRequest = InvoiceRequest(offer, finalAmount, 1, features, randomKey(), Block.RegtestGenesisBlock.hash) + val blindedRoute = BlindedRouteCreation.createBlindedRouteWithoutHops(e, hex"2a2a2a2a", 1 msat, CltvExpiry(500_000)).route + val paymentInfo = OfferTypes.PaymentInfo(1_000 msat, 0, CltvExpiryDelta(24), 0 msat, finalAmount, Features.empty) + Bolt12Invoice(offer, invoiceRequest, paymentPreimage, priv_e.privateKey, CltvExpiryDelta(6), features, Seq(PaymentBlindedRoute(blindedRoute, paymentInfo))) + } + + test("forward single-part blinded payment") { f => + import f._ + val invoice = createBolt12Invoice(Features(VariableLengthOnion -> Mandatory, RouteBlinding -> Mandatory)) + val req = SendPaymentToNode(finalAmount, invoice, 1, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams) + sender.send(initiator, req) + val id = sender.expectMsgType[UUID] + payFsm.expectMsg(SendPaymentConfig(id, id, None, paymentHash, invoice.nodeId, Upstream.Local(id), Some(invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = true)) + val payment = payFsm.expectMsgType[PaymentLifecycle.SendPaymentToNode] + assert(payment.amount == finalAmount) + assert(payment.recipient.nodeId == invoice.nodeId) + assert(payment.recipient.totalAmount == finalAmount) + assert(payment.recipient.extraEdges.nonEmpty) + assert(payment.recipient.expiry == req.finalExpiry(nodeParams)) + assert(payment.recipient.isInstanceOf[BlindedRecipient]) + + sender.send(initiator, GetPayment(Left(id))) + sender.expectMsg(PaymentIsPending(id, invoice.paymentHash, PendingPaymentToNode(sender.ref, req))) + sender.send(initiator, GetPayment(Right(invoice.paymentHash))) + sender.expectMsg(PaymentIsPending(id, invoice.paymentHash, PendingPaymentToNode(sender.ref, req))) + + val pf = PaymentFailed(id, invoice.paymentHash, Nil) + payFsm.send(initiator, pf) + sender.expectMsg(pf) + eventListener.expectNoMessage(100 millis) + + sender.send(initiator, GetPayment(Left(id))) + sender.expectMsg(NoPendingPayment(Left(id))) + } + + test("forward multi-part blinded payment") { f => + import f._ + val invoice = createBolt12Invoice(Features(VariableLengthOnion -> Mandatory, BasicMultiPartPayment -> Optional, RouteBlinding -> Mandatory)) + val req = SendPaymentToNode(finalAmount, invoice, 1, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams) + sender.send(initiator, req) + val id = sender.expectMsgType[UUID] + multiPartPayFsm.expectMsg(SendPaymentConfig(id, id, None, paymentHash, invoice.nodeId, Upstream.Local(id), Some(invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = true)) + val payment = multiPartPayFsm.expectMsgType[SendMultiPartPayment] + assert(payment.recipient.nodeId == invoice.nodeId) + assert(payment.recipient.totalAmount == finalAmount) + assert(payment.recipient.extraEdges.nonEmpty) + assert(payment.recipient.expiry == req.finalExpiry(nodeParams)) + assert(payment.recipient.isInstanceOf[BlindedRecipient]) + + sender.send(initiator, GetPayment(Left(id))) + sender.expectMsg(PaymentIsPending(id, invoice.paymentHash, PendingPaymentToNode(sender.ref, req))) + sender.send(initiator, GetPayment(Right(invoice.paymentHash))) + sender.expectMsg(PaymentIsPending(id, invoice.paymentHash, PendingPaymentToNode(sender.ref, req))) + + val ps = PaymentSent(id, invoice.paymentHash, paymentPreimage, finalAmount, invoice.nodeId, Seq(PartialPayment(UUID.randomUUID(), finalAmount, 0 msat, randomBytes32(), None))) + payFsm.send(initiator, ps) + sender.expectMsg(ps) + eventListener.expectNoMessage(100 millis) + + sender.send(initiator, GetPayment(Left(id))) + sender.expectMsg(NoPendingPayment(Left(id))) + } + + test("reject blinded payment when route blinding deactivated", Tag(Tags.DisableRouteBlinding)) { f => + import f._ + val invoice = createBolt12Invoice(Features(VariableLengthOnion -> Mandatory, BasicMultiPartPayment -> Optional, RouteBlinding -> Mandatory)) + val req = SendPaymentToNode(finalAmount, invoice, 1, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams) + sender.send(initiator, req) + val id = sender.expectMsgType[UUID] + val fail = sender.expectMsgType[PaymentFailed] + assert(fail.id == id) + assert(fail.failures == LocalFailure(finalAmount, Nil, UnsupportedFeatures(invoice.features)) :: Nil) + } + test("forward trampoline payment") { f => import f._ val ignoredRoutingHints = List(List(ExtraHop(b, channelUpdate_bc.shortChannelId, feeBase = 10 msat, feeProportionalMillionths = 1, cltvExpiryDelta = CltvExpiryDelta(12)))) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index e97cf86ff5..ede7440e0a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -39,7 +39,7 @@ import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle._ import fr.acinq.eclair.payment.send.{ClearRecipient, PaymentLifecycle} import fr.acinq.eclair.router.Announcements.makeChannelUpdate -import fr.acinq.eclair.router.BaseRouterSpec.{channelAnnouncement, channelHopFromUpdate} +import fr.acinq.eclair.router.BaseRouterSpec.{blindedRouteFromHops, channelAnnouncement, channelHopFromUpdate} import fr.acinq.eclair.router.Graph.WeightRatios import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.router._ @@ -58,6 +58,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val defaultAmountMsat = 142_000_000 msat val defaultExpiry = Channel.MIN_CLTV_EXPIRY_DELTA.toCltvExpiry(BlockHeight(40_000)) + val defaultRouteExpiry = CltvExpiry(100_000) val defaultPaymentPreimage = randomBytes32() val defaultPaymentHash = Crypto.sha256(defaultPaymentPreimage) val defaultOrigin = Origin.LocalCold(UUID.randomUUID()) @@ -78,10 +79,10 @@ class PaymentLifecycleSpec extends BaseRouterSpec { eventListener: TestProbe, metricsListener: TestProbe) - def createPaymentLifecycle(storeInDb: Boolean = true, publishEvent: Boolean = true, recordMetrics: Boolean = true): PaymentFixture = { + def createPaymentLifecycle(invoice: Invoice, storeInDb: Boolean = true, publishEvent: Boolean = true, recordMetrics: Boolean = true): PaymentFixture = { val (id, parentId) = (UUID.randomUUID(), UUID.randomUUID()) val nodeParams = TestConstants.Alice.nodeParams.copy(nodeKeyManager = testNodeKeyManager, channelKeyManager = testChannelKeyManager) - val cfg = SendPaymentConfig(id, parentId, Some(defaultExternalId), defaultPaymentHash, d, Upstream.Local(id), Some(defaultInvoice), storeInDb, publishEvent, recordMetrics) + val cfg = SendPaymentConfig(id, parentId, Some(defaultExternalId), defaultPaymentHash, invoice.nodeId, Upstream.Local(id), Some(invoice), storeInDb, publishEvent, recordMetrics) val (routerForwarder, register, sender, monitor, eventListener, metricsListener) = (TestProbe(), TestProbe(), TestProbe(), TestProbe(), TestProbe(), TestProbe()) val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, cfg, routerForwarder.ref, register.ref)) paymentFSM ! SubscribeTransitionCallBack(monitor.ref) @@ -99,7 +100,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("send to route") { () => - val payFixture = createPaymentLifecycle(recordMetrics = false) + val payFixture = createPaymentLifecycle(defaultInvoice, recordMetrics = false) import payFixture._ import cfg._ @@ -127,7 +128,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("send to route (node_id only)") { routerFixture => - val payFixture = createPaymentLifecycle(recordMetrics = false) + val payFixture = createPaymentLifecycle(defaultInvoice, recordMetrics = false) import payFixture._ import cfg._ @@ -156,7 +157,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("send to route (nodes not found in the graph)") { routerFixture => - val payFixture = createPaymentLifecycle(recordMetrics = false) + val payFixture = createPaymentLifecycle(defaultInvoice, recordMetrics = false) import payFixture._ val brokenRoute = SendPaymentToRoute(sender.ref, Left(PredefinedNodeRoute(defaultAmountMsat, Seq(randomKey().publicKey, randomKey().publicKey, randomKey().publicKey))), defaultRecipient) @@ -173,7 +174,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("send to route (channels not found in the graph)") { routerFixture => - val payFixture = createPaymentLifecycle(recordMetrics = false) + val payFixture = createPaymentLifecycle(defaultInvoice, recordMetrics = false) import payFixture._ val brokenRoute = SendPaymentToRoute(sender.ref, Left(PredefinedChannelRoute(defaultAmountMsat, randomKey().publicKey, Seq(ShortChannelId(1), ShortChannelId(2)))), defaultRecipient) @@ -190,7 +191,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("send to route (routing hints)") { routerFixture => - val payFixture = createPaymentLifecycle(recordMetrics = false) + val payFixture = createPaymentLifecycle(defaultInvoice, recordMetrics = false) import payFixture._ import cfg._ @@ -219,8 +220,33 @@ class PaymentLifecycleSpec extends BaseRouterSpec { assert(routerForwarder.expectMsgType[RouteDidRelay].route.hops.map(_.nodeId) === Seq(a, b, c)) } + test("send to route (blinded route)") { () => + val (invoice, blindedHop, recipient) = blindedRouteFromHops(defaultAmountMsat, defaultExpiry, Seq(channelHopFromUpdate(b, c, update_bc)), defaultRouteExpiry, defaultPaymentPreimage) + val route = Route(defaultAmountMsat, Seq(channelHopFromUpdate(a, b, update_ab)), Some(blindedHop)) + val payFixture = createPaymentLifecycle(invoice) + import payFixture._ + + val request = SendPaymentToRoute(sender.ref, Right(route), recipient) + sender.send(paymentFSM, request) + routerForwarder.expectNoMessage(100 millis) // we don't need the router, we have the pre-computed route + awaitCond(nodeParams.db.payments.getOutgoingPayment(cfg.id).exists(_.status == OutgoingPaymentStatus.Pending)) + val Some(outgoing) = nodeParams.db.payments.getOutgoingPayment(cfg.id) + assert(outgoing.amount == defaultAmountMsat) + assert(outgoing.recipientAmount == defaultAmountMsat) + assert(outgoing.invoice.contains(invoice)) + assert(outgoing.status == OutgoingPaymentStatus.Pending) + + sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFulfill(UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentPreimage)))) + val ps = sender.expectMsgType[PaymentSent] + assert(ps.id == cfg.parentId) + assert(ps.parts.head.route.contains(route.hops ++ Seq(blindedHop))) + awaitCond(nodeParams.db.payments.getOutgoingPayment(cfg.id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Succeeded])) + + assert(routerForwarder.expectMsgType[RouteDidRelay].route === route) + } + test("payment failed (route not found)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -245,7 +271,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (route too expensive)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -277,7 +303,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (cannot build onion)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -297,7 +323,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (unparsable failure)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -342,7 +368,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (local error)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -363,7 +389,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (register error)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -383,7 +409,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (first hop returns an UpdateFailMalformedHtlc)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -406,7 +432,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (first htlc failed on-chain)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -429,7 +455,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (disconnected before signing the first htlc)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -453,7 +479,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (TemporaryChannelFailure)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ val request = SendPaymentToNode(sender.ref, defaultRecipient, 2, defaultRouteParams) @@ -483,7 +509,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (Update)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -538,7 +564,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (Update in last attempt)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ val request = SendPaymentToNode(sender.ref, defaultRecipient, 1, defaultRouteParams) @@ -562,7 +588,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (Update in assisted route)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -606,7 +632,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (Update disabled in assisted route)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -636,7 +662,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } def testPermanentFailure(router: ActorRef, failure: FailureMessage): Unit = { - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -673,8 +699,40 @@ class PaymentLifecycleSpec extends BaseRouterSpec { testPermanentFailure(routerFixture.router, FailureMessageCodecs.failureMessageCodec.decode(hex"4011".bits).require.value) } + test("payment failed (blinded route)") { routerFixture => + val (invoice, blindedHop, recipient) = blindedRouteFromHops(defaultAmountMsat, defaultExpiry, Seq(channelHopFromUpdate(b, c, update_bc)), defaultRouteExpiry, defaultPaymentPreimage) + assert(recipient.extraEdges.length == 1) + val payFixture = createPaymentLifecycle(invoice) + import payFixture._ + + val request = SendPaymentToNode(sender.ref, recipient, 2, defaultRouteParams) + sender.send(paymentFSM, request) + + routerForwarder.expectMsgType[RouteRequest] + routerForwarder.forward(routerFixture.router) + awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) + val WaitingForComplete(_, cmd1, Nil, sharedSecrets, _, route) = paymentFSM.stateData + register.expectMsg(ForwardShortId(paymentFSM.toTyped, scid_ab, cmd1)) + + // The payment fails inside the blinded route: the introduction node sends back an error. + val failure = InvalidOnionBlinding(randomBytes32()) + val failureOnion = Sphinx.FailurePacket.create(sharedSecrets.head._1, failure) + sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, failureOnion)))) + + // We retry but we exclude the failed blinded route. + val routeRequest = routerForwarder.expectMsgType[RouteRequest] + assert(routeRequest.target == recipient) + assert(routeRequest.ignore.channels.map(_.shortChannelId) == Set(blindedHop.dummyId)) + routerForwarder.forward(routerFixture.router) + + // Without the blinded route, the router cannot find a route to the recipient. + val failed = sender.expectMsgType[PaymentFailed] + assert(failed.failures == Seq(RemoteFailure(defaultAmountMsat, route.hops ++ Seq(blindedHop), Sphinx.DecryptedFailurePacket(b, failure)), LocalFailure(defaultAmountMsat, Nil, RouteNotFound))) + awaitCond(nodeParams.db.payments.getOutgoingPayment(cfg.id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Failed])) + } + test("payment succeeded") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -730,7 +788,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { watcher.send(router, ValidateResult(chan_bh, Right((Transaction(version = 0, txIn = Nil, txOut = TxOut(1000000 sat, write(pay2wsh(Scripts.multiSig2of2(funding_b, funding_h)))) :: Nil, lockTime = 0), UtxoStatus.Unspent)))) watcher.expectMsgType[WatchExternalChannelSpent] - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ // we send a payment to H @@ -764,6 +822,37 @@ class PaymentLifecycleSpec extends BaseRouterSpec { assert(routerForwarder.expectMsgType[RouteDidRelay].route.hops.map(_.shortChannelId) == Seq(update_ab, channelUpdate_bh).map(_.shortChannelId)) } + test("payment success (blinded route)") { routerFixture => + val (invoice, blindedHop, recipient) = blindedRouteFromHops(defaultAmountMsat, defaultExpiry, Seq(channelHopFromUpdate(b, c, update_bc)), defaultRouteExpiry, defaultPaymentPreimage) + assert(recipient.extraEdges.length == 1) + val payFixture = createPaymentLifecycle(invoice) + import payFixture._ + + val request = SendPaymentToNode(sender.ref, recipient, 2, defaultRouteParams) + sender.send(paymentFSM, request) + routerForwarder.expectMsgType[RouteRequest] + routerForwarder.forward(routerFixture.router) + awaitCond(nodeParams.db.payments.getOutgoingPayment(cfg.id).exists(_.status == OutgoingPaymentStatus.Pending)) + val Some(outgoing) = nodeParams.db.payments.getOutgoingPayment(cfg.id) + assert(outgoing.copy(createdAt = 0 unixms) == OutgoingPayment(cfg.id, cfg.parentId, Some(defaultExternalId), defaultPaymentHash, PaymentType.Blinded, defaultAmountMsat, defaultAmountMsat, recipient.nodeId, 0 unixms, Some(invoice), OutgoingPaymentStatus.Pending)) + sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFulfill(UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentPreimage)))) + + val ps = eventListener.expectMsgType[PaymentSent] + assert(ps.id == cfg.parentId) + assert(ps.feesPaid == blindedHop.fee(defaultAmountMsat)) + assert(ps.recipientAmount == defaultAmountMsat) + assert(ps.paymentPreimage == defaultPaymentPreimage) + awaitCond(nodeParams.db.payments.getOutgoingPayment(cfg.id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Succeeded])) + + val metrics = metricsListener.expectMsgType[PathFindingExperimentMetrics] + assert(metrics.status == "SUCCESS") + assert(metrics.amount == defaultAmountMsat) + assert(metrics.fees == blindedHop.fee(defaultAmountMsat)) + metricsListener.expectNoMessage(100 millis) + + assert(routerForwarder.expectMsgType[RouteDidRelay].route.hops.map(_.shortChannelId) == Seq(update_ab.shortChannelId)) + } + test("filter errors properly") { () => val failures = Seq( LocalFailure(defaultAmountMsat, Nil, RouteNotFound), @@ -782,6 +871,10 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("ignore failed nodes/channels") { () => val route_abcd = channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: Nil + val update_de = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_d, e, ShortChannelId(1729), CltvExpiryDelta(3), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 4, htlcMaximumMsat = htlcMaximum) + val (_, blindedHop_bc, _) = blindedRouteFromHops(defaultAmountMsat, defaultExpiry, Seq(channelHopFromUpdate(b, c, update_bc), channelHopFromUpdate(c, d, update_cd)), defaultRouteExpiry, defaultPaymentPreimage) + val blindedRoute_abc = channelHopFromUpdate(a, b, update_ab) :: blindedHop_bc :: Nil + val (_, blindedHop_de, _) = blindedRouteFromHops(defaultAmountMsat, defaultExpiry, Seq(channelHopFromUpdate(d, e, update_de)), defaultRouteExpiry, defaultPaymentPreimage) val testCases = Seq( // local failures -> ignore first channel if there is one (LocalFailure(defaultAmountMsat, Nil, RouteNotFound), Set.empty, Set.empty), @@ -795,9 +888,14 @@ class PaymentLifecycleSpec extends BaseRouterSpec { (RemoteFailure(defaultAmountMsat, route_abcd, Sphinx.DecryptedFailurePacket(b, PermanentChannelFailure)), Set.empty, Set(ChannelDesc(scid_bc, b, c))), (RemoteFailure(defaultAmountMsat, route_abcd, Sphinx.DecryptedFailurePacket(c, UnknownNextPeer)), Set.empty, Set(ChannelDesc(scid_cd, c, d))), (RemoteFailure(defaultAmountMsat, route_abcd, Sphinx.DecryptedFailurePacket(b, FeeInsufficient(100 msat, update_bc))), Set.empty, Set.empty), - // unreadable remote failures -> blacklist all nodes except our direct peer and the final recipient + (RemoteFailure(defaultAmountMsat, blindedRoute_abc, Sphinx.DecryptedFailurePacket(b, InvalidOnionBlinding(randomBytes32()))), Set.empty, Set(ChannelDesc(blindedHop_bc.dummyId, blindedHop_bc.nodeId, blindedHop_bc.nextNodeId))), + (RemoteFailure(defaultAmountMsat, blindedRoute_abc, Sphinx.DecryptedFailurePacket(blindedHop_bc.route.blindedNodeIds(1), InvalidOnionBlinding(randomBytes32()))), Set.empty, Set(ChannelDesc(blindedHop_bc.dummyId, blindedHop_bc.nodeId, blindedHop_bc.nextNodeId))), + // unreadable remote failures -> blacklist all nodes except our direct peer, the final recipient or the last hop (UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: Nil), Set.empty, Set.empty), - (UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: NodeHop(d, e, CltvExpiryDelta(24), 0 msat) :: Nil), Set(c, d), Set.empty) + (UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: Nil), Set(c), Set.empty), + (UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: channelHopFromUpdate(d, e, update_de) :: Nil), Set(c, d), Set.empty), + (UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: NodeHop(d, e, CltvExpiryDelta(24), 0 msat) :: Nil), Set(c), Set.empty), + (UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: blindedHop_de :: Nil), Set(c), Set.empty), ) for ((failure, expectedNodes, expectedChannels) <- testCases) { @@ -817,7 +915,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("disable database and events") { routerFixture => - val payFixture = createPaymentLifecycle(storeInDb = false, publishEvent = false, recordMetrics = false) + val payFixture = createPaymentLifecycle(defaultInvoice, storeInDb = false, publishEvent = false, recordMetrics = false) import payFixture._ import cfg._ @@ -838,7 +936,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("send to route (no retry on error") { () => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala index a607d9bad8..2777a4534d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala @@ -23,21 +23,25 @@ import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.Features._ import fr.acinq.eclair.channel._ import fr.acinq.eclair.channel.fsm.Channel +import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.IncomingPaymentPacket.{ChannelRelayPacket, FinalPacket, NodeRelayPacket, decrypt} import fr.acinq.eclair.payment.OutgoingPaymentPacket._ -import fr.acinq.eclair.payment.send.{ClearRecipient, ClearTrampolineRecipient} -import fr.acinq.eclair.router.BaseRouterSpec.channelHopFromUpdate +import fr.acinq.eclair.payment.send.{BlindedRecipient, ClearRecipient, ClearTrampolineRecipient} +import fr.acinq.eclair.router.BaseRouterSpec.{blindedRouteFromHops, channelHopFromUpdate} +import fr.acinq.eclair.router.BlindedRouteCreation import fr.acinq.eclair.router.Router.{NodeHop, Route} import fr.acinq.eclair.transactions.Transactions.InputInfo +import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer, PaymentInfo} import fr.acinq.eclair.wire.protocol.OnionPaymentPayloadTlv.{AmountToForward, OutgoingCltv, PaymentData} import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload} import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, ShortChannelId, TestConstants, TimestampSecondLong, UInt64, nodeFee, randomBytes32, randomKey} +import fr.acinq.eclair.{BlockHeight, CltvExpiry, CltvExpiryDelta, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, ShortChannelId, TestConstants, TimestampSecondLong, UInt64, nodeFee, randomBytes32, randomKey} import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite import scodec.bits.{ByteVector, HexStringSyntax} import java.util.UUID +import scala.util.Success /** * Created by PM on 31/05/2016. @@ -61,7 +65,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { def testBuildOutgoingPayment(): Unit = { val recipient = ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), recipient) assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) assert(payment.cmd.amount == amount_ab) assert(payment.cmd.cltvExpiry == expiry_ab) @@ -119,7 +123,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { test("build outgoing payment for direct peer") { val recipient = ClearRecipient(b, Features.empty, finalAmount, finalExpiry, paymentSecret, paymentMetadata_opt = Some(paymentMetadata)) val route = Route(finalAmount, hops.take(1), None) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) assert(payment.cmd.amount == finalAmount) assert(payment.cmd.cltvExpiry == finalExpiry) assert(payment.cmd.paymentHash == paymentHash) @@ -140,7 +144,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { test("build outgoing payment with greater amount and expiry") { val recipient = ClearRecipient(b, Features.empty, finalAmount, finalExpiry, paymentSecret, paymentMetadata_opt = Some(paymentMetadata)) val route = Route(finalAmount, hops.take(1), None) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) // let's peel the onion val add_b = UpdateAddHtlc(randomBytes32(), 0, finalAmount + 100.msat, paymentHash, finalExpiry + CltvExpiryDelta(6), payment.cmd.onion, None) @@ -152,6 +156,125 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(payload_b.asInstanceOf[FinalPayload.Standard].paymentSecret == paymentSecret) } + test("build outgoing blinded payment") { + val (invoice, route, recipient) = longBlindedHops(hex"deadbeef") + assert(recipient.extraEdges.length == 1) + assert(recipient.extraEdges.head.sourceNodeId == c) + assert(recipient.extraEdges.head.targetNodeId == invoice.nodeId) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) + assert(payment.cmd.amount >= amount_ab) + assert(payment.cmd.cltvExpiry == expiry_ab) + assert(payment.cmd.nextBlindingKey_opt.isEmpty) + + val add_b = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Right(relay_b@ChannelRelayPacket(_, payload_b, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) + assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) + assert(relay_b.amountToForward >= amount_bc) + assert(relay_b.outgoingCltv == expiry_bc) + assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId) + assert(relay_b.relayFeeMsat == fee_b) + assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) + assert(payload_b.isInstanceOf[IntermediatePayload.ChannelRelay.Standard]) + + val add_c = UpdateAddHtlc(randomBytes32(), 1, amount_bc, paymentHash, expiry_bc, packet_c, None) + val Right(relay_c@ChannelRelayPacket(_, payload_c, packet_d)) = decrypt(add_c, priv_c.privateKey, Features(RouteBlinding -> Optional)) + assert(packet_d.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) + assert(relay_c.amountToForward == amount_cd) + assert(relay_c.outgoingCltv == expiry_cd) + assert(payload_c.outgoingChannelId == channelUpdate_cd.shortChannelId) + assert(relay_c.relayFeeMsat == fee_c) + assert(relay_c.expiryDelta == channelUpdate_cd.cltvExpiryDelta) + assert(payload_c.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) + val blinding_d = payload_c.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding + + val add_d = UpdateAddHtlc(randomBytes32(), 2, amount_cd, paymentHash, expiry_cd, packet_d, Some(blinding_d)) + val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional)) + assert(packet_e.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) + assert(relay_d.amountToForward == amount_de) + assert(relay_d.outgoingCltv == expiry_de) + assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(relay_d.relayFeeMsat == fee_d) + assert(relay_d.expiryDelta == channelUpdate_de.cltvExpiryDelta) + assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) + val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding + + val add_e = UpdateAddHtlc(randomBytes32(), 2, amount_de, paymentHash, expiry_de, packet_e, Some(blinding_e)) + val Right(FinalPacket(_, payload_e)) = decrypt(add_e, priv_e.privateKey, Features(RouteBlinding -> Optional)) + assert(payload_e.amount == finalAmount) + assert(payload_e.totalAmount == finalAmount) + assert(payload_e.expiry == finalExpiry) + assert(payload_e.isInstanceOf[FinalPayload.Blinded]) + assert(payload_e.asInstanceOf[FinalPayload.Blinded].pathId == hex"deadbeef") + } + + test("build outgoing blinded payment for introduction node") { + // a -> b -> c where c uses a 0-hop blinded route. + val recipientKey = randomKey() + val features = Features[InvoiceFeature](VariableLengthOnion -> Mandatory, BasicMultiPartPayment -> Optional, RouteBlinding -> Mandatory) + val offer = Offer(None, "Bolt12 r0cks", recipientKey.publicKey, features, Block.RegtestGenesisBlock.hash) + val invoiceRequest = InvoiceRequest(offer, amount_bc, 1, features, randomKey(), Block.RegtestGenesisBlock.hash) + val blindedRoute = BlindedRouteCreation.createBlindedRouteWithoutHops(c, hex"deadbeef", 1 msat, CltvExpiry(500_000)).route + val paymentInfo = PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 1 msat, amount_bc, Features.empty) + val invoice = Bolt12Invoice(offer, invoiceRequest, paymentPreimage, recipientKey, CltvExpiryDelta(6), features, Seq(PaymentBlindedRoute(blindedRoute, paymentInfo))) + val recipient = BlindedRecipient(invoice, amount_bc, expiry_bc, Nil) + val hops = Seq(channelHopFromUpdate(a, b, channelUpdate_ab), channelHopFromUpdate(b, c, channelUpdate_bc)) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(amount_bc, hops, Some(recipient.blindedHops.head)), recipient) + assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) + assert(payment.cmd.amount == amount_ab) + assert(payment.cmd.cltvExpiry == expiry_ab) + assert(payment.cmd.nextBlindingKey_opt.isEmpty) + + val add_b = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Right(relay_b@ChannelRelayPacket(_, payload_b, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) + assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) + assert(relay_b.amountToForward >= amount_bc) + assert(relay_b.outgoingCltv == expiry_bc) + assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId) + assert(relay_b.relayFeeMsat == fee_b) + assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) + assert(payload_b.isInstanceOf[IntermediatePayload.ChannelRelay.Standard]) + + val add_c = UpdateAddHtlc(randomBytes32(), 1, amount_bc, paymentHash, expiry_bc, packet_c, None) + val Right(FinalPacket(_, payload_c)) = decrypt(add_c, priv_c.privateKey, Features(RouteBlinding -> Optional)) + assert(payload_c.amount == amount_bc) + assert(payload_c.totalAmount == amount_bc) + assert(payload_c.expiry == expiry_bc) + assert(payload_c.isInstanceOf[FinalPayload.Blinded]) + assert(payload_c.asInstanceOf[FinalPayload.Blinded].pathId == hex"deadbeef") + } + + test("build outgoing blinded payment starting at our node") { + val (route, recipient) = singleBlindedHop(hex"123456") + assert(recipient.extraEdges.length == 1) + assert(recipient.extraEdges.head.sourceNodeId == a) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) + assert(payment.cmd.amount == finalAmount) + assert(payment.cmd.cltvExpiry == finalExpiry) + assert(payment.cmd.nextBlindingKey_opt.nonEmpty) + + val add_b = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Right(FinalPacket(_, payload_b)) = decrypt(add_b, priv_b.privateKey, Features(RouteBlinding -> Optional)) + assert(payload_b.amount == finalAmount) + assert(payload_b.totalAmount == finalAmount) + assert(payload_b.expiry == finalExpiry) + assert(payload_b.isInstanceOf[FinalPayload.Blinded]) + assert(payload_b.asInstanceOf[FinalPayload.Blinded].pathId == hex"123456") + } + + test("build outgoing blinded payment with greater amount and expiry") { + val (route, recipient) = singleBlindedHop(hex"123456") + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) + + val add_b = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount + 100.msat, payment.cmd.paymentHash, payment.cmd.cltvExpiry + CltvExpiryDelta(6), payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Right(FinalPacket(_, payload_b)) = decrypt(add_b, priv_b.privateKey, Features(RouteBlinding -> Optional)) + assert(payload_b.amount == finalAmount) + assert(payload_b.totalAmount == finalAmount) + assert(payload_b.expiry == finalExpiry) + } + test("build outgoing trampoline payment") { // simple trampoline route to e: // .----. @@ -162,7 +285,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) assert(recipient.trampolineAmount == amount_bc) assert(recipient.trampolineExpiry == expiry_bc) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) assert(payment.cmd.amount == amount_ab) assert(payment.cmd.cltvExpiry == expiry_ab) @@ -189,7 +312,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards the trampoline payment to e through d. val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, inner_c.outgoingCltv, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e)) - val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, priv_c.privateKey, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) assert(payment_e.outgoingChannel == channelUpdate_cd.shortChannelId) assert(payment_e.cmd.amount == amount_cd) assert(payment_e.cmd.cltvExpiry == expiry_cd) @@ -215,7 +338,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) assert(recipient.trampolineAmount == amount_bc) assert(recipient.trampolineExpiry == expiry_bc) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) assert(payment.cmd.amount == amount_ab) assert(payment.cmd.cltvExpiry == expiry_ab) @@ -240,7 +363,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards the trampoline payment to e through d. val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, inner_c.outgoingCltv, inner_c.paymentSecret.get, invoice.extraEdges, inner_c.paymentMetadata) - val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, priv_c.privateKey, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) assert(payment_e.outgoingChannel == channelUpdate_cd.shortChannelId) assert(payment_e.cmd.amount == amount_cd) assert(payment_e.cmd.cltvExpiry == expiry_cd) @@ -259,7 +382,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val routingHintOverflow = List(List.fill(7)(Bolt11Invoice.ExtraHop(randomKey().publicKey, ShortChannelId(1), 10 msat, 100, CltvExpiryDelta(12)))) val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, Some(finalAmount), paymentHash, priv_e.privateKey, Left("#reckless"), CltvExpiryDelta(18), None, None, routingHintOverflow) val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) - val Left(failure) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) + val Left(failure) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) assert(failure.isInstanceOf[CannotCreateOnion]) } @@ -268,13 +391,37 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val invoiceFeatures = Features[InvoiceFeature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, PaymentMetadata -> Optional, TrampolinePaymentPrototype -> Optional) val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, Some(finalAmount), paymentHash, priv_e.privateKey, Left("Much payment very metadata"), CltvExpiryDelta(9), features = invoiceFeatures, paymentMetadata = Some(paymentMetadata)) val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) - val Left(failure) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) + val Left(failure) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) assert(failure.isInstanceOf[CannotCreateOnion]) } + test("fail to build outgoing payment with invalid route") { + val recipient = ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret) + val route = Route(finalAmount, hops.dropRight(1), None) // route doesn't reach e + val Left(failure) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + assert(failure == InvalidRouteRecipient(e, d)) + } + + test("fail to build outgoing trampoline payment with invalid route") { + val invoiceFeatures = Features[InvoiceFeature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, PaymentMetadata -> Optional, TrampolinePaymentPrototype -> Optional) + val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, paymentHash, priv_e.privateKey, Left("invoice"), CltvExpiryDelta(6), paymentSecret = paymentSecret, features = invoiceFeatures) + val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) + val route = Route(finalAmount, trampolineChannelHops, None) // missing trampoline hop + val Left(failure) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + assert(failure == MissingTrampolineHop(c)) + } + + test("fail to build outgoing blinded payment with invalid route") { + val (_, route, recipient) = longBlindedHops(hex"deadbeef") + assert(buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient).isRight) + val routeMissingBlindedHop = route.copy(finalHop_opt = None) + val Left(failure) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, routeMissingBlindedHop, recipient) + assert(failure == MissingBlindedHop(Set(c))) + } + test("fail to decrypt when the onion is invalid") { val recipient = ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), recipient) val add = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion.copy(payload = payment.cmd.onion.payload.reverse), None) val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) assert(failure.isInstanceOf[InvalidOnionHmac]) @@ -284,7 +431,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val invoiceFeatures = Features[InvoiceFeature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, PaymentMetadata -> Optional, TrampolinePaymentPrototype -> Optional) val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, paymentHash, priv_e.privateKey, Left("invoice"), CltvExpiryDelta(6), paymentSecret = paymentSecret, features = invoiceFeatures, paymentMetadata = Some(hex"010203")) val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) val add_b = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) @@ -294,7 +441,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards an invalid trampoline onion to e through d. val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, inner_c.outgoingCltv, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e.copy(payload = trampolinePacket_e.payload.reverse))) - val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, priv_c.privateKey, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) assert(payment_e.outgoingChannel == channelUpdate_cd.shortChannelId) val add_d = UpdateAddHtlc(randomBytes32(), 3, payment_e.cmd.amount, paymentHash, payment_e.cmd.cltvExpiry, payment_e.cmd.onion, None) val Right(ChannelRelayPacket(_, _, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) @@ -306,16 +453,47 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { test("fail to decrypt when payment hash doesn't match associated data") { val recipient = ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash.reverse, Route(finalAmount, hops, None), recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash.reverse, Route(finalAmount, hops, None), recipient) val add = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) assert(failure.isInstanceOf[InvalidOnionHmac]) } + test("fail to decrypt when blinded route data is invalid") { + val (route, recipient) = { + val features = Features[InvoiceFeature](VariableLengthOnion -> Mandatory, BasicMultiPartPayment -> Optional, RouteBlinding -> Mandatory) + val offer = Offer(None, "Bolt12 r0cks", c, features, Block.RegtestGenesisBlock.hash) + val invoiceRequest = InvoiceRequest(offer, amount_bc, 1, features, randomKey(), Block.RegtestGenesisBlock.hash) + // We send the wrong blinded payload to the introduction node. + val tmpBlindedRoute = BlindedRouteCreation.createBlindedRouteFromHops(Seq(channelHopFromUpdate(b, c, channelUpdate_bc)), hex"deadbeef", 1 msat, CltvExpiry(500_000)).route + val blindedRoute = tmpBlindedRoute.copy(blindedNodes = tmpBlindedRoute.blindedNodes.reverse) + val paymentInfo = OfferTypes.PaymentInfo(fee_b, 0, channelUpdate_bc.cltvExpiryDelta, 0 msat, amount_bc, Features.empty) + val invoice = Bolt12Invoice(offer, invoiceRequest, paymentPreimage, priv_c.privateKey, CltvExpiryDelta(6), features, Seq(PaymentBlindedRoute(blindedRoute, paymentInfo))) + val recipient = BlindedRecipient(invoice, amount_bc, expiry_bc, Nil) + val route = Route(amount_bc, Seq(channelHopFromUpdate(a, b, channelUpdate_ab)), Some(recipient.blindedHops.head)) + (route, recipient) + } + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) + assert(payment.cmd.amount == amount_bc + fee_b) + + val add_b = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Left(failure) = decrypt(add_b, priv_b.privateKey, Features(RouteBlinding -> Optional)) + assert(failure.isInstanceOf[InvalidOnionBlinding]) + } + + test("fail to decrypt blinded payment when route blinding is disabled") { + val (route, recipient) = singleBlindedHop(hex"00000000") + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + val add_b = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Left(failure) = decrypt(add_b, priv_b.privateKey, Features.empty) // b doesn't support route blinding + assert(failure == InvalidOnionPayload(UInt64(10), 0)) + } + test("fail to decrypt at the final node when amount has been modified by next-to-last node") { val recipient = ClearRecipient(b, Features.empty, finalAmount, finalExpiry, paymentSecret) val route = Route(finalAmount, hops.take(1), None) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) val add = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount - 100.msat, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) assert(failure == FinalIncorrectHtlcAmount(payment.cmd.amount - 100.msat)) @@ -324,12 +502,64 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { test("fail to decrypt at the final node when expiry has been modified by next-to-last node") { val recipient = ClearRecipient(b, Features.empty, finalAmount, finalExpiry, paymentSecret) val route = Route(finalAmount, hops.take(1), None) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) val add = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry - CltvExpiryDelta(12), payment.cmd.onion, None) val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) assert(failure == FinalIncorrectCltvExpiry(payment.cmd.cltvExpiry - CltvExpiryDelta(12))) } + test("fail to decrypt blinded payment at the final node when amount is too low") { + val (route, recipient) = shortBlindedHops() + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_c.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + assert(payment.outgoingChannel == channelUpdate_cd.shortChannelId) + assert(payment.cmd.amount == amount_cd) + + // A smaller amount is sent to d, who doesn't know that it's invalid. + val add_d = UpdateAddHtlc(randomBytes32(), 0, amount_de, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional)) + assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(relay_d.amountToForward < amount_de) + assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) + val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding + + // When e receives a smaller amount than expected, it rejects the payment. + val add_e = UpdateAddHtlc(randomBytes32(), 0, relay_d.amountToForward, paymentHash, relay_d.outgoingCltv, packet_e, Some(blinding_e)) + val Left(failure) = decrypt(add_e, priv_e.privateKey, Features(RouteBlinding -> Optional)) + assert(failure.isInstanceOf[InvalidOnionBlinding]) + } + + test("fail to decrypt blinded payment at the final node when expiry is too low") { + val (route, recipient) = shortBlindedHops() + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_c.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + assert(payment.outgoingChannel == channelUpdate_cd.shortChannelId) + assert(payment.cmd.cltvExpiry == expiry_cd) + + // A smaller expiry is sent to d, who doesn't know that it's invalid. + val add_d = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, paymentHash, expiry_de, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional)) + assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(relay_d.outgoingCltv < expiry_de) + assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) + val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding + + // When e receives a smaller expiry than expected, it rejects the payment. + val add_e = UpdateAddHtlc(randomBytes32(), 0, relay_d.amountToForward, paymentHash, relay_d.outgoingCltv, packet_e, Some(blinding_e)) + val Left(failure) = decrypt(add_e, priv_e.privateKey, Features(RouteBlinding -> Optional)) + assert(failure.isInstanceOf[InvalidOnionBlinding]) + } + + test("fail to decrypt blinded payment at intermediate node when expiry is too high") { + val routeExpiry = expiry_de - channelUpdate_de.cltvExpiryDelta + val (route, recipient) = shortBlindedHops(routeExpiry) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_c.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + assert(payment.outgoingChannel == channelUpdate_cd.shortChannelId) + assert(payment.cmd.cltvExpiry > expiry_de) + + val add_d = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Left(failure) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional)) + assert(failure.isInstanceOf[InvalidOnionBlinding]) + } + // Create a trampoline payment to e: // .----. // / \ @@ -340,7 +570,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val invoiceFeatures = Features[InvoiceFeature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, TrampolinePaymentPrototype -> Optional) val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, paymentHash, priv_e.privateKey, Left("invoice"), CltvExpiryDelta(6), paymentSecret = paymentSecret, features = invoiceFeatures) val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) val add_b = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) @@ -355,7 +585,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards an invalid amount to e through (the outer total amount doesn't match the inner amount). val invalidTotalAmount = inner_c.amountToForward - 1.msat val recipient_e = ClearRecipient(e, Features.empty, invalidTotalAmount, inner_c.outgoingCltv, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e)) - val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(invalidTotalAmount, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, priv_c.privateKey, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(invalidTotalAmount, afterTrampolineChannelHops, None), recipient_e) val add_d = UpdateAddHtlc(randomBytes32(), 3, payment_e.cmd.amount, paymentHash, payment_e.cmd.cltvExpiry, payment_e.cmd.onion, None) val Right(ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) @@ -371,7 +601,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards an invalid amount to e through (the outer expiry doesn't match the inner expiry). val invalidExpiry = inner_c.outgoingCltv - CltvExpiryDelta(12) val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, invalidExpiry, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e)) - val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, priv_c.privateKey, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) val add_d = UpdateAddHtlc(randomBytes32(), 3, payment_e.cmd.amount, paymentHash, payment_e.cmd.cltvExpiry, payment_e.cmd.onion, None) val Right(ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) @@ -395,6 +625,56 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(failure == FinalIncorrectCltvExpiry(expiry_bc - CltvExpiryDelta(12))) } + test("build htlc failure onion") { + // a -> b -> c -> d -> e + val recipient = ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), recipient) + val add_b = UpdateAddHtlc(randomBytes32(), 0, amount_ab, paymentHash, expiry_ab, payment.cmd.onion, None) + val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) + val add_c = UpdateAddHtlc(randomBytes32(), 1, amount_bc, paymentHash, expiry_bc, packet_c, None) + val Right(ChannelRelayPacket(_, _, packet_d)) = decrypt(add_c, priv_c.privateKey, Features.empty) + val add_d = UpdateAddHtlc(randomBytes32(), 2, amount_cd, paymentHash, expiry_cd, packet_d, None) + val Right(ChannelRelayPacket(_, _, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) + val add_e = UpdateAddHtlc(randomBytes32(), 3, amount_de, paymentHash, expiry_de, packet_e, None) + val Right(FinalPacket(_, payload_e)) = decrypt(add_e, priv_e.privateKey, Features.empty) + assert(payload_e.isInstanceOf[FinalPayload.Standard]) + + // e returns a failure + val failure = IncorrectOrUnknownPaymentDetails(finalAmount, BlockHeight(currentBlockCount)) + val Right(fail_e) = buildHtlcFailure(priv_e.privateKey, CMD_FAIL_HTLC(add_e.id, Right(failure)), add_e) + assert(fail_e.id == add_e.id) + val Right(fail_d) = buildHtlcFailure(priv_d.privateKey, CMD_FAIL_HTLC(add_d.id, Left(fail_e.reason)), add_d) + assert(fail_d.id == add_d.id) + val Right(fail_c) = buildHtlcFailure(priv_c.privateKey, CMD_FAIL_HTLC(add_c.id, Left(fail_d.reason)), add_c) + assert(fail_c.id == add_c.id) + val Right(fail_b) = buildHtlcFailure(priv_b.privateKey, CMD_FAIL_HTLC(add_b.id, Left(fail_c.reason)), add_b) + assert(fail_b.id == add_b.id) + val Success(Sphinx.DecryptedFailurePacket(failingNode, decryptedFailure)) = Sphinx.FailurePacket.decrypt(fail_b.reason, payment.sharedSecrets) + assert(failingNode == e) + assert(decryptedFailure == failure) + } + + test("build htlc failure onion (blinded payment)") { + // a -> b -> c -> d -> e, blinded after c + val (_, route, recipient) = longBlindedHops(hex"0451") + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + val add_b = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) + val add_c = UpdateAddHtlc(randomBytes32(), 1, amount_bc, paymentHash, expiry_bc, packet_c, None) + val Right(_: ChannelRelayPacket) = decrypt(add_c, priv_c.privateKey, Features(RouteBlinding -> Optional)) + + // only the introduction node is allowed to send an `update_fail_htlc` message: downstream nodes must send + // `update_fail_malformed_htlc` which doesn't use onion encryption + val failure = InvalidOnionBlinding(Sphinx.hash(add_c.onionRoutingPacket)) + val Right(fail_c) = buildHtlcFailure(priv_c.privateKey, CMD_FAIL_HTLC(add_c.id, Right(failure)), add_c) + assert(fail_c.id == add_c.id) + val Right(fail_b) = buildHtlcFailure(priv_b.privateKey, CMD_FAIL_HTLC(add_b.id, Left(fail_c.reason)), add_b) + assert(fail_b.id == add_b.id) + val Success(Sphinx.DecryptedFailurePacket(failingNode, decryptedFailure)) = Sphinx.FailurePacket.decrypt(fail_b.reason, payment.sharedSecrets) + assert(failingNode == c) + assert(decryptedFailure == failure) + } + } object PaymentPacketSpec { @@ -411,9 +691,9 @@ object PaymentPacketSpec { } } - def randomExtendedPrivateKey: ExtendedPrivateKey = DeterministicWallet.generate(randomBytes32()) + def randomExtendedPrivateKey(): ExtendedPrivateKey = DeterministicWallet.generate(randomBytes32()) - val (priv_a, priv_b, priv_c, priv_d, priv_e) = (TestConstants.Alice.nodeKeyManager.nodeKey, TestConstants.Bob.nodeKeyManager.nodeKey, randomExtendedPrivateKey, randomExtendedPrivateKey, randomExtendedPrivateKey) + val (priv_a, priv_b, priv_c, priv_d, priv_e) = (TestConstants.Alice.nodeKeyManager.nodeKey, TestConstants.Bob.nodeKeyManager.nodeKey, randomExtendedPrivateKey(), randomExtendedPrivateKey(), randomExtendedPrivateKey()) val (a, b, c, d, e) = (priv_a.publicKey, priv_b.publicKey, priv_c.publicKey, priv_d.publicKey, priv_e.publicKey) val sig = Crypto.sign(Crypto.sha256(ByteVector.empty), priv_a.privateKey) val defaultChannelUpdate = ChannelUpdate(sig, Block.RegtestGenesisBlock.hash, ShortChannelId(0), 0 unixsec, ChannelUpdate.MessageFlags(dontForward = false), ChannelUpdate.ChannelFlags.DUMMY, CltvExpiryDelta(0), 42000 msat, 0 msat, 0, 500_000_000 msat) @@ -453,6 +733,32 @@ object PaymentPacketSpec { val expiry_ab = expiry_bc + channelUpdate_bc.cltvExpiryDelta val amount_ab = amount_bc + fee_b + // fully blinded route a -> b + def singleBlindedHop(pathId: ByteVector = hex"deadbeef", routeExpiry: CltvExpiry = CltvExpiry(500_000)): (Route, BlindedRecipient) = { + val (_, blindedHop, recipient) = blindedRouteFromHops(finalAmount, finalExpiry, Seq(channelHopFromUpdate(a, b, channelUpdate_ab)), routeExpiry, paymentPreimage, pathId) + (Route(finalAmount, Nil, Some(blindedHop)), recipient) + } + + // route c -> d -> e, blinded after d + def shortBlindedHops(routeExpiry: CltvExpiry = CltvExpiry(500_000)): (Route, BlindedRecipient) = { + val (_, blindedHop, recipient) = blindedRouteFromHops(finalAmount, finalExpiry, Seq(channelHopFromUpdate(d, e, channelUpdate_de)), routeExpiry, paymentPreimage) + (Route(finalAmount, Seq(channelHopFromUpdate(c, d, channelUpdate_cd)), Some(blindedHop)), recipient) + } + + // route a -> b -> c -> d -> e, blinded after c + def longBlindedHops(pathId: ByteVector): (Bolt12Invoice, Route, BlindedRecipient) = { + val hopsToBlind = Seq( + channelHopFromUpdate(c, d, channelUpdate_cd), + channelHopFromUpdate(d, e, channelUpdate_de), + ) + val (invoice, blindedHop, recipient) = blindedRouteFromHops(finalAmount, finalExpiry, hopsToBlind, CltvExpiry(500_000), paymentPreimage, pathId) + val hops = Seq( + channelHopFromUpdate(a, b, channelUpdate_ab), + channelHopFromUpdate(b, c, channelUpdate_bc), + ) + (invoice, Route(finalAmount, hops, Some(blindedHop)), recipient) + } + // simple trampoline route to e: // .----. // / \ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala index 9d7e66ab60..b1a58a4a95 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala @@ -633,8 +633,8 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit else if (htlcId == 1L) Some(nonRelayedHtlc2In.add) else None } - def localNodeId: PublicKey = randomExtendedPrivateKey.publicKey - def remoteNodeId: PublicKey = randomExtendedPrivateKey.publicKey + def localNodeId: PublicKey = randomExtendedPrivateKey().publicKey + def remoteNodeId: PublicKey = randomExtendedPrivateKey().publicKey def capacity: Satoshi = Long.MaxValue.sat def availableBalanceForReceive: MilliSatoshi = Long.MaxValue.msat def availableBalanceForSend: MilliSatoshi = 0.msat @@ -721,7 +721,7 @@ object PostRestartHtlcCleanerSpec { val (paymentHash1, paymentHash2, paymentHash3) = (Crypto.sha256(preimage1), Crypto.sha256(preimage2), Crypto.sha256(preimage3)) def buildHtlc(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32): UpdateAddHtlc = { - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), SpontaneousRecipient(e, finalAmount, finalExpiry, randomBytes32())) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), SpontaneousRecipient(e, finalAmount, finalExpiry, randomBytes32())) UpdateAddHtlc(channelId, htlcId, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) } @@ -730,7 +730,7 @@ object PostRestartHtlcCleanerSpec { def buildHtlcOut(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32): DirectedHtlc = OutgoingHtlc(buildHtlc(htlcId, channelId, paymentHash)) def buildFinalHtlc(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32): DirectedHtlc = { - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, Seq(channelHopFromUpdate(a, b, channelUpdate_ab)), None), SpontaneousRecipient(b, finalAmount, finalExpiry, randomBytes32())) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, Seq(channelHopFromUpdate(a, b, channelUpdate_ab)), None), SpontaneousRecipient(b, finalAmount, finalExpiry, randomBytes32())) IncomingHtlc(UpdateAddHtlc(channelId, htlcId, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None)) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala index f0f07e895a..96c65a553d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala @@ -90,7 +90,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat } // we use this to build a valid onion - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret)) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret)) // and then manually build an htlc val add_ab = UpdateAddHtlc(randomBytes32(), 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) relayer ! RelayForward(add_ab) @@ -100,7 +100,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat test("relay an htlc-add at the final node to the payment handler") { f => import f._ - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops.take(1), None), ClearRecipient(b, Features.empty, finalAmount, finalExpiry, paymentSecret)) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops.take(1), None), ClearRecipient(b, Features.empty, finalAmount, finalExpiry, paymentSecret)) val add_ab = UpdateAddHtlc(channelId_ab, 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) relayer ! RelayForward(add_ab) @@ -119,7 +119,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val finalTrampolinePayload = NodePayload(b, FinalPayload.Standard.createPayload(finalAmount, totalAmount, finalExpiry, paymentSecret)) val Right(trampolineOnion) = buildOnion(PaymentOnionCodecs.trampolineOnionPayloadLength, Seq(finalTrampolinePayload), paymentHash) val recipient = ClearRecipient(b, nodeParams.features.invoiceFeatures(), finalAmount, finalExpiry, randomBytes32(), nextTrampolineOnion_opt = Some(trampolineOnion.packet)) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, Seq(channelHopFromUpdate(priv_a.publicKey, b, channelUpdate_ab)), None), recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, Seq(channelHopFromUpdate(priv_a.publicKey, b, channelUpdate_ab)), None), recipient) assert(payment.cmd.amount == finalAmount) assert(payment.cmd.cltvExpiry == finalExpiry) val add_ab = UpdateAddHtlc(channelId_ab, 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) @@ -140,7 +140,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat import f._ // we use this to build a valid onion - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret)) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret)) // and then manually build an htlc with an invalid onion (hmac) val add_ab = UpdateAddHtlc(channelId_ab, 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion.copy(hmac = payment.cmd.onion.hmac.reverse), None) relayer ! RelayForward(add_ab) @@ -161,7 +161,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, paymentHash, priv_c.privateKey, Left("invoice"), CltvExpiryDelta(6), paymentSecret = paymentSecret, features = invoiceFeatures) val trampolineHop = NodeHop(b, c, channelUpdate_bc.cltvExpiryDelta, fee_b) val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, Seq(channelHopFromUpdate(priv_a.publicKey, b, channelUpdate_ab)), Some(trampolineHop)), recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, Seq(channelHopFromUpdate(priv_a.publicKey, b, channelUpdate_ab)), Some(trampolineHop)), recipient) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId_ab, 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala index 3aef6d1cf7..259a038d58 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala @@ -29,10 +29,13 @@ import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.crypto.keymanager.{LocalChannelKeyManager, LocalNodeKeyManager} import fr.acinq.eclair.io.Peer.PeerRoutingMessage +import fr.acinq.eclair.payment.send.BlindedRecipient +import fr.acinq.eclair.payment.{Bolt12Invoice, PaymentBlindedRoute} import fr.acinq.eclair.router.Announcements._ import fr.acinq.eclair.router.BaseRouterSpec.channelAnnouncement import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.transactions.Scripts +import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer} import fr.acinq.eclair.wire.protocol._ import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike @@ -225,9 +228,11 @@ abstract class BaseRouterSpec extends TestKitBaseClass with FixtureAnyFunSuiteLi withFixture(test.toNoArgTest(FixtureParam(nodeParams, router, watcher))) } } + } object BaseRouterSpec { + def channelAnnouncement(shortChannelId: RealShortChannelId, node1_priv: PrivateKey, node2_priv: PrivateKey, funding1_priv: PrivateKey, funding2_priv: PrivateKey) = { val witness = Announcements.generateChannelAnnouncementWitness(Block.RegtestGenesisBlock.hash, shortChannelId, node1_priv.publicKey, node2_priv.publicKey, funding1_priv.publicKey, funding2_priv.publicKey, Features.empty) val node1_sig = Announcements.signChannelAnnouncement(witness, node1_priv) @@ -240,4 +245,39 @@ object BaseRouterSpec { def channelHopFromUpdate(nodeId: PublicKey, nextNodeId: PublicKey, channelUpdate: ChannelUpdate): ChannelHop = { ChannelHop(channelUpdate.shortChannelId, nodeId, nextNodeId, HopRelayParams.FromAnnouncement(channelUpdate)) } + + def blindedRouteFromHops(amount: MilliSatoshi, + expiry: CltvExpiry, + hops: Seq[ChannelHop], + routeExpiry: CltvExpiry, + preimage: ByteVector32 = randomBytes32(), + pathId: ByteVector = randomBytes(32)): (Bolt12Invoice, BlindedHop, BlindedRecipient) = { + val (invoice, recipient) = blindedRoutesFromPaths(amount, expiry, Seq(hops), routeExpiry, preimage, pathId) + (invoice, recipient.blindedHops.head, recipient) + } + + def blindedRoutesFromPaths(amount: MilliSatoshi, + expiry: CltvExpiry, + paths: Seq[Seq[ChannelHop]], + routeExpiry: CltvExpiry, + preimage: ByteVector32 = randomBytes32(), + pathId: ByteVector = randomBytes(32)): (Bolt12Invoice, BlindedRecipient) = { + val recipientKey = randomKey() + val features = Features[InvoiceFeature]( + Features.VariableLengthOnion -> FeatureSupport.Mandatory, + Features.BasicMultiPartPayment -> FeatureSupport.Optional, + Features.RouteBlinding -> FeatureSupport.Mandatory + ) + val offer = Offer(None, "Bolt12 r0cks", recipientKey.publicKey, features, Block.RegtestGenesisBlock.hash) + val invoiceRequest = InvoiceRequest(offer, amount, 1, features, randomKey(), Block.RegtestGenesisBlock.hash) + val blindedRoutes = paths.map(hops => { + val blindedRoute = BlindedRouteCreation.createBlindedRouteFromHops(hops, pathId, 1 msat, routeExpiry).route + val paymentInfo = BlindedRouteCreation.aggregatePaymentInfo(amount, hops) + PaymentBlindedRoute(blindedRoute, paymentInfo) + }) + val invoice = Bolt12Invoice(offer, invoiceRequest, preimage, recipientKey, CltvExpiryDelta(6), features, blindedRoutes) + val recipient = BlindedRecipient(invoice, amount, expiry, Nil) + (invoice, recipient) + } + } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala index cefde92823..95a0fe4707 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala @@ -27,10 +27,11 @@ import fr.acinq.eclair.channel.{AvailableBalanceChanged, CommitmentsSpec, LocalC import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.io.Peer.PeerRoutingMessage import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop +import fr.acinq.eclair.payment.Invoice.ExtraEdge import fr.acinq.eclair.payment.send.{ClearRecipient, ClearTrampolineRecipient, SpontaneousRecipient} import fr.acinq.eclair.payment.{Bolt11Invoice, Invoice} import fr.acinq.eclair.router.Announcements.{makeChannelUpdate, makeNodeAnnouncement} -import fr.acinq.eclair.router.BaseRouterSpec.channelAnnouncement +import fr.acinq.eclair.router.BaseRouterSpec.{blindedRoutesFromPaths, channelAnnouncement} import fr.acinq.eclair.router.Graph.RoutingHeuristics import fr.acinq.eclair.router.RouteCalculationSpec.{DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, DEFAULT_ROUTE_PARAMS, route2NodeIds} import fr.acinq.eclair.router.Router._ @@ -479,6 +480,84 @@ class RouterSpec extends BaseRouterSpec { assert(route2.finalHop_opt.contains(trampolineHop)) } + test("routes found (with blinded hops)") { fixture => + import fixture._ + val sender = TestProbe() + val r = randomKey().publicKey + val hopsToRecipient = Seq( + ChannelHop(ShortChannelId(10000), b, r, HopRelayParams.FromHint(ExtraEdge(b, r, ShortChannelId(10000), 800 msat, 0, CltvExpiryDelta(36), 1 msat, Some(400_000 msat)))) :: Nil, + ChannelHop(ShortChannelId(10001), c, r, HopRelayParams.FromHint(ExtraEdge(c, r, ShortChannelId(10001), 500 msat, 0, CltvExpiryDelta(36), 1 msat, Some(400_000 msat)))) :: Nil, + ) + + { + // Amount split between both blinded routes: + val (_, recipient) = blindedRoutesFromPaths(600_000 msat, DEFAULT_EXPIRY, hopsToRecipient, DEFAULT_EXPIRY) + sender.send(router, RouteRequest(a, recipient, DEFAULT_ROUTE_PARAMS, allowMultiPart = true)) + val routes = sender.expectMsgType[RouteResponse].routes + assert(routes.length == 2) + assert(routes.flatMap(_.finalHop_opt) == recipient.blindedHops) + assert(routes.map(route => route2NodeIds(route)).toSet == Set(Seq(a, b), Seq(a, b, c))) + assert(routes.map(route => route.blindedFee + route.channelFee(false)).toSet == Set(510 msat, 800 msat)) + } + { + // One blinded route is ignored, we use the other one: + val (_, recipient) = blindedRoutesFromPaths(300_000 msat, DEFAULT_EXPIRY, hopsToRecipient, DEFAULT_EXPIRY) + val ignored = Ignore(Set.empty, Set(ChannelDesc(recipient.extraEdges.last.shortChannelId, recipient.extraEdges.last.sourceNodeId, recipient.extraEdges.last.targetNodeId))) + sender.send(router, RouteRequest(a, recipient, DEFAULT_ROUTE_PARAMS, ignore = ignored)) + val routes = sender.expectMsgType[RouteResponse].routes + assert(routes.length == 1) + assert(routes.head.finalHop_opt.nonEmpty) + assert(route2NodeIds(routes.head) == Seq(a, b)) + assert(routes.head.blindedFee == 800.msat) + } + { + // One blinded route is ignored, the other one doesn't have enough capacity: + val (_, recipient) = blindedRoutesFromPaths(500_000 msat, DEFAULT_EXPIRY, hopsToRecipient, DEFAULT_EXPIRY) + val ignored = Ignore(Set.empty, Set(ChannelDesc(recipient.extraEdges.last.shortChannelId, recipient.extraEdges.last.sourceNodeId, recipient.extraEdges.last.targetNodeId))) + sender.send(router, RouteRequest(a, recipient, DEFAULT_ROUTE_PARAMS, allowMultiPart = true, ignore = ignored)) + sender.expectMsg(Failure(RouteNotFound)) + } + { + // One blinded route is pending, we use the other one: + val (_, recipient) = blindedRoutesFromPaths(600_000 msat, DEFAULT_EXPIRY, hopsToRecipient, DEFAULT_EXPIRY) + sender.send(router, RouteRequest(a, recipient, DEFAULT_ROUTE_PARAMS, allowMultiPart = true)) + val routes1 = sender.expectMsgType[RouteResponse].routes + assert(routes1.length == 2) + sender.send(router, RouteRequest(a, recipient, DEFAULT_ROUTE_PARAMS, allowMultiPart = true, pendingPayments = Seq(routes1.head))) + val routes2 = sender.expectMsgType[RouteResponse].routes + assert(routes2 == routes1.tail) + } + { + // One blinded route is pending, we send two htlcs to the other one: + val (_, recipient) = blindedRoutesFromPaths(600_000 msat, DEFAULT_EXPIRY, hopsToRecipient, DEFAULT_EXPIRY) + sender.send(router, RouteRequest(a, recipient, DEFAULT_ROUTE_PARAMS, allowMultiPart = true)) + val routes1 = sender.expectMsgType[RouteResponse].routes + assert(routes1.length == 2) + sender.send(router, RouteRequest(a, recipient, DEFAULT_ROUTE_PARAMS, allowMultiPart = true, pendingPayments = Seq(routes1.head))) + val routes2 = sender.expectMsgType[RouteResponse].routes + assert(routes2 == routes1.tail) + sender.send(router, RouteRequest(a, recipient, DEFAULT_ROUTE_PARAMS, allowMultiPart = true, pendingPayments = Seq(routes1.head, routes2.head.copy(amount = routes2.head.amount - 25_000.msat)))) + val routes3 = sender.expectMsgType[RouteResponse].routes + assert(routes3.length == 1) + assert(routes3.head.amount == 25_000.msat) + } + { + // One blinded route is pending, we cannot use the other one because of the fee budget: + val (_, recipient) = blindedRoutesFromPaths(600_000 msat, DEFAULT_EXPIRY, hopsToRecipient, DEFAULT_EXPIRY) + val routeParams1 = DEFAULT_ROUTE_PARAMS.copy(boundaries = SearchBoundaries(5000 msat, 0.0, 6, CltvExpiryDelta(1008))) + sender.send(router, RouteRequest(a, recipient, routeParams1, allowMultiPart = true)) + val routes1 = sender.expectMsgType[RouteResponse].routes + assert(routes1.length == 2) + assert(routes1.head.blindedFee + routes1.head.channelFee(false) == 800.msat) + val routeParams2 = DEFAULT_ROUTE_PARAMS.copy(boundaries = SearchBoundaries(1000 msat, 0.0, 6, CltvExpiryDelta(1008))) + sender.send(router, RouteRequest(a, recipient, routeParams2, allowMultiPart = true, pendingPayments = Seq(routes1.head))) + sender.expectMsg(Failure(RouteNotFound)) + val routeParams3 = DEFAULT_ROUTE_PARAMS.copy(boundaries = SearchBoundaries(1500 msat, 0.0, 6, CltvExpiryDelta(1008))) + sender.send(router, RouteRequest(a, recipient, routeParams3, allowMultiPart = true, pendingPayments = Seq(routes1.head))) + assert(sender.expectMsgType[RouteResponse].routes.length == 1) + } + } + test("route not found (channel disabled)") { fixture => import fixture._ val sender = TestProbe() From 855ce3b5ab0e54a70549a8bc62d1c941386ab2b1 Mon Sep 17 00:00:00 2001 From: t-bast Date: Fri, 16 Dec 2022 16:29:57 +0100 Subject: [PATCH 2/2] fixup! Send payments to blinded routes --- eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index 6101ab5693..ed4f63a023 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -566,10 +566,7 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging { randomKey(), randomKey(), intermediateNodes.map(OnionMessages.IntermediateNode(_)), - destination match { - case Left(key) => OnionMessages.Recipient(key, None) - case Right(route) => OnionMessages.BlindedPath(route) - }, + destination match { case Left(key) => OnionMessages.Recipient(key, None) case Right(route) => OnionMessages.BlindedPath(route) }, replyRoute.map(OnionMessagePayloadTlv.ReplyPath(_) :: Nil).getOrElse(Nil), userCustomTlvs) match { case Success((nextNodeId, message)) =>