diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala index bb50bdd555..1f9f27f37e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala @@ -367,6 +367,7 @@ object Sphinx extends Logging { val subsequentNodes: Seq[BlindedNode] = blindedNodes.tail val blindedNodeIds: Seq[PublicKey] = blindedNodes.map(_.blindedPublicKey) val encryptedPayloads: Seq[ByteVector] = blindedNodes.map(_.encryptedPayload) + val length: Int = blindedNodes.length - 1 } /** diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala index de4398403e..5f24180c9d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala @@ -19,7 +19,7 @@ package fr.acinq.eclair.db import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.payment._ -import fr.acinq.eclair.router.Router.{ChannelHop, Hop, NodeHop} +import fr.acinq.eclair.router.Router.{BlindedHop, ChannelHop, Hop, NodeHop} import fr.acinq.eclair.{MilliSatoshi, Paginated, ShortChannelId, TimestampMilli} import scodec.bits.ByteVector @@ -226,6 +226,7 @@ object HopSummary { def apply(h: Hop): HopSummary = { val shortChannelId = h match { case ch: ChannelHop => Some(ch.shortChannelId) + case _: BlindedHop => None case _: NodeHop => None } HopSummary(h.nodeId, h.nextNodeId, shortChannelId) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala b/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala index 4fb7ea15bc..84db12f7b5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala @@ -31,7 +31,7 @@ import fr.acinq.eclair.io.Peer import fr.acinq.eclair.message.OnionMessages import fr.acinq.eclair.payment.PaymentFailure.PaymentFailedSummary import fr.acinq.eclair.payment._ -import fr.acinq.eclair.router.Router.{HopRelayParams, NodeHop, Route} +import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.transactions.DirectedHtlc import fr.acinq.eclair.transactions.Transactions._ import fr.acinq.eclair.wire.protocol.MessageOnionCodecs.blindedRouteCodec @@ -296,12 +296,14 @@ object ColorSerializer extends MinimalSerializer({ // @formatter:off private sealed trait HopJson private case class ChannelHopJson(nodeId: PublicKey, nextNodeId: PublicKey, source: HopRelayParams) extends HopJson +private case class BlindedHopJson(nodeId: PublicKey, nextNodeId: PublicKey, paymentInfo: OfferTypes.PaymentInfo) extends HopJson private case class NodeHopJson(nodeId: PublicKey, nextNodeId: PublicKey, fee: MilliSatoshi, cltvExpiryDelta: CltvExpiryDelta) extends HopJson private case class RouteFullJson(amount: MilliSatoshi, hops: Seq[HopJson]) object RouteFullSerializer extends ConvertClassSerializer[Route](route => { val channelHops = route.hops.map(h => ChannelHopJson(h.nodeId, h.nextNodeId, h.params)) val finalHop_opt = route.finalHop_opt.map { case h: NodeHop => NodeHopJson(h.nodeId, h.nextNodeId, h.fee, h.cltvExpiryDelta) + case h: BlindedHop => BlindedHopJson(h.nodeId, h.nextNodeId, h.paymentInfo) } RouteFullJson(route.amount, channelHops ++ finalHop_opt.toSeq) }) @@ -315,6 +317,8 @@ object RouteNodeIdsSerializer extends ConvertClassSerializer[Route](route => { val finalNodeIds = route.finalHop_opt match { case Some(hop: NodeHop) if channelNodeIds.nonEmpty => Seq(hop.nextNodeId) case Some(hop: NodeHop) => Seq(hop.nodeId, hop.nextNodeId) + case Some(hop: BlindedHop) if channelNodeIds.nonEmpty => hop.route.blindedNodeIds.tail + case Some(hop: BlindedHop) => hop.route.introductionNodeId +: hop.route.blindedNodeIds.tail case None => Nil } RouteNodeIdsJson(route.amount, channelNodeIds ++ finalNodeIds) @@ -325,6 +329,7 @@ object RouteShortChannelIdsSerializer extends ConvertClassSerializer[Route](rout val hops = route.hops.map(_.shortChannelId) val finalHop = route.finalHop_opt.map { case _: NodeHop => "trampoline" + case _: BlindedHop => "blinded" } RouteShortChannelIdsJson(route.amount, hops, finalHop) }) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala index 98a579785c..cb4d139c00 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala @@ -24,8 +24,8 @@ import fr.acinq.eclair.payment.send.PaymentError.RetryExhausted import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.{ClearRecipient, Recipient} import fr.acinq.eclair.router.Announcements -import fr.acinq.eclair.router.Router.{ChannelDesc, ChannelHop, Hop, Ignore} -import fr.acinq.eclair.wire.protocol.{ChannelDisabled, ChannelUpdate, Node, TemporaryChannelFailure} +import fr.acinq.eclair.router.Router._ +import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{MilliSatoshi, ShortChannelId, TimestampMilli} import scodec.bits.ByteVector @@ -183,11 +183,15 @@ object PaymentFailure { .isDefined /** Ignore the channel outgoing from the given nodeId in the given route. */ - private def ignoreNodeOutgoingChannel(nodeId: PublicKey, hops: Seq[Hop], ignore: Ignore): Ignore = { + private def ignoreNodeOutgoingEdge(nodeId: PublicKey, hops: Seq[Hop], ignore: Ignore): Ignore = { hops.collectFirst { case hop: ChannelHop if hop.nodeId == nodeId => ChannelDesc(hop.shortChannelId, hop.nodeId, hop.nextNodeId) + case hop: BlindedHop if hop.nodeId == nodeId => ChannelDesc(hop.dummyId, hop.nodeId, hop.nextNodeId) + // The error comes from inside the blinded route: this is a spec violation, errors should always come from the + // introduction node, so we definitely want to ignore this blinded route when this happens. + case hop: BlindedHop if hop.route.blindedNodeIds.contains(nodeId) => ChannelDesc(hop.dummyId, hop.nodeId, hop.nextNodeId) } match { - case Some(faultyChannel) => ignore + faultyChannel + case Some(faultyEdge) => ignore + faultyEdge case None => ignore } } @@ -207,7 +211,7 @@ object PaymentFailure { case _ => false } if (shouldIgnore) { - ignoreNodeOutgoingChannel(nodeId, hops, ignore) + ignoreNodeOutgoingEdge(nodeId, hops, ignore) } else { // We were using an outdated channel update, we should retry with the new one and nobody should be penalized. ignore @@ -217,10 +221,14 @@ object PaymentFailure { ignore + nodeId } case RemoteFailure(_, hops, Sphinx.DecryptedFailurePacket(nodeId, _)) => - ignoreNodeOutgoingChannel(nodeId, hops, ignore) + ignoreNodeOutgoingEdge(nodeId, hops, ignore) case UnreadableRemoteFailure(_, hops) => - // We don't know which node is sending garbage, let's blacklist all nodes except the one we are directly connected to and the final recipient. - val blacklist = hops.map(_.nextNodeId).drop(1).dropRight(1).toSet + // We don't know which node is sending garbage, let's blacklist all nodes except: + // - the one we are directly connected to: it would be too restrictive for retries + // - the final recipient: they have no incentive to send garbage since they want that payment + // - the introduction point of a blinded route: we don't want a node before the blinded path to force us to ignore that blinded path + // - the trampoline node: we don't want a node before the trampoline node to force us to ignore that trampoline node + val blacklist = hops.collect { case hop: ChannelHop => hop }.map(_.nextNodeId).drop(1).dropRight(1).toSet ignore ++ blacklist case LocalFailure(_, hops, _) => hops.headOption match { case Some(hop: ChannelHop) => diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala index a063d74c7b..9410dbce35 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala @@ -23,7 +23,7 @@ import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.eclair.channel.{CMD_ADD_HTLC, CMD_FAIL_HTLC, CannotExtractSharedSecret, Origin} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.send.Recipient -import fr.acinq.eclair.router.Router.Route +import fr.acinq.eclair.router.Router.{BlindedHop, ChannelHop, Route} import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload, PerHopPayload} import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, ShortChannelId, UInt64, randomKey} @@ -182,6 +182,8 @@ object IncomingPaymentPacket { case payload if add.amountMsat < payload.paymentConstraints.minAmount => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) case payload if add.cltvExpiry > payload.paymentConstraints.maxCltvExpiry => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) case payload if !Features.areCompatible(Features.empty, payload.allowedFeatures) => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) + case payload if add.amountMsat < payload.amount => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) + case payload if add.cltvExpiry < payload.expiry => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) case payload => Right(FinalPacket(add, payload)) } } @@ -231,8 +233,10 @@ object OutgoingPaymentPacket { sealed trait OutgoingPaymentError extends Throwable case class CannotCreateOnion(message: String) extends OutgoingPaymentError { override def getMessage: String = message } + case class CannotDecryptBlindedRoute(message: String) extends OutgoingPaymentError { override def getMessage: String = message } case class InvalidRouteRecipient(expected: PublicKey, actual: PublicKey) extends OutgoingPaymentError { override def getMessage: String = s"expected route to $expected, got route to $actual" } case class MissingTrampolineHop(trampolineNodeId: PublicKey) extends OutgoingPaymentError { override def getMessage: String = s"expected route to trampoline node $trampolineNodeId" } + case class MissingBlindedHop(introductionNodeIds: Set[PublicKey]) extends OutgoingPaymentError { override def getMessage: String = s"expected blinded route using one of the following introduction nodes: ${introductionNodeIds.mkString(", ")}" } case object EmptyRoute extends OutgoingPaymentError { override def getMessage: String = "route cannot be empty" } sealed trait Upstream @@ -261,15 +265,41 @@ object OutgoingPaymentPacket { } } + private case class OutgoingPaymentWithChannel(shortChannelId: ShortChannelId, nextBlinding_opt: Option[PublicKey], payment: PaymentPayloads) + + private def getOutgoingChannel(privateKey: PrivateKey, payment: PaymentPayloads, route: Route): Either[OutgoingPaymentError, OutgoingPaymentWithChannel] = { + route.hops.headOption match { + case Some(hop) => Right(OutgoingPaymentWithChannel(hop.shortChannelId, None, payment)) + case None => route.finalHop_opt match { + case Some(hop: BlindedHop) => + // We are the introduction node of the blinded route: we need to decrypt the first payload. + val firstBlinding = hop.route.introductionNode.blindingEphemeralKey + val firstEncryptedPayload = hop.route.introductionNode.encryptedPayload + RouteBlindingEncryptedDataCodecs.decode(privateKey, firstBlinding, firstEncryptedPayload) match { + case Left(e) => Left(CannotDecryptBlindedRoute(e.message)) + case Right(decoded) => + val tlvs = TlvStream[OnionPaymentPayloadTlv](OnionPaymentPayloadTlv.EncryptedRecipientData(firstEncryptedPayload), OnionPaymentPayloadTlv.BlindingPoint(firstBlinding)) + IntermediatePayload.ChannelRelay.Blinded.validate(tlvs, decoded.tlvs, decoded.nextBlinding) match { + case Left(e) => Left(CannotDecryptBlindedRoute(e.failureMessage.message)) + case Right(payload) => + val payment1 = PaymentPayloads(payload.amountToForward(payment.amount), payload.outgoingCltv(payment.expiry), payment.payloads.tail) + Right(OutgoingPaymentWithChannel(payload.outgoingChannelId, Some(decoded.nextBlinding), payment1)) + } + } + case _ => Left(EmptyRoute) + } + } + } + /** Build the command to add an HTLC for the given recipient using the provided route. */ - def buildOutgoingPayment(replyTo: ActorRef, upstream: Upstream, paymentHash: ByteVector32, route: Route, recipient: Recipient): Either[OutgoingPaymentError, OutgoingPaymentPacket] = { - val outgoingChannel = route.hops.head.shortChannelId + def buildOutgoingPayment(replyTo: ActorRef, privateKey: PrivateKey, upstream: Upstream, paymentHash: ByteVector32, route: Route, recipient: Recipient): Either[OutgoingPaymentError, OutgoingPaymentPacket] = { for { - payment <- recipient.buildPayloads(paymentHash, route) - onion <- buildOnion(PaymentOnionCodecs.paymentOnionPayloadLength, payment.payloads, paymentHash) // BOLT 2 requires that associatedData == paymentHash + paymentTmp <- recipient.buildPayloads(paymentHash, route) + outgoing <- getOutgoingChannel(privateKey, paymentTmp, route) + onion <- buildOnion(PaymentOnionCodecs.paymentOnionPayloadLength, outgoing.payment.payloads, paymentHash) // BOLT 2 requires that associatedData == paymentHash } yield { - val cmd = CMD_ADD_HTLC(replyTo, payment.amount, paymentHash, payment.expiry, onion.packet, None, Origin.Hot(replyTo, upstream), commit = true) - OutgoingPaymentPacket(cmd, outgoingChannel, onion.sharedSecrets) + val cmd = CMD_ADD_HTLC(replyTo, outgoing.payment.amount, paymentHash, outgoing.payment.expiry, onion.packet, outgoing.nextBlinding_opt, Origin.Hot(replyTo, upstream), commit = true) + OutgoingPaymentPacket(cmd, outgoing.shortChannelId, onion.sharedSecrets) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala index 68574d0d0f..97fb1d8721 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala @@ -97,7 +97,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, if (cfg.storeInDb && d.pending.isEmpty && d.failures.isEmpty) { // In cases where we fail early (router error during the first attempt), the DB won't have an entry for that // payment, which may be confusing for users. - val dummyPayment = OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, d.request.recipient.totalAmount, d.request.recipient.totalAmount, d.request.recipient.nodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending) + val dummyPayment = OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, cfg.paymentType, d.request.recipient.totalAmount, d.request.recipient.totalAmount, d.request.recipient.nodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending) nodeParams.db.payments.addOutgoingPayment(dummyPayment) nodeParams.db.payments.updateOutgoingPayment(PaymentFailed(id, paymentHash, failure :: Nil)) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala index 2a05451dc3..f0fc579f2d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala @@ -21,6 +21,7 @@ import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto} import fr.acinq.eclair.channel.fsm.Channel import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.db.PaymentType import fr.acinq.eclair.payment.OutgoingPaymentPacket.Upstream import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.send.PaymentError._ @@ -50,22 +51,20 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn } val paymentCfg = SendPaymentConfig(paymentId, paymentId, r.externalId, r.paymentHash, r.invoice.nodeId, Upstream.Local(paymentId), Some(r.invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = true) val finalExpiry = r.finalExpiry(nodeParams) - r.invoice match { - case invoice: Bolt11Invoice => - val recipient = ClearRecipient(invoice, r.recipientAmount, finalExpiry, r.userCustomTlvs) - if (!nodeParams.features.invoiceFeatures().areSupported(recipient.features)) { - sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, UnsupportedFeatures(recipient.features)) :: Nil) - } else if (Features.canUseFeature(nodeParams.features.invoiceFeatures(), recipient.features, Features.BasicMultiPartPayment)) { - val fsm = outgoingPaymentFactory.spawnOutgoingMultiPartPayment(context, paymentCfg) - fsm ! MultiPartPaymentLifecycle.SendMultiPartPayment(self, recipient, r.maxAttempts, r.routeParams) - context become main(pending + (paymentId -> PendingPaymentToNode(sender(), r))) - } else { - val fsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) - fsm ! PaymentLifecycle.SendPaymentToNode(self, recipient, r.maxAttempts, r.routeParams) - context become main(pending + (paymentId -> PendingPaymentToNode(sender(), r))) - } - case _: Bolt12Invoice => - sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, new IllegalArgumentException("payments to Bolt12 invoices are not supported yet")) :: Nil) + val recipient = r.invoice match { + case invoice: Bolt11Invoice => ClearRecipient(invoice, r.recipientAmount, finalExpiry, r.userCustomTlvs) + case invoice: Bolt12Invoice => BlindedRecipient(invoice, r.recipientAmount, finalExpiry, r.userCustomTlvs) + } + if (!nodeParams.features.invoiceFeatures().areSupported(recipient.features)) { + sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, UnsupportedFeatures(recipient.features)) :: Nil) + } else if (Features.canUseFeature(nodeParams.features.invoiceFeatures(), recipient.features, Features.BasicMultiPartPayment)) { + val fsm = outgoingPaymentFactory.spawnOutgoingMultiPartPayment(context, paymentCfg) + fsm ! MultiPartPaymentLifecycle.SendMultiPartPayment(self, recipient, r.maxAttempts, r.routeParams) + context become main(pending + (paymentId -> PendingPaymentToNode(sender(), r))) + } else { + val fsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) + fsm ! PaymentLifecycle.SendPaymentToNode(self, recipient, r.maxAttempts, r.routeParams) + context become main(pending + (paymentId -> PendingPaymentToNode(sender(), r))) } case r: SendSpontaneousPayment => @@ -119,18 +118,16 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, t) :: Nil) } case None => - r.invoice match { - case invoice: Bolt11Invoice => - sender() ! SendPaymentToRouteResponse(paymentId, parentPaymentId, None) - val paymentCfg = SendPaymentConfig(paymentId, parentPaymentId, r.externalId, r.paymentHash, r.recipientNodeId, Upstream.Local(paymentId), Some(r.invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = false) - val finalExpiry = r.finalExpiry(nodeParams) - val recipient = ClearRecipient(invoice, r.recipientAmount, finalExpiry, Nil) - val payFsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) - payFsm ! PaymentLifecycle.SendPaymentToRoute(self, Left(r.route), recipient) - context become main(pending + (paymentId -> PendingPaymentToRoute(sender(), r))) - case _: Bolt12Invoice => - sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, new IllegalArgumentException("payments to Bolt12 invoices are not supported yet")) :: Nil) + sender() ! SendPaymentToRouteResponse(paymentId, parentPaymentId, None) + val paymentCfg = SendPaymentConfig(paymentId, parentPaymentId, r.externalId, r.paymentHash, r.recipientNodeId, Upstream.Local(paymentId), Some(r.invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = false) + val finalExpiry = r.finalExpiry(nodeParams) + val recipient = r.invoice match { + case invoice: Bolt11Invoice => ClearRecipient(invoice, r.recipientAmount, finalExpiry, Nil) + case invoice: Bolt12Invoice => BlindedRecipient(invoice, r.recipientAmount, finalExpiry, Nil) } + val payFsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) + payFsm ! PaymentLifecycle.SendPaymentToRoute(self, Left(r.route), recipient) + context become main(pending + (paymentId -> PendingPaymentToRoute(sender(), r))) case _ => sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, TrampolineMultiNodeNotSupported) :: Nil) } @@ -195,18 +192,15 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn } private def buildTrampolineRecipient(r: SendRequestedPayment, trampolineHop: NodeHop): Try[ClearTrampolineRecipient] = { + // We generate a random secret for the payment to the trampoline node. + val trampolineSecret = r match { + case r: SendPaymentToRoute => r.trampoline_opt.map(_.paymentSecret).getOrElse(randomBytes32()) + case _ => randomBytes32() + } + val finalExpiry = r.finalExpiry(nodeParams) r.invoice match { - case invoice: Bolt11Invoice => - // We generate a random secret for the payment to the trampoline node. - val trampolineSecret = r match { - case r: SendPaymentToRoute => r.trampoline_opt.map(_.paymentSecret).getOrElse(randomBytes32()) - case _ => randomBytes32() - } - val finalExpiry = r.finalExpiry(nodeParams) - val recipient = ClearTrampolineRecipient(invoice, r.recipientAmount, finalExpiry, trampolineHop, trampolineSecret) - Success(recipient) - case _: Bolt12Invoice => - Failure(new IllegalArgumentException("payments to Bolt12 invoices are not supported yet")) + case invoice: Bolt11Invoice => Success(ClearTrampolineRecipient(invoice, r.recipientAmount, finalExpiry, trampolineHop, trampolineSecret)) + case _: Bolt12Invoice => Failure(new IllegalArgumentException("trampoline blinded payments are not supported yet")) } } @@ -403,9 +397,13 @@ object PaymentInitiator { storeInDb: Boolean, // e.g. for trampoline we don't want to store in the DB when we're relaying payments publishEvent: Boolean, recordPathFindingMetrics: Boolean) { - def createPaymentSent(recipient: Recipient, preimage: ByteVector32, parts: Seq[PaymentSent.PartialPayment]) = PaymentSent(parentId, paymentHash, preimage, recipient.totalAmount, recipient.nodeId, parts) + val paymentContext: PaymentContext = PaymentContext(id, parentId, paymentHash) + val paymentType = invoice match { + case Some(_: Bolt12Invoice) => PaymentType.Blinded + case _ => PaymentType.Standard + } - def paymentContext: PaymentContext = PaymentContext(id, parentId, paymentHash) + def createPaymentSent(recipient: Recipient, preimage: ByteVector32, parts: Seq[PaymentSent.PartialPayment]) = PaymentSent(parentId, paymentHash, preimage, recipient.totalAmount, recipient.nodeId, parts) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala index c5e0476e96..066e8a0ab5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala @@ -60,7 +60,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A route => self ! RouteResponse(route :: Nil) ) if (cfg.storeInDb) { - paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, request.amount, request.recipient.totalAmount, request.recipient.nodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending)) + paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, cfg.paymentType, request.amount, request.recipient.totalAmount, request.recipient.nodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending)) } goto(WAITING_FOR_ROUTE) using WaitingForRoute(request, Nil, Ignore.empty) @@ -68,7 +68,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A log.debug("sending {} to {}", request.amount, request.recipient.nodeId) router ! RouteRequest(nodeParams.nodeId, request.recipient, request.routeParams, paymentContext = Some(cfg.paymentContext)) if (cfg.storeInDb) { - paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, request.amount, request.recipient.totalAmount, request.recipient.nodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending)) + paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, cfg.paymentType, request.amount, request.recipient.totalAmount, request.recipient.nodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending)) } goto(WAITING_FOR_ROUTE) using WaitingForRoute(request, Nil, Ignore.empty) } @@ -76,7 +76,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A when(WAITING_FOR_ROUTE) { case Event(RouteResponse(route +: _), WaitingForRoute(request, failures, ignore)) => log.info(s"route found: attempt=${failures.size + 1}/${request.maxAttempts} route=${route.printNodes()} channels=${route.printChannels()}") - OutgoingPaymentPacket.buildOutgoingPayment(self, cfg.upstream, paymentHash, route, request.recipient) match { + OutgoingPaymentPacket.buildOutgoingPayment(self, nodeParams.privateKey, cfg.upstream, paymentHash, route, request.recipient) match { case Right(payment) => register ! Register.ForwardShortId(self.toTyped[Register.ForwardShortIdFailure[CMD_ADD_HTLC]], payment.outgoingChannel, payment.cmd) goto(WAITING_FOR_PAYMENT_COMPLETE) using WaitingForComplete(request, payment.cmd, failures, payment.sharedSecrets, ignore, route) @@ -252,6 +252,19 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A goto(WAITING_FOR_ROUTE) using WaitingForRoute(request, failures :+ failure, ignore + nodeId) } } + case Success(e@Sphinx.DecryptedFailurePacket(nodeId, _: InvalidOnionBlinding)) => + // there was a failure inside the blinded route we used: we cannot know why it failed, so let's ignore it. + log.info(s"received an error coming from nodeId=$nodeId inside the blinded route, retrying with different blinded routes") + val failure = RemoteFailure(request.amount, route.fullRoute, e) + val ignore1 = PaymentFailure.updateIgnored(failure, ignore) + request match { + case _: SendPaymentToRoute => + log.error("unexpected retry during SendPaymentToRoute") + stop(FSM.Normal) + case request: SendPaymentToNode => + router ! RouteRequest(nodeParams.nodeId, recipient, request.routeParams, ignore1, paymentContext = Some(cfg.paymentContext)) + goto(WAITING_FOR_ROUTE) using WaitingForRoute(request, failures :+ failure, ignore1) + } case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) => log.info(s"received an error message from nodeId=$nodeId, trying to use a different channel (failure=$failureMessage)") val failure = RemoteFailure(request.amount, route.fullRoute, e) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala index ed0690c400..3e599bd255 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala @@ -21,9 +21,9 @@ import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.Invoice.ExtraEdge import fr.acinq.eclair.payment.OutgoingPaymentPacket._ -import fr.acinq.eclair.payment.{Bolt11Invoice, OutgoingPaymentPacket} -import fr.acinq.eclair.router.Router.{ChannelHop, NodeHop, Route} -import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload} +import fr.acinq.eclair.payment.{Bolt11Invoice, Bolt12Invoice, OutgoingPaymentPacket} +import fr.acinq.eclair.router.Router._ +import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload, OutgoingBlindedPerHopPayload} import fr.acinq.eclair.wire.protocol.{GenericTlv, OnionRoutingPacket, PaymentOnionCodecs} import fr.acinq.eclair.{CltvExpiry, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, ShortChannelId} import scodec.bits.ByteVector @@ -54,9 +54,9 @@ sealed trait Recipient { object Recipient { /** Iteratively build all the payloads for a payment relayed through channel hops. */ - def buildPayloads(finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, finalPayload: NodePayload, hops: Seq[ChannelHop]): PaymentPayloads = { + def buildPayloads(finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, finalPayloads: Seq[NodePayload], hops: Seq[ChannelHop]): PaymentPayloads = { // We ignore the first hop since the route starts at our node. - hops.tail.foldRight(PaymentPayloads(finalAmount, finalExpiry, Seq(finalPayload))) { + hops.tail.foldRight(PaymentPayloads(finalAmount, finalExpiry, finalPayloads)) { case (hop, current) => val payload = NodePayload(hop.nodeId, IntermediatePayload.ChannelRelay.Standard(hop.shortChannelId, current.amount, current.expiry)) PaymentPayloads(current.amount + hop.fee(current.amount), current.expiry + hop.cltvExpiryDelta, payload +: current.payloads) @@ -80,7 +80,7 @@ case class ClearRecipient(nodeId: PublicKey, case Some(trampolinePacket) => NodePayload(nodeId, FinalPayload.Standard.createTrampolinePayload(route.amount, totalAmount, expiry, paymentSecret, trampolinePacket)) case None => NodePayload(nodeId, FinalPayload.Standard.createPayload(route.amount, totalAmount, expiry, paymentSecret, paymentMetadata_opt, customTlvs)) } - Recipient.buildPayloads(route.amount, expiry, finalPayload, route.hops) + Recipient.buildPayloads(route.amount, expiry, Seq(finalPayload), route.hops) }) } } @@ -111,12 +111,81 @@ case class SpontaneousRecipient(nodeId: PublicKey, override def buildPayloads(paymentHash: ByteVector32, route: Route): Either[OutgoingPaymentError, PaymentPayloads] = { ClearRecipient.validateRoute(nodeId, route).map(_ => { val finalPayload = NodePayload(nodeId, FinalPayload.Standard.createKeySendPayload(route.amount, totalAmount, expiry, preimage, customTlvs)) - Recipient.buildPayloads(totalAmount, expiry, finalPayload, route.hops) + Recipient.buildPayloads(totalAmount, expiry, Seq(finalPayload), route.hops) }) } } -/** A payment recipient that can be reached through a given trampoline node (usually not found in the routing graph). */ +/** A payment recipient that hides its real identity using route blinding. */ +case class BlindedRecipient(nodeId: PublicKey, + features: Features[InvoiceFeature], + totalAmount: MilliSatoshi, + expiry: CltvExpiry, + blindedHops: Seq[BlindedHop], + customTlvs: Seq[GenericTlv] = Nil) extends Recipient { + require(blindedHops.nonEmpty, "blinded routes must be provided") + + override val extraEdges = blindedHops.map { h => + ExtraEdge(h.route.introductionNodeId, nodeId, h.dummyId, h.paymentInfo.feeBase, h.paymentInfo.feeProportionalMillionths, h.paymentInfo.cltvExpiryDelta, h.paymentInfo.minHtlc, Some(h.paymentInfo.maxHtlc)) + } + + private def validateRoute(route: Route): Either[OutgoingPaymentError, BlindedHop] = { + route.finalHop_opt match { + case Some(blindedHop: BlindedHop) => Right(blindedHop) + case _ => Left(MissingBlindedHop(blindedHops.map(_.route.introductionNodeId).toSet)) + } + } + + private def buildBlindedPayloads(amount: MilliSatoshi, blindedHop: BlindedHop): PaymentPayloads = { + val blinding = blindedHop.route.introductionNode.blindingEphemeralKey + val payloads = if (blindedHop.route.subsequentNodes.isEmpty) { + // The recipient is also the introduction node. + Seq(NodePayload(blindedHop.route.introductionNodeId, OutgoingBlindedPerHopPayload.createFinalIntroductionPayload(amount, totalAmount, expiry, blinding, blindedHop.route.introductionNode.encryptedPayload, customTlvs))) + } else { + val introductionPayload = NodePayload(blindedHop.route.introductionNodeId, OutgoingBlindedPerHopPayload.createIntroductionPayload(blindedHop.route.introductionNode.encryptedPayload, blinding)) + val intermediatePayloads = blindedHop.route.subsequentNodes.dropRight(1).map(n => NodePayload(n.blindedPublicKey, OutgoingBlindedPerHopPayload.createIntermediatePayload(n.encryptedPayload))) + val finalPayload = NodePayload(blindedHop.route.blindedNodes.last.blindedPublicKey, OutgoingBlindedPerHopPayload.createFinalPayload(amount, totalAmount, expiry, blindedHop.route.blindedNodes.last.encryptedPayload, customTlvs)) + introductionPayload +: intermediatePayloads :+ finalPayload + } + val introductionAmount = amount + blindedHop.paymentInfo.fee(amount) + val introductionExpiry = expiry + blindedHop.paymentInfo.cltvExpiryDelta + PaymentPayloads(introductionAmount, introductionExpiry, payloads) + } + + override def buildPayloads(paymentHash: ByteVector32, route: Route): Either[OutgoingPaymentError, PaymentPayloads] = { + validateRoute(route).map(blindedHop => { + val blindedPayloads = buildBlindedPayloads(route.amount, blindedHop) + if (route.hops.isEmpty) { + // We are the introduction node of the blinded route. + blindedPayloads + } else { + Recipient.buildPayloads(blindedPayloads.amount, blindedPayloads.expiry, blindedPayloads.payloads, route.hops) + } + }) + } +} + +object BlindedRecipient { + def apply(invoice: Bolt12Invoice, totalAmount: MilliSatoshi, expiry: CltvExpiry, customTlvs: Seq[GenericTlv]): BlindedRecipient = { + val blindedHops = invoice.blindedPaths.zip(invoice.blindedPathsInfo).map { + case (route, info) => + // We don't know the scids of channels inside the blinded route, but it's useful to have an ID to refer to a + // given edge in the graph, so we create a dummy one for the duration of the payment attempt. + val dummyId = ShortChannelId.generateLocalAlias() + BlindedHop(dummyId, route, info) + } + BlindedRecipient(invoice.nodeId, invoice.features, totalAmount, expiry, blindedHops, customTlvs) + } +} + +/** + * A payment recipient that can be reached through a trampoline node (such recipients usually cannot be found in the + * public graph). Splitting a payment across multiple trampoline nodes is not supported yet, but can easily be added + * with a new field containing a bigger recipient total amount. + * + * Note that we don't need to support the case where we'd use multiple trampoline hops in the same route: since we have + * access to the network graph, it's always more efficient to find a channel route to the last trampoline node. + */ case class ClearTrampolineRecipient(invoice: Bolt11Invoice, totalAmount: MilliSatoshi, expiry: CltvExpiry, @@ -137,7 +206,7 @@ case class ClearTrampolineRecipient(invoice: Bolt11Invoice, private def validateRoute(route: Route): Either[OutgoingPaymentError, NodeHop] = { route.finalHop_opt match { case Some(trampolineHop: NodeHop) => Right(trampolineHop) - case None => Left(MissingTrampolineHop(trampolineNodeId)) + case _ => Left(MissingTrampolineHop(trampolineNodeId)) } } @@ -147,7 +216,7 @@ case class ClearTrampolineRecipient(invoice: Bolt11Invoice, trampolineOnion <- createTrampolinePacket(paymentHash, trampolineHop) } yield { val trampolinePayload = NodePayload(trampolineHop.nodeId, FinalPayload.Standard.createTrampolinePayload(route.amount, trampolineAmount, trampolineExpiry, trampolinePaymentSecret, trampolineOnion.packet)) - Recipient.buildPayloads(route.amount, trampolineExpiry, trampolinePayload, route.hops) + Recipient.buildPayloads(route.amount, trampolineExpiry, Seq(trampolinePayload), route.hops) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala index b8de39d65a..1f91d50f34 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala @@ -22,7 +22,7 @@ import com.softwaremill.quicklens.ModifyPimp import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair._ -import fr.acinq.eclair.payment.send.{ClearRecipient, ClearTrampolineRecipient, Recipient, SpontaneousRecipient} +import fr.acinq.eclair.payment.send._ import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph.graphEdgeToHop import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} import fr.acinq.eclair.router.Graph.{InfiniteLoop, NegativeProbability, RichWeight} @@ -134,6 +134,21 @@ object RouteCalculation { val amountToSend = recipient.totalAmount - pendingAmount val maxFee = totalMaxFee - pendingChannelFee (targetNodeId, amountToSend, maxFee, Set.empty) + case recipient: BlindedRecipient => + // Blinded routes all end at a different (blinded) node, so we create graph edges in which they lead to the same node. + val targetNodeId = randomKey().publicKey + val extraEdges = recipient.extraEdges + .map(_.copy(targetNodeId = targetNodeId)) + .filterNot(e => ignoredEdges.exists(_.shortChannelId == e.shortChannelId)) + // For blinded routes, the maximum htlc field is used to indicate the maximum amount that can be sent through the route. + .map(e => GraphEdge(e).copy(balance_opt = e.htlcMaximum_opt)) + .toSet + val amountToSend = recipient.totalAmount - pendingAmount + // When we are the introduction node and includeLocalChannelCost is false, we cannot easily remove the fee for + // the first hop in the blinded route (we would need to decrypt the route and fetch the corresponding channel). + // In that case, we will slightly over-estimate the fee we're paying, but at least we won't exceed our fee budget. + val maxFee = totalMaxFee - pendingChannelFee - r.pendingPayments.map(_.blindedFee).sum + (targetNodeId, amountToSend, maxFee, extraEdges) case recipient: ClearTrampolineRecipient => // Trampoline payments require finding routes to the trampoline node, not the final recipient. // This also ensures that we correctly take the trampoline fee into account only once, even when using MPP to @@ -146,11 +161,17 @@ object RouteCalculation { } private def addFinalHop(recipient: Recipient, routes: Seq[Route]): Seq[Route] = { - routes.map(route => { + routes.flatMap(route => { recipient match { - case _: ClearRecipient => route - case _: SpontaneousRecipient => route - case recipient: ClearTrampolineRecipient => route.copy(finalHop_opt = Some(recipient.trampolineHop)) + case _: ClearRecipient => Some(route) + case _: SpontaneousRecipient => Some(route) + case recipient: ClearTrampolineRecipient => Some(route.copy(finalHop_opt = Some(recipient.trampolineHop))) + case recipient: BlindedRecipient => + route.hops.lastOption.flatMap { + hop => recipient.blindedHops.find(_.dummyId == hop.shortChannelId) + }.map { + blindedHop => Route(route.amount, route.hops.dropRight(1), Some(blindedHop)) + } } }) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index 56d3970ab7..32c43707f3 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -28,6 +28,7 @@ import fr.acinq.eclair._ import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher.{ValidateResult, WatchExternalChannelSpent, WatchExternalChannelSpentTriggered} import fr.acinq.eclair.channel._ +import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.db.NetworkDb import fr.acinq.eclair.io.Peer.PeerRoutingMessage @@ -497,6 +498,24 @@ object Router { sealed trait FinalHop extends Hop + /** + * A directed hop over a blinded route composed of multiple (blinded) channels. + * Since a blinded route has to be used from start to end, we model it as a single virtual hop. + * + * @param dummyId dummy identifier to allow indexing in maps: unlike normal scid aliases, this one doesn't exist + * in our routing tables and should be used carefully. + * @param route blinded route covered by that hop. + * @param paymentInfo payment information about the blinded route. + */ + case class BlindedHop(dummyId: Alias, route: BlindedRoute, paymentInfo: OfferTypes.PaymentInfo) extends FinalHop { + // @formatter:off + override val nodeId = route.introductionNodeId + override val nextNodeId = route.blindedNodes.last.blindedPublicKey + override val cltvExpiryDelta = paymentInfo.cltvExpiryDelta + override def fee(amount: MilliSatoshi): MilliSatoshi = paymentInfo.fee(amount) + // @formatter:on + } + /** * A directed hop between two trampoline nodes. * These nodes need not be connected and we don't need to know a route between them. @@ -568,7 +587,19 @@ object Router { */ val trampolineFee: MilliSatoshi = finalHop_opt.collect { case hop: NodeHop => hop.fee(amount) }.getOrElse(0 msat) + /** + * Fee paid for the blinded route, if any. + * Note that when we are the introduction node for the blinded route, we cannot easily compute the fee without the + * cost for the first local channel. + */ + val blindedFee: MilliSatoshi = finalHop_opt.collect { case hop: BlindedHop => hop.fee(amount) }.getOrElse(0 msat) + /** Fee paid for the channel hops towards the recipient or the source of the final hop, if any. */ + + /** + * Fee paid for the channel hops towards the recipient or the source of the final hop. + * Note that this doesn't include the fees for the final hop, if one exits. + */ def channelFee(includeLocalChannelCost: Boolean): MilliSatoshi = { val hopsToPay = if (includeLocalChannelCost) hops else hops.drop(1) val amountToSend = hopsToPay.foldRight(amount) { case (hop, amount1) => amount1 + hop.fee(amount1) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala index 91314bd229..ee2b1b6207 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala @@ -39,9 +39,8 @@ object OfferCodecs { private val absoluteExpiry: Codec[AbsoluteExpiry] = tlvField(tu64overflow.as[TimestampSecond].as[AbsoluteExpiry]) private val blindedNodeCodec: Codec[BlindedNode] = (("nodeId" | publicKey) :: ("encryptedData" | variableSizeBytes(uint16, bytes))).as[BlindedNode] - - private val pathCodec: Codec[BlindedRoute] = (("firstNodeId" | publicKey) :: ("blinding" | publicKey) :: ("path" | listOfN(uint8, blindedNodeCodec).xmap[Seq[BlindedNode]](_.toSeq, _.toList))).as[BlindedRoute] - + private val blindedNodesCodec: Codec[Seq[BlindedNode]] = listOfN(uint8, blindedNodeCodec).xmap(_.toSeq, _.toList) + private val pathCodec: Codec[BlindedRoute] = (("firstNodeId" | publicKey) :: ("blinding" | publicKey) :: ("path" | blindedNodesCodec)).as[BlindedRoute] private val paths: Codec[Paths] = tlvField(list(pathCodec).xmap[Seq[BlindedRoute]](_.toSeq, _.toList).as[Paths]) private val issuer: Codec[Issuer] = tlvField(utf8.as[Issuer]) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala index eca19fb94a..ab26f44ed8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala @@ -191,13 +191,13 @@ object PaymentOnion { import OnionPaymentPayloadTlv._ /* - * PerHopPayload - * | - * | - * +--------------+---------------+ - * | | - * | | - * IntermediatePayload FinalPayload + * PerHopPayload + * | + * | + * +------------------------------+-----------------------------+ + * | | | + * | | | + * IntermediatePayload FinalPayload OutgoingBlindedPerHopPayload * | | * | | * +---------+---------+ +------+------+ @@ -437,6 +437,33 @@ object PaymentOnion { } } + /** + * An opaque blinded payload (used when sending to a blinded route, never used to decode incoming payloads). + * We cannot use the other payload types because we cannot decrypt the recipient encrypted data, so we don't even + * know if those payloads are valid. + */ + case class OutgoingBlindedPerHopPayload(records: TlvStream[OnionPaymentPayloadTlv]) extends PerHopPayload { + require(records.get[EncryptedRecipientData].nonEmpty, "blinded per-hop payload must contain encrypted data") + } + + object OutgoingBlindedPerHopPayload { + def createIntroductionPayload(encryptedRecipientData: ByteVector, blinding: PublicKey): OutgoingBlindedPerHopPayload = { + OutgoingBlindedPerHopPayload(TlvStream(Seq(EncryptedRecipientData(encryptedRecipientData), BlindingPoint(blinding)))) + } + + def createIntermediatePayload(encryptedRecipientData: ByteVector): OutgoingBlindedPerHopPayload = { + OutgoingBlindedPerHopPayload(TlvStream(Seq(EncryptedRecipientData(encryptedRecipientData)))) + } + + def createFinalPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, encryptedRecipientData: ByteVector, customTlvs: Seq[GenericTlv] = Nil): OutgoingBlindedPerHopPayload = { + OutgoingBlindedPerHopPayload(TlvStream(Seq(AmountToForward(amount), TotalAmount(totalAmount), OutgoingCltv(expiry), EncryptedRecipientData(encryptedRecipientData)), customTlvs)) + } + + def createFinalIntroductionPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, blinding: PublicKey, encryptedRecipientData: ByteVector, customTlvs: Seq[GenericTlv] = Nil): OutgoingBlindedPerHopPayload = { + OutgoingBlindedPerHopPayload(TlvStream(Seq(AmountToForward(amount), TotalAmount(totalAmount), OutgoingCltv(expiry), EncryptedRecipientData(encryptedRecipientData), BlindingPoint(blinding)), customTlvs)) + } + } + } object PaymentOnionCodecs { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala index cbe392ef63..7fa8c9ed0e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala @@ -134,7 +134,7 @@ object RouteBlindingEncryptedDataCodecs { // @formatter:off case class RouteBlindingDecryptedData(tlvs: TlvStream[RouteBlindingEncryptedDataTlv], nextBlinding: PublicKey) - sealed trait InvalidEncryptedData + sealed trait InvalidEncryptedData { def message: String } case class CannotDecryptData(message: String) extends InvalidEncryptedData case class CannotDecodeData(message: String) extends InvalidEncryptedData // @formatter:on diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala index d06c59de45..c805309591 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala @@ -121,7 +121,7 @@ class FuzzySpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Channe // allow overpaying (no more than 2 times the required amount) val amount = requiredAmount + Random.nextInt(requiredAmount.toLong.toInt).msat val expiry = (Channel.MIN_CLTV_EXPIRY_DELTA + 1).toCltvExpiry(currentBlockHeight = BlockHeight(400000)) - val Right(payment) = OutgoingPaymentPacket.buildOutgoingPayment(self, Upstream.Local(UUID.randomUUID()), invoice.paymentHash, makeSingleHopRoute(amount, invoice.nodeId), ClearRecipient(invoice, amount, expiry, Nil)) + val Right(payment) = OutgoingPaymentPacket.buildOutgoingPayment(self, randomKey(), Upstream.Local(UUID.randomUUID()), invoice.paymentHash, makeSingleHopRoute(amount, invoice.nodeId), ClearRecipient(invoice, amount, expiry, Nil)) payment.cmd } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala index 10bcc52e1d..8266c0d2f4 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala @@ -359,7 +359,7 @@ trait ChannelStateTestsBase extends Assertions with Eventually { val paymentHash = Crypto.sha256(paymentPreimage) val expiry = cltvExpiryDelta.toCltvExpiry(currentBlockHeight) val recipient = SpontaneousRecipient(destination, amount, expiry, paymentPreimage) - val Right(payment) = OutgoingPaymentPacket.buildOutgoingPayment(replyTo, upstream, paymentHash, makeSingleHopRoute(amount, destination), recipient) + val Right(payment) = OutgoingPaymentPacket.buildOutgoingPayment(replyTo, randomKey(), upstream, paymentHash, makeSingleHopRoute(amount, destination), recipient) (paymentPreimage, payment.cmd.copy(commit = false)) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala index b473943559..b9f06744b5 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala @@ -59,7 +59,7 @@ class ShutdownStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wit // alice sends an HTLC to bob val h1 = Crypto.sha256(r1) val recipient1 = SpontaneousRecipient(TestConstants.Bob.nodeParams.nodeId, 300_000_000 msat, CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight), r1) - val Right(cmd1) = OutgoingPaymentPacket.buildOutgoingPayment(sender.ref, Upstream.Local(UUID.randomUUID), h1, makeSingleHopRoute(recipient1.totalAmount, recipient1.nodeId), recipient1).map(_.cmd.copy(commit = false)) + val Right(cmd1) = OutgoingPaymentPacket.buildOutgoingPayment(sender.ref, TestConstants.Alice.nodeParams.privateKey, Upstream.Local(UUID.randomUUID), h1, makeSingleHopRoute(recipient1.totalAmount, recipient1.nodeId), recipient1).map(_.cmd.copy(commit = false)) alice ! cmd1 sender.expectMsgType[RES_SUCCESS[CMD_ADD_HTLC]] val htlc1 = alice2bob.expectMsgType[UpdateAddHtlc] @@ -68,7 +68,7 @@ class ShutdownStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wit // alice sends another HTLC to bob val h2 = Crypto.sha256(r2) val recipient2 = SpontaneousRecipient(TestConstants.Bob.nodeParams.nodeId, 200_000_000 msat, CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight), r2) - val Right(cmd2) = OutgoingPaymentPacket.buildOutgoingPayment(sender.ref, Upstream.Local(UUID.randomUUID), h2, makeSingleHopRoute(recipient2.totalAmount, recipient2.nodeId), recipient2).map(_.cmd.copy(commit = false)) + val Right(cmd2) = OutgoingPaymentPacket.buildOutgoingPayment(sender.ref, TestConstants.Alice.nodeParams.privateKey, Upstream.Local(UUID.randomUUID), h2, makeSingleHopRoute(recipient2.totalAmount, recipient2.nodeId), recipient2).map(_.cmd.copy(commit = false)) alice ! cmd2 sender.expectMsgType[RES_SUCCESS[CMD_ADD_HTLC]] val htlc2 = alice2bob.expectMsgType[UpdateAddHtlc] diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/json/JsonSerializersSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/json/JsonSerializersSpec.scala index ed6eb2c56c..39d2ef757d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/json/JsonSerializersSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/json/JsonSerializersSpec.scala @@ -191,9 +191,9 @@ class JsonSerializersSpec extends AnyFunSuite with Matchers { } test("Bolt 12 invoice") { - val ref = "lni1qvsyxjtl6luzd9t3pr62xr7eemp6awnejusgf6gw45q75vcfqqqqqqqyyz9ut9uduhtztjgpxm06394g5qkw7v79g4czw6zxsl3lnrsvljj0qzqrq83yqzscd9h8vmmfvdjjqamfw35zqmtpdeujqenfv4kxgucvqsqsqqgqzzsq83r88cmqur4u9xxgmykeuhmf4stsrsl8h9keu8q9pf5wsq69zaauq2vu8z4tg6zltltxekud2a59r9l8j7uy4ydnfe752tqz7a8g03fcvqgzw2q6uut0ldazm8squqsmq3gls2ut7gzcpfty8zh0d5vmf30dlyssqwh4qdsxe8n3d9d38e8g85r966yjqr7xuljey2yxar05x3mx46qz8qhef7n5hg0fxx7phe7jtag5evuj28rsas80gglufwf3y8qqqqyjjqqqqt7sz3qqqqqqqqqqq05qqqqqqqqqrcjqqqqpgptpd35kxeg7yypugee7xc8qa0pf3jxe9k0976dvzuqu8eaedk0pcpg2dr5qx3gh00pqqyujvgy04twhrv0h3vtzvhjmqcde62ug3ygp9hr66wrzdm425238zc26v5nswjf8d5syymmz9qzx9gqxhc4zq5v4r4x98jgyqd0sk2fae803crnevusngv9wq7jl8cf5e5eny56p9spquypwqgq8kvq5quzqqpqj84t0szcxqqywkwkudz29aaspxgx2n6mw2fh2ckwdnwylkgpcypc66qe7tapqdq39vzrhp7nkwfg9gj2ztk658g0ecgalqdjh4gmuruzqqsv99asjssw802xu2ls93fzzl64c5jehhxvp0m7au9klpg4pnedtwl0zgw648gwtalak3hr5uxxl5qzp7txlud9y7qfk038cq4pusj8aqymsxqgzq07szwgq" + val ref = "lni1qvsxlc5vp2m0rvmjcxn2y34wv0m5lyc7sdj7zksgn35dvxgqqqqqqqqyypmpsc7ww3cxguwl27ela95ykset7t8tlvyfy7a200eujcnhczws6zqyrvhqd5s2p4kkjmnfd4skcgr0venx2ussmvpexwyy4tcadvgg89l9aljus6709kx235hhqrk6n8dey98uyuftzdqrshx2mcmnvj7gxa709vhgcrqr7hcdp7l7x9t2au8dj8tjreqyfvuqyqkp4czrrxpn3hdrdqu8k3teynrl4nf977deq2ja53zkefmsr0nyjgqp3dtytq67durd2jjupmnjdtvvlsuw7lsm6tvcyrrqx7pwaunazjuqn2uk9gvpzj0aeagku2h2wv8vcfmfekflxmfmgu0kqqa94fuhuhyn6jxefk7k5rfgejltw7cwa4fhjl67trweukvsw4wqq7ec790vdjar7qwtsnlfrq84fmq02ah49fvnnsj06j5wzgwqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqm9crdyqqqrcss83y2e9lqnu7tht4ntvp24fksw26hwf5yrg6dyk2jz472efs2rjh4ycsra98j4l2k35fg7qhvapz26js2rh0j5n36pzlt9kaprvl3zd9s29egq33khv3m9gszev88kpfrveu8g5xr8khk6tev8jmpxg3pxfhpcx6f4jtlm4ltwgpwqgqpduzqyec5qspkg0zqq788krw2w2kstvsz3dekms304ykkh395zl6chm34vdu03yuvgwzm0580zu6sp2f07uwa4crgvkgucd8zdpt5vu302nc" val pr = Invoice.fromString(ref).get - JsonSerializers.serialization.write(pr)(JsonSerializers.formats) shouldBe """{"amount":123456,"nodeId":"03c4673e360e0ebc298c8d92d9e5f69ac1701c3e7b96d9e1c050a68e80345177bc","paymentHash":"51951d4c53c904035f0b293dc9df1c0e7967213430ae07a5f3e134cd33325341","description":"invoice with many fields","features":{"activated":{"var_onion_optin":"mandatory","option_route_blinding":"mandatory"},"unknown":[]},"blindedPaths":[{"introductionNodeId":"03c4673e360e0ebc298c8d92d9e5f69ac1701c3e7b96d9e1c050a68e80345177bc","blindedNodeIds":["027281ae716ffb7a2d9e00e021b0451f82b8bf20580a56438aef6d19b4c5edf921"]}],"createdAt":1654654654,"expiresAt":1654658254,"serialized":"lni1qvsyxjtl6luzd9t3pr62xr7eemp6awnejusgf6gw45q75vcfqqqqqqqyyz9ut9uduhtztjgpxm06394g5qkw7v79g4czw6zxsl3lnrsvljj0qzqrq83yqzscd9h8vmmfvdjjqamfw35zqmtpdeujqenfv4kxgucvqsqsqqgqzzsq83r88cmqur4u9xxgmykeuhmf4stsrsl8h9keu8q9pf5wsq69zaauq2vu8z4tg6zltltxekud2a59r9l8j7uy4ydnfe752tqz7a8g03fcvqgzw2q6uut0ldazm8squqsmq3gls2ut7gzcpfty8zh0d5vmf30dlyssqwh4qdsxe8n3d9d38e8g85r966yjqr7xuljey2yxar05x3mx46qz8qhef7n5hg0fxx7phe7jtag5evuj28rsas80gglufwf3y8qqqqyjjqqqqt7sz3qqqqqqqqqqq05qqqqqqqqqrcjqqqqpgptpd35kxeg7yypugee7xc8qa0pf3jxe9k0976dvzuqu8eaedk0pcpg2dr5qx3gh00pqqyujvgy04twhrv0h3vtzvhjmqcde62ug3ygp9hr66wrzdm425238zc26v5nswjf8d5syymmz9qzx9gqxhc4zq5v4r4x98jgyqd0sk2fae803crnevusngv9wq7jl8cf5e5eny56p9spquypwqgq8kvq5quzqqpqj84t0szcxqqywkwkudz29aaspxgx2n6mw2fh2ckwdnwylkgpcypc66qe7tapqdq39vzrhp7nkwfg9gj2ztk658g0ecgalqdjh4gmuruzqqsv99asjssw802xu2ls93fzzl64c5jehhxvp0m7au9klpg4pnedtwl0zgw648gwtalak3hr5uxxl5qzp7txlud9y7qfk038cq4pusj8aqymsxqgzq07szwgq"}""" + JsonSerializers.serialization.write(pr)(JsonSerializers.formats) shouldBe """{"amount":456001234,"nodeId":"03c48ac97e09f3cbbaeb35b02aaa6d072b57726841a34d25952157caca60a1caf5","paymentHash":"2cb0e7b052366787450c33daf6d2f2c3cb6132221326e1c1b49ac97fdd7eb720","description":"minimal offer","features":{"activated":{},"unknown":[]},"blindedPaths":[{"introductionNodeId":"03933884aaf1d6b108397e5efe5c86bcf2d8ca8d2f700eda99db9214fc2712b134","blindedNodeIds":["02c1ae043198338dda368387b457924c7facd25f79b902a5da4456ca7701be6492","03782eef27d14b809ab962a181149fdcf516e2aea730ecc2769cd93f36d3b471f6"]}],"createdAt":1668002363,"expiresAt":1668009563,"serialized":"lni1qvsxlc5vp2m0rvmjcxn2y34wv0m5lyc7sdj7zksgn35dvxgqqqqqqqqyypmpsc7ww3cxguwl27ela95ykset7t8tlvyfy7a200eujcnhczws6zqyrvhqd5s2p4kkjmnfd4skcgr0venx2ussmvpexwyy4tcadvgg89l9aljus6709kx235hhqrk6n8dey98uyuftzdqrshx2mcmnvj7gxa709vhgcrqr7hcdp7l7x9t2au8dj8tjreqyfvuqyqkp4czrrxpn3hdrdqu8k3teynrl4nf977deq2ja53zkefmsr0nyjgqp3dtytq67durd2jjupmnjdtvvlsuw7lsm6tvcyrrqx7pwaunazjuqn2uk9gvpzj0aeagku2h2wv8vcfmfekflxmfmgu0kqqa94fuhuhyn6jxefk7k5rfgejltw7cwa4fhjl67trweukvsw4wqq7ec790vdjar7qwtsnlfrq84fmq02ah49fvnnsj06j5wzgwqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqm9crdyqqqrcss83y2e9lqnu7tht4ntvp24fksw26hwf5yrg6dyk2jz472efs2rjh4ycsra98j4l2k35fg7qhvapz26js2rh0j5n36pzlt9kaprvl3zd9s29egq33khv3m9gszev88kpfrveu8g5xr8khk6tev8jmpxg3pxfhpcx6f4jtlm4ltwgpwqgqpduzqyec5qspkg0zqq788krw2w2kstvsz3dekms304ykkh395zl6chm34vdu03yuvgwzm0580zu6sp2f07uwa4crgvkgucd8zdpt5vu302nc"}""" } test("GlobalBalance serializer") { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala index d19ef2b605..030349f3ee 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala @@ -225,13 +225,9 @@ class OnionMessagesSpec extends AnyFunSuite { val carol = randomKey() val sessionKey = randomKey() val blindingSecret = randomKey() - val pathId = randomBytes(65201) - val Success((_, messageForAlice)) = buildMessage(sessionKey, blindingSecret, IntermediateNode(alice.publicKey) :: IntermediateNode(bob.publicKey) :: Nil, Recipient(carol.publicKey, Some(pathId)), Nil) - println(messageForAlice.onionRoutingPacket.payload.length) - // Checking that the onion is relayed properly process(alice, messageForAlice) match { case SendMessage(nextNodeId, onionForBob) => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala index b7c811d4b7..3b12de42e8 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala @@ -22,13 +22,12 @@ import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, Crypto} import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.Features.{BasicMultiPartPayment, VariableLengthOnion} import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute import fr.acinq.eclair.payment.Bolt12Invoice.{hrp, signatureTag} import fr.acinq.eclair.wire.protocol.OfferCodecs.{invoiceRequestTlvCodec, invoiceTlvCodec} import fr.acinq.eclair.wire.protocol.OfferTypes._ import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec -import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{AllowedFeatures, PaymentConstraints} -import fr.acinq.eclair.wire.protocol.{GenericTlv, OfferTypes, RouteBlindingEncryptedDataTlv, TlvStream} +import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{AllowedFeatures, PathId, PaymentConstraints} +import fr.acinq.eclair.wire.protocol.{GenericTlv, OfferTypes, TlvStream} import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, FeatureSupport, Features, MilliSatoshiLong, TimestampSecond, TimestampSecondLong, UInt64, randomBytes32, randomBytes64, randomKey} import org.scalatest.funsuite.AnyFunSuite import scodec.bits._ @@ -51,7 +50,7 @@ class Bolt12InvoiceSpec extends AnyFunSuite { } def createPaymentBlindedRoute(nodeId: PublicKey, sessionKey: PrivateKey = randomKey(), pathId: ByteVector = randomBytes32()): PaymentBlindedRoute = { - val selfPayload = blindedRouteDataCodec.encode(TlvStream(Seq(RouteBlindingEncryptedDataTlv.PathId(pathId), PaymentConstraints(CltvExpiry(1234567), 0 msat), AllowedFeatures(Features.empty)))).require.bytes + val selfPayload = blindedRouteDataCodec.encode(TlvStream(Seq(PathId(pathId), PaymentConstraints(CltvExpiry(1234567), 0 msat), AllowedFeatures(Features.empty)))).require.bytes PaymentBlindedRoute(Sphinx.RouteBlinding.create(sessionKey, Seq(nodeId), Seq(selfPayload)).route, PaymentInfo(1 msat, 2, CltvExpiryDelta(3), 4 msat, 5 msat, Features.empty)) } @@ -274,18 +273,18 @@ class Bolt12InvoiceSpec extends AnyFunSuite { PaymentHash(randomBytes32()), ) // This minimal invoice is valid. - val signed = signInvoiceTlvs(TlvStream[InvoiceTlv](tlvs), nodeKey) + val signed = signInvoiceTlvs(TlvStream(tlvs), nodeKey) val signedEncoded = Bech32.encodeBytes(hrp, invoiceTlvCodec.encode(signed).require.bytes.toArray, Bech32.Encoding.Beck32WithoutChecksum) assert(Bolt12Invoice.fromString(signedEncoded).isSuccess) // But removing any TLV makes it invalid. for (tlv <- tlvs) { val incomplete = tlvs.filterNot(_ == tlv) - val incompleteSigned = signInvoiceTlvs(TlvStream[InvoiceTlv](incomplete), nodeKey) + val incompleteSigned = signInvoiceTlvs(TlvStream(incomplete), nodeKey) val incompleteSignedEncoded = Bech32.encodeBytes(hrp, invoiceTlvCodec.encode(incompleteSigned).require.bytes.toArray, Bech32.Encoding.Beck32WithoutChecksum) assert(Bolt12Invoice.fromString(incompleteSignedEncoded).isFailure) } // Missing signature is also invalid. - val unsignedEncoded = Bech32.encodeBytes(hrp, invoiceTlvCodec.encode(TlvStream[InvoiceTlv](tlvs)).require.bytes.toArray, Bech32.Encoding.Beck32WithoutChecksum) + val unsignedEncoded = Bech32.encodeBytes(hrp, invoiceTlvCodec.encode(TlvStream(tlvs)).require.bytes.toArray, Bech32.Encoding.Beck32WithoutChecksum) assert(Bolt12Invoice.fromString(unsignedEncoded).isFailure) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala index 1ea2063575..ae9707958f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala @@ -32,7 +32,7 @@ import fr.acinq.eclair.payment.send.PaymentError.RetryExhausted import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToRoute import fr.acinq.eclair.payment.send._ -import fr.acinq.eclair.router.BaseRouterSpec.channelHopFromUpdate +import fr.acinq.eclair.router.BaseRouterSpec.{blindedRouteFromHops, channelHopFromUpdate} import fr.acinq.eclair.router.Graph.WeightRatios import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.router.{Announcements, RouteNotFound} @@ -178,6 +178,33 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS metricsListener.expectNoMessage() } + test("successful first attempt (blinded)") { f => + import f._ + + assert(payFsm.stateName == WAIT_FOR_PAYMENT_REQUEST) + val (_, hop_be, recipient) = blindedRouteFromHops(finalAmount, expiry, Seq(channelHopFromUpdate(b, e, channelUpdate_be)), blindedRouteExpiry, paymentPreimage) + val payment = SendMultiPartPayment(sender.ref, recipient, 1, routeParams) + sender.send(payFsm, payment) + + router.expectMsg(RouteRequest(nodeParams.nodeId, recipient, routeParams.copy(randomize = false), allowMultiPart = true, paymentContext = Some(cfg.paymentContext))) + assert(payFsm.stateName == WAIT_FOR_ROUTES) + + val routes = Seq( + Route(600_000 msat, Seq(hop_ab_1), Some(hop_be)), + Route(400_000 msat, Seq(hop_ab_2), Some(hop_be)), + ) + router.send(payFsm, RouteResponse(routes)) + val childPayments = childPayFsm.expectMsgType[SendPaymentToRoute] :: childPayFsm.expectMsgType[SendPaymentToRoute] :: Nil + assert(childPayments.map(_.route).toSet == routes.map(r => Right(r)).toSet) + childPayments.foreach(childPayment => assert(childPayment.recipient == recipient)) + assert(childPayments.map(_.amount).toSet == Set(400_000 msat, 600_000 msat)) + assert(payFsm.stateName == PAYMENT_IN_PROGRESS) + + val result = fulfillPendingPayments(f, 2, recipient.nodeId, finalAmount) + assert(result.amountWithFees == 1_000_200.msat) + assert(result.nonTrampolineFees == 200.msat) + } + test("successful retry") { f => import f._ @@ -190,7 +217,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.expectNoMessage(100 millis) val childId = payFsm.stateData.asInstanceOf[PaymentProgress].pending.keys.head - childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(failingRoute.amount, failingRoute.hops, Sphinx.DecryptedFailurePacket(b, PermanentChannelFailure))))) + childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(failingRoute.amount, failingRoute.fullRoute, Sphinx.DecryptedFailurePacket(b, PermanentChannelFailure))))) // We retry ignoring the failing channel. router.expectMsg(RouteRequest(nodeParams.nodeId, clearRecipient, routeParams.copy(randomize = true), allowMultiPart = true, ignore = Ignore(Set.empty, Set(ChannelDesc(channelId_be, b, e))), paymentContext = Some(cfg.paymentContext))) router.send(payFsm, RouteResponse(Seq(Route(400_000 msat, hop_ac_1 :: hop_ce :: Nil, None), Route(600_000 msat, hop_ad :: hop_de :: Nil, None)))) @@ -264,7 +291,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.expectNoMessage(100 millis) val (failedId, failedRoute) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toSeq - childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(LocalFailure(failedRoute.amount, failedRoute.hops, RemoteCannotAffordFeesForNewHtlc(randomBytes32(), finalAmount, 15 sat, 0 sat, 15 sat))))) + childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(LocalFailure(failedRoute.amount, failedRoute.fullRoute, RemoteCannotAffordFeesForNewHtlc(randomBytes32(), finalAmount, 15 sat, 0 sat, 15 sat))))) // We retry without the failing channel. router.expectMsg(RouteRequest( @@ -290,7 +317,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.expectNoMessage(100 millis) val (failedId, failedRoute) :: (_, pendingRoute) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toSeq - childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(LocalFailure(failedRoute.amount, failedRoute.hops, ChannelUnavailable(randomBytes32()))))) + childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(LocalFailure(failedRoute.amount, failedRoute.fullRoute, ChannelUnavailable(randomBytes32()))))) // If the router doesn't find routes, we will retry without ignoring the channel: it may work with a different split // of the amount to send. @@ -337,7 +364,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS // B changed his fees and expiry after the invoice was issued. val channelUpdate = channelUpdate_be.copy(feeBaseMsat = 250 msat, feeProportionalMillionths = 150, cltvExpiryDelta = CltvExpiryDelta(24)) val childId = payFsm.stateData.asInstanceOf[PaymentProgress].pending.keys.head - childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(route.amount, route.hops, Sphinx.DecryptedFailurePacket(b, FeeInsufficient(finalAmount, channelUpdate)))))) + childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(route.amount, route.fullRoute, Sphinx.DecryptedFailurePacket(b, FeeInsufficient(finalAmount, channelUpdate)))))) // We update the routing hints accordingly before requesting a new route. val extraEdge1 = extraEdge.copy(feeBase = 250 msat, feeProportionalMillionths = 150, cltvExpiryDelta = CltvExpiryDelta(24)) assert(router.expectMsgType[RouteRequest].target.extraEdges == Seq(extraEdge1)) @@ -361,13 +388,33 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS // NB: we need a channel update with a valid signature, otherwise we'll ignore the node instead of this specific channel. val channelUpdate = Announcements.makeChannelUpdate(channelUpdate_be.chainHash, priv_b, e, channelUpdate_be.shortChannelId, channelUpdate_be.cltvExpiryDelta, channelUpdate_be.htlcMinimumMsat, channelUpdate_be.feeBaseMsat, channelUpdate_be.feeProportionalMillionths, channelUpdate_be.htlcMaximumMsat) val childId = payFsm.stateData.asInstanceOf[PaymentProgress].pending.keys.head - childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(route.amount, route.hops, Sphinx.DecryptedFailurePacket(b, TemporaryChannelFailure(channelUpdate)))))) + childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(route.amount, route.fullRoute, Sphinx.DecryptedFailurePacket(b, TemporaryChannelFailure(channelUpdate)))))) // We update the routing hints accordingly before requesting a new route and ignore the channel. val routeRequest = router.expectMsgType[RouteRequest] assert(routeRequest.target.extraEdges == Seq(extraEdge)) assert(routeRequest.ignore.channels.map(_.shortChannelId) == Set(channelUpdate.shortChannelId)) } + test("retry with ignored blinded route") { f => + import f._ + + val (_, hop_be, recipient) = blindedRouteFromHops(finalAmount, expiry, Seq(channelHopFromUpdate(b, e, channelUpdate_be)), blindedRouteExpiry, paymentPreimage) + val payment = SendMultiPartPayment(sender.ref, recipient, 3, routeParams) + sender.send(payFsm, payment) + assert(router.expectMsgType[RouteRequest].target.extraEdges == recipient.extraEdges) + val route = Route(finalAmount, Seq(hop_ab_1), Some(hop_be)) + router.send(payFsm, RouteResponse(Seq(route))) + childPayFsm.expectMsgType[SendPaymentToRoute] + childPayFsm.expectNoMessage(100 millis) + + // The blinded route fails to relay the payment. + val childId = payFsm.stateData.asInstanceOf[PaymentProgress].pending.keys.head + childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(route.amount, route.fullRoute, Sphinx.DecryptedFailurePacket(b, InvalidOnionBlinding(randomBytes32())))))) + // We retry and ignore that blinded route. + val routeRequest = router.expectMsgType[RouteRequest] + assert(routeRequest.ignore.channels.map(_.shortChannelId) == Set(hop_be.dummyId)) + } + test("update routing hints") { () => val recipient = ClearRecipient(e, Features.empty, finalAmount, expiry, randomBytes32(), Seq( ExtraEdge(a, b, ShortChannelId(1), 10 msat, 0, CltvExpiryDelta(12), 1 msat, None), @@ -479,7 +526,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.expectMsgType[SendPaymentToRoute] val (failedId, failedRoute) = payFsm.stateData.asInstanceOf[PaymentProgress].pending.head - val result = abortAfterFailure(f, PaymentFailed(failedId, paymentHash, Seq(RemoteFailure(failedRoute.amount, failedRoute.hops, Sphinx.DecryptedFailurePacket(e, IncorrectOrUnknownPaymentDetails(600_000 msat, BlockHeight(0))))))) + val result = abortAfterFailure(f, PaymentFailed(failedId, paymentHash, Seq(RemoteFailure(failedRoute.amount, failedRoute.fullRoute, Sphinx.DecryptedFailurePacket(e, IncorrectOrUnknownPaymentDetails(600_000 msat, BlockHeight(0))))))) assert(result.failures.length == 1) val metrics = metricsListener.expectMsgType[PathFindingExperimentMetrics] @@ -500,7 +547,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.expectMsgType[SendPaymentToRoute] val (failedId, failedRoute) = payFsm.stateData.asInstanceOf[PaymentProgress].pending.head - val result = abortAfterFailure(f, PaymentFailed(failedId, paymentHash, Seq(LocalFailure(failedRoute.amount, failedRoute.hops, HtlcsTimedoutDownstream(channelId = ByteVector32.One, htlcs = Set.empty))))) + val result = abortAfterFailure(f, PaymentFailed(failedId, paymentHash, Seq(LocalFailure(failedRoute.amount, failedRoute.fullRoute, HtlcsTimedoutDownstream(channelId = ByteVector32.One, htlcs = Set.empty))))) assert(result.failures.length == 1) } @@ -533,7 +580,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.expectMsgType[SendPaymentToRoute] val (failedId, failedRoute) :: (successId, successRoute) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toSeq - childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(UnreadableRemoteFailure(failedRoute.amount, failedRoute.hops)))) + childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(UnreadableRemoteFailure(failedRoute.amount, failedRoute.fullRoute)))) router.expectMsgType[RouteRequest] val result = fulfillPendingPayments(f, 1, e, finalAmount) @@ -553,11 +600,11 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.expectMsgType[SendPaymentToRoute] val (failedId, failedRoute) :: (successId, successRoute) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toSeq - childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(RemoteFailure(failedRoute.amount, failedRoute.hops, Sphinx.DecryptedFailurePacket(e, PaymentTimeout))))) + childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(RemoteFailure(failedRoute.amount, failedRoute.fullRoute, Sphinx.DecryptedFailurePacket(e, PaymentTimeout))))) awaitCond(payFsm.stateName == PAYMENT_ABORTED) sender.watch(payFsm) - childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(PaymentSent.PartialPayment(successId, successRoute.amount, successRoute.channelFee(false), randomBytes32(), Some(successRoute.hops))))) + childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(PaymentSent.PartialPayment(successId, successRoute.amount, successRoute.channelFee(false), randomBytes32(), Some(successRoute.fullRoute))))) sender.expectMsg(PreimageReceived(paymentHash, paymentPreimage)) val result = sender.expectMsgType[PaymentSent] assert(result.id == cfg.id) @@ -586,12 +633,12 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS childPayFsm.expectMsgType[SendPaymentToRoute] val (childId, route) :: (failedId, failedRoute) :: Nil = payFsm.stateData.asInstanceOf[PaymentProgress].pending.toSeq - childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(PaymentSent.PartialPayment(childId, route.amount, route.channelFee(false), randomBytes32(), Some(route.hops))))) + childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(PaymentSent.PartialPayment(childId, route.amount, route.channelFee(false), randomBytes32(), Some(route.fullRoute))))) sender.expectMsg(PreimageReceived(paymentHash, paymentPreimage)) awaitCond(payFsm.stateName == PAYMENT_SUCCEEDED) sender.watch(payFsm) - childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(RemoteFailure(failedRoute.amount, failedRoute.hops, Sphinx.DecryptedFailurePacket(e, PaymentTimeout))))) + childPayFsm.send(payFsm, PaymentFailed(failedId, paymentHash, Seq(RemoteFailure(failedRoute.amount, failedRoute.fullRoute, Sphinx.DecryptedFailurePacket(e, PaymentTimeout))))) val result = sender.expectMsgType[PaymentSent] assert(result.parts.length == 1 && result.parts.head.id == childId) assert(result.amountWithFees < finalAmount) // we got the preimage without paying the full amount @@ -611,7 +658,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS assert(pending.size == childCount) val partialPayments = pending.map { - case (childId, route) => PaymentSent.PartialPayment(childId, route.amount, route.channelFee(false), randomBytes32(), Some(route.hops)) + case (childId, route) => PaymentSent.PartialPayment(childId, route.amount, route.channelFee(false) + route.blindedFee, randomBytes32(), Some(route.fullRoute)) } partialPayments.foreach(pp => childPayFsm.send(payFsm, PaymentSent(cfg.id, paymentHash, paymentPreimage, finalAmount, e, Seq(pp)))) sender.expectMsg(PreimageReceived(paymentHash, paymentPreimage)) @@ -666,6 +713,7 @@ object MultiPartPaymentLifecycleSpec { val paymentPreimage = randomBytes32() val paymentHash = Crypto.sha256(paymentPreimage) val expiry = CltvExpiry(1105) + val blindedRouteExpiry = CltvExpiry(500_000) val finalAmount = 1_000_000 msat val routeParams = PathFindingConf( randomize = false, diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala index 7627e47c6d..3c08a631ac 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala @@ -32,8 +32,9 @@ import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.SendMultiPartPayme import fr.acinq.eclair.payment.send.PaymentError.UnsupportedFeatures import fr.acinq.eclair.payment.send.PaymentInitiator._ import fr.acinq.eclair.payment.send._ -import fr.acinq.eclair.router.RouteNotFound import fr.acinq.eclair.router.Router._ +import fr.acinq.eclair.router.{BlindedRouteCreation, RouteNotFound} +import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer} import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, InvoiceFeature, MilliSatoshiLong, NodeParams, PaymentFinalExpiryConf, TestConstants, TestKitBaseClass, TimestampSecond, UnknownFeature, randomBytes32, randomKey} import org.scalatest.funsuite.FixtureAnyFunSuiteLike @@ -51,6 +52,7 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike object Tags { val DisableMPP = "mpp_disabled" + val DisableRouteBlinding = "route_blinding_disabled" val RandomizeFinalExpiry = "random_final_expiry" } @@ -58,22 +60,31 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val featuresWithoutMpp: Features[InvoiceFeature] = Features( VariableLengthOnion -> Mandatory, - PaymentSecret -> Mandatory + PaymentSecret -> Mandatory, + RouteBlinding -> Optional, ) val featuresWithMpp: Features[InvoiceFeature] = Features( VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, + RouteBlinding -> Optional, ) val featuresWithTrampoline: Features[InvoiceFeature] = Features( VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, + RouteBlinding -> Optional, TrampolinePaymentPrototype -> Optional, ) + val featuresWithoutRouteBlinding: Features[InvoiceFeature] = Features( + VariableLengthOnion -> Mandatory, + PaymentSecret -> Mandatory, + BasicMultiPartPayment -> Optional, + ) + case class FakePaymentFactory(payFsm: TestProbe, multiPartPayFsm: TestProbe) extends PaymentInitiator.MultiPartPaymentFactory { // @formatter:off override def spawnOutgoingPayment(context: ActorContext, cfg: SendPaymentConfig): ActorRef = { @@ -88,7 +99,13 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike } override def withFixture(test: OneArgTest): Outcome = { - val features = if (test.tags.contains(Tags.DisableMPP)) featuresWithoutMpp else featuresWithMpp + val features = if (test.tags.contains(Tags.DisableMPP)) { + featuresWithoutMpp + } else if (test.tags.contains(Tags.DisableRouteBlinding)) { + featuresWithoutRouteBlinding + } else { + featuresWithMpp + } val paymentFinalExpiry = if (test.tags.contains(Tags.RandomizeFinalExpiry)) { PaymentFinalExpiryConf(CltvExpiryDelta(50), CltvExpiryDelta(200)) } else { @@ -276,6 +293,82 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike sender.expectMsg(NoPendingPayment(Right(invoice.paymentHash))) } + def createBolt12Invoice(features: Features[InvoiceFeature]): Bolt12Invoice = { + val offer = Offer(None, "Bolt12 r0cks", e, features, Block.RegtestGenesisBlock.hash) + val invoiceRequest = InvoiceRequest(offer, finalAmount, 1, features, randomKey(), Block.RegtestGenesisBlock.hash) + val blindedRoute = BlindedRouteCreation.createBlindedRouteWithoutHops(e, hex"2a2a2a2a", 1 msat, CltvExpiry(500_000)).route + val paymentInfo = OfferTypes.PaymentInfo(1_000 msat, 0, CltvExpiryDelta(24), 0 msat, finalAmount, Features.empty) + Bolt12Invoice(offer, invoiceRequest, paymentPreimage, priv_e.privateKey, CltvExpiryDelta(6), features, Seq(PaymentBlindedRoute(blindedRoute, paymentInfo))) + } + + test("forward single-part blinded payment") { f => + import f._ + val invoice = createBolt12Invoice(Features(VariableLengthOnion -> Mandatory, RouteBlinding -> Mandatory)) + val req = SendPaymentToNode(finalAmount, invoice, 1, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams) + sender.send(initiator, req) + val id = sender.expectMsgType[UUID] + payFsm.expectMsg(SendPaymentConfig(id, id, None, paymentHash, invoice.nodeId, Upstream.Local(id), Some(invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = true)) + val payment = payFsm.expectMsgType[PaymentLifecycle.SendPaymentToNode] + assert(payment.amount == finalAmount) + assert(payment.recipient.nodeId == invoice.nodeId) + assert(payment.recipient.totalAmount == finalAmount) + assert(payment.recipient.extraEdges.nonEmpty) + assert(payment.recipient.expiry == req.finalExpiry(nodeParams)) + assert(payment.recipient.isInstanceOf[BlindedRecipient]) + + sender.send(initiator, GetPayment(Left(id))) + sender.expectMsg(PaymentIsPending(id, invoice.paymentHash, PendingPaymentToNode(sender.ref, req))) + sender.send(initiator, GetPayment(Right(invoice.paymentHash))) + sender.expectMsg(PaymentIsPending(id, invoice.paymentHash, PendingPaymentToNode(sender.ref, req))) + + val pf = PaymentFailed(id, invoice.paymentHash, Nil) + payFsm.send(initiator, pf) + sender.expectMsg(pf) + eventListener.expectNoMessage(100 millis) + + sender.send(initiator, GetPayment(Left(id))) + sender.expectMsg(NoPendingPayment(Left(id))) + } + + test("forward multi-part blinded payment") { f => + import f._ + val invoice = createBolt12Invoice(Features(VariableLengthOnion -> Mandatory, BasicMultiPartPayment -> Optional, RouteBlinding -> Mandatory)) + val req = SendPaymentToNode(finalAmount, invoice, 1, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams) + sender.send(initiator, req) + val id = sender.expectMsgType[UUID] + multiPartPayFsm.expectMsg(SendPaymentConfig(id, id, None, paymentHash, invoice.nodeId, Upstream.Local(id), Some(invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = true)) + val payment = multiPartPayFsm.expectMsgType[SendMultiPartPayment] + assert(payment.recipient.nodeId == invoice.nodeId) + assert(payment.recipient.totalAmount == finalAmount) + assert(payment.recipient.extraEdges.nonEmpty) + assert(payment.recipient.expiry == req.finalExpiry(nodeParams)) + assert(payment.recipient.isInstanceOf[BlindedRecipient]) + + sender.send(initiator, GetPayment(Left(id))) + sender.expectMsg(PaymentIsPending(id, invoice.paymentHash, PendingPaymentToNode(sender.ref, req))) + sender.send(initiator, GetPayment(Right(invoice.paymentHash))) + sender.expectMsg(PaymentIsPending(id, invoice.paymentHash, PendingPaymentToNode(sender.ref, req))) + + val ps = PaymentSent(id, invoice.paymentHash, paymentPreimage, finalAmount, invoice.nodeId, Seq(PartialPayment(UUID.randomUUID(), finalAmount, 0 msat, randomBytes32(), None))) + payFsm.send(initiator, ps) + sender.expectMsg(ps) + eventListener.expectNoMessage(100 millis) + + sender.send(initiator, GetPayment(Left(id))) + sender.expectMsg(NoPendingPayment(Left(id))) + } + + test("reject blinded payment when route blinding deactivated", Tag(Tags.DisableRouteBlinding)) { f => + import f._ + val invoice = createBolt12Invoice(Features(VariableLengthOnion -> Mandatory, BasicMultiPartPayment -> Optional, RouteBlinding -> Mandatory)) + val req = SendPaymentToNode(finalAmount, invoice, 1, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams) + sender.send(initiator, req) + val id = sender.expectMsgType[UUID] + val fail = sender.expectMsgType[PaymentFailed] + assert(fail.id == id) + assert(fail.failures == LocalFailure(finalAmount, Nil, UnsupportedFeatures(invoice.features)) :: Nil) + } + test("forward trampoline payment") { f => import f._ val ignoredRoutingHints = List(List(ExtraHop(b, channelUpdate_bc.shortChannelId, feeBase = 10 msat, feeProportionalMillionths = 1, cltvExpiryDelta = CltvExpiryDelta(12)))) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index e97cf86ff5..ede7440e0a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -39,7 +39,7 @@ import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle._ import fr.acinq.eclair.payment.send.{ClearRecipient, PaymentLifecycle} import fr.acinq.eclair.router.Announcements.makeChannelUpdate -import fr.acinq.eclair.router.BaseRouterSpec.{channelAnnouncement, channelHopFromUpdate} +import fr.acinq.eclair.router.BaseRouterSpec.{blindedRouteFromHops, channelAnnouncement, channelHopFromUpdate} import fr.acinq.eclair.router.Graph.WeightRatios import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.router._ @@ -58,6 +58,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val defaultAmountMsat = 142_000_000 msat val defaultExpiry = Channel.MIN_CLTV_EXPIRY_DELTA.toCltvExpiry(BlockHeight(40_000)) + val defaultRouteExpiry = CltvExpiry(100_000) val defaultPaymentPreimage = randomBytes32() val defaultPaymentHash = Crypto.sha256(defaultPaymentPreimage) val defaultOrigin = Origin.LocalCold(UUID.randomUUID()) @@ -78,10 +79,10 @@ class PaymentLifecycleSpec extends BaseRouterSpec { eventListener: TestProbe, metricsListener: TestProbe) - def createPaymentLifecycle(storeInDb: Boolean = true, publishEvent: Boolean = true, recordMetrics: Boolean = true): PaymentFixture = { + def createPaymentLifecycle(invoice: Invoice, storeInDb: Boolean = true, publishEvent: Boolean = true, recordMetrics: Boolean = true): PaymentFixture = { val (id, parentId) = (UUID.randomUUID(), UUID.randomUUID()) val nodeParams = TestConstants.Alice.nodeParams.copy(nodeKeyManager = testNodeKeyManager, channelKeyManager = testChannelKeyManager) - val cfg = SendPaymentConfig(id, parentId, Some(defaultExternalId), defaultPaymentHash, d, Upstream.Local(id), Some(defaultInvoice), storeInDb, publishEvent, recordMetrics) + val cfg = SendPaymentConfig(id, parentId, Some(defaultExternalId), defaultPaymentHash, invoice.nodeId, Upstream.Local(id), Some(invoice), storeInDb, publishEvent, recordMetrics) val (routerForwarder, register, sender, monitor, eventListener, metricsListener) = (TestProbe(), TestProbe(), TestProbe(), TestProbe(), TestProbe(), TestProbe()) val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, cfg, routerForwarder.ref, register.ref)) paymentFSM ! SubscribeTransitionCallBack(monitor.ref) @@ -99,7 +100,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("send to route") { () => - val payFixture = createPaymentLifecycle(recordMetrics = false) + val payFixture = createPaymentLifecycle(defaultInvoice, recordMetrics = false) import payFixture._ import cfg._ @@ -127,7 +128,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("send to route (node_id only)") { routerFixture => - val payFixture = createPaymentLifecycle(recordMetrics = false) + val payFixture = createPaymentLifecycle(defaultInvoice, recordMetrics = false) import payFixture._ import cfg._ @@ -156,7 +157,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("send to route (nodes not found in the graph)") { routerFixture => - val payFixture = createPaymentLifecycle(recordMetrics = false) + val payFixture = createPaymentLifecycle(defaultInvoice, recordMetrics = false) import payFixture._ val brokenRoute = SendPaymentToRoute(sender.ref, Left(PredefinedNodeRoute(defaultAmountMsat, Seq(randomKey().publicKey, randomKey().publicKey, randomKey().publicKey))), defaultRecipient) @@ -173,7 +174,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("send to route (channels not found in the graph)") { routerFixture => - val payFixture = createPaymentLifecycle(recordMetrics = false) + val payFixture = createPaymentLifecycle(defaultInvoice, recordMetrics = false) import payFixture._ val brokenRoute = SendPaymentToRoute(sender.ref, Left(PredefinedChannelRoute(defaultAmountMsat, randomKey().publicKey, Seq(ShortChannelId(1), ShortChannelId(2)))), defaultRecipient) @@ -190,7 +191,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("send to route (routing hints)") { routerFixture => - val payFixture = createPaymentLifecycle(recordMetrics = false) + val payFixture = createPaymentLifecycle(defaultInvoice, recordMetrics = false) import payFixture._ import cfg._ @@ -219,8 +220,33 @@ class PaymentLifecycleSpec extends BaseRouterSpec { assert(routerForwarder.expectMsgType[RouteDidRelay].route.hops.map(_.nodeId) === Seq(a, b, c)) } + test("send to route (blinded route)") { () => + val (invoice, blindedHop, recipient) = blindedRouteFromHops(defaultAmountMsat, defaultExpiry, Seq(channelHopFromUpdate(b, c, update_bc)), defaultRouteExpiry, defaultPaymentPreimage) + val route = Route(defaultAmountMsat, Seq(channelHopFromUpdate(a, b, update_ab)), Some(blindedHop)) + val payFixture = createPaymentLifecycle(invoice) + import payFixture._ + + val request = SendPaymentToRoute(sender.ref, Right(route), recipient) + sender.send(paymentFSM, request) + routerForwarder.expectNoMessage(100 millis) // we don't need the router, we have the pre-computed route + awaitCond(nodeParams.db.payments.getOutgoingPayment(cfg.id).exists(_.status == OutgoingPaymentStatus.Pending)) + val Some(outgoing) = nodeParams.db.payments.getOutgoingPayment(cfg.id) + assert(outgoing.amount == defaultAmountMsat) + assert(outgoing.recipientAmount == defaultAmountMsat) + assert(outgoing.invoice.contains(invoice)) + assert(outgoing.status == OutgoingPaymentStatus.Pending) + + sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFulfill(UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentPreimage)))) + val ps = sender.expectMsgType[PaymentSent] + assert(ps.id == cfg.parentId) + assert(ps.parts.head.route.contains(route.hops ++ Seq(blindedHop))) + awaitCond(nodeParams.db.payments.getOutgoingPayment(cfg.id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Succeeded])) + + assert(routerForwarder.expectMsgType[RouteDidRelay].route === route) + } + test("payment failed (route not found)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -245,7 +271,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (route too expensive)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -277,7 +303,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (cannot build onion)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -297,7 +323,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (unparsable failure)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -342,7 +368,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (local error)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -363,7 +389,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (register error)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -383,7 +409,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (first hop returns an UpdateFailMalformedHtlc)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -406,7 +432,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (first htlc failed on-chain)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -429,7 +455,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (disconnected before signing the first htlc)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -453,7 +479,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (TemporaryChannelFailure)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ val request = SendPaymentToNode(sender.ref, defaultRecipient, 2, defaultRouteParams) @@ -483,7 +509,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (Update)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -538,7 +564,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (Update in last attempt)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ val request = SendPaymentToNode(sender.ref, defaultRecipient, 1, defaultRouteParams) @@ -562,7 +588,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (Update in assisted route)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -606,7 +632,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("payment failed (Update disabled in assisted route)") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -636,7 +662,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } def testPermanentFailure(router: ActorRef, failure: FailureMessage): Unit = { - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -673,8 +699,40 @@ class PaymentLifecycleSpec extends BaseRouterSpec { testPermanentFailure(routerFixture.router, FailureMessageCodecs.failureMessageCodec.decode(hex"4011".bits).require.value) } + test("payment failed (blinded route)") { routerFixture => + val (invoice, blindedHop, recipient) = blindedRouteFromHops(defaultAmountMsat, defaultExpiry, Seq(channelHopFromUpdate(b, c, update_bc)), defaultRouteExpiry, defaultPaymentPreimage) + assert(recipient.extraEdges.length == 1) + val payFixture = createPaymentLifecycle(invoice) + import payFixture._ + + val request = SendPaymentToNode(sender.ref, recipient, 2, defaultRouteParams) + sender.send(paymentFSM, request) + + routerForwarder.expectMsgType[RouteRequest] + routerForwarder.forward(routerFixture.router) + awaitCond(paymentFSM.stateName == WAITING_FOR_PAYMENT_COMPLETE) + val WaitingForComplete(_, cmd1, Nil, sharedSecrets, _, route) = paymentFSM.stateData + register.expectMsg(ForwardShortId(paymentFSM.toTyped, scid_ab, cmd1)) + + // The payment fails inside the blinded route: the introduction node sends back an error. + val failure = InvalidOnionBlinding(randomBytes32()) + val failureOnion = Sphinx.FailurePacket.create(sharedSecrets.head._1, failure) + sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, failureOnion)))) + + // We retry but we exclude the failed blinded route. + val routeRequest = routerForwarder.expectMsgType[RouteRequest] + assert(routeRequest.target == recipient) + assert(routeRequest.ignore.channels.map(_.shortChannelId) == Set(blindedHop.dummyId)) + routerForwarder.forward(routerFixture.router) + + // Without the blinded route, the router cannot find a route to the recipient. + val failed = sender.expectMsgType[PaymentFailed] + assert(failed.failures == Seq(RemoteFailure(defaultAmountMsat, route.hops ++ Seq(blindedHop), Sphinx.DecryptedFailurePacket(b, failure)), LocalFailure(defaultAmountMsat, Nil, RouteNotFound))) + awaitCond(nodeParams.db.payments.getOutgoingPayment(cfg.id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Failed])) + } + test("payment succeeded") { routerFixture => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ @@ -730,7 +788,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { watcher.send(router, ValidateResult(chan_bh, Right((Transaction(version = 0, txIn = Nil, txOut = TxOut(1000000 sat, write(pay2wsh(Scripts.multiSig2of2(funding_b, funding_h)))) :: Nil, lockTime = 0), UtxoStatus.Unspent)))) watcher.expectMsgType[WatchExternalChannelSpent] - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ // we send a payment to H @@ -764,6 +822,37 @@ class PaymentLifecycleSpec extends BaseRouterSpec { assert(routerForwarder.expectMsgType[RouteDidRelay].route.hops.map(_.shortChannelId) == Seq(update_ab, channelUpdate_bh).map(_.shortChannelId)) } + test("payment success (blinded route)") { routerFixture => + val (invoice, blindedHop, recipient) = blindedRouteFromHops(defaultAmountMsat, defaultExpiry, Seq(channelHopFromUpdate(b, c, update_bc)), defaultRouteExpiry, defaultPaymentPreimage) + assert(recipient.extraEdges.length == 1) + val payFixture = createPaymentLifecycle(invoice) + import payFixture._ + + val request = SendPaymentToNode(sender.ref, recipient, 2, defaultRouteParams) + sender.send(paymentFSM, request) + routerForwarder.expectMsgType[RouteRequest] + routerForwarder.forward(routerFixture.router) + awaitCond(nodeParams.db.payments.getOutgoingPayment(cfg.id).exists(_.status == OutgoingPaymentStatus.Pending)) + val Some(outgoing) = nodeParams.db.payments.getOutgoingPayment(cfg.id) + assert(outgoing.copy(createdAt = 0 unixms) == OutgoingPayment(cfg.id, cfg.parentId, Some(defaultExternalId), defaultPaymentHash, PaymentType.Blinded, defaultAmountMsat, defaultAmountMsat, recipient.nodeId, 0 unixms, Some(invoice), OutgoingPaymentStatus.Pending)) + sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFulfill(UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentPreimage)))) + + val ps = eventListener.expectMsgType[PaymentSent] + assert(ps.id == cfg.parentId) + assert(ps.feesPaid == blindedHop.fee(defaultAmountMsat)) + assert(ps.recipientAmount == defaultAmountMsat) + assert(ps.paymentPreimage == defaultPaymentPreimage) + awaitCond(nodeParams.db.payments.getOutgoingPayment(cfg.id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Succeeded])) + + val metrics = metricsListener.expectMsgType[PathFindingExperimentMetrics] + assert(metrics.status == "SUCCESS") + assert(metrics.amount == defaultAmountMsat) + assert(metrics.fees == blindedHop.fee(defaultAmountMsat)) + metricsListener.expectNoMessage(100 millis) + + assert(routerForwarder.expectMsgType[RouteDidRelay].route.hops.map(_.shortChannelId) == Seq(update_ab.shortChannelId)) + } + test("filter errors properly") { () => val failures = Seq( LocalFailure(defaultAmountMsat, Nil, RouteNotFound), @@ -782,6 +871,10 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("ignore failed nodes/channels") { () => val route_abcd = channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: Nil + val update_de = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_d, e, ShortChannelId(1729), CltvExpiryDelta(3), htlcMinimumMsat = 0 msat, feeBaseMsat = 10 msat, feeProportionalMillionths = 4, htlcMaximumMsat = htlcMaximum) + val (_, blindedHop_bc, _) = blindedRouteFromHops(defaultAmountMsat, defaultExpiry, Seq(channelHopFromUpdate(b, c, update_bc), channelHopFromUpdate(c, d, update_cd)), defaultRouteExpiry, defaultPaymentPreimage) + val blindedRoute_abc = channelHopFromUpdate(a, b, update_ab) :: blindedHop_bc :: Nil + val (_, blindedHop_de, _) = blindedRouteFromHops(defaultAmountMsat, defaultExpiry, Seq(channelHopFromUpdate(d, e, update_de)), defaultRouteExpiry, defaultPaymentPreimage) val testCases = Seq( // local failures -> ignore first channel if there is one (LocalFailure(defaultAmountMsat, Nil, RouteNotFound), Set.empty, Set.empty), @@ -795,9 +888,14 @@ class PaymentLifecycleSpec extends BaseRouterSpec { (RemoteFailure(defaultAmountMsat, route_abcd, Sphinx.DecryptedFailurePacket(b, PermanentChannelFailure)), Set.empty, Set(ChannelDesc(scid_bc, b, c))), (RemoteFailure(defaultAmountMsat, route_abcd, Sphinx.DecryptedFailurePacket(c, UnknownNextPeer)), Set.empty, Set(ChannelDesc(scid_cd, c, d))), (RemoteFailure(defaultAmountMsat, route_abcd, Sphinx.DecryptedFailurePacket(b, FeeInsufficient(100 msat, update_bc))), Set.empty, Set.empty), - // unreadable remote failures -> blacklist all nodes except our direct peer and the final recipient + (RemoteFailure(defaultAmountMsat, blindedRoute_abc, Sphinx.DecryptedFailurePacket(b, InvalidOnionBlinding(randomBytes32()))), Set.empty, Set(ChannelDesc(blindedHop_bc.dummyId, blindedHop_bc.nodeId, blindedHop_bc.nextNodeId))), + (RemoteFailure(defaultAmountMsat, blindedRoute_abc, Sphinx.DecryptedFailurePacket(blindedHop_bc.route.blindedNodeIds(1), InvalidOnionBlinding(randomBytes32()))), Set.empty, Set(ChannelDesc(blindedHop_bc.dummyId, blindedHop_bc.nodeId, blindedHop_bc.nextNodeId))), + // unreadable remote failures -> blacklist all nodes except our direct peer, the final recipient or the last hop (UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: Nil), Set.empty, Set.empty), - (UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: NodeHop(d, e, CltvExpiryDelta(24), 0 msat) :: Nil), Set(c, d), Set.empty) + (UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: Nil), Set(c), Set.empty), + (UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: channelHopFromUpdate(d, e, update_de) :: Nil), Set(c, d), Set.empty), + (UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: NodeHop(d, e, CltvExpiryDelta(24), 0 msat) :: Nil), Set(c), Set.empty), + (UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: blindedHop_de :: Nil), Set(c), Set.empty), ) for ((failure, expectedNodes, expectedChannels) <- testCases) { @@ -817,7 +915,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("disable database and events") { routerFixture => - val payFixture = createPaymentLifecycle(storeInDb = false, publishEvent = false, recordMetrics = false) + val payFixture = createPaymentLifecycle(defaultInvoice, storeInDb = false, publishEvent = false, recordMetrics = false) import payFixture._ import cfg._ @@ -838,7 +936,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { } test("send to route (no retry on error") { () => - val payFixture = createPaymentLifecycle() + val payFixture = createPaymentLifecycle(defaultInvoice) import payFixture._ import cfg._ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala index a607d9bad8..2777a4534d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala @@ -23,21 +23,25 @@ import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.Features._ import fr.acinq.eclair.channel._ import fr.acinq.eclair.channel.fsm.Channel +import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.IncomingPaymentPacket.{ChannelRelayPacket, FinalPacket, NodeRelayPacket, decrypt} import fr.acinq.eclair.payment.OutgoingPaymentPacket._ -import fr.acinq.eclair.payment.send.{ClearRecipient, ClearTrampolineRecipient} -import fr.acinq.eclair.router.BaseRouterSpec.channelHopFromUpdate +import fr.acinq.eclair.payment.send.{BlindedRecipient, ClearRecipient, ClearTrampolineRecipient} +import fr.acinq.eclair.router.BaseRouterSpec.{blindedRouteFromHops, channelHopFromUpdate} +import fr.acinq.eclair.router.BlindedRouteCreation import fr.acinq.eclair.router.Router.{NodeHop, Route} import fr.acinq.eclair.transactions.Transactions.InputInfo +import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer, PaymentInfo} import fr.acinq.eclair.wire.protocol.OnionPaymentPayloadTlv.{AmountToForward, OutgoingCltv, PaymentData} import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload} import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, ShortChannelId, TestConstants, TimestampSecondLong, UInt64, nodeFee, randomBytes32, randomKey} +import fr.acinq.eclair.{BlockHeight, CltvExpiry, CltvExpiryDelta, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, ShortChannelId, TestConstants, TimestampSecondLong, UInt64, nodeFee, randomBytes32, randomKey} import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite import scodec.bits.{ByteVector, HexStringSyntax} import java.util.UUID +import scala.util.Success /** * Created by PM on 31/05/2016. @@ -61,7 +65,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { def testBuildOutgoingPayment(): Unit = { val recipient = ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), recipient) assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) assert(payment.cmd.amount == amount_ab) assert(payment.cmd.cltvExpiry == expiry_ab) @@ -119,7 +123,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { test("build outgoing payment for direct peer") { val recipient = ClearRecipient(b, Features.empty, finalAmount, finalExpiry, paymentSecret, paymentMetadata_opt = Some(paymentMetadata)) val route = Route(finalAmount, hops.take(1), None) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) assert(payment.cmd.amount == finalAmount) assert(payment.cmd.cltvExpiry == finalExpiry) assert(payment.cmd.paymentHash == paymentHash) @@ -140,7 +144,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { test("build outgoing payment with greater amount and expiry") { val recipient = ClearRecipient(b, Features.empty, finalAmount, finalExpiry, paymentSecret, paymentMetadata_opt = Some(paymentMetadata)) val route = Route(finalAmount, hops.take(1), None) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) // let's peel the onion val add_b = UpdateAddHtlc(randomBytes32(), 0, finalAmount + 100.msat, paymentHash, finalExpiry + CltvExpiryDelta(6), payment.cmd.onion, None) @@ -152,6 +156,125 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(payload_b.asInstanceOf[FinalPayload.Standard].paymentSecret == paymentSecret) } + test("build outgoing blinded payment") { + val (invoice, route, recipient) = longBlindedHops(hex"deadbeef") + assert(recipient.extraEdges.length == 1) + assert(recipient.extraEdges.head.sourceNodeId == c) + assert(recipient.extraEdges.head.targetNodeId == invoice.nodeId) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) + assert(payment.cmd.amount >= amount_ab) + assert(payment.cmd.cltvExpiry == expiry_ab) + assert(payment.cmd.nextBlindingKey_opt.isEmpty) + + val add_b = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Right(relay_b@ChannelRelayPacket(_, payload_b, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) + assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) + assert(relay_b.amountToForward >= amount_bc) + assert(relay_b.outgoingCltv == expiry_bc) + assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId) + assert(relay_b.relayFeeMsat == fee_b) + assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) + assert(payload_b.isInstanceOf[IntermediatePayload.ChannelRelay.Standard]) + + val add_c = UpdateAddHtlc(randomBytes32(), 1, amount_bc, paymentHash, expiry_bc, packet_c, None) + val Right(relay_c@ChannelRelayPacket(_, payload_c, packet_d)) = decrypt(add_c, priv_c.privateKey, Features(RouteBlinding -> Optional)) + assert(packet_d.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) + assert(relay_c.amountToForward == amount_cd) + assert(relay_c.outgoingCltv == expiry_cd) + assert(payload_c.outgoingChannelId == channelUpdate_cd.shortChannelId) + assert(relay_c.relayFeeMsat == fee_c) + assert(relay_c.expiryDelta == channelUpdate_cd.cltvExpiryDelta) + assert(payload_c.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) + val blinding_d = payload_c.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding + + val add_d = UpdateAddHtlc(randomBytes32(), 2, amount_cd, paymentHash, expiry_cd, packet_d, Some(blinding_d)) + val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional)) + assert(packet_e.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) + assert(relay_d.amountToForward == amount_de) + assert(relay_d.outgoingCltv == expiry_de) + assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(relay_d.relayFeeMsat == fee_d) + assert(relay_d.expiryDelta == channelUpdate_de.cltvExpiryDelta) + assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) + val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding + + val add_e = UpdateAddHtlc(randomBytes32(), 2, amount_de, paymentHash, expiry_de, packet_e, Some(blinding_e)) + val Right(FinalPacket(_, payload_e)) = decrypt(add_e, priv_e.privateKey, Features(RouteBlinding -> Optional)) + assert(payload_e.amount == finalAmount) + assert(payload_e.totalAmount == finalAmount) + assert(payload_e.expiry == finalExpiry) + assert(payload_e.isInstanceOf[FinalPayload.Blinded]) + assert(payload_e.asInstanceOf[FinalPayload.Blinded].pathId == hex"deadbeef") + } + + test("build outgoing blinded payment for introduction node") { + // a -> b -> c where c uses a 0-hop blinded route. + val recipientKey = randomKey() + val features = Features[InvoiceFeature](VariableLengthOnion -> Mandatory, BasicMultiPartPayment -> Optional, RouteBlinding -> Mandatory) + val offer = Offer(None, "Bolt12 r0cks", recipientKey.publicKey, features, Block.RegtestGenesisBlock.hash) + val invoiceRequest = InvoiceRequest(offer, amount_bc, 1, features, randomKey(), Block.RegtestGenesisBlock.hash) + val blindedRoute = BlindedRouteCreation.createBlindedRouteWithoutHops(c, hex"deadbeef", 1 msat, CltvExpiry(500_000)).route + val paymentInfo = PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 1 msat, amount_bc, Features.empty) + val invoice = Bolt12Invoice(offer, invoiceRequest, paymentPreimage, recipientKey, CltvExpiryDelta(6), features, Seq(PaymentBlindedRoute(blindedRoute, paymentInfo))) + val recipient = BlindedRecipient(invoice, amount_bc, expiry_bc, Nil) + val hops = Seq(channelHopFromUpdate(a, b, channelUpdate_ab), channelHopFromUpdate(b, c, channelUpdate_bc)) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(amount_bc, hops, Some(recipient.blindedHops.head)), recipient) + assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) + assert(payment.cmd.amount == amount_ab) + assert(payment.cmd.cltvExpiry == expiry_ab) + assert(payment.cmd.nextBlindingKey_opt.isEmpty) + + val add_b = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Right(relay_b@ChannelRelayPacket(_, payload_b, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) + assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) + assert(relay_b.amountToForward >= amount_bc) + assert(relay_b.outgoingCltv == expiry_bc) + assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId) + assert(relay_b.relayFeeMsat == fee_b) + assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) + assert(payload_b.isInstanceOf[IntermediatePayload.ChannelRelay.Standard]) + + val add_c = UpdateAddHtlc(randomBytes32(), 1, amount_bc, paymentHash, expiry_bc, packet_c, None) + val Right(FinalPacket(_, payload_c)) = decrypt(add_c, priv_c.privateKey, Features(RouteBlinding -> Optional)) + assert(payload_c.amount == amount_bc) + assert(payload_c.totalAmount == amount_bc) + assert(payload_c.expiry == expiry_bc) + assert(payload_c.isInstanceOf[FinalPayload.Blinded]) + assert(payload_c.asInstanceOf[FinalPayload.Blinded].pathId == hex"deadbeef") + } + + test("build outgoing blinded payment starting at our node") { + val (route, recipient) = singleBlindedHop(hex"123456") + assert(recipient.extraEdges.length == 1) + assert(recipient.extraEdges.head.sourceNodeId == a) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) + assert(payment.cmd.amount == finalAmount) + assert(payment.cmd.cltvExpiry == finalExpiry) + assert(payment.cmd.nextBlindingKey_opt.nonEmpty) + + val add_b = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Right(FinalPacket(_, payload_b)) = decrypt(add_b, priv_b.privateKey, Features(RouteBlinding -> Optional)) + assert(payload_b.amount == finalAmount) + assert(payload_b.totalAmount == finalAmount) + assert(payload_b.expiry == finalExpiry) + assert(payload_b.isInstanceOf[FinalPayload.Blinded]) + assert(payload_b.asInstanceOf[FinalPayload.Blinded].pathId == hex"123456") + } + + test("build outgoing blinded payment with greater amount and expiry") { + val (route, recipient) = singleBlindedHop(hex"123456") + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) + + val add_b = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount + 100.msat, payment.cmd.paymentHash, payment.cmd.cltvExpiry + CltvExpiryDelta(6), payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Right(FinalPacket(_, payload_b)) = decrypt(add_b, priv_b.privateKey, Features(RouteBlinding -> Optional)) + assert(payload_b.amount == finalAmount) + assert(payload_b.totalAmount == finalAmount) + assert(payload_b.expiry == finalExpiry) + } + test("build outgoing trampoline payment") { // simple trampoline route to e: // .----. @@ -162,7 +285,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) assert(recipient.trampolineAmount == amount_bc) assert(recipient.trampolineExpiry == expiry_bc) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) assert(payment.cmd.amount == amount_ab) assert(payment.cmd.cltvExpiry == expiry_ab) @@ -189,7 +312,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards the trampoline payment to e through d. val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, inner_c.outgoingCltv, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e)) - val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, priv_c.privateKey, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) assert(payment_e.outgoingChannel == channelUpdate_cd.shortChannelId) assert(payment_e.cmd.amount == amount_cd) assert(payment_e.cmd.cltvExpiry == expiry_cd) @@ -215,7 +338,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) assert(recipient.trampolineAmount == amount_bc) assert(recipient.trampolineExpiry == expiry_bc) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) assert(payment.cmd.amount == amount_ab) assert(payment.cmd.cltvExpiry == expiry_ab) @@ -240,7 +363,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards the trampoline payment to e through d. val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, inner_c.outgoingCltv, inner_c.paymentSecret.get, invoice.extraEdges, inner_c.paymentMetadata) - val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, priv_c.privateKey, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) assert(payment_e.outgoingChannel == channelUpdate_cd.shortChannelId) assert(payment_e.cmd.amount == amount_cd) assert(payment_e.cmd.cltvExpiry == expiry_cd) @@ -259,7 +382,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val routingHintOverflow = List(List.fill(7)(Bolt11Invoice.ExtraHop(randomKey().publicKey, ShortChannelId(1), 10 msat, 100, CltvExpiryDelta(12)))) val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, Some(finalAmount), paymentHash, priv_e.privateKey, Left("#reckless"), CltvExpiryDelta(18), None, None, routingHintOverflow) val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) - val Left(failure) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) + val Left(failure) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) assert(failure.isInstanceOf[CannotCreateOnion]) } @@ -268,13 +391,37 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val invoiceFeatures = Features[InvoiceFeature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, PaymentMetadata -> Optional, TrampolinePaymentPrototype -> Optional) val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, Some(finalAmount), paymentHash, priv_e.privateKey, Left("Much payment very metadata"), CltvExpiryDelta(9), features = invoiceFeatures, paymentMetadata = Some(paymentMetadata)) val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) - val Left(failure) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) + val Left(failure) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) assert(failure.isInstanceOf[CannotCreateOnion]) } + test("fail to build outgoing payment with invalid route") { + val recipient = ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret) + val route = Route(finalAmount, hops.dropRight(1), None) // route doesn't reach e + val Left(failure) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + assert(failure == InvalidRouteRecipient(e, d)) + } + + test("fail to build outgoing trampoline payment with invalid route") { + val invoiceFeatures = Features[InvoiceFeature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, PaymentMetadata -> Optional, TrampolinePaymentPrototype -> Optional) + val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, paymentHash, priv_e.privateKey, Left("invoice"), CltvExpiryDelta(6), paymentSecret = paymentSecret, features = invoiceFeatures) + val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) + val route = Route(finalAmount, trampolineChannelHops, None) // missing trampoline hop + val Left(failure) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + assert(failure == MissingTrampolineHop(c)) + } + + test("fail to build outgoing blinded payment with invalid route") { + val (_, route, recipient) = longBlindedHops(hex"deadbeef") + assert(buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient).isRight) + val routeMissingBlindedHop = route.copy(finalHop_opt = None) + val Left(failure) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, routeMissingBlindedHop, recipient) + assert(failure == MissingBlindedHop(Set(c))) + } + test("fail to decrypt when the onion is invalid") { val recipient = ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), recipient) val add = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion.copy(payload = payment.cmd.onion.payload.reverse), None) val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) assert(failure.isInstanceOf[InvalidOnionHmac]) @@ -284,7 +431,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val invoiceFeatures = Features[InvoiceFeature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, PaymentMetadata -> Optional, TrampolinePaymentPrototype -> Optional) val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, paymentHash, priv_e.privateKey, Left("invoice"), CltvExpiryDelta(6), paymentSecret = paymentSecret, features = invoiceFeatures, paymentMetadata = Some(hex"010203")) val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) val add_b = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) @@ -294,7 +441,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards an invalid trampoline onion to e through d. val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, inner_c.outgoingCltv, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e.copy(payload = trampolinePacket_e.payload.reverse))) - val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, priv_c.privateKey, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) assert(payment_e.outgoingChannel == channelUpdate_cd.shortChannelId) val add_d = UpdateAddHtlc(randomBytes32(), 3, payment_e.cmd.amount, paymentHash, payment_e.cmd.cltvExpiry, payment_e.cmd.onion, None) val Right(ChannelRelayPacket(_, _, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) @@ -306,16 +453,47 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { test("fail to decrypt when payment hash doesn't match associated data") { val recipient = ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash.reverse, Route(finalAmount, hops, None), recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash.reverse, Route(finalAmount, hops, None), recipient) val add = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) assert(failure.isInstanceOf[InvalidOnionHmac]) } + test("fail to decrypt when blinded route data is invalid") { + val (route, recipient) = { + val features = Features[InvoiceFeature](VariableLengthOnion -> Mandatory, BasicMultiPartPayment -> Optional, RouteBlinding -> Mandatory) + val offer = Offer(None, "Bolt12 r0cks", c, features, Block.RegtestGenesisBlock.hash) + val invoiceRequest = InvoiceRequest(offer, amount_bc, 1, features, randomKey(), Block.RegtestGenesisBlock.hash) + // We send the wrong blinded payload to the introduction node. + val tmpBlindedRoute = BlindedRouteCreation.createBlindedRouteFromHops(Seq(channelHopFromUpdate(b, c, channelUpdate_bc)), hex"deadbeef", 1 msat, CltvExpiry(500_000)).route + val blindedRoute = tmpBlindedRoute.copy(blindedNodes = tmpBlindedRoute.blindedNodes.reverse) + val paymentInfo = OfferTypes.PaymentInfo(fee_b, 0, channelUpdate_bc.cltvExpiryDelta, 0 msat, amount_bc, Features.empty) + val invoice = Bolt12Invoice(offer, invoiceRequest, paymentPreimage, priv_c.privateKey, CltvExpiryDelta(6), features, Seq(PaymentBlindedRoute(blindedRoute, paymentInfo))) + val recipient = BlindedRecipient(invoice, amount_bc, expiry_bc, Nil) + val route = Route(amount_bc, Seq(channelHopFromUpdate(a, b, channelUpdate_ab)), Some(recipient.blindedHops.head)) + (route, recipient) + } + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + assert(payment.outgoingChannel == channelUpdate_ab.shortChannelId) + assert(payment.cmd.amount == amount_bc + fee_b) + + val add_b = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Left(failure) = decrypt(add_b, priv_b.privateKey, Features(RouteBlinding -> Optional)) + assert(failure.isInstanceOf[InvalidOnionBlinding]) + } + + test("fail to decrypt blinded payment when route blinding is disabled") { + val (route, recipient) = singleBlindedHop(hex"00000000") + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + val add_b = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Left(failure) = decrypt(add_b, priv_b.privateKey, Features.empty) // b doesn't support route blinding + assert(failure == InvalidOnionPayload(UInt64(10), 0)) + } + test("fail to decrypt at the final node when amount has been modified by next-to-last node") { val recipient = ClearRecipient(b, Features.empty, finalAmount, finalExpiry, paymentSecret) val route = Route(finalAmount, hops.take(1), None) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) val add = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount - 100.msat, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) assert(failure == FinalIncorrectHtlcAmount(payment.cmd.amount - 100.msat)) @@ -324,12 +502,64 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { test("fail to decrypt at the final node when expiry has been modified by next-to-last node") { val recipient = ClearRecipient(b, Features.empty, finalAmount, finalExpiry, paymentSecret) val route = Route(finalAmount, hops.take(1), None) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) val add = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry - CltvExpiryDelta(12), payment.cmd.onion, None) val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) assert(failure == FinalIncorrectCltvExpiry(payment.cmd.cltvExpiry - CltvExpiryDelta(12))) } + test("fail to decrypt blinded payment at the final node when amount is too low") { + val (route, recipient) = shortBlindedHops() + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_c.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + assert(payment.outgoingChannel == channelUpdate_cd.shortChannelId) + assert(payment.cmd.amount == amount_cd) + + // A smaller amount is sent to d, who doesn't know that it's invalid. + val add_d = UpdateAddHtlc(randomBytes32(), 0, amount_de, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional)) + assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(relay_d.amountToForward < amount_de) + assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) + val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding + + // When e receives a smaller amount than expected, it rejects the payment. + val add_e = UpdateAddHtlc(randomBytes32(), 0, relay_d.amountToForward, paymentHash, relay_d.outgoingCltv, packet_e, Some(blinding_e)) + val Left(failure) = decrypt(add_e, priv_e.privateKey, Features(RouteBlinding -> Optional)) + assert(failure.isInstanceOf[InvalidOnionBlinding]) + } + + test("fail to decrypt blinded payment at the final node when expiry is too low") { + val (route, recipient) = shortBlindedHops() + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_c.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + assert(payment.outgoingChannel == channelUpdate_cd.shortChannelId) + assert(payment.cmd.cltvExpiry == expiry_cd) + + // A smaller expiry is sent to d, who doesn't know that it's invalid. + val add_d = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, paymentHash, expiry_de, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional)) + assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(relay_d.outgoingCltv < expiry_de) + assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) + val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding + + // When e receives a smaller expiry than expected, it rejects the payment. + val add_e = UpdateAddHtlc(randomBytes32(), 0, relay_d.amountToForward, paymentHash, relay_d.outgoingCltv, packet_e, Some(blinding_e)) + val Left(failure) = decrypt(add_e, priv_e.privateKey, Features(RouteBlinding -> Optional)) + assert(failure.isInstanceOf[InvalidOnionBlinding]) + } + + test("fail to decrypt blinded payment at intermediate node when expiry is too high") { + val routeExpiry = expiry_de - channelUpdate_de.cltvExpiryDelta + val (route, recipient) = shortBlindedHops(routeExpiry) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_c.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + assert(payment.outgoingChannel == channelUpdate_cd.shortChannelId) + assert(payment.cmd.cltvExpiry > expiry_de) + + val add_d = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Left(failure) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional)) + assert(failure.isInstanceOf[InvalidOnionBlinding]) + } + // Create a trampoline payment to e: // .----. // / \ @@ -340,7 +570,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val invoiceFeatures = Features[InvoiceFeature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional, TrampolinePaymentPrototype -> Optional) val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, paymentHash, priv_e.privateKey, Left("invoice"), CltvExpiryDelta(6), paymentSecret = paymentSecret, features = invoiceFeatures) val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, trampolineChannelHops, Some(trampolineHop)), recipient) val add_b = UpdateAddHtlc(randomBytes32(), 1, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) @@ -355,7 +585,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards an invalid amount to e through (the outer total amount doesn't match the inner amount). val invalidTotalAmount = inner_c.amountToForward - 1.msat val recipient_e = ClearRecipient(e, Features.empty, invalidTotalAmount, inner_c.outgoingCltv, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e)) - val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(invalidTotalAmount, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, priv_c.privateKey, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(invalidTotalAmount, afterTrampolineChannelHops, None), recipient_e) val add_d = UpdateAddHtlc(randomBytes32(), 3, payment_e.cmd.amount, paymentHash, payment_e.cmd.cltvExpiry, payment_e.cmd.onion, None) val Right(ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) @@ -371,7 +601,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards an invalid amount to e through (the outer expiry doesn't match the inner expiry). val invalidExpiry = inner_c.outgoingCltv - CltvExpiryDelta(12) val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, invalidExpiry, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e)) - val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(ActorRef.noSender, priv_c.privateKey, Upstream.Trampoline(Seq(add_c)), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) val add_d = UpdateAddHtlc(randomBytes32(), 3, payment_e.cmd.amount, paymentHash, payment_e.cmd.cltvExpiry, payment_e.cmd.onion, None) val Right(ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) @@ -395,6 +625,56 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(failure == FinalIncorrectCltvExpiry(expiry_bc - CltvExpiryDelta(12))) } + test("build htlc failure onion") { + // a -> b -> c -> d -> e + val recipient = ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), recipient) + val add_b = UpdateAddHtlc(randomBytes32(), 0, amount_ab, paymentHash, expiry_ab, payment.cmd.onion, None) + val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) + val add_c = UpdateAddHtlc(randomBytes32(), 1, amount_bc, paymentHash, expiry_bc, packet_c, None) + val Right(ChannelRelayPacket(_, _, packet_d)) = decrypt(add_c, priv_c.privateKey, Features.empty) + val add_d = UpdateAddHtlc(randomBytes32(), 2, amount_cd, paymentHash, expiry_cd, packet_d, None) + val Right(ChannelRelayPacket(_, _, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) + val add_e = UpdateAddHtlc(randomBytes32(), 3, amount_de, paymentHash, expiry_de, packet_e, None) + val Right(FinalPacket(_, payload_e)) = decrypt(add_e, priv_e.privateKey, Features.empty) + assert(payload_e.isInstanceOf[FinalPayload.Standard]) + + // e returns a failure + val failure = IncorrectOrUnknownPaymentDetails(finalAmount, BlockHeight(currentBlockCount)) + val Right(fail_e) = buildHtlcFailure(priv_e.privateKey, CMD_FAIL_HTLC(add_e.id, Right(failure)), add_e) + assert(fail_e.id == add_e.id) + val Right(fail_d) = buildHtlcFailure(priv_d.privateKey, CMD_FAIL_HTLC(add_d.id, Left(fail_e.reason)), add_d) + assert(fail_d.id == add_d.id) + val Right(fail_c) = buildHtlcFailure(priv_c.privateKey, CMD_FAIL_HTLC(add_c.id, Left(fail_d.reason)), add_c) + assert(fail_c.id == add_c.id) + val Right(fail_b) = buildHtlcFailure(priv_b.privateKey, CMD_FAIL_HTLC(add_b.id, Left(fail_c.reason)), add_b) + assert(fail_b.id == add_b.id) + val Success(Sphinx.DecryptedFailurePacket(failingNode, decryptedFailure)) = Sphinx.FailurePacket.decrypt(fail_b.reason, payment.sharedSecrets) + assert(failingNode == e) + assert(decryptedFailure == failure) + } + + test("build htlc failure onion (blinded payment)") { + // a -> b -> c -> d -> e, blinded after c + val (_, route, recipient) = longBlindedHops(hex"0451") + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient) + val add_b = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) + val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) + val add_c = UpdateAddHtlc(randomBytes32(), 1, amount_bc, paymentHash, expiry_bc, packet_c, None) + val Right(_: ChannelRelayPacket) = decrypt(add_c, priv_c.privateKey, Features(RouteBlinding -> Optional)) + + // only the introduction node is allowed to send an `update_fail_htlc` message: downstream nodes must send + // `update_fail_malformed_htlc` which doesn't use onion encryption + val failure = InvalidOnionBlinding(Sphinx.hash(add_c.onionRoutingPacket)) + val Right(fail_c) = buildHtlcFailure(priv_c.privateKey, CMD_FAIL_HTLC(add_c.id, Right(failure)), add_c) + assert(fail_c.id == add_c.id) + val Right(fail_b) = buildHtlcFailure(priv_b.privateKey, CMD_FAIL_HTLC(add_b.id, Left(fail_c.reason)), add_b) + assert(fail_b.id == add_b.id) + val Success(Sphinx.DecryptedFailurePacket(failingNode, decryptedFailure)) = Sphinx.FailurePacket.decrypt(fail_b.reason, payment.sharedSecrets) + assert(failingNode == c) + assert(decryptedFailure == failure) + } + } object PaymentPacketSpec { @@ -411,9 +691,9 @@ object PaymentPacketSpec { } } - def randomExtendedPrivateKey: ExtendedPrivateKey = DeterministicWallet.generate(randomBytes32()) + def randomExtendedPrivateKey(): ExtendedPrivateKey = DeterministicWallet.generate(randomBytes32()) - val (priv_a, priv_b, priv_c, priv_d, priv_e) = (TestConstants.Alice.nodeKeyManager.nodeKey, TestConstants.Bob.nodeKeyManager.nodeKey, randomExtendedPrivateKey, randomExtendedPrivateKey, randomExtendedPrivateKey) + val (priv_a, priv_b, priv_c, priv_d, priv_e) = (TestConstants.Alice.nodeKeyManager.nodeKey, TestConstants.Bob.nodeKeyManager.nodeKey, randomExtendedPrivateKey(), randomExtendedPrivateKey(), randomExtendedPrivateKey()) val (a, b, c, d, e) = (priv_a.publicKey, priv_b.publicKey, priv_c.publicKey, priv_d.publicKey, priv_e.publicKey) val sig = Crypto.sign(Crypto.sha256(ByteVector.empty), priv_a.privateKey) val defaultChannelUpdate = ChannelUpdate(sig, Block.RegtestGenesisBlock.hash, ShortChannelId(0), 0 unixsec, ChannelUpdate.MessageFlags(dontForward = false), ChannelUpdate.ChannelFlags.DUMMY, CltvExpiryDelta(0), 42000 msat, 0 msat, 0, 500_000_000 msat) @@ -453,6 +733,32 @@ object PaymentPacketSpec { val expiry_ab = expiry_bc + channelUpdate_bc.cltvExpiryDelta val amount_ab = amount_bc + fee_b + // fully blinded route a -> b + def singleBlindedHop(pathId: ByteVector = hex"deadbeef", routeExpiry: CltvExpiry = CltvExpiry(500_000)): (Route, BlindedRecipient) = { + val (_, blindedHop, recipient) = blindedRouteFromHops(finalAmount, finalExpiry, Seq(channelHopFromUpdate(a, b, channelUpdate_ab)), routeExpiry, paymentPreimage, pathId) + (Route(finalAmount, Nil, Some(blindedHop)), recipient) + } + + // route c -> d -> e, blinded after d + def shortBlindedHops(routeExpiry: CltvExpiry = CltvExpiry(500_000)): (Route, BlindedRecipient) = { + val (_, blindedHop, recipient) = blindedRouteFromHops(finalAmount, finalExpiry, Seq(channelHopFromUpdate(d, e, channelUpdate_de)), routeExpiry, paymentPreimage) + (Route(finalAmount, Seq(channelHopFromUpdate(c, d, channelUpdate_cd)), Some(blindedHop)), recipient) + } + + // route a -> b -> c -> d -> e, blinded after c + def longBlindedHops(pathId: ByteVector): (Bolt12Invoice, Route, BlindedRecipient) = { + val hopsToBlind = Seq( + channelHopFromUpdate(c, d, channelUpdate_cd), + channelHopFromUpdate(d, e, channelUpdate_de), + ) + val (invoice, blindedHop, recipient) = blindedRouteFromHops(finalAmount, finalExpiry, hopsToBlind, CltvExpiry(500_000), paymentPreimage, pathId) + val hops = Seq( + channelHopFromUpdate(a, b, channelUpdate_ab), + channelHopFromUpdate(b, c, channelUpdate_bc), + ) + (invoice, Route(finalAmount, hops, Some(blindedHop)), recipient) + } + // simple trampoline route to e: // .----. // / \ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala index 9d7e66ab60..b1a58a4a95 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala @@ -633,8 +633,8 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit else if (htlcId == 1L) Some(nonRelayedHtlc2In.add) else None } - def localNodeId: PublicKey = randomExtendedPrivateKey.publicKey - def remoteNodeId: PublicKey = randomExtendedPrivateKey.publicKey + def localNodeId: PublicKey = randomExtendedPrivateKey().publicKey + def remoteNodeId: PublicKey = randomExtendedPrivateKey().publicKey def capacity: Satoshi = Long.MaxValue.sat def availableBalanceForReceive: MilliSatoshi = Long.MaxValue.msat def availableBalanceForSend: MilliSatoshi = 0.msat @@ -721,7 +721,7 @@ object PostRestartHtlcCleanerSpec { val (paymentHash1, paymentHash2, paymentHash3) = (Crypto.sha256(preimage1), Crypto.sha256(preimage2), Crypto.sha256(preimage3)) def buildHtlc(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32): UpdateAddHtlc = { - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), SpontaneousRecipient(e, finalAmount, finalExpiry, randomBytes32())) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), SpontaneousRecipient(e, finalAmount, finalExpiry, randomBytes32())) UpdateAddHtlc(channelId, htlcId, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) } @@ -730,7 +730,7 @@ object PostRestartHtlcCleanerSpec { def buildHtlcOut(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32): DirectedHtlc = OutgoingHtlc(buildHtlc(htlcId, channelId, paymentHash)) def buildFinalHtlc(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32): DirectedHtlc = { - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, Seq(channelHopFromUpdate(a, b, channelUpdate_ab)), None), SpontaneousRecipient(b, finalAmount, finalExpiry, randomBytes32())) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, Seq(channelHopFromUpdate(a, b, channelUpdate_ab)), None), SpontaneousRecipient(b, finalAmount, finalExpiry, randomBytes32())) IncomingHtlc(UpdateAddHtlc(channelId, htlcId, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None)) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala index f0f07e895a..96c65a553d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala @@ -90,7 +90,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat } // we use this to build a valid onion - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret)) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret)) // and then manually build an htlc val add_ab = UpdateAddHtlc(randomBytes32(), 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) relayer ! RelayForward(add_ab) @@ -100,7 +100,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat test("relay an htlc-add at the final node to the payment handler") { f => import f._ - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops.take(1), None), ClearRecipient(b, Features.empty, finalAmount, finalExpiry, paymentSecret)) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops.take(1), None), ClearRecipient(b, Features.empty, finalAmount, finalExpiry, paymentSecret)) val add_ab = UpdateAddHtlc(channelId_ab, 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) relayer ! RelayForward(add_ab) @@ -119,7 +119,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val finalTrampolinePayload = NodePayload(b, FinalPayload.Standard.createPayload(finalAmount, totalAmount, finalExpiry, paymentSecret)) val Right(trampolineOnion) = buildOnion(PaymentOnionCodecs.trampolineOnionPayloadLength, Seq(finalTrampolinePayload), paymentHash) val recipient = ClearRecipient(b, nodeParams.features.invoiceFeatures(), finalAmount, finalExpiry, randomBytes32(), nextTrampolineOnion_opt = Some(trampolineOnion.packet)) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, Seq(channelHopFromUpdate(priv_a.publicKey, b, channelUpdate_ab)), None), recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, Seq(channelHopFromUpdate(priv_a.publicKey, b, channelUpdate_ab)), None), recipient) assert(payment.cmd.amount == finalAmount) assert(payment.cmd.cltvExpiry == finalExpiry) val add_ab = UpdateAddHtlc(channelId_ab, 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) @@ -140,7 +140,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat import f._ // we use this to build a valid onion - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret)) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret)) // and then manually build an htlc with an invalid onion (hmac) val add_ab = UpdateAddHtlc(channelId_ab, 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion.copy(hmac = payment.cmd.onion.hmac.reverse), None) relayer ! RelayForward(add_ab) @@ -161,7 +161,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, None, paymentHash, priv_c.privateKey, Left("invoice"), CltvExpiryDelta(6), paymentSecret = paymentSecret, features = invoiceFeatures) val trampolineHop = NodeHop(b, c, channelUpdate_bc.cltvExpiryDelta, fee_b) val recipient = ClearTrampolineRecipient(invoice, finalAmount, finalExpiry, trampolineHop, randomBytes32()) - val Right(payment) = buildOutgoingPayment(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, Seq(channelHopFromUpdate(priv_a.publicKey, b, channelUpdate_ab)), Some(trampolineHop)), recipient) + val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(recipient.trampolineAmount, Seq(channelHopFromUpdate(priv_a.publicKey, b, channelUpdate_ab)), Some(trampolineHop)), recipient) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId_ab, 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala index 3aef6d1cf7..259a038d58 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala @@ -29,10 +29,13 @@ import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.crypto.keymanager.{LocalChannelKeyManager, LocalNodeKeyManager} import fr.acinq.eclair.io.Peer.PeerRoutingMessage +import fr.acinq.eclair.payment.send.BlindedRecipient +import fr.acinq.eclair.payment.{Bolt12Invoice, PaymentBlindedRoute} import fr.acinq.eclair.router.Announcements._ import fr.acinq.eclair.router.BaseRouterSpec.channelAnnouncement import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.transactions.Scripts +import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer} import fr.acinq.eclair.wire.protocol._ import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike @@ -225,9 +228,11 @@ abstract class BaseRouterSpec extends TestKitBaseClass with FixtureAnyFunSuiteLi withFixture(test.toNoArgTest(FixtureParam(nodeParams, router, watcher))) } } + } object BaseRouterSpec { + def channelAnnouncement(shortChannelId: RealShortChannelId, node1_priv: PrivateKey, node2_priv: PrivateKey, funding1_priv: PrivateKey, funding2_priv: PrivateKey) = { val witness = Announcements.generateChannelAnnouncementWitness(Block.RegtestGenesisBlock.hash, shortChannelId, node1_priv.publicKey, node2_priv.publicKey, funding1_priv.publicKey, funding2_priv.publicKey, Features.empty) val node1_sig = Announcements.signChannelAnnouncement(witness, node1_priv) @@ -240,4 +245,39 @@ object BaseRouterSpec { def channelHopFromUpdate(nodeId: PublicKey, nextNodeId: PublicKey, channelUpdate: ChannelUpdate): ChannelHop = { ChannelHop(channelUpdate.shortChannelId, nodeId, nextNodeId, HopRelayParams.FromAnnouncement(channelUpdate)) } + + def blindedRouteFromHops(amount: MilliSatoshi, + expiry: CltvExpiry, + hops: Seq[ChannelHop], + routeExpiry: CltvExpiry, + preimage: ByteVector32 = randomBytes32(), + pathId: ByteVector = randomBytes(32)): (Bolt12Invoice, BlindedHop, BlindedRecipient) = { + val (invoice, recipient) = blindedRoutesFromPaths(amount, expiry, Seq(hops), routeExpiry, preimage, pathId) + (invoice, recipient.blindedHops.head, recipient) + } + + def blindedRoutesFromPaths(amount: MilliSatoshi, + expiry: CltvExpiry, + paths: Seq[Seq[ChannelHop]], + routeExpiry: CltvExpiry, + preimage: ByteVector32 = randomBytes32(), + pathId: ByteVector = randomBytes(32)): (Bolt12Invoice, BlindedRecipient) = { + val recipientKey = randomKey() + val features = Features[InvoiceFeature]( + Features.VariableLengthOnion -> FeatureSupport.Mandatory, + Features.BasicMultiPartPayment -> FeatureSupport.Optional, + Features.RouteBlinding -> FeatureSupport.Mandatory + ) + val offer = Offer(None, "Bolt12 r0cks", recipientKey.publicKey, features, Block.RegtestGenesisBlock.hash) + val invoiceRequest = InvoiceRequest(offer, amount, 1, features, randomKey(), Block.RegtestGenesisBlock.hash) + val blindedRoutes = paths.map(hops => { + val blindedRoute = BlindedRouteCreation.createBlindedRouteFromHops(hops, pathId, 1 msat, routeExpiry).route + val paymentInfo = BlindedRouteCreation.aggregatePaymentInfo(amount, hops) + PaymentBlindedRoute(blindedRoute, paymentInfo) + }) + val invoice = Bolt12Invoice(offer, invoiceRequest, preimage, recipientKey, CltvExpiryDelta(6), features, blindedRoutes) + val recipient = BlindedRecipient(invoice, amount, expiry, Nil) + (invoice, recipient) + } + } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala index cefde92823..95a0fe4707 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala @@ -27,10 +27,11 @@ import fr.acinq.eclair.channel.{AvailableBalanceChanged, CommitmentsSpec, LocalC import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.io.Peer.PeerRoutingMessage import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop +import fr.acinq.eclair.payment.Invoice.ExtraEdge import fr.acinq.eclair.payment.send.{ClearRecipient, ClearTrampolineRecipient, SpontaneousRecipient} import fr.acinq.eclair.payment.{Bolt11Invoice, Invoice} import fr.acinq.eclair.router.Announcements.{makeChannelUpdate, makeNodeAnnouncement} -import fr.acinq.eclair.router.BaseRouterSpec.channelAnnouncement +import fr.acinq.eclair.router.BaseRouterSpec.{blindedRoutesFromPaths, channelAnnouncement} import fr.acinq.eclair.router.Graph.RoutingHeuristics import fr.acinq.eclair.router.RouteCalculationSpec.{DEFAULT_AMOUNT_MSAT, DEFAULT_EXPIRY, DEFAULT_ROUTE_PARAMS, route2NodeIds} import fr.acinq.eclair.router.Router._ @@ -479,6 +480,84 @@ class RouterSpec extends BaseRouterSpec { assert(route2.finalHop_opt.contains(trampolineHop)) } + test("routes found (with blinded hops)") { fixture => + import fixture._ + val sender = TestProbe() + val r = randomKey().publicKey + val hopsToRecipient = Seq( + ChannelHop(ShortChannelId(10000), b, r, HopRelayParams.FromHint(ExtraEdge(b, r, ShortChannelId(10000), 800 msat, 0, CltvExpiryDelta(36), 1 msat, Some(400_000 msat)))) :: Nil, + ChannelHop(ShortChannelId(10001), c, r, HopRelayParams.FromHint(ExtraEdge(c, r, ShortChannelId(10001), 500 msat, 0, CltvExpiryDelta(36), 1 msat, Some(400_000 msat)))) :: Nil, + ) + + { + // Amount split between both blinded routes: + val (_, recipient) = blindedRoutesFromPaths(600_000 msat, DEFAULT_EXPIRY, hopsToRecipient, DEFAULT_EXPIRY) + sender.send(router, RouteRequest(a, recipient, DEFAULT_ROUTE_PARAMS, allowMultiPart = true)) + val routes = sender.expectMsgType[RouteResponse].routes + assert(routes.length == 2) + assert(routes.flatMap(_.finalHop_opt) == recipient.blindedHops) + assert(routes.map(route => route2NodeIds(route)).toSet == Set(Seq(a, b), Seq(a, b, c))) + assert(routes.map(route => route.blindedFee + route.channelFee(false)).toSet == Set(510 msat, 800 msat)) + } + { + // One blinded route is ignored, we use the other one: + val (_, recipient) = blindedRoutesFromPaths(300_000 msat, DEFAULT_EXPIRY, hopsToRecipient, DEFAULT_EXPIRY) + val ignored = Ignore(Set.empty, Set(ChannelDesc(recipient.extraEdges.last.shortChannelId, recipient.extraEdges.last.sourceNodeId, recipient.extraEdges.last.targetNodeId))) + sender.send(router, RouteRequest(a, recipient, DEFAULT_ROUTE_PARAMS, ignore = ignored)) + val routes = sender.expectMsgType[RouteResponse].routes + assert(routes.length == 1) + assert(routes.head.finalHop_opt.nonEmpty) + assert(route2NodeIds(routes.head) == Seq(a, b)) + assert(routes.head.blindedFee == 800.msat) + } + { + // One blinded route is ignored, the other one doesn't have enough capacity: + val (_, recipient) = blindedRoutesFromPaths(500_000 msat, DEFAULT_EXPIRY, hopsToRecipient, DEFAULT_EXPIRY) + val ignored = Ignore(Set.empty, Set(ChannelDesc(recipient.extraEdges.last.shortChannelId, recipient.extraEdges.last.sourceNodeId, recipient.extraEdges.last.targetNodeId))) + sender.send(router, RouteRequest(a, recipient, DEFAULT_ROUTE_PARAMS, allowMultiPart = true, ignore = ignored)) + sender.expectMsg(Failure(RouteNotFound)) + } + { + // One blinded route is pending, we use the other one: + val (_, recipient) = blindedRoutesFromPaths(600_000 msat, DEFAULT_EXPIRY, hopsToRecipient, DEFAULT_EXPIRY) + sender.send(router, RouteRequest(a, recipient, DEFAULT_ROUTE_PARAMS, allowMultiPart = true)) + val routes1 = sender.expectMsgType[RouteResponse].routes + assert(routes1.length == 2) + sender.send(router, RouteRequest(a, recipient, DEFAULT_ROUTE_PARAMS, allowMultiPart = true, pendingPayments = Seq(routes1.head))) + val routes2 = sender.expectMsgType[RouteResponse].routes + assert(routes2 == routes1.tail) + } + { + // One blinded route is pending, we send two htlcs to the other one: + val (_, recipient) = blindedRoutesFromPaths(600_000 msat, DEFAULT_EXPIRY, hopsToRecipient, DEFAULT_EXPIRY) + sender.send(router, RouteRequest(a, recipient, DEFAULT_ROUTE_PARAMS, allowMultiPart = true)) + val routes1 = sender.expectMsgType[RouteResponse].routes + assert(routes1.length == 2) + sender.send(router, RouteRequest(a, recipient, DEFAULT_ROUTE_PARAMS, allowMultiPart = true, pendingPayments = Seq(routes1.head))) + val routes2 = sender.expectMsgType[RouteResponse].routes + assert(routes2 == routes1.tail) + sender.send(router, RouteRequest(a, recipient, DEFAULT_ROUTE_PARAMS, allowMultiPart = true, pendingPayments = Seq(routes1.head, routes2.head.copy(amount = routes2.head.amount - 25_000.msat)))) + val routes3 = sender.expectMsgType[RouteResponse].routes + assert(routes3.length == 1) + assert(routes3.head.amount == 25_000.msat) + } + { + // One blinded route is pending, we cannot use the other one because of the fee budget: + val (_, recipient) = blindedRoutesFromPaths(600_000 msat, DEFAULT_EXPIRY, hopsToRecipient, DEFAULT_EXPIRY) + val routeParams1 = DEFAULT_ROUTE_PARAMS.copy(boundaries = SearchBoundaries(5000 msat, 0.0, 6, CltvExpiryDelta(1008))) + sender.send(router, RouteRequest(a, recipient, routeParams1, allowMultiPart = true)) + val routes1 = sender.expectMsgType[RouteResponse].routes + assert(routes1.length == 2) + assert(routes1.head.blindedFee + routes1.head.channelFee(false) == 800.msat) + val routeParams2 = DEFAULT_ROUTE_PARAMS.copy(boundaries = SearchBoundaries(1000 msat, 0.0, 6, CltvExpiryDelta(1008))) + sender.send(router, RouteRequest(a, recipient, routeParams2, allowMultiPart = true, pendingPayments = Seq(routes1.head))) + sender.expectMsg(Failure(RouteNotFound)) + val routeParams3 = DEFAULT_ROUTE_PARAMS.copy(boundaries = SearchBoundaries(1500 msat, 0.0, 6, CltvExpiryDelta(1008))) + sender.send(router, RouteRequest(a, recipient, routeParams3, allowMultiPart = true, pendingPayments = Seq(routes1.head))) + assert(sender.expectMsgType[RouteResponse].routes.length == 1) + } + } + test("route not found (channel disabled)") { fixture => import fixture._ val sender = TestProbe()