From 6101933d9a9000106d1de2ca93a6ee10fdca269e Mon Sep 17 00:00:00 2001 From: Thomas HUET Date: Fri, 14 Oct 2022 15:50:08 +0200 Subject: [PATCH 1/3] Blind payments using BlindedHop --- .../scala/fr/acinq/eclair/db/PaymentsDb.scala | 3 +- .../acinq/eclair/json/JsonSerializers.scala | 30 ++++-- .../acinq/eclair/payment/Bolt11Invoice.scala | 14 ++- .../acinq/eclair/payment/Bolt12Invoice.scala | 30 ++++-- .../fr/acinq/eclair/payment/Invoice.scala | 30 +++++- .../acinq/eclair/payment/PaymentPacket.scala | 101 +++++++++++++----- .../eclair/payment/relay/NodeRelay.scala | 8 +- .../send/MultiPartPaymentLifecycle.scala | 14 +-- .../payment/send/PaymentInitiator.scala | 28 ++--- .../payment/send/PaymentLifecycle.scala | 89 ++++++++------- .../acinq/eclair/router/BalanceEstimate.scala | 6 +- .../scala/fr/acinq/eclair/router/Graph.scala | 23 ++-- .../eclair/router/RouteCalculation.scala | 8 +- .../scala/fr/acinq/eclair/router/Router.scala | 33 ++++-- .../eclair/wire/protocol/OfferTypes.scala | 6 +- .../eclair/wire/protocol/PaymentOnion.scala | 70 +++++++++--- .../eclair/wire/protocol/RouteBlinding.scala | 6 +- .../fr/acinq/eclair/channel/FuzzySpec.scala | 8 +- .../ChannelStateTestsHelperMethods.scala | 6 +- .../channel/states/f/ShutdownStateSpec.scala | 8 +- .../fr/acinq/eclair/crypto/SphinxSpec.scala | 12 +-- .../eclair/payment/Bolt12InvoiceSpec.scala | 10 +- .../eclair/payment/MultiPartHandlerSpec.scala | 4 +- .../MultiPartPaymentLifecycleSpec.scala | 69 +++++++----- .../eclair/payment/PaymentInitiatorSpec.scala | 29 ++--- .../eclair/payment/PaymentLifecycleSpec.scala | 61 +++++------ .../eclair/payment/PaymentPacketSpec.scala | 62 +++++------ .../payment/PostRestartHtlcCleanerSpec.scala | 4 +- .../payment/relay/NodeRelayerSpec.scala | 14 +-- .../eclair/payment/relay/RelayerSpec.scala | 16 +-- .../eclair/router/RouteCalculationSpec.scala | 10 +- .../fr/acinq/eclair/router/RouterSpec.scala | 18 ++-- .../wire/protocol/PaymentOnionSpec.scala | 21 ++-- 33 files changed, 523 insertions(+), 328 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala index e4e7002f27..b4b848eca6 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala @@ -19,7 +19,7 @@ package fr.acinq.eclair.db import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.payment._ -import fr.acinq.eclair.router.Router.{ChannelHop, Hop, NodeHop} +import fr.acinq.eclair.router.Router.{BlindedHop, ChannelHop, Hop, NodeHop} import fr.acinq.eclair.{MilliSatoshi, ShortChannelId, TimestampMilli} import scodec.bits.ByteVector @@ -227,6 +227,7 @@ object HopSummary { val shortChannelId = h match { case ch: ChannelHop => Some(ch.shortChannelId) case _: NodeHop => None + case _: BlindedHop => None } HopSummary(h.nodeId, h.nextNodeId, shortChannelId) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala b/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala index dfd05bb85b..df3af56dd6 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala @@ -31,10 +31,11 @@ import fr.acinq.eclair.io.Peer import fr.acinq.eclair.message.OnionMessages import fr.acinq.eclair.payment.PaymentFailure.PaymentFailedSummary import fr.acinq.eclair.payment._ -import fr.acinq.eclair.router.Router.{ChannelRelayParams, Route} +import fr.acinq.eclair.router.Router.{BlindedHop, ChannelHop, ChannelRelayParams, Route} import fr.acinq.eclair.transactions.DirectedHtlc import fr.acinq.eclair.transactions.Transactions._ import fr.acinq.eclair.wire.protocol.MessageOnionCodecs.blindedRouteCodec +import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{Alias, BlockHeight, CltvExpiry, CltvExpiryDelta, Feature, FeatureSupport, MilliSatoshi, ShortChannelId, TimestampMilli, TimestampSecond, UInt64, UnknownFeature} import org.json4s @@ -294,9 +295,14 @@ object ColorSerializer extends MinimalSerializer({ }) // @formatter:off -private case class ChannelHopJson(nodeId: PublicKey, nextNodeId: PublicKey, source: ChannelRelayParams) -private case class RouteFullJson(amount: MilliSatoshi, hops: Seq[ChannelHopJson]) -object RouteFullSerializer extends ConvertClassSerializer[Route](route => RouteFullJson(route.amount, route.hops.map(h => ChannelHopJson(h.nodeId, h.nextNodeId, h.params)))) +private sealed trait HopJson +private case class ChannelHopJson(nodeId: PublicKey, nextNodeId: PublicKey, source: ChannelRelayParams) extends HopJson +private case class BlindedHopJson(nodeId: PublicKey, nextNodeId: PublicKey, paymentInfo: PaymentInfo) extends HopJson +private case class RouteFullJson(amount: MilliSatoshi, hops: Seq[HopJson]) +object RouteFullSerializer extends ConvertClassSerializer[Route](route => RouteFullJson(route.amount, route.hops.map { + case h: ChannelHop => ChannelHopJson(h.nodeId, h.nextNodeId, h.params) + case h: BlindedHop => BlindedHopJson(h.nodeId, h.nextNodeId, h.paymentInfo) +})) private case class RouteNodeIdsJson(amount: MilliSatoshi, nodeIds: Seq[PublicKey]) object RouteNodeIdsSerializer extends ConvertClassSerializer[Route](route => { @@ -307,8 +313,12 @@ object RouteNodeIdsSerializer extends ConvertClassSerializer[Route](route => { RouteNodeIdsJson(route.amount, nodeIds) }) -private case class RouteShortChannelIdsJson(amount: MilliSatoshi, shortChannelIds: Seq[ShortChannelId]) -object RouteShortChannelIdsSerializer extends ConvertClassSerializer[Route](route => RouteShortChannelIdsJson(route.amount, route.hops.map(_.shortChannelId))) +private case class RouteShortChannelIdsJson(amount: MilliSatoshi, shortChannelIds: Seq[String]) +object RouteShortChannelIdsSerializer extends ConvertClassSerializer[Route](route => + RouteShortChannelIdsJson(route.amount, route.hops.map { + case hop: ChannelHop => hop.shortChannelId.toString + case _: BlindedHop => "blinded" + })) // @formatter:on // @formatter:off @@ -395,7 +405,7 @@ object InvoiceSerializer extends MinimalSerializer({ case p: Bolt12Invoice => val fieldList = List( JField("amount", JLong(p.amount.toLong)), - JField("nodeId", JString(p.nodeId.toString())), + JField("nodeId", JString(p.signingNodeId.toString())), JField("paymentHash", JString(p.paymentHash.toString())), p.description.fold(string => JField("description", JString(string)), hash => JField("descriptionHash", JString(hash.toHex))), JField("features", Extraction.decompose(p.features)( @@ -404,10 +414,10 @@ object InvoiceSerializer extends MinimalSerializer({ FeatureSupportSerializer + UnknownFeatureSerializer )), - JField("blindedPaths", JArray(p.blindedPaths.map(path => { + JField("blindedPaths", JArray(p.extraEdges.map(path => { JObject(List( - JField("introductionNodeId", JString(path.introductionNodeId.toString())), - JField("blindedNodeIds", JArray(path.blindedNodes.map(n => JString(n.blindedPublicKey.toString())).toList)) + JField("introductionNodeId", JString(path.path.introductionNodeId.toString())), + JField("blindedNodeIds", JArray(path.path.blindedNodes.map(n => JString(n.blindedPublicKey.toString())).toList)) )) }).toList)), JField("createdAt", JLong(p.createdAt.toLong)), diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt11Invoice.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt11Invoice.scala index 7099374910..682bd16246 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt11Invoice.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt11Invoice.scala @@ -19,7 +19,10 @@ package fr.acinq.eclair.payment import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, ByteVector64, Crypto} import fr.acinq.bitcoin.{Base58, Base58Check, Bech32} -import fr.acinq.eclair.{CltvExpiryDelta, Feature, FeatureSupport, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, ShortChannelId, TimestampSecond, randomBytes32} +import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload.Partial +import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, PerHopPayload} +import fr.acinq.eclair.wire.protocol.{GenericTlv, OnionRoutingPacket, PaymentOnion} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, FeatureSupport, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, ShortChannelId, TimestampSecond, randomBytes32} import scodec.bits.{BitVector, ByteOrdering, ByteVector} import scodec.codecs.{list, ubyte} import scodec.{Codec, Err} @@ -129,6 +132,15 @@ case class Bolt11Invoice(prefix: String, amount_opt: Option[MilliSatoshi], creat val int5s = eight2fiveCodec.decode(data).require.value Bech32.encode(hrp, int5s.toArray, Bech32.Encoding.Bech32) } + + override def singlePartFinalPayload(amount: MilliSatoshi, expiry: CltvExpiry, userCustomTlvs: Seq[GenericTlv]): FinalPayload.Standard = + FinalPayload.Standard.createSinglePartPayload(amount, expiry, paymentSecret, paymentMetadata, userCustomTlvs) + + override def multiPartFinalPayload(totalAmount: MilliSatoshi, expiry: CltvExpiry, userCustomTlvs: Seq[GenericTlv]): FinalPayload.Standard.Partial = + FinalPayload.Standard.createMultiPartPayload(totalAmount, expiry, paymentSecret, paymentMetadata, userCustomTlvs = userCustomTlvs) + + override def trampolinePayload(totalAmount: MilliSatoshi, expiry: CltvExpiry, trampolineSecret: ByteVector32, trampolinePacket: OnionRoutingPacket): FinalPayload.Standard.Partial = + FinalPayload.Standard.createTrampolinePayload(totalAmount, expiry, trampolineSecret, trampolinePacket, paymentMetadata) } object Bolt11Invoice { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt12Invoice.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt12Invoice.scala index 5405c546a4..b3bd99b27e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt12Invoice.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt12Invoice.scala @@ -17,14 +17,15 @@ package fr.acinq.eclair.payment import fr.acinq.bitcoin.Bech32 -import fr.acinq.bitcoin.scalacompat.Crypto.PrivateKey +import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, ByteVector64, Crypto} import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.crypto.Sphinx.RouteBlinding +import fr.acinq.eclair.payment.Invoice.BlindedEdge import fr.acinq.eclair.wire.protocol.OfferTypes._ import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{InvalidTlvPayload, MissingRequiredTlv} -import fr.acinq.eclair.wire.protocol.{OfferCodecs, OfferTypes, TlvStream} -import fr.acinq.eclair.{CltvExpiryDelta, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, TimestampSecond, UInt64, randomBytes32} +import fr.acinq.eclair.wire.protocol.PaymentOnion.{BlindedPerHopPayload, FinalPayload} +import fr.acinq.eclair.wire.protocol.{GenericTlv, OfferCodecs, OfferTypes, OnionRoutingPacket, TlvStream} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, TimestampSecond, UInt64, randomKey} import scodec.bits.ByteVector import java.util.concurrent.TimeUnit @@ -41,19 +42,20 @@ case class Bolt12Invoice(records: TlvStream[InvoiceTlv]) extends Invoice { val amount: MilliSatoshi = records.get[Amount].map(_.amount).get override val amount_opt: Option[MilliSatoshi] = Some(amount) - override val nodeId: Crypto.PublicKey = records.get[NodeId].get.publicKey + override val nodeId: Crypto.PublicKey = randomKey().publicKey override val paymentHash: ByteVector32 = records.get[PaymentHash].get.hash - override val paymentSecret: ByteVector32 = randomBytes32() override val paymentMetadata: Option[ByteVector] = None override val description: Either[String, ByteVector32] = Left(records.get[Description].get.description) - override val extraEdges: Seq[Invoice.ExtraEdge] = Seq.empty // TODO: the blinded paths need to be converted to graph edges + override val extraEdges: Seq[BlindedEdge] = records.get[Paths].get.paths.zip(records.get[PaymentPathsInfo].get.paymentInfo).map { + case (path, payInfo) => BlindedEdge(path, payInfo, nodeId) + } override val createdAt: TimestampSecond = records.get[CreatedAt].get.timestamp override val relativeExpiry: FiniteDuration = FiniteDuration(records.get[RelativeExpiry].map(_.seconds).getOrElse(DEFAULT_EXPIRY_SECONDS), TimeUnit.SECONDS) override val minFinalCltvExpiryDelta: CltvExpiryDelta = records.get[Cltv].map(_.minFinalCltvExpiry).getOrElse(DEFAULT_MIN_FINAL_EXPIRY_DELTA) override val features: Features[InvoiceFeature] = records.get[FeaturesTlv].map(_.features.invoiceFeatures()).getOrElse(Features.empty) + val signingNodeId: PublicKey = records.get[NodeId].get.publicKey val chain: ByteVector32 = records.get[Chain].map(_.hash).getOrElse(Block.LivenetGenesisBlock.hash) val offerId: Option[ByteVector32] = records.get[OfferId].map(_.offerId) - val blindedPaths: Seq[RouteBlinding.BlindedRoute] = records.get[Paths].get.paths val issuer: Option[String] = records.get[Issuer].map(_.issuer) val quantity: Option[Long] = records.get[Quantity].map(_.quantity) val refundFor: Option[ByteVector32] = records.get[RefundFor].map(_.refundedPaymentHash) @@ -67,7 +69,7 @@ case class Bolt12Invoice(records: TlvStream[InvoiceTlv]) extends Invoice { // It is assumed that the request is valid for this offer. def isValidFor(offer: Offer, request: InvoiceRequest): Boolean = { - nodeId == offer.nodeId && + signingNodeId == offer.nodeId && checkSignature() && offerId.contains(request.offerId) && request.chain == chain && @@ -91,7 +93,7 @@ case class Bolt12Invoice(records: TlvStream[InvoiceTlv]) extends Invoice { } def checkSignature(): Boolean = { - verifySchnorr(signatureTag("signature"), rootHash(OfferTypes.removeSignature(records), OfferCodecs.invoiceTlvCodec), signature, OfferTypes.xOnlyPublicKey(nodeId)) + verifySchnorr(signatureTag("signature"), rootHash(OfferTypes.removeSignature(records), OfferCodecs.invoiceTlvCodec), signature, OfferTypes.xOnlyPublicKey(signingNodeId)) } override def toString: String = { @@ -99,6 +101,14 @@ case class Bolt12Invoice(records: TlvStream[InvoiceTlv]) extends Invoice { Bech32.encodeBytes(hrp, data.toArray, Bech32.Encoding.Beck32WithoutChecksum) } + override def singlePartFinalPayload(amount: MilliSatoshi, expiry: CltvExpiry, userCustomTlvs: Seq[GenericTlv]): BlindedPerHopPayload = + FinalPayload.Blinded.createSinglePartPayload(amount, userCustomTlvs) + + override def multiPartFinalPayload(totalAmount: MilliSatoshi, expiry: CltvExpiry, userCustomTlvs: Seq[GenericTlv]): FinalPayload.Blinded.Partial = + FinalPayload.Blinded.createMultiPartPayload(totalAmount, userCustomTlvs) + + override def trampolinePayload(totalAmount: MilliSatoshi, expiry: CltvExpiry, trampolineSecret: ByteVector32, trampolinePacket: OnionRoutingPacket): FinalPayload.Blinded.Partial = + FinalPayload.Blinded.createTrampolinePayload(totalAmount, trampolinePacket) } object Bolt12Invoice { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Invoice.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Invoice.scala index 0dccf730c3..db577e9f5e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Invoice.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Invoice.scala @@ -18,9 +18,13 @@ package fr.acinq.eclair.payment import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey +import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute import fr.acinq.eclair.payment.relay.Relayer -import fr.acinq.eclair.wire.protocol.ChannelUpdate -import fr.acinq.eclair.{CltvExpiryDelta, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, ShortChannelId, TimestampSecond} +import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo +import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload.Partial +import fr.acinq.eclair.wire.protocol.PaymentOnion.PerHopPayload +import fr.acinq.eclair.wire.protocol.{ChannelUpdate, GenericTlv, OnionRoutingPacket} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, ShortChannelId, TimestampSecond} import scodec.bits.ByteVector import scala.concurrent.duration.FiniteDuration @@ -35,8 +39,6 @@ trait Invoice { val paymentHash: ByteVector32 - val paymentSecret: ByteVector32 - val paymentMetadata: Option[ByteVector] val description: Either[String, ByteVector32] @@ -52,6 +54,12 @@ trait Invoice { def isExpired(): Boolean = createdAt + relativeExpiry.toSeconds <= TimestampSecond.now() def toString: String + + def singlePartFinalPayload(amount: MilliSatoshi, expiry: CltvExpiry, userCustomTlvs: Seq[GenericTlv] = Nil): PerHopPayload + + def multiPartFinalPayload(totalAmount: MilliSatoshi, expiry: CltvExpiry, userCustomTlvs: Seq[GenericTlv] = Nil): Partial + + def trampolinePayload(totalAmount: MilliSatoshi, expiry: CltvExpiry, trampolineSecret: ByteVector32, trampolinePacket: OnionRoutingPacket): Partial } object Invoice { @@ -59,12 +67,14 @@ object Invoice { sealed trait ExtraEdge { // @formatter:off def sourceNodeId: PublicKey + def targetNodeId: PublicKey + def shortChannelId: ShortChannelId def feeBase: MilliSatoshi def feeProportionalMillionths: Long def cltvExpiryDelta: CltvExpiryDelta def htlcMinimum: MilliSatoshi def htlcMaximum_opt: Option[MilliSatoshi] - def relayFees: Relayer.RelayFees = Relayer.RelayFees(feeBase = feeBase, feeProportionalMillionths = feeProportionalMillionths) + final def relayFees: Relayer.RelayFees = Relayer.RelayFees(feeBase = feeBase, feeProportionalMillionths = feeProportionalMillionths) // @formatter:on } @@ -81,6 +91,16 @@ object Invoice { def update(u: ChannelUpdate): BasicEdge = copy(feeBase = u.feeBaseMsat, feeProportionalMillionths = u.feeProportionalMillionths, cltvExpiryDelta = u.cltvExpiryDelta) } + case class BlindedEdge(path: BlindedRoute, payInfo: PaymentInfo, targetNodeId: PublicKey) extends ExtraEdge { + override val sourceNodeId: PublicKey = path.introductionNodeId + override val shortChannelId: ShortChannelId = ShortChannelId.generateLocalAlias() + override val feeBase: MilliSatoshi = payInfo.feeBase + override val feeProportionalMillionths: Long = payInfo.feeProportionalMillionths + override val cltvExpiryDelta: CltvExpiryDelta = payInfo.cltvExpiryDelta + override val htlcMinimum: MilliSatoshi = payInfo.minHtlc + override val htlcMaximum_opt: Option[MilliSatoshi] = Some(payInfo.maxHtlc) + } + def fromString(input: String): Try[Invoice] = { if (input.toLowerCase.startsWith("lni")) { Bolt12Invoice.fromString(input) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala index 2a4cb71588..c6ab330b29 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala @@ -22,15 +22,18 @@ import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.eclair.channel.{CMD_ADD_HTLC, CMD_FAIL_HTLC, CannotExtractSharedSecret, Origin} import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.router.Router.{ChannelHop, Hop, NodeHop} -import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload, PerHopPayload} +import fr.acinq.eclair.router.Router +import fr.acinq.eclair.router.Router.{BlindedHop, ChannelHop, ConnectedHop, Hop, NodeHop} +import fr.acinq.eclair.wire.protocol.OnionPaymentPayloadTlv.{BlindingPoint, EncryptedRecipientData} +import fr.acinq.eclair.wire.protocol.PaymentOnion.{BlindedPerHopPayload, FinalPayload, IntermediatePayload, PerHopPayload} +import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.RouteBlindingDecryptedData import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, ShortChannelId, UInt64, randomBytes32, randomKey} import scodec.bits.ByteVector import scodec.{Attempt, DecodeResult} import java.util.UUID -import scala.util.Try +import scala.util.{Failure, Try} /** * Created by t-bast on 08/10/2019. @@ -196,7 +199,7 @@ object IncomingPaymentPacket { case innerPayload => // We merge contents from the outer and inner payloads. // We must use the inner payload's total amount and payment secret because the payment may be split between multiple trampoline payments (#reckless). - Right(FinalPacket(add, FinalPayload.Standard.createMultiPartPayload(outerPayload.amount, innerPayload.totalAmount, outerPayload.expiry, innerPayload.paymentSecret, innerPayload.paymentMetadata))) + Right(FinalPacket(add, FinalPayload.Standard.createMultiPartPayload(innerPayload.totalAmount, outerPayload.expiry, innerPayload.paymentSecret, innerPayload.paymentMetadata).withAmount(outerPayload.amount))) } } } @@ -232,6 +235,18 @@ object OutgoingPaymentPacket { Sphinx.create(sessionKey, packetPayloadLength, nodes, payloadsBin, Some(associatedData)) } + def buildBlindedPayloads(hop: BlindedHop, finalPayloads: Seq[PerHopPayload]): Seq[PerHopPayload]= + if (hop.route.blindedNodes.length == 1) { + val additionalTlvs = Seq(EncryptedRecipientData(hop.route.encryptedPayloads.head), BlindingPoint(hop.route.blindingKey)) + val blindedPayload = BlindedPerHopPayload(finalPayloads.head.records.copy(records = finalPayloads.head.records.records.toSeq ++ additionalTlvs)) + blindedPayload +: finalPayloads.drop(1) + } else { + val firstBlinded = BlindedPerHopPayload(TlvStream(EncryptedRecipientData(hop.route.encryptedPayloads.head), BlindingPoint(hop.route.blindingKey))) + val intermediateBlinded = hop.route.encryptedPayloads.drop(1).dropRight(1).map(data => BlindedPerHopPayload(TlvStream(EncryptedRecipientData(data)))) + val lastBlinded = BlindedPerHopPayload(finalPayloads.head.records.copy(records = finalPayloads.head.records.records.toSeq :+ EncryptedRecipientData(hop.route.encryptedPayloads.last))) + firstBlinded +: (intermediateBlinded ++ (lastBlinded +: finalPayloads.drop(1))) + } + /** * Build the onion payloads for each hop. * @@ -242,14 +257,22 @@ object OutgoingPaymentPacket { * - firstExpiry is the cltv expiry for the first htlc in the route * - a sequence of payloads that will be used to build the onion */ - def buildPayloads(hops: Seq[Hop], finalPayload: FinalPayload): (MilliSatoshi, CltvExpiry, Seq[PerHopPayload]) = { - hops.reverse.foldLeft((finalPayload.amount, finalPayload.expiry, Seq[PerHopPayload](finalPayload))) { - case ((amount, expiry, payloads), hop) => - val payload = hop match { - case hop: ChannelHop => IntermediatePayload.ChannelRelay.Standard(hop.shortChannelId, amount, expiry) - case hop: NodeHop => IntermediatePayload.NodeRelay.Standard(amount, expiry, hop.nextNodeId) + def buildPayloads(hops: Seq[Hop], finalPayload: PerHopPayload, lastAmount: MilliSatoshi, lastExpiry: CltvExpiry): (MilliSatoshi, CltvExpiry, Seq[PerHopPayload]) = { + val (firstAmount, firstExpiry, payloads) = hops.drop(1).reverse.foldLeft((lastAmount, lastExpiry, Seq(finalPayload))) { + case ((amount, expiry, finalPayloads), hop) => + val payloads = hop match { + case hop: ChannelHop => IntermediatePayload.ChannelRelay.Standard(hop.shortChannelId, amount, expiry) +: finalPayloads + case hop: NodeHop => IntermediatePayload.NodeRelay.Standard(amount, expiry, hop.nextNodeId) +: finalPayloads + case hop: BlindedHop => buildBlindedPayloads(hop, finalPayloads) } - (amount + hop.fee(amount), expiry + hop.cltvExpiryDelta, payload +: payloads) + (amount + hop.fee(amount), expiry + hop.cltvExpiryDelta, payloads) + } + // The first payload would be for us, so we don't need to build it. + // However a single blinded hop can contain many payloads, in that case we need to build the blinded payloads and drop the first one. + // The fees for our part of the blinded hop have already been deducted in `buildCommand`. + hops.head match { + case hop: BlindedHop => (firstAmount + hop.fee(firstAmount), firstExpiry + hop.cltvExpiryDelta, buildBlindedPayloads(hop, payloads).drop(1)) + case _ => (firstAmount, firstExpiry, payloads) } } @@ -263,18 +286,22 @@ object OutgoingPaymentPacket { * - firstExpiry is the cltv expiry for the first htlc in the route * - the onion to include in the HTLC */ - private def buildPacket(packetPayloadLength: Int, paymentHash: ByteVector32, hops: Seq[Hop], finalPayload: FinalPayload): Try[(MilliSatoshi, CltvExpiry, Sphinx.PacketAndSecrets)] = { - val (firstAmount, firstExpiry, payloads) = buildPayloads(hops.drop(1), finalPayload) - val nodes = hops.map(_.nextNodeId) + private def buildPacket(packetPayloadLength: Int, paymentHash: ByteVector32, hops: Seq[Hop], finalPayload: PerHopPayload, amount: MilliSatoshi, expiry: CltvExpiry): Try[(MilliSatoshi, CltvExpiry, Sphinx.PacketAndSecrets)] = { + val (firstAmount, firstExpiry, payloads) = buildPayloads(hops, finalPayload, amount, expiry) + val nodes = hops.flatMap { + case hop: ChannelHop => Seq(hop.nextNodeId) + case hop: NodeHop => Seq(hop.nextNodeId) + case hop: BlindedHop => hop.route.blindedNodeIds.drop(1) + } // BOLT 2 requires that associatedData == paymentHash buildOnion(packetPayloadLength, nodes, payloads, paymentHash).map(onion => (firstAmount, firstExpiry, onion)) } - def buildPaymentPacket(paymentHash: ByteVector32, hops: Seq[Hop], finalPayload: FinalPayload): Try[(MilliSatoshi, CltvExpiry, Sphinx.PacketAndSecrets)] = - buildPacket(PaymentOnionCodecs.paymentOnionPayloadLength, paymentHash, hops, finalPayload) + def buildPaymentPacket(paymentHash: ByteVector32, hops: Seq[Hop], finalPayload: PerHopPayload, amount: MilliSatoshi, expiry: CltvExpiry): Try[(MilliSatoshi, CltvExpiry, Sphinx.PacketAndSecrets)] = + buildPacket(PaymentOnionCodecs.paymentOnionPayloadLength, paymentHash, hops, finalPayload, amount, expiry) - def buildTrampolinePacket(paymentHash: ByteVector32, hops: Seq[Hop], finalPayload: FinalPayload): Try[(MilliSatoshi, CltvExpiry, Sphinx.PacketAndSecrets)] = - buildPacket(PaymentOnionCodecs.trampolineOnionPayloadLength, paymentHash, hops, finalPayload) + def buildTrampolinePacket(paymentHash: ByteVector32, hops: Seq[Hop], finalPayload: PerHopPayload, amount: MilliSatoshi, expiry: CltvExpiry): Try[(MilliSatoshi, CltvExpiry, Sphinx.PacketAndSecrets)] = + buildPacket(PaymentOnionCodecs.trampolineOnionPayloadLength, paymentHash, hops, finalPayload, amount, expiry) /** * Build an encrypted trampoline onion packet when the final recipient doesn't support trampoline. @@ -288,15 +315,15 @@ object OutgoingPaymentPacket { * - firstExpiry is the cltv expiry for the first trampoline node in the route * - the trampoline onion to include in final payload of a normal onion */ - def buildTrampolineToLegacyPacket(invoice: Bolt11Invoice, hops: Seq[NodeHop], finalPayload: FinalPayload): Try[(MilliSatoshi, CltvExpiry, Sphinx.PacketAndSecrets)] = { + def buildTrampolineToLegacyPacket(invoice: Bolt11Invoice, hops: Seq[NodeHop], finalPayload: PerHopPayload, amount: MilliSatoshi, expiry: CltvExpiry): Try[(MilliSatoshi, CltvExpiry, Sphinx.PacketAndSecrets)] = { // NB: the final payload will never reach the recipient, since the next-to-last node in the trampoline route will convert that to a non-trampoline payment. // We use the smallest final payload possible, otherwise we may overflow the trampoline onion size. - val dummyFinalPayload = FinalPayload.Standard.createSinglePartPayload(finalPayload.amount, finalPayload.expiry, randomBytes32(), None) - val (firstAmount, firstExpiry, payloads) = hops.drop(1).reverse.foldLeft((finalPayload.amount, finalPayload.expiry, Seq[PerHopPayload](dummyFinalPayload))) { + val dummyFinalPayload = FinalPayload.Standard.createSinglePartPayload(amount, expiry, randomBytes32(), None) + val (firstAmount, firstExpiry, payloads) = hops.drop(1).reverse.foldLeft((amount, expiry, Seq[PerHopPayload](dummyFinalPayload))) { case ((amount, expiry, payloads), hop) => // The next-to-last node in the trampoline route must receive invoice data to indicate the conversion to a non-trampoline payment. val payload = if (payloads.length == 1) { - IntermediatePayload.NodeRelay.Standard.createNodeRelayToNonTrampolinePayload(finalPayload.amount, finalPayload.totalAmount, finalPayload.expiry, hop.nextNodeId, invoice) + IntermediatePayload.NodeRelay.Standard.createNodeRelayToNonTrampolinePayload(amount, amount, expiry, hop.nextNodeId, invoice) } else { IntermediatePayload.NodeRelay.Standard(amount, expiry, hop.nextNodeId) } @@ -322,10 +349,34 @@ object OutgoingPaymentPacket { * * @return the command and the onion shared secrets (used to decrypt the error in case of payment failure) */ - def buildCommand(replyTo: ActorRef, upstream: Upstream, paymentHash: ByteVector32, hops: Seq[ChannelHop], finalPayload: FinalPayload): Try[(CMD_ADD_HTLC, Seq[(ByteVector32, PublicKey)])] = { - buildPaymentPacket(paymentHash, hops, finalPayload).map { + def buildCommand(replyTo: ActorRef, + privateKey: PrivateKey, + upstream: Upstream, + paymentHash: ByteVector32, + hops: Seq[ConnectedHop], + finalPayload: PerHopPayload, + amount: MilliSatoshi, + expiry: CltvExpiry): Try[(ShortChannelId, CMD_ADD_HTLC, Seq[(ByteVector32, PublicKey)])] = { + val (shortChannelId, nextBlindingKey_opt, nextHops) = hops.head match { + case hop: ChannelHop => (hop.shortChannelId, None, hops) + case hop: Router.BlindedHop => + RouteBlindingEncryptedDataCodecs.decode(privateKey, hop.route.blindingKey, hop.route.encryptedPayloads.head) match { + case Left(e) => return Failure(e) + case Right(RouteBlindingDecryptedData(encryptedDataTlvs, nextBlindingKey)) => + IntermediatePayload.ChannelRelay.Blinded.validate(TlvStream(EncryptedRecipientData(ByteVector.empty)), encryptedDataTlvs, nextBlindingKey) match { + case Left(invalidTlv) => return Failure(RouteBlindingEncryptedDataCodecs.CannotDecodeData(invalidTlv.failureMessage.message)) + case Right(payload) => + // TODO(trampoline-to-blind): Check fees and CLTV. As long as we are the sender it's fine but it is needed if we trampoline the payment for someone else. + val amountWithFees = amount + hop.paymentInfo.fee(amount) + val remainingFee = payload.amountToForward(amountWithFees) - amount + val tailPaymentInfo = hop.paymentInfo.copy(feeBase = remainingFee, feeProportionalMillionths = 0, cltvExpiryDelta = hop.paymentInfo.cltvExpiryDelta - payload.cltvExpiryDelta) + (payload.outgoingChannelId, Some(nextBlindingKey), Seq(hop.copy(paymentInfo = tailPaymentInfo))) + } + } + } + buildPaymentPacket(paymentHash, nextHops, finalPayload, amount, expiry).map { case (firstAmount, firstExpiry, onion) => - CMD_ADD_HTLC(replyTo, firstAmount, paymentHash, firstExpiry, onion.packet, None, Origin.Hot(replyTo, upstream), commit = true) -> onion.sharedSecrets + (shortChannelId, CMD_ADD_HTLC(replyTo, firstAmount, paymentHash, firstExpiry, onion.packet, nextBlindingKey_opt, Origin.Hot(replyTo, upstream), commit = true), onion.sharedSecrets) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index 5dcdfba6cc..3828305e92 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -311,14 +311,15 @@ class NodeRelay private(nodeParams: NodeParams, val paymentSecret = payloadOut.paymentSecret.get // NB: we've verified that there was a payment secret in validateRelay if (Features(features).hasFeature(Features.BasicMultiPartPayment)) { context.log.debug("sending the payment to non-trampoline recipient using MPP") - val payment = SendMultiPartPayment(payFsmAdapters, paymentSecret, payloadOut.outgoingNodeId, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, payloadOut.paymentMetadata, extraEdges, routeParams) + val finalPayload = FinalPayload.Standard.createMultiPartPayload(payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, payloadOut.paymentMetadata) + val payment = SendMultiPartPayment(payFsmAdapters, payloadOut.outgoingNodeId, finalPayload, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, extraEdges, routeParams) val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = true) payFSM ! payment payFSM } else { context.log.debug("sending the payment to non-trampoline recipient without MPP") val finalPayload = FinalPayload.Standard.createSinglePartPayload(payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, payloadOut.paymentMetadata) - val payment = SendPaymentToNode(payFsmAdapters, payloadOut.outgoingNodeId, finalPayload, nodeParams.maxPaymentAttempts, extraEdges, routeParams) + val payment = SendPaymentToNode(payFsmAdapters, payloadOut.outgoingNodeId, finalPayload, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, extraEdges, routeParams) val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = false) payFSM ! payment payFSM @@ -327,7 +328,8 @@ class NodeRelay private(nodeParams: NodeParams, context.log.debug("sending the payment to the next trampoline node") val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = true) val paymentSecret = randomBytes32() // we generate a new secret to protect against probing attacks - val payment = SendMultiPartPayment(payFsmAdapters, paymentSecret, payloadOut.outgoingNodeId, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, None, routeParams = routeParams, additionalTlvs = Seq(OnionPaymentPayloadTlv.TrampolineOnion(packetOut))) + val finalPayload = FinalPayload.Standard.createTrampolinePayload(payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, packetOut, payloadOut.paymentMetadata) + val payment = SendMultiPartPayment(payFsmAdapters, payloadOut.outgoingNodeId, finalPayload, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, routeParams = routeParams) payFSM ! payment payFSM } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala index d0bc976e6e..6d758c1b11 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala @@ -30,10 +30,8 @@ import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToRoute import fr.acinq.eclair.router.Router._ -import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload -import fr.acinq.eclair.wire.protocol._ +import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload.Partial import fr.acinq.eclair.{CltvExpiry, FSMDiagnosticActorLogging, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, TimestampMilli} -import scodec.bits.ByteVector import java.util.UUID import java.util.concurrent.TimeUnit @@ -317,16 +315,13 @@ object MultiPartPaymentLifecycle { * @param userCustomTlvs when provided, additional user-defined custom tlvs that will be added to the onion sent to the target node. */ case class SendMultiPartPayment(replyTo: ActorRef, - paymentSecret: ByteVector32, targetNodeId: PublicKey, + finalPayload: Partial, totalAmount: MilliSatoshi, targetExpiry: CltvExpiry, maxAttempts: Int, - paymentMetadata: Option[ByteVector], extraEdges: Seq[ExtraEdge] = Nil, - routeParams: RouteParams, - additionalTlvs: Seq[OnionPaymentPayloadTlv] = Nil, - userCustomTlvs: Seq[GenericTlv] = Nil) { + routeParams: RouteParams) { require(totalAmount > 0.msat, s"total amount must be > 0") } @@ -405,8 +400,7 @@ object MultiPartPaymentLifecycle { Some(cfg.paymentContext)) private def createChildPayment(replyTo: ActorRef, route: Route, request: SendMultiPartPayment): SendPaymentToRoute = { - val finalPayload = FinalPayload.Standard.createMultiPartPayload(route.amount, request.totalAmount, request.targetExpiry, request.paymentSecret, request.paymentMetadata, request.additionalTlvs, request.userCustomTlvs) - SendPaymentToRoute(replyTo, Right(route), finalPayload) + SendPaymentToRoute(replyTo, Right(route), request.finalPayload.withAmount(route.amount), route.amount, request.targetExpiry) } /** When we receive an error from the final recipient or payment gets settled on chain, we should fail the whole payment, it's useless to retry. */ diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala index 84d44de269..bdcfd98221 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala @@ -19,7 +19,6 @@ package fr.acinq.eclair.payment.send import akka.actor.{Actor, ActorContext, ActorLogging, ActorRef, Props} import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto} -import fr.acinq.eclair.Features.BasicMultiPartPayment import fr.acinq.eclair.channel.fsm.Channel import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.OutgoingPaymentPacket.Upstream @@ -55,13 +54,14 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn if (!nodeParams.features.invoiceFeatures().areSupported(r.invoice.features)) { sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, UnsupportedFeatures(r.invoice.features)) :: Nil) } else if (Features.canUseFeature(nodeParams.features.invoiceFeatures(), r.invoice.features, Features.BasicMultiPartPayment)) { + val finalPayload = r.invoice.multiPartFinalPayload(r.recipientAmount, finalExpiry, r.userCustomTlvs) val fsm = outgoingPaymentFactory.spawnOutgoingMultiPartPayment(context, paymentCfg) - fsm ! MultiPartPaymentLifecycle.SendMultiPartPayment(self, r.invoice.paymentSecret, r.recipientNodeId, r.recipientAmount, finalExpiry, r.maxAttempts, r.invoice.paymentMetadata, r.invoice.extraEdges, r.routeParams, userCustomTlvs = r.userCustomTlvs) + fsm ! MultiPartPaymentLifecycle.SendMultiPartPayment(self, r.recipientNodeId, finalPayload, r.recipientAmount, finalExpiry, r.maxAttempts, r.invoice.extraEdges, r.routeParams) context become main(pending + (paymentId -> PendingPaymentToNode(sender(), r))) } else { - val finalPayload = FinalPayload.Standard.createSinglePartPayload(r.recipientAmount, finalExpiry, r.invoice.paymentSecret, r.invoice.paymentMetadata, r.userCustomTlvs) + val finalPayload = r.invoice.singlePartFinalPayload(r.recipientAmount, finalExpiry, r.userCustomTlvs) val fsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) - fsm ! PaymentLifecycle.SendPaymentToNode(self, r.recipientNodeId, finalPayload, r.maxAttempts, r.invoice.extraEdges, r.routeParams) + fsm ! PaymentLifecycle.SendPaymentToNode(self, r.recipientNodeId, finalPayload, r.recipientAmount, finalExpiry, r.maxAttempts, r.invoice.extraEdges, r.routeParams) context become main(pending + (paymentId -> PendingPaymentToNode(sender(), r))) } @@ -72,7 +72,7 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn val finalExpiry = Channel.MIN_CLTV_EXPIRY_DELTA.toCltvExpiry(nodeParams.currentBlockHeight + 1) val finalPayload = FinalPayload.Standard(TlvStream(Seq(OnionPaymentPayloadTlv.AmountToForward(r.recipientAmount), OnionPaymentPayloadTlv.OutgoingCltv(finalExpiry), OnionPaymentPayloadTlv.PaymentData(randomBytes32(), r.recipientAmount), OnionPaymentPayloadTlv.KeySend(r.paymentPreimage)), r.userCustomTlvs)) val fsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) - fsm ! PaymentLifecycle.SendPaymentToNode(self, r.recipientNodeId, finalPayload, r.maxAttempts, routeParams = r.routeParams) + fsm ! PaymentLifecycle.SendPaymentToNode(self, r.recipientNodeId, finalPayload, r.recipientAmount, finalExpiry, r.maxAttempts, routeParams = r.routeParams) context become main(pending + (paymentId -> PendingSpontaneousPayment(sender(), r))) case r: SendTrampolinePayment => @@ -108,8 +108,9 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn // We generate a random secret for the payment to the first trampoline node. val trampolineSecret = r.trampolineSecret.getOrElse(randomBytes32()) sender() ! SendPaymentToRouteResponse(paymentId, parentPaymentId, Some(trampolineSecret)) + val finalPayload = r.invoice.trampolinePayload(trampolineAmount, trampolineExpiry, trampolineSecret, trampolineOnion).withAmount(r.amount) val payFsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) - payFsm ! PaymentLifecycle.SendPaymentToRoute(self, Left(r.route), FinalPayload.Standard.createMultiPartPayload(r.amount, trampolineAmount, trampolineExpiry, trampolineSecret, r.invoice.paymentMetadata, Seq(OnionPaymentPayloadTlv.TrampolineOnion(trampolineOnion))), r.invoice.extraEdges) + payFsm ! PaymentLifecycle.SendPaymentToRoute(self, Left(r.route), finalPayload, r.amount, trampolineExpiry, r.invoice.extraEdges) context become main(pending + (paymentId -> PendingPaymentToRoute(sender(), r))) case Failure(t) => log.warning("cannot send outgoing trampoline payment: {}", t.getMessage) @@ -117,8 +118,9 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn } case Nil => sender() ! SendPaymentToRouteResponse(paymentId, parentPaymentId, None) + val finalPayload = r.invoice.multiPartFinalPayload(r.recipientAmount, finalExpiry).withAmount(r.amount) val payFsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg) - payFsm ! PaymentLifecycle.SendPaymentToRoute(self, Left(r.route), FinalPayload.Standard.createMultiPartPayload(r.amount, r.recipientAmount, finalExpiry, r.invoice.paymentSecret, r.invoice.paymentMetadata), r.invoice.extraEdges) + payFsm ! PaymentLifecycle.SendPaymentToRoute(self, Left(r.route), finalPayload, r.amount, finalExpiry, r.invoice.extraEdges) context become main(pending + (paymentId -> PendingPaymentToRoute(sender(), r))) case _ => sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, TrampolineMultiNodeNotSupported) :: Nil) @@ -191,17 +193,18 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn NodeHop(nodeParams.nodeId, trampolineNodeId, nodeParams.channelConf.expiryDelta, 0 msat), NodeHop(trampolineNodeId, r.recipientNodeId, trampolineExpiryDelta, trampolineFees) // for now we only use a single trampoline hop ) + val finalExpiry = r.finalExpiry(nodeParams.currentBlockHeight) val finalPayload = if (r.invoice.features.hasFeature(Features.BasicMultiPartPayment)) { - FinalPayload.Standard.createMultiPartPayload(r.recipientAmount, r.recipientAmount, r.finalExpiry(nodeParams.currentBlockHeight), r.invoice.paymentSecret, r.invoice.paymentMetadata) + r.invoice.multiPartFinalPayload(r.recipientAmount, finalExpiry).withAmount(r.recipientAmount) } else { - FinalPayload.Standard.createSinglePartPayload(r.recipientAmount, r.finalExpiry(nodeParams.currentBlockHeight), r.invoice.paymentSecret, r.invoice.paymentMetadata) + r.invoice.singlePartFinalPayload(r.recipientAmount, finalExpiry) } // We assume that the trampoline node supports multi-part payments (it should). val trampolinePacket_opt = if (r.invoice.features.hasFeature(Features.TrampolinePaymentPrototype)) { - OutgoingPaymentPacket.buildTrampolinePacket(r.paymentHash, trampolineRoute, finalPayload) + OutgoingPaymentPacket.buildTrampolinePacket(r.paymentHash, trampolineRoute, finalPayload, r.recipientAmount, finalExpiry) } else { r.invoice match { - case invoice: Bolt11Invoice => OutgoingPaymentPacket.buildTrampolineToLegacyPacket(invoice, trampolineRoute, finalPayload) + case invoice: Bolt11Invoice => OutgoingPaymentPacket.buildTrampolineToLegacyPacket(invoice, trampolineRoute, finalPayload, r.recipientAmount, finalExpiry) case _ => Failure(new Exception("Trampoline to legacy is only supported for Bolt11 invoices.")) } } @@ -216,8 +219,9 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn val trampolineSecret = randomBytes32() buildTrampolinePayment(r, r.trampolineNodeId, trampolineFees, trampolineExpiryDelta).map { case (trampolineAmount, trampolineExpiry, trampolineOnion) => + val finalPayload = r.invoice.trampolinePayload(trampolineAmount, trampolineExpiry, trampolineSecret, trampolineOnion) val fsm = outgoingPaymentFactory.spawnOutgoingMultiPartPayment(context, paymentCfg) - fsm ! MultiPartPaymentLifecycle.SendMultiPartPayment(self, trampolineSecret, r.trampolineNodeId, trampolineAmount, trampolineExpiry, nodeParams.maxPaymentAttempts, r.invoice.paymentMetadata, r.invoice.extraEdges, r.routeParams, Seq(OnionPaymentPayloadTlv.TrampolineOnion(trampolineOnion))) + fsm ! MultiPartPaymentLifecycle.SendMultiPartPayment(self, r.trampolineNodeId, finalPayload, trampolineAmount, trampolineExpiry, nodeParams.maxPaymentAttempts, r.invoice.extraEdges, r.routeParams) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala index 561dd517df..6f08fb84bc 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala @@ -55,21 +55,21 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A when(WAITING_FOR_REQUEST) { case Event(c: SendPaymentToRoute, WaitingForRequest) => - log.debug("sending {} to route {}", c.finalPayload.amount, c.printRoute()) + log.debug("sending {} to route {}", c.amount, c.printRoute()) c.route.fold( - hops => router ! FinalizeRoute(c.finalPayload.amount, hops, c.extraEdges, paymentContext = Some(cfg.paymentContext)), + hops => router ! FinalizeRoute(c.amount, hops, c.extraEdges, paymentContext = Some(cfg.paymentContext)), route => self ! RouteResponse(route :: Nil) ) if (cfg.storeInDb) { - paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, c.finalPayload.amount, cfg.recipientAmount, cfg.recipientNodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending)) + paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, c.amount, cfg.recipientAmount, cfg.recipientNodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending)) } goto(WAITING_FOR_ROUTE) using WaitingForRoute(c, Nil, Ignore.empty) case Event(c: SendPaymentToNode, WaitingForRequest) => - log.debug("sending {} to {}", c.finalPayload.amount, c.targetNodeId) - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.maxFee, c.extraEdges, routeParams = c.routeParams, paymentContext = Some(cfg.paymentContext)) + log.debug("sending {} to {}", c.amount, c.targetNodeId) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.amount, c.maxFee, c.extraEdges, routeParams = c.routeParams, paymentContext = Some(cfg.paymentContext)) if (cfg.storeInDb) { - paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, c.finalPayload.amount, cfg.recipientAmount, cfg.recipientNodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending)) + paymentsDb.addOutgoingPayment(OutgoingPayment(id, cfg.parentId, cfg.externalId, paymentHash, PaymentType.Standard, c.amount, cfg.recipientAmount, cfg.recipientNodeId, TimestampMilli.now(), cfg.invoice, OutgoingPaymentStatus.Pending)) } goto(WAITING_FOR_ROUTE) using WaitingForRoute(c, Nil, Ignore.empty) } @@ -77,20 +77,20 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A when(WAITING_FOR_ROUTE) { case Event(RouteResponse(route +: _), WaitingForRoute(c, failures, ignore)) => log.info(s"route found: attempt=${failures.size + 1}/${c.maxAttempts} route=${route.printNodes()} channels=${route.printChannels()}") - OutgoingPaymentPacket.buildCommand(self, cfg.upstream, paymentHash, route.hops, c.finalPayload) match { - case Success((cmd, sharedSecrets)) => - register ! Register.ForwardShortId(self.toTyped[Register.ForwardShortIdFailure[CMD_ADD_HTLC]], route.hops.head.shortChannelId, cmd) + OutgoingPaymentPacket.buildCommand(self, nodeParams.privateKey, cfg.upstream, paymentHash, route.hops, c.finalPayload, c.amount, c.expiry) match { + case Success((channelId, cmd, sharedSecrets)) => + register ! Register.ForwardShortId(self.toTyped[Register.ForwardShortIdFailure[CMD_ADD_HTLC]], channelId, cmd) goto(WAITING_FOR_PAYMENT_COMPLETE) using WaitingForComplete(c, cmd, failures, sharedSecrets, ignore, route) case Failure(t) => log.warning("cannot send outgoing payment: {}", t.getMessage) - Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(c.finalPayload.amount, Nil, t))).increment() - myStop(c, Left(PaymentFailed(id, paymentHash, failures :+ LocalFailure(c.finalPayload.amount, Nil, t)))) + Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(c.amount, Nil, t))).increment() + myStop(c, Left(PaymentFailed(id, paymentHash, failures :+ LocalFailure(c.amount, Nil, t)))) } case Event(Status.Failure(t), WaitingForRoute(c, failures, _)) => log.warning("router error: {}", t.getMessage) - Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(c.finalPayload.amount, Nil, t))).increment() - myStop(c, Left(PaymentFailed(id, paymentHash, failures :+ LocalFailure(c.finalPayload.amount, Nil, t)))) + Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(c.amount, Nil, t))).increment() + myStop(c, Left(PaymentFailed(id, paymentHash, failures :+ LocalFailure(c.amount, Nil, t)))) } when(WAITING_FOR_PAYMENT_COMPLETE) { @@ -105,7 +105,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A case Event(RES_ADD_SETTLED(_, htlc, fulfill: HtlcResult.Fulfill), d: WaitingForComplete) => router ! Router.RouteDidRelay(d.route) Metrics.PaymentAttempt.withTag(Tags.MultiPart, value = false).record(d.failures.size + 1) - val p = PartialPayment(id, d.c.finalPayload.amount, d.cmd.amount - d.c.finalPayload.amount, htlc.channelId, Some(cfg.fullRoute(d.route))) + val p = PartialPayment(id, d.c.amount, d.cmd.amount - d.c.amount, htlc.channelId, Some(cfg.fullRoute(d.route))) myStop(d.c, Right(cfg.createPaymentSent(fulfill.paymentPreimage, p :: Nil))) case Event(RES_ADD_SETTLED(_, _, fail: HtlcResult.Fail), d: WaitingForComplete) => @@ -139,7 +139,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A data.c match { case sendPaymentToNode: SendPaymentToNode => val ignore1 = PaymentFailure.updateIgnored(failure, data.ignore) - router ! RouteRequest(nodeParams.nodeId, data.c.targetNodeId, data.c.finalPayload.amount, sendPaymentToNode.maxFee, data.c.extraEdges, ignore1, sendPaymentToNode.routeParams, paymentContext = Some(cfg.paymentContext)) + router ! RouteRequest(nodeParams.nodeId, data.c.targetNodeId, data.c.amount, sendPaymentToNode.maxFee, data.c.extraEdges, ignore1, sendPaymentToNode.routeParams, paymentContext = Some(cfg.paymentContext)) goto(WAITING_FOR_ROUTE) using WaitingForRoute(data.c, data.failures :+ failure, ignore1) case _: SendPaymentToRoute => log.error("unexpected retry during SendPaymentToRoute") @@ -153,11 +153,11 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A private def handleLocalFail(d: WaitingForComplete, t: Throwable, isFatal: Boolean) = { t match { case UpdateMalformedException => Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType.Malformed).increment() - case _ => Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(d.c.finalPayload.amount, cfg.fullRoute(d.route), t))).increment() + case _ => Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(d.c.amount, cfg.fullRoute(d.route), t))).increment() } // we only retry if the error isn't fatal, and we haven't exhausted the max number of retried val doRetry = !isFatal && (d.failures.size + 1 < d.c.maxAttempts) - val localFailure = LocalFailure(d.c.finalPayload.amount, cfg.fullRoute(d.route), t) + val localFailure = LocalFailure(d.c.amount, cfg.fullRoute(d.route), t) if (doRetry) { log.info(s"received an error message from local, trying to use a different channel (failure=${t.getMessage})") retry(localFailure, d) @@ -170,10 +170,10 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A import d._ ((Sphinx.FailurePacket.decrypt(fail.reason, sharedSecrets) match { case success@Success(e) => - Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(RemoteFailure(d.c.finalPayload.amount, Nil, e))).increment() + Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(RemoteFailure(d.c.amount, Nil, e))).increment() success case failure@Failure(_) => - Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(UnreadableRemoteFailure(d.c.finalPayload.amount, Nil))).increment() + Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(UnreadableRemoteFailure(d.c.amount, Nil))).increment() failure }) match { case res@Success(Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) => @@ -185,7 +185,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A failureMessage match { case TemporaryChannelFailure(update) => d.route.hops.find(_.nodeId == nodeId) match { - case Some(failingHop) if ChannelRelayParams.areSame(failingHop.params, ChannelRelayParams.FromAnnouncement(update), ignoreHtlcSize = true) => + case Some(failingHop: ChannelHop) if ChannelRelayParams.areSame(failingHop.params, ChannelRelayParams.FromAnnouncement(update), ignoreHtlcSize = true) => router ! Router.ChannelCouldNotRelay(stoppedRoute.amount, failingHop) case _ => // otherwise the relay parameters may have changed, so it's not necessarily a liquidity issue } @@ -197,7 +197,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) if nodeId == c.targetNodeId => // if destination node returns an error, we fail the payment immediately log.warning(s"received an error message from target nodeId=$nodeId, failing the payment (failure=$failureMessage)") - myStop(c, Left(PaymentFailed(id, paymentHash, failures :+ RemoteFailure(d.c.finalPayload.amount, cfg.fullRoute(route), e)))) + myStop(c, Left(PaymentFailed(id, paymentHash, failures :+ RemoteFailure(d.c.amount, cfg.fullRoute(route), e)))) case res if failures.size + 1 >= c.maxAttempts => // otherwise we never try more than maxAttempts, no matter the kind of error returned val failure = res match { @@ -207,24 +207,24 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A case failureMessage: Update => handleUpdate(nodeId, failureMessage, d) case _ => } - RemoteFailure(d.c.finalPayload.amount, cfg.fullRoute(route), e) + RemoteFailure(d.c.amount, cfg.fullRoute(route), e) case Failure(t) => log.warning(s"cannot parse returned error ${fail.reason.toHex} with sharedSecrets=$sharedSecrets: ${t.getMessage}") - UnreadableRemoteFailure(d.c.finalPayload.amount, cfg.fullRoute(route)) + UnreadableRemoteFailure(d.c.amount, cfg.fullRoute(route)) } log.warning(s"too many failed attempts, failing the payment") myStop(c, Left(PaymentFailed(id, paymentHash, failures :+ failure))) case Failure(t) => log.warning(s"cannot parse returned error: ${t.getMessage}, route=${route.printNodes()}") - val failure = UnreadableRemoteFailure(d.c.finalPayload.amount, cfg.fullRoute(route)) + val failure = UnreadableRemoteFailure(d.c.amount, cfg.fullRoute(route)) retry(failure, d) case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Node)) => log.info(s"received 'Node' type error message from nodeId=$nodeId, trying to route around it (failure=$failureMessage)") - val failure = RemoteFailure(d.c.finalPayload.amount, cfg.fullRoute(route), e) + val failure = RemoteFailure(d.c.amount, cfg.fullRoute(route), e) retry(failure, d) case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Update)) => log.info(s"received 'Update' type error message from nodeId=$nodeId, retrying payment (failure=$failureMessage)") - val failure = RemoteFailure(d.c.finalPayload.amount, cfg.fullRoute(route), e) + val failure = RemoteFailure(d.c.amount, cfg.fullRoute(route), e) if (Announcements.checkSig(failureMessage.update, nodeId)) { val extraEdges1 = handleUpdate(nodeId, failureMessage, d) val ignore1 = PaymentFailure.updateIgnored(failure, ignore) @@ -234,7 +234,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A log.error("unexpected retry during SendPaymentToRoute") stop(FSM.Normal) case c: SendPaymentToNode => - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.maxFee, extraEdges1, ignore1, c.routeParams, paymentContext = Some(cfg.paymentContext)) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.amount, c.maxFee, extraEdges1, ignore1, c.routeParams, paymentContext = Some(cfg.paymentContext)) goto(WAITING_FOR_ROUTE) using WaitingForRoute(c, failures :+ failure, ignore1) } } else { @@ -245,13 +245,13 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A log.error("unexpected retry during SendPaymentToRoute") stop(FSM.Normal) case c: SendPaymentToNode => - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.maxFee, c.extraEdges, ignore + nodeId, c.routeParams, paymentContext = Some(cfg.paymentContext)) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.amount, c.maxFee, c.extraEdges, ignore + nodeId, c.routeParams, paymentContext = Some(cfg.paymentContext)) goto(WAITING_FOR_ROUTE) using WaitingForRoute(c, failures :+ failure, ignore + nodeId) } } case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) => log.info(s"received an error message from nodeId=$nodeId, trying to use a different channel (failure=$failureMessage)") - val failure = RemoteFailure(d.c.finalPayload.amount, cfg.fullRoute(route), e) + val failure = RemoteFailure(d.c.amount, cfg.fullRoute(route), e) retry(failure, d) } } @@ -263,7 +263,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A */ private def handleUpdate(nodeId: PublicKey, failure: Update, data: WaitingForComplete): Seq[ExtraEdge] = { val extraEdges1 = data.route.hops.find(_.nodeId == nodeId) match { - case Some(hop) => hop.params match { + case Some(hop: ChannelHop) => hop.params match { case ann: ChannelRelayParams.FromAnnouncement => if (ann.channelUpdate.shortChannelId != failure.update.shortChannelId) { // it is possible that nodes in the route prefer using a different channel (to the same N+1 node) than the one we requested, that's fine @@ -286,19 +286,22 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A if (failure.update.channelFlags.isEnabled) { data.c.extraEdges.map { case edge: BasicEdge if edge.sourceNodeId == nodeId && edge.targetNodeId == hop.nextNodeId => edge.update(failure.update) - case edge: BasicEdge => edge + case edge => edge } } else { // if the channel is disabled, we temporarily exclude it: this is necessary because the routing hint doesn't // contain channel flags to indicate that it's disabled // we want the exclusion to be router-wide so that sister payments in the case of MPP are aware the channel is faulty data.c.extraEdges - .find { case edge: BasicEdge => edge.sourceNodeId == nodeId && edge.targetNodeId == hop.nextNodeId } - .foreach { case edge: BasicEdge => router ! ExcludeChannel(ChannelDesc(edge.shortChannelId, edge.sourceNodeId, edge.targetNodeId), Some(nodeParams.routerConf.channelExcludeDuration)) } + .find(edge => edge.sourceNodeId == nodeId && edge.targetNodeId == hop.nextNodeId) + .foreach(edge => router ! ExcludeChannel(ChannelDesc(edge.shortChannelId, edge.sourceNodeId, edge.targetNodeId), Some(nodeParams.routerConf.channelExcludeDuration))) // we remove this edge for our next payment attempt - data.c.extraEdges.filterNot { case edge: BasicEdge => edge.sourceNodeId == nodeId && edge.targetNodeId == hop.nextNodeId } + data.c.extraEdges.filterNot(edge => edge.sourceNodeId == nodeId && edge.targetNodeId == hop.nextNodeId) } } + case Some(_: BlindedHop) => + log.error(s"received update for blinded route, this should never happen") + data.c.extraEdges case None => log.error(s"couldn't find node=$nodeId in the route, this should never happen") data.c.extraEdges @@ -359,7 +362,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A } request match { case request: SendPaymentToNode => - context.system.eventStream.publish(PathFindingExperimentMetrics(cfg.paymentHash, request.finalPayload.amount, fees, status, duration, now, isMultiPart = false, request.routeParams.experimentName, cfg.recipientNodeId, request.extraEdges)) + context.system.eventStream.publish(PathFindingExperimentMetrics(cfg.paymentHash, request.amount, fees, status, duration, now, isMultiPart = false, request.routeParams.experimentName, cfg.recipientNodeId, request.extraEdges)) case _: SendPaymentToRoute => () } } @@ -390,7 +393,9 @@ object PaymentLifecycle { sealed trait SendPayment { // @formatter:off def replyTo: ActorRef - def finalPayload: FinalPayload + def finalPayload: PerHopPayload + def amount: MilliSatoshi + def expiry: CltvExpiry def extraEdges: Seq[ExtraEdge] def targetNodeId: PublicKey def maxAttempts: Int @@ -405,7 +410,9 @@ object PaymentLifecycle { */ case class SendPaymentToRoute(replyTo: ActorRef, route: Either[PredefinedRoute, Route], - finalPayload: FinalPayload, + finalPayload: PerHopPayload, + amount: MilliSatoshi, + expiry: CltvExpiry, extraEdges: Seq[ExtraEdge] = Nil) extends SendPayment { require(route.fold(!_.isEmpty, _.hops.nonEmpty), "payment route must not be empty") @@ -432,13 +439,15 @@ object PaymentLifecycle { */ case class SendPaymentToNode(replyTo: ActorRef, targetNodeId: PublicKey, - finalPayload: FinalPayload, + finalPayload: PerHopPayload, + amount: MilliSatoshi, + expiry: CltvExpiry, maxAttempts: Int, extraEdges: Seq[ExtraEdge] = Nil, routeParams: RouteParams) extends SendPayment { - require(finalPayload.amount > 0.msat, s"amount must be > 0") + require(amount > 0.msat, s"amount must be > 0") - val maxFee: MilliSatoshi = routeParams.getMaxFee(finalPayload.amount) + val maxFee: MilliSatoshi = routeParams.getMaxFee(amount) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/BalanceEstimate.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/BalanceEstimate.scala index 18ea2a810d..46982b4bd1 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/BalanceEstimate.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/BalanceEstimate.scala @@ -295,16 +295,18 @@ case class GraphWithBalanceEstimates(graph: DirectedGraph, private val balances: def routeCouldRelay(route: Route): GraphWithBalanceEstimates = { val (balances1, _) = route.hops.foldRight((balances, route.amount)) { - case (hop, (balances, amount)) => + case (hop: ChannelHop, (balances, amount)) => (balances.channelCouldSend(hop, amount), amount + hop.fee(amount)) + case (_, x) => x } GraphWithBalanceEstimates(graph, balances1) } def routeDidRelay(route: Route): GraphWithBalanceEstimates = { val (balances1, _) = route.hops.foldRight((balances, route.amount)) { - case (hop, (balances, amount)) => + case (hop: ChannelHop, (balances, amount)) => (balances.channelDidSend(hop, amount), amount + hop.fee(amount)) + case (_, x) => x } GraphWithBalanceEstimates(graph, balances1) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala index 554a25dbea..d03cba33db 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala @@ -18,8 +18,8 @@ package fr.acinq.eclair.router import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.{Btc, BtcDouble, MilliBtc, Satoshi} -import fr.acinq.eclair.payment.relay.Relayer.RelayFees import fr.acinq.eclair.payment.Invoice +import fr.acinq.eclair.payment.relay.Relayer.RelayFees import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.wire.protocol.ChannelUpdate @@ -459,17 +459,16 @@ object Graph { balance_opt = pc.getBalanceSameSideAs(u) ) - def apply(e: Invoice.ExtraEdge): GraphEdge = e match { - case e@Invoice.BasicEdge(sourceNodeId, targetNodeId, shortChannelId, _, _, _) => - val maxBtc = 21e6.btc - GraphEdge( - desc = ChannelDesc(shortChannelId, sourceNodeId, targetNodeId), - params = ChannelRelayParams.FromHint(e), - // Bolt 11 routing hints don't include the channel's capacity, so we assume it's big enough - capacity = maxBtc.toSatoshi, - // we assume channels provided as hints have enough balance to handle the payment - balance_opt = Some(maxBtc.toMilliSatoshi) - ) + def apply(e: Invoice.ExtraEdge): GraphEdge = { + val maxBtc = 21e6.btc + GraphEdge( + desc = ChannelDesc(e.shortChannelId, e.sourceNodeId, e.targetNodeId), + params = ChannelRelayParams.FromHint(e), + // Bolt 11 routing hints don't include the channel's capacity, so we assume it's big enough + capacity = maxBtc.toSatoshi, + // we assume channels provided as hints have enough balance to handle the payment + balance_opt = Some(maxBtc.toMilliSatoshi) + ) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala index b15763634b..6e358f7daa 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala @@ -368,9 +368,11 @@ object RouteCalculation { /** Update used capacity by taking into account an HTLC sent to the given route. */ private def updateUsedCapacity(route: Route, usedCapacity: mutable.Map[ShortChannelId, MilliSatoshi]): Unit = { - route.hops.reverse.foldLeft(route.amount) { case (amount, hop) => - usedCapacity.updateWith(hop.shortChannelId)(previous => Some(amount + previous.getOrElse(0 msat))) - amount + hop.fee(amount) + route.hops.reverse.foldLeft(route.amount) { + case (amount, hop: ChannelHop) => + usedCapacity.updateWith(hop.shortChannelId)(previous => Some(amount + previous.getOrElse(0 msat))) + amount + hop.fee(amount) + case (amount, hop) => amount + hop.fee(amount) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index 4d35bd3f5f..314d65cc19 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -28,6 +28,7 @@ import fr.acinq.eclair._ import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher.{ValidateResult, WatchExternalChannelSpent, WatchExternalChannelSpentTriggered} import fr.acinq.eclair.channel._ +import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.db.NetworkDb import fr.acinq.eclair.io.Peer.PeerRoutingMessage @@ -38,6 +39,7 @@ import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph import fr.acinq.eclair.router.Graph.{HeuristicsConstants, WeightRatios} import fr.acinq.eclair.router.Monitoring.Metrics +import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo import fr.acinq.eclair.wire.protocol._ import java.util.UUID @@ -405,9 +407,8 @@ object Router { } } } - // @formatter:on - trait Hop { + sealed trait Hop { /** @return the id of the start node. */ def nodeId: PublicKey @@ -424,7 +425,10 @@ object Router { def cltvExpiryDelta: CltvExpiryDelta } - // @formatter:off + sealed trait ConnectedHop extends Hop { + def length: Int + } + /** Channel routing parameters */ sealed trait ChannelRelayParams { def cltvExpiryDelta: CltvExpiryDelta @@ -465,11 +469,17 @@ object Router { * @param shortChannelId scid that will be used to build the payment onion. * @param params source for the channel parameters. */ - case class ChannelHop(shortChannelId: ShortChannelId, nodeId: PublicKey, nextNodeId: PublicKey, params: ChannelRelayParams) extends Hop { - // @formatter:off - override def cltvExpiryDelta: CltvExpiryDelta = params.cltvExpiryDelta + case class ChannelHop(shortChannelId: ShortChannelId, nodeId: PublicKey, nextNodeId: PublicKey, params: ChannelRelayParams) extends ConnectedHop { + override val cltvExpiryDelta: CltvExpiryDelta = params.cltvExpiryDelta override def fee(amount: MilliSatoshi): MilliSatoshi = params.fee(amount) - // @formatter:on + override val length = 1 + } + + case class BlindedHop(route: BlindedRoute, paymentInfo: PaymentInfo, nextNodeId: PublicKey) extends ConnectedHop { + override def nodeId: PublicKey = route.introductionNodeId + override def cltvExpiryDelta: CltvExpiryDelta = paymentInfo.cltvExpiryDelta + override def length: Int = route.blindedNodes.length - 1 + override def fee(amount: MilliSatoshi): MilliSatoshi = paymentInfo.fee(amount) } /** @@ -535,10 +545,10 @@ object Router { */ case class PaymentContext(id: UUID, parentId: UUID, paymentHash: ByteVector32) - case class Route(amount: MilliSatoshi, hops: Seq[ChannelHop]) { + case class Route(amount: MilliSatoshi, hops: Seq[ConnectedHop]) { require(hops.nonEmpty, "route cannot be empty") - val length = hops.length + val length: Int = hops.map(_.length).sum def fee(includeLocalChannelCost: Boolean): MilliSatoshi = { val hopsToPay = if (includeLocalChannelCost) hops else hops.drop(1) @@ -548,7 +558,10 @@ object Router { def printNodes(): String = hops.map(_.nextNodeId).mkString("->") - def printChannels(): String = hops.map(_.shortChannelId).mkString("->") + def printChannels(): String = hops.map { + case hop: ChannelHop => hop.shortChannelId.toString + case _: BlindedHop => "blinded" + }.mkString("->") def stopAt(nodeId: PublicKey): Route = { val amountAtStop = hops.reverse.takeWhile(_.nextNodeId != nodeId).foldLeft(amount) { case (amount1, hop) => amount1 + hop.fee(amount1) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala index 61c3c047eb..abb284cb18 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala @@ -22,7 +22,7 @@ import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, ByteVector64, Crypto, import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.TlvCodecs.genericTlv -import fr.acinq.eclair.{CltvExpiryDelta, Feature, Features, InvoiceFeature, MilliSatoshi, TimestampSecond, UInt64} +import fr.acinq.eclair.{CltvExpiryDelta, Feature, Features, InvoiceFeature, MilliSatoshi, TimestampSecond, UInt64, nodeFee} import fr.acinq.secp256k1.Secp256k1JvmKt import scodec.Codec import scodec.bits.ByteVector @@ -65,7 +65,9 @@ object OfferTypes { cltvExpiryDelta: CltvExpiryDelta, minHtlc: MilliSatoshi, maxHtlc: MilliSatoshi, - allowedFeatures: Features[Feature]) + allowedFeatures: Features[Feature]) { + def fee(amount: MilliSatoshi): MilliSatoshi = nodeFee(feeBase, feeProportionalMillionths, amount) + } case class PaymentPathsInfo(paymentInfo: Seq[PaymentInfo]) extends InvoiceTlv diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala index 4ea07846c6..8eb6a29de7 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala @@ -22,7 +22,7 @@ import fr.acinq.eclair.payment.Bolt11Invoice import fr.acinq.eclair.wire.protocol.CommonCodecs._ import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.TlvCodecs._ -import fr.acinq.eclair.{CltvExpiry, Features, MilliSatoshi, MilliSatoshiLong, ShortChannelId, UInt64} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, ShortChannelId, UInt64} import scodec.bits.{BitVector, ByteVector} /** @@ -217,6 +217,9 @@ object PaymentOnion { def records: TlvStream[OnionPaymentPayloadTlv] } + /** An opaque blinded payload. */ + case class BlindedPerHopPayload(records: TlvStream[OnionPaymentPayloadTlv]) extends PerHopPayload + /** Per-hop payload for an intermediate node. */ sealed trait IntermediatePayload extends PerHopPayload @@ -264,7 +267,8 @@ object PaymentOnion { val paymentConstraints = blindedRecords.get[RouteBlindingEncryptedDataTlv.PaymentConstraints].get val allowedFeatures = blindedRecords.get[RouteBlindingEncryptedDataTlv.AllowedFeatures].map(_.features).getOrElse(Features.empty) override def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi = ((incomingAmount - paymentRelay.feeBase).toLong * 1_000_000 + 1_000_000 + paymentRelay.feeProportionalMillionths - 1).msat / (1_000_000 + paymentRelay.feeProportionalMillionths) - override def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry = incomingCltv - paymentRelay.cltvExpiryDelta + val cltvExpiryDelta: CltvExpiryDelta = paymentRelay.cltvExpiryDelta + override def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry = incomingCltv - cltvExpiryDelta // @formatter:on } @@ -350,18 +354,23 @@ object PaymentOnion { // @formatter:off def amount: MilliSatoshi def totalAmount: MilliSatoshi - def expiry: CltvExpiry // @formatter:on } object FinalPayload { + /** An incomplete payload missing the amount used for multipart payments before we know the size of each part. */ + sealed trait Partial { + def records: TlvStream[OnionPaymentPayloadTlv] + def withAmount(amount: MilliSatoshi): PerHopPayload + } + case class Standard(records: TlvStream[OnionPaymentPayloadTlv]) extends FinalPayload { override val amount = records.get[AmountToForward].get.amount override val totalAmount = records.get[PaymentData].map(_.totalAmount match { case MilliSatoshi(0) => amount case totalAmount => totalAmount }).getOrElse(amount) - override val expiry = records.get[OutgoingCltv].get.cltv + val expiry = records.get[OutgoingCltv].get.cltv val paymentSecret = records.get[PaymentData].get.secret val paymentPreimage = records.get[KeySend].map(_.paymentPreimage) val paymentMetadata = records.get[PaymentMetadata].map(_.data) @@ -385,20 +394,35 @@ object PaymentOnion { Standard(TlvStream(tlvs, userCustomTlvs)) } - def createMultiPartPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, paymentMetadata: Option[ByteVector], additionalTlvs: Seq[OnionPaymentPayloadTlv] = Nil, userCustomTlvs: Seq[GenericTlv] = Nil): Standard = { + case class Partial(records: TlvStream[OnionPaymentPayloadTlv]) extends FinalPayload.Partial { + override def withAmount(amount: MilliSatoshi): Standard = Standard(records.copy(records = AmountToForward(amount) +: records.records.toSeq)) + } + + def createMultiPartPayload(totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, paymentMetadata: Option[ByteVector], additionalTlvs: Seq[OnionPaymentPayloadTlv] = Nil, userCustomTlvs: Seq[GenericTlv] = Nil): Partial = { val tlvs = Seq( - Some(AmountToForward(amount)), Some(OutgoingCltv(expiry)), Some(PaymentData(paymentSecret, totalAmount)), paymentMetadata.map(m => PaymentMetadata(m)) ).flatten - Standard(TlvStream(tlvs ++ additionalTlvs, userCustomTlvs)) + Partial(TlvStream(tlvs ++ additionalTlvs, userCustomTlvs)) } + def createMultiPartPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, paymentMetadata: Option[ByteVector]): Standard = + createMultiPartPayload(totalAmount, expiry, paymentSecret, paymentMetadata).withAmount(amount) + /** Create a trampoline outer payload. */ - def createTrampolinePayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, trampolinePacket: OnionRoutingPacket): Standard = { - Standard(TlvStream(AmountToForward(amount), OutgoingCltv(expiry), PaymentData(paymentSecret, totalAmount), TrampolineOnion(trampolinePacket))) + def createTrampolinePayload(totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, trampolinePacket: OnionRoutingPacket, paymentMetadata: Option[ByteVector]): Partial = { + val tlvs = Seq( + Some(OutgoingCltv(expiry)), + Some(PaymentData(paymentSecret, totalAmount)), + paymentMetadata.map(m => PaymentMetadata(m)), + Some(TrampolineOnion(trampolinePacket)), + ).flatten + Partial(TlvStream(tlvs)) } + + def createTrampolinePayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, trampolinePacket: OnionRoutingPacket): Standard = + createTrampolinePayload(totalAmount, expiry, paymentSecret, trampolinePacket, None).withAmount(amount) } /** @@ -407,7 +431,6 @@ object PaymentOnion { case class Blinded(records: TlvStream[OnionPaymentPayloadTlv], blindedRecords: TlvStream[RouteBlindingEncryptedDataTlv]) extends FinalPayload { override val amount = records.get[AmountToForward].get.amount override val totalAmount = records.get[TotalAmount].map(_.totalAmount).getOrElse(amount) - override val expiry = records.get[OutgoingCltv].get.cltv val pathId = blindedRecords.get[RouteBlindingEncryptedDataTlv.PathId].get.data val paymentConstraints = blindedRecords.get[RouteBlindingEncryptedDataTlv.PaymentConstraints].get val allowedFeatures = blindedRecords.get[RouteBlindingEncryptedDataTlv.AllowedFeatures].map(_.features).getOrElse(Features.empty) @@ -416,13 +439,11 @@ object PaymentOnion { object Blinded { def validate(records: TlvStream[OnionPaymentPayloadTlv], blindedRecords: TlvStream[RouteBlindingEncryptedDataTlv]): Either[InvalidTlvPayload, Blinded] = { if (records.get[AmountToForward].isEmpty) return Left(MissingRequiredTlv(UInt64(2))) - if (records.get[OutgoingCltv].isEmpty) return Left(MissingRequiredTlv(UInt64(4))) if (records.get[EncryptedRecipientData].isEmpty) return Left(MissingRequiredTlv(UInt64(10))) // Bolt 4: MUST return an error if the payload contains other tlv fields than `encrypted_recipient_data`, `current_blinding_point`, `amt_to_forward`, `outgoing_cltv_value` and `total_amount_msat`. if (records.unknown.nonEmpty) return Left(ForbiddenTlv(records.unknown.head.tag)) records.records.find { case _: AmountToForward => false - case _: OutgoingCltv => false case _: EncryptedRecipientData => false case _: BlindingPoint => false case _: TotalAmount => false @@ -433,9 +454,32 @@ object PaymentOnion { } BlindedRouteData.validPaymentRecipientData(blindedRecords).map(blindedRecords => Blinded(records, blindedRecords)) } + + def createSinglePartPayload(amount: MilliSatoshi, userCustomTlvs: Seq[GenericTlv] = Nil): BlindedPerHopPayload = { + val tlvs = Seq( + AmountToForward(amount), + TotalAmount(amount), + ) + BlindedPerHopPayload(TlvStream(tlvs, userCustomTlvs)) + } + + case class Partial(records: TlvStream[OnionPaymentPayloadTlv]) extends FinalPayload.Partial { + override def withAmount(amount: MilliSatoshi): BlindedPerHopPayload = BlindedPerHopPayload(records.copy(records = AmountToForward(amount) +: records.records.toSeq)) + } + + def createMultiPartPayload(totalAmount: MilliSatoshi, userCustomTlvs: Seq[GenericTlv] = Nil): Partial = { + val tlvs = Seq( + TotalAmount(totalAmount), + ) + Partial(TlvStream(tlvs, userCustomTlvs)) + } + + def createTrampolinePayload(totalAmount: MilliSatoshi, trampolinePacket: OnionRoutingPacket): Partial = { + // Trampoline is not compatible with blinded payloads yet. + Partial(TlvStream()) + } } } - } object PaymentOnionCodecs { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala index cffd7459f9..fa6aa5e102 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala @@ -135,9 +135,9 @@ object RouteBlindingEncryptedDataCodecs { // @formatter:off case class RouteBlindingDecryptedData(tlvs: TlvStream[RouteBlindingEncryptedDataTlv], nextBlinding: PublicKey) - sealed trait InvalidEncryptedData - case class CannotDecryptData(message: String) extends InvalidEncryptedData - case class CannotDecodeData(message: String) extends InvalidEncryptedData + sealed trait InvalidEncryptedData extends Exception + case class CannotDecryptData(message: String) extends Exception(message) with InvalidEncryptedData + case class CannotDecodeData(message: String) extends Exception(message) with InvalidEncryptedData // @formatter:on /** diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala index 6d9ad568f5..5ac54683f6 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala @@ -34,7 +34,7 @@ import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.receive.MultiPartHandler.ReceiveStandardPayment import fr.acinq.eclair.payment.receive.PaymentHandler import fr.acinq.eclair.payment.relay.Relayer -import fr.acinq.eclair.router.Router.ChannelHop +import fr.acinq.eclair.router.Router.{ChannelHop, ChannelRelayParams} import fr.acinq.eclair.wire.protocol._ import grizzled.slf4j.Logging import org.scalatest.funsuite.FixtureAnyFunSuiteLike @@ -118,11 +118,11 @@ class FuzzySpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Channe // we don't want to be below htlcMinimumMsat val requiredAmount = 1000000 msat - def buildCmdAdd(paymentHash: ByteVector32, dest: PublicKey, paymentSecret: ByteVector32): CMD_ADD_HTLC = { + def buildCmdAdd(paymentHash: ByteVector32, dest: PublicKey, invoice: Invoice): CMD_ADD_HTLC = { // allow overpaying (no more than 2 times the required amount) val amount = requiredAmount + Random.nextInt(requiredAmount.toLong.toInt).msat val expiry = (Channel.MIN_CLTV_EXPIRY_DELTA + 1).toCltvExpiry(currentBlockHeight = BlockHeight(400000)) - OutgoingPaymentPacket.buildCommand(self, Upstream.Local(UUID.randomUUID()), paymentHash, ChannelHop(null, null, dest, null) :: Nil, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount, expiry, paymentSecret, None)).get._1 + OutgoingPaymentPacket.buildCommand(self, randomKey(), Upstream.Local(UUID.randomUUID()), paymentHash, ChannelHop(null, null, dest, ChannelRelayParams.FromHint(Invoice.BasicEdge(null, dest, null, 1 msat, 2, CltvExpiryDelta(3)))) :: Nil, invoice.singlePartFinalPayload(amount, expiry), amount, expiry).get._2 } def initiatePaymentOrStop(remaining: Int): Unit = @@ -130,7 +130,7 @@ class FuzzySpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Channe paymentHandler ! ReceiveStandardPayment(Some(requiredAmount), Left("One coffee")) context become { case req: Invoice => - sendChannel ! buildCmdAdd(req.paymentHash, req.nodeId, req.paymentSecret) + sendChannel ! buildCmdAdd(req.paymentHash, req.nodeId, req) context become { case RES_SUCCESS(_: CMD_ADD_HTLC, _) => () case RES_ADD_SETTLED(_, htlc, _: HtlcResult.Fulfill) => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala index 5db972bd54..b8ed6efbfc 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala @@ -33,9 +33,9 @@ import fr.acinq.eclair.channel.fsm.Channel import fr.acinq.eclair.channel.publish.TxPublisher import fr.acinq.eclair.channel.publish.TxPublisher.PublishReplaceableTx import fr.acinq.eclair.channel.states.ChannelStateTestsBase.FakeTxPublisherFactory -import fr.acinq.eclair.payment.OutgoingPaymentPacket +import fr.acinq.eclair.payment.{Invoice, OutgoingPaymentPacket} import fr.acinq.eclair.payment.OutgoingPaymentPacket.Upstream -import fr.acinq.eclair.router.Router.ChannelHop +import fr.acinq.eclair.router.Router.{ChannelHop, ChannelRelayParams} import fr.acinq.eclair.transactions.Transactions import fr.acinq.eclair.transactions.Transactions._ import fr.acinq.eclair.wire.protocol._ @@ -357,7 +357,7 @@ trait ChannelStateTestsBase extends Assertions with Eventually { def makeCmdAdd(amount: MilliSatoshi, cltvExpiryDelta: CltvExpiryDelta, destination: PublicKey, paymentPreimage: ByteVector32, currentBlockHeight: BlockHeight, upstream: Upstream, replyTo: ActorRef = TestProbe().ref): (ByteVector32, CMD_ADD_HTLC) = { val paymentHash: ByteVector32 = Crypto.sha256(paymentPreimage) val expiry = cltvExpiryDelta.toCltvExpiry(currentBlockHeight) - val cmd = OutgoingPaymentPacket.buildCommand(replyTo, upstream, paymentHash, ChannelHop(null, null, destination, null) :: Nil, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount, expiry, randomBytes32(), None)).get._1.copy(commit = false) + val cmd = OutgoingPaymentPacket.buildCommand(replyTo, randomKey(), upstream, paymentHash, ChannelHop(null, null, destination, ChannelRelayParams.FromHint(Invoice.BasicEdge(null, destination, null, 0 msat, 111, CltvExpiryDelta(55)))) :: Nil, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount, expiry, randomBytes32(), None), amount, expiry).get._2.copy(commit = false) (paymentPreimage, cmd) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala index 4f79a07e74..b88aa2b64f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala @@ -29,9 +29,9 @@ import fr.acinq.eclair.channel.states.{ChannelStateTestsBase, ChannelStateTestsT import fr.acinq.eclair.payment.OutgoingPaymentPacket.Upstream import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.relay.Relayer._ -import fr.acinq.eclair.router.Router.ChannelHop +import fr.acinq.eclair.router.Router.{ChannelHop, ChannelRelayParams} import fr.acinq.eclair.wire.protocol.{ClosingSigned, CommitSig, Error, FailureMessageCodecs, PaymentOnion, PermanentChannelFailure, RevokeAndAck, Shutdown, UpdateAddHtlc, UpdateFailHtlc, UpdateFailMalformedHtlc, UpdateFee, UpdateFulfillHtlc} -import fr.acinq.eclair.{BlockHeight, CltvExpiry, CltvExpiryDelta, MilliSatoshiLong, TestConstants, TestKitBaseClass, randomBytes32} +import fr.acinq.eclair.{BlockHeight, CltvExpiry, CltvExpiryDelta, MilliSatoshiLong, TestConstants, TestKitBaseClass, randomBytes32, randomKey} import org.scalatest.funsuite.FixtureAnyFunSuiteLike import org.scalatest.{Outcome, Tag} import scodec.bits.ByteVector @@ -60,7 +60,7 @@ class ShutdownStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wit val h1 = Crypto.sha256(r1) val amount1 = 300000000 msat val expiry1 = CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight) - val cmd1 = OutgoingPaymentPacket.buildCommand(sender.ref, Upstream.Local(UUID.randomUUID), h1, ChannelHop(null, null, TestConstants.Bob.nodeParams.nodeId, null) :: Nil, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount1, expiry1, randomBytes32(), None)).get._1.copy(commit = false) + val cmd1 = OutgoingPaymentPacket.buildCommand(sender.ref, randomKey(), Upstream.Local(UUID.randomUUID), h1, ChannelHop(null, null, TestConstants.Bob.nodeParams.nodeId, ChannelRelayParams.FromHint(Invoice.BasicEdge(null, TestConstants.Bob.nodeParams.nodeId, null, 1 msat, 2, CltvExpiryDelta(3)))) :: Nil, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount1, expiry1, randomBytes32(), None), amount1, expiry1).get._2.copy(commit = false) alice ! cmd1 sender.expectMsgType[RES_SUCCESS[CMD_ADD_HTLC]] val htlc1 = alice2bob.expectMsgType[UpdateAddHtlc] @@ -70,7 +70,7 @@ class ShutdownStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wit val h2 = Crypto.sha256(r2) val amount2 = 200000000 msat val expiry2 = CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight) - val cmd2 = OutgoingPaymentPacket.buildCommand(sender.ref, Upstream.Local(UUID.randomUUID), h2, ChannelHop(null, null, TestConstants.Bob.nodeParams.nodeId, null) :: Nil, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount2, expiry2, randomBytes32(), None)).get._1.copy(commit = false) + val cmd2 = OutgoingPaymentPacket.buildCommand(sender.ref, randomKey(), Upstream.Local(UUID.randomUUID), h2, ChannelHop(null, null, TestConstants.Bob.nodeParams.nodeId, ChannelRelayParams.FromHint(Invoice.BasicEdge(null, TestConstants.Bob.nodeParams.nodeId, null, 1 msat, 2, CltvExpiryDelta(3)))) :: Nil, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(amount2, expiry2, randomBytes32(), None), amount2, expiry2).get._2.copy(commit = false) alice ! cmd2 sender.expectMsgType[RES_SUCCESS[CMD_ADD_HTLC]] val htlc2 = alice2bob.expectMsgType[UpdateAddHtlc] diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala index 1af09e8c68..2ee5885700 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala @@ -438,12 +438,12 @@ class SphinxSpec extends AnyFunSuite { // The sender includes the correct encrypted recipient data in each blinded node's payload. TlvStream[OnionPaymentPayloadTlv](OnionPaymentPayloadTlv.EncryptedRecipientData(blindedRoute.encryptedPayloads(1))), TlvStream[OnionPaymentPayloadTlv](OnionPaymentPayloadTlv.EncryptedRecipientData(blindedRoute.encryptedPayloads(2))), - TlvStream[OnionPaymentPayloadTlv](OnionPaymentPayloadTlv.AmountToForward(100_000 msat), OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(749000)), OnionPaymentPayloadTlv.EncryptedRecipientData(blindedRoute.encryptedPayloads(3))), + TlvStream[OnionPaymentPayloadTlv](OnionPaymentPayloadTlv.AmountToForward(100_000 msat), OnionPaymentPayloadTlv.EncryptedRecipientData(blindedRoute.encryptedPayloads(3))), ).map(tlvs => PaymentOnionCodecs.perHopPayloadCodec.encode(tlvs).require.bytes) val nodeIds = Seq(alice, bob).map(_.publicKey) ++ blindedRoute.blindedNodeIds.tail val Success(PacketAndSecrets(onion, sharedSecrets)) = create(sessionKey, 1300, nodeIds, payloads, associatedData) - assert(serializePaymentOnion(onion) == hex"0002531fe6068134503d2723133227c867ac8fa6c83c537e9a44c3c5bdbdcb1fe337dadf610256c6ab518495dce9cdedf9391e21a71daddfe667387c384267a4c6453777590fc38e591b4f04a1e96bd1dec4af605d6adda2690de4ebe5d56ad2013b520af2a3c49316bc590ee83e8c31b1eb11ff766dad27ca993326b1ed582fb451a2ad87fbf6601134c6341c4a2deb6850e25a355be68dbb6923dc89444fdd74a0f700433b667bda345926099f5547b07e97ad903e8a01566a78ae177366239e793dac719de805565b6d0a942f42722a79dba29ebf4f9ec40cf579191716aac3a79f78c1d43398fba3f304786435976102a924ba4ba3de6150c829ce01c25428f2f5d05ef023be7d590ecdf6603730db3948f80ca1ed3d85227e64ef77200b9b557f427b6e1073cfa0e63e4485441768b98ab11ba8104a6cee1d7af7bb9d3167503ea010fabcd207b0b37a68b84be55663802d96faee291e8241b5e6c4b38e0c6d17ef6ba7bbe93f02046975bb01b7f766fcfc5a755af11a90cc7eb3505986b56e07a7855534d03b79f0dfbfe645b0d6d4185c038771fd25b800aa26b2ed2e30b1e713659468618a2fea04fcd04732a6fa9e77db73d0efa5253e123d5c2306ddb9ebf7bc897b559cf9870715039c6183082d762b6417f99d0f71ff7c060f6b564ad6827edaffa72eefcc4ce633a8da8d41c19d8f6aebd8878869eb518ccc16dccae6a94c690957598ce0295c1c46af5d7a2f0955b5400526bfd1430f554562614b5d00feff3946427be520cce52bfe9b6a9c2b1da6701c8ca628a69d6d40e20dd69d6e879d7a052d9c16f52c26e3bf745daeb3578c211475f2953e3c42308af89f3fd3c93bb4ba7320b35721bfdf2ad3db94b711fbdccdbe8465d9ff7bc9a293861dcea15bfa4f64993e9a751f571ab24a3219446483968821aa19a8d89ec611d686ff5f8fdc340aa8185ae29b01e60fb5a4c5c4bf8054c711522fc74e1d60976c33d2dfd782bbd555b8d06af6e688b3f541f1275706d045c607eea5926c49ced5bd368914f5ef793c3d6c1ab08dae689f0d71d64ec9c136cd38ac038cfa37846e3df7ce4bf63f44fce412bf3c9b8f21eabc34186a9c660b23fb7f3fa26cc9d830b40b499c613c2569d5e5f10823854471d3ac8bf655b020c37309fbaa0d0af5f14babd9485347ccd891bbd1e3b73e800c500be25073ee8a3844aca1cb9fa06d5579532da09a480cbec171b2ca9f83985d1a8cf60092fedaa88d4ccc711243298beb3d9d46c87542072aebb33d5a5ee671d4974b93c901eb1b5b4eaefc3669a7daa5154dced8cdc1bf49c1ba829bcbdee4e1f2f703c983872a7bff0669c9322c13a7cfb3f7f98b7ddcb47042a4786368a182f9c667d495438b6dee2d2a6ad0f8795ac499c3c3e9d584f6cf8279497fecc51c9203510858d738cec815a13d35d220ea297333068d8b64f4bcb627d127ab1e7732c840da45d35647e9e319bac2e95bb49f070e32772e2a8a6b55ca35d2391de4269cd6c5030203ab14abfca973a032b6ce10e958f1be2399c98ee70da0363c2f9a4e52546d8eef0b63cbab415a9341dbb9099df5e1ba2a83c2be15a96518741eacbe0f5d45e81ed5ddb76438a45cc5bb8d87abba0dd8c9181eff8b1f7c3939f3600883a3139515c53a07429247db278384d727d9b3b327c0f47dd4319d12e24ac2713f8c828217491df60f5b002cc58476a7b857dffb148179ffa5c62060d26dc3a9df11beccf77929e5d752d7351e58dc7f5265946792e7733886240efa0994868aa28a66754dccee99abd37a78558c858ddc9ca52aee32e263dd5165cdaf8ff74dfa9b61506af68b2fc9c0b887d3e49cc534040221f72fe6ec705e3964ad1e6d686840dd821c7a386baecb841369c98f5b493820be03c3b726cba925c72b05ea3d1b") + assert(serializePaymentOnion(onion) == hex"0002531fe6068134503d2723133227c867ac8fa6c83c537e9a44c3c5bdbdcb1fe337dadf610256c6ab518495dce9cdedf9391e21a71dadae10fd58e9f942b47910d5930eb0cd8a93d1a9d82f6d15eb30469544f953cbeea2690de4ebe5d56ad2013b520af2a3c49316bc590ee83e8c31b1eb11ff766dad27ca993326b1ed582fb451a2ad87fbf6601134c6341c4a2deb6850e25a355be68dbb6923dc89444fdd74a0f700433b667bda345926099f5547b07e97ad903e8a01566a78ae177366239e793dac719de805565b6d0ac99f607988fe1d495e3c521226c1aa365ad8bb0ce3ba89f7bd12f763d5dde227786435976102a924ba4ba3de6150c829ce01c25428f2f5d05ef023be7d590ecdf6603730db3948f80ca1ed3d85227e64ef77200b9b557f427b6e1073cfa0e63e4485441768b98ab11ba8104a6cee1d7af7bb9a965b018f15a640032384ae5bf66aae631c36e27b1659e5e24cf59e2ba7cc91b38e0c6d17ef6ba7bbe93f02046975bb01b7f766fcfc5a755af11a90cc7eb3505986b56e07a7855534d03b79f0dfbfe645b0d6d4185c038771fd25b800aa26b2ed2e30b1e713659468618a2fea04fcd047328c0e59dbc2cafdc6dc8ae27d2ef1a66edde4421aeee6aa237db172fa79d6f0048e082d762b6419b54c7ec7aa4b5d328d8d8cf9c900d730d737b3f1fe3d1d8a0bffeb2a02c87d780607c2cbfde72956d0d943e02df35141a10e22124f6a06a006e3b02acd5fdfc7215ed4e9165dba47ac8889e62551bd699b488fb2facce52bfe9b6a9c2b1da6701c8ca628a69d6d40e20dd69d6e879d7a6dbfe891468eb48deeeb9557c64fce6f5f3513b9661bf002d1b97b46cb02f017644d1c44134d2d584ebee9e23fa8df8fbbdb517312e58e6bebd4a05cbe4bb2c7e724a592c978ecb99ab961dd2534e2bf4e1d19a1b3424c5beb7d52ef358ca9da68ee450df5d3192aca77382c443e14249842189234ede17a70f69ea2cb818083a763495047d17d5ebd645ace0f24d3d55333143edc87b7ef9edccbcd84a040db2938b9b06bb7185abe233eb7e3eb4188190bb442e3a0e39e5979f460a07465b9094544f715341b8b50545314f31ccff9ebed799ea6c71a6786bee35170a2aa240ec748802f4569c1bd2021a71303d1c1813d66f09e398de0f7cccb0e93e156c3b5e0bc8bca59b66ab24ddcdcfb195eb73cfb85dcf997a163ad186e0ae2c31e3789c77b550a52e8e9948c909b268946526cb76552952a88698745664b1158c2670b0028de729da3e63b912294de3b2c7d1ce4dffc5ec4dc33e18bb54d87d463dd60717fb123f18a04e2ce39c3280b0facf8b46909b137ae029f48886b5ca7afb003f005369982fdb6ec74a853fdeadfd583e4c31fba838428f86901d5c2bc6ff02a705fcc8b903bc64fcd800034df880c85a513bee7e30cb687704f41d7d850dc72fd631d8f8b2d9b5f7deee3a2a55d0def17313a571b3c94882b89c8fc537a0599fa67ab5f16db2e282c59437ab820fc67954f81182a8e43f7e0c8521e8bac614280051745145f7d8cebc8b54bb3077ddb8fbea92c7f1e42a48ff16d73af6b6f5002a85de92c61d27f6fa321ee67482a5cfa35fcfe8242438256edbb89bc9dfa82546735567367876323afb2c7d37f6ec1196d75e5059c6f90bd28093d991305d8acb01c0a1e64503bed932419b9925b357e56201f9c774b17170b060b764c80840f989f10c57ba473f2d642ad0ef91d61cb1311c523f088954c9fd8684992d6fc8b3e45d2eafd39d48fb83c74a8eba569023a8374dcc68086925c9c86be344d960cef3289647a2c9511676e365ca09ef55e7f80d99875280c6e08767f1efd3943bbd76ce5622dbe1dfa8b8513c5092054400a1f1935f7c6619ec3c49e62a83a06d98595dd87e9999b7cb8d1faef71") // Alice can decrypt the onion as usual. val Right(DecryptedPacket(onionPayloadAlice, packetForBob, sharedSecretAlice)) = peel(alice, associatedData, onion) @@ -522,10 +522,10 @@ class SphinxSpec extends AnyFunSuite { assert(Seq(sharedSecretAlice, sharedSecretBob, sharedSecretCarol, sharedSecretDave, sharedSecretEve) == sharedSecrets.map(_._1)) val packets = Seq(packetForBob, packetForCarol, packetForDave, packetForEve, packetForNobody) - assert(packets(0).hmac == ByteVector32(hex"0b462fb9321df3f139d2efccdc54471840e5cb50b4f7dae44df9c8c3e5ffabde")) - assert(packets(1).hmac == ByteVector32(hex"5c7b8d4f3061b3e58194edfb76ac339932c61ff77b024192508c9628a0206bb7")) - assert(packets(2).hmac == ByteVector32(hex"6a8df602e649e459b456df92327d7cf28132b735d38d3692c7c199e27d298c85")) - assert(packets(3).hmac == ByteVector32(hex"6db2bc62c58cd931570f8b7eb13b96e40b8a34e13655eb4f4b3a3ec87824403d")) + assert(packets(0).hmac == ByteVector32(hex"7ab0b5d9a7dcf322e4047f68a5bd85515d6d79c79f3b2664ac6199287cf10aed")) + assert(packets(1).hmac == ByteVector32(hex"01cbaf4492e6750e4117f07794616cd67109b2515b1f571d2c4a52d38e5e7a94")) + assert(packets(2).hmac == ByteVector32(hex"6d2aca5357fc52e30ba75c8cdaf17ee4a9cbd7b428b6f98dcba472fe173ba6d0")) + assert(packets(3).hmac == ByteVector32(hex"47467bcedc35f4182ea05711a296138c0bd79d0c643a3a39aa72cb3514f12c58")) assert(packets(4).hmac == ByteVector32(hex"0000000000000000000000000000000000000000000000000000000000000000")) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala index 77a7728a6b..df44ff4dd0 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala @@ -339,8 +339,8 @@ class Bolt12InvoiceSpec extends AnyFunSuite { assert(codedDecoded.description == Left(description)) assert(codedDecoded.features == features) assert(codedDecoded.issuer.contains(issuer)) - assert(codedDecoded.nodeId.value.drop(1) == nodeKey.publicKey.value.drop(1)) - assert(codedDecoded.blindedPaths == Seq(path)) + assert(codedDecoded.signingNodeId.value.drop(1) == nodeKey.publicKey.value.drop(1)) + assert(codedDecoded.extraEdges.map(_.path) == Seq(path)) assert(codedDecoded.quantity.contains(quantity)) assert(codedDecoded.payerKey.contains(payerKey)) assert(codedDecoded.payerNote.contains(payerNote)) @@ -376,7 +376,7 @@ class Bolt12InvoiceSpec extends AnyFunSuite { val encodedInvoice = "lni1qvsxlc5vp2m0rvmjcxn2y34wv0m5lyc7sdj7zksgn35dvxgqqqqqqqqyyrfgrkke8dp3jww26jz8zgvhxhdhzgj062ejxecv9uqsdhh2x9lnjzqrkudsqzstd45ku6tdv9kzqarfwqg8sqj075ch7pgu0ah2cqnxchxw46mv2a66js86hxz5u3ala0mtc7syqupdypsecj08jzgq82kzfmd8ncs9mufkaea9dr305na9vccycmjmlfspqvxsr2nmet6yjwzmjtrmqspxnyt9wl9jv46ep5t49amw3xpj82hk6qqjy0yn6ww6ektzyys7qrm6zcul88r27ysuqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqpdcmqqqqq83pqf8l2vtlq5w87m4vqfnvtn82adk9wadfgratnp2wg7l7ha4u0gzqwf3qrm6y88mr3y2du7fzqjpamgedldayx8nenfwwtfmy877hpvs33e8zsprzmlnns23qshlyweee7p4m365legtkdgvy6s02rdqsv38mwnmk8p88cz03dt7zuqsqzmcyqvpkh2g4088w2xu7uvu6zvsxwrh2vgvppgnmf0vyqhqwqv6w8lgeulalcq6xznps7gw9h0rtfpwxftz4l7j2nnuzj3gpy86kg34awtdq" val decodedInvoice = Bolt12Invoice.fromString(encodedInvoice).get assert(decodedInvoice.amount == invoice.amount) - assert(decodedInvoice.nodeId == invoice.nodeId) + assert(decodedInvoice.signingNodeId == invoice.signingNodeId) assert(decodedInvoice.paymentHash == invoice.paymentHash) assert(decodedInvoice.description == invoice.description) assert(decodedInvoice.payerKey == invoice.payerKey) @@ -405,7 +405,7 @@ class Bolt12InvoiceSpec extends AnyFunSuite { val encodedInvoice = "lni1qvsxlc5vp2m0rvmjcxn2y34wv0m5lyc7sdj7zksgn35dvxgqqqqqqqqyypmpsc7ww3cxguwl27ela95ykset7t8tlvyfy7a200eujcnhczws6zqyrvhqd5s2p4kkjmnfd4skcgr0venx2ussnqpufzkf0cyl8ja6av6mq242d5rjk4mjdpq6xnf9j5s40jk2vzsu4agr8f5tqgegums2pxkyxcarfk6fyzdk37akrn808xrptvzzj222gv9szqervpzvaxzejaejwul8wkjuldd0qpjxpt85vlp3mncpyx30dgrzduqr99dq04sehw2nh3kqcnmj87gn9x5fcln9njcshnjcqc4c4d9vvw98fxeqm2037p4e82jce87n6nud6gncvysuqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqpktsx6gqqq83pq0zg4jt7p8euhwhtxkcz42ndqu44wunggx356fv4y9tu4jnq58902f3q86209t74drgj3upwe6zy449q58wl9f8r5z97ktd6zxelzy6tq5tjsprzmlnns23q9jcw0vzjxencw3gvx0d0d5hjc09kzv3zzvnwrsd5ntyhlht7kuszuqsqzmcyqh2ej2lvwj9chganv56tasj2a4x9expx44tr65u9cw8xyrdzvqnd09g60evuy5gqs08hxmx4rd2npqfdekmqjc4zvf5qf0v65uta9glq" val decodedInvoice = Bolt12Invoice.fromString(encodedInvoice).get assert(decodedInvoice.amount == invoice.amount) - assert(decodedInvoice.nodeId == invoice.nodeId) + assert(decodedInvoice.signingNodeId == invoice.signingNodeId) assert(decodedInvoice.paymentHash == invoice.paymentHash) assert(decodedInvoice.description == invoice.description) assert(decodedInvoice.payerKey == invoice.payerKey) @@ -441,7 +441,7 @@ class Bolt12InvoiceSpec extends AnyFunSuite { val encodedInvoice = "lni1qvsyxjtl6luzd9t3pr62xr7eemp6awnejusgf6gw45q75vcfqqqqqqqyyqcnw8ucesh0ttrka67a62qyf04tsprv4ul6uyrpctdm596q7av2zzqrdhwsqzsndanxvetjypmkjargypch2ctww35hg7gsnqpj0t74n8dryfh5vz9ed2cy9lj43064sgga830x0mxgh6vkxgsyxnczy9ysc4m9zqvmruq7clt4dfxuwjn8hmc240m0pm4yclacwtkugtaqzq75v83x5evkfwaj4amaac7e84kf9l6zcr28nyv7mx09jv87zvdvcuqr9d5ex7wdrd3g7vjxjnztctuk2tuasa5xs8klwadygqaq5dtner75zpfmptt0jv7mha7s60gft0nh8efmcysuqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqmwaqqqqq9q3v9kxjcm9gp3xjemndphhqtnrdak3uggry7hatxw6xgn0gcytj64sgtl9tzl4tqs360z7vlkv305evv3qgd8jqq2gycs8cmgrl28nvm3wlqqheha0t570rgaszg7mzvvzvwmx9s92nmyujkfgq33dleec9gs8up5r8hpz5vcfzxv706ag9yrde627yfhscttac8lw9u5u3g3udvpwqgqz9uzqt5ag0q6zkyft7jwxxcgr9etqk2psjcc44rzye2yzvx5mw7qw694lzka89xnn49qt6yh8am5xtdr5jy3mkzg49xwnz2zvx2z3a7rdajg" val decodedInvoice = Bolt12Invoice.fromString(encodedInvoice).get assert(decodedInvoice.amount == invoice.amount) - assert(decodedInvoice.nodeId == invoice.nodeId) + assert(decodedInvoice.signingNodeId == invoice.signingNodeId) assert(decodedInvoice.paymentHash == invoice.paymentHash) assert(decodedInvoice.description == invoice.description) assert(decodedInvoice.payerKey == invoice.payerKey) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala index 135b81ba3b..e8318b6b29 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala @@ -261,8 +261,8 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(privKey, offer, invoiceReq)) val invoice = sender.expectMsgType[Bolt12Invoice] assert(invoice.amount == 25_000.msat) - assert(invoice.nodeId == privKey.publicKey) - assert(invoice.blindedPaths.nonEmpty) + assert(invoice.signingNodeId == privKey.publicKey) + assert(invoice.extraEdges.nonEmpty) assert(invoice.features.hasFeature(RouteBlinding, Some(Mandatory))) assert(invoice.description == Left("a blinded coffee please")) assert(invoice.offerId.contains(offer.offerId)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala index 2d72e08bec..82a156abf5 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala @@ -80,7 +80,9 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS import f._ assert(payFsm.stateName == WAIT_FOR_PAYMENT_REQUEST) - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 1, None, routeParams = routeParams.copy(randomize = true)) + val paymentSecret = randomBytes32() + val finalPayload = FinalPayload.Standard.createMultiPartPayload(finalAmount, expiry, paymentSecret, None) + val payment = SendMultiPartPayment(sender.ref, e, finalPayload, finalAmount, expiry, 1, routeParams = routeParams.copy(randomize = true)) sender.send(payFsm, payment) router.expectMsg(RouteRequest(nodeParams.nodeId, e, finalAmount, maxFee, routeParams = routeParams.copy(randomize = false), allowMultiPart = true, paymentContext = Some(cfg.paymentContext))) @@ -91,10 +93,10 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS val childPayment = childPayFsm.expectMsgType[SendPaymentToRoute] assert(childPayment.route == Right(singleRoute)) assert(childPayment.finalPayload.isInstanceOf[FinalPayload.Standard]) - assert(childPayment.finalPayload.expiry == expiry) - assert(childPayment.finalPayload.asInstanceOf[FinalPayload.Standard].paymentSecret == payment.paymentSecret) - assert(childPayment.finalPayload.amount == finalAmount) - assert(childPayment.finalPayload.totalAmount == finalAmount) + assert(childPayment.finalPayload.asInstanceOf[FinalPayload.Standard].expiry == expiry) + assert(childPayment.finalPayload.asInstanceOf[FinalPayload.Standard].paymentSecret == paymentSecret) + assert(childPayment.finalPayload.asInstanceOf[FinalPayload.Standard].amount == finalAmount) + assert(childPayment.finalPayload.asInstanceOf[FinalPayload.Standard].totalAmount == finalAmount) assert(payFsm.stateName == PAYMENT_IN_PROGRESS) val result = fulfillPendingPayments(f, 1) @@ -114,7 +116,9 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS import f._ assert(payFsm.stateName == WAIT_FOR_PAYMENT_REQUEST) - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, 1200000 msat, expiry, 1, Some(hex"012345"), routeParams = routeParams.copy(randomize = false)) + val paymentSecret = randomBytes32() + val finalPayload = FinalPayload.Standard.createMultiPartPayload(1200000 msat, expiry, paymentSecret, Some(hex"012345")) + val payment = SendMultiPartPayment(sender.ref, e, finalPayload, 1200000 msat, expiry, 1, routeParams = routeParams.copy(randomize = false)) sender.send(payFsm, payment) router.expectMsg(RouteRequest(nodeParams.nodeId, e, 1200000 msat, maxFee, routeParams = routeParams.copy(randomize = false), allowMultiPart = true, paymentContext = Some(cfg.paymentContext))) @@ -127,12 +131,12 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS router.send(payFsm, RouteResponse(routes)) val childPayments = childPayFsm.expectMsgType[SendPaymentToRoute] :: childPayFsm.expectMsgType[SendPaymentToRoute] :: Nil assert(childPayments.map(_.route).toSet == routes.map(r => Right(r)).toSet) - assert(childPayments.map(_.finalPayload.expiry).toSet == Set(expiry)) + assert(childPayments.map(_.finalPayload.asInstanceOf[FinalPayload.Standard].expiry).toSet == Set(expiry)) childPayments.foreach(childPayment => assert(childPayment.finalPayload.isInstanceOf[FinalPayload.Standard])) - assert(childPayments.map(_.finalPayload.asInstanceOf[FinalPayload.Standard].paymentSecret).toSet == Set(payment.paymentSecret)) + assert(childPayments.map(_.finalPayload.asInstanceOf[FinalPayload.Standard].paymentSecret).toSet == Set(paymentSecret)) assert(childPayments.map(_.finalPayload.asInstanceOf[FinalPayload.Standard].paymentMetadata).toSet == Set(Some(hex"012345"))) - assert(childPayments.map(_.finalPayload.amount).toSet == Set(500000 msat, 700000 msat)) - assert(childPayments.map(_.finalPayload.totalAmount).toSet == Set(1200000 msat)) + assert(childPayments.map(_.finalPayload.asInstanceOf[FinalPayload.Standard].amount).toSet == Set(500000 msat, 700000 msat)) + assert(childPayments.map(_.finalPayload.asInstanceOf[FinalPayload.Standard].totalAmount).toSet == Set(1200000 msat)) assert(payFsm.stateName == PAYMENT_IN_PROGRESS) val result = fulfillPendingPayments(f, 2) @@ -154,7 +158,8 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS // We include a bunch of additional tlv records. val trampolineTlv = OnionPaymentPayloadTlv.TrampolineOnion(OnionRoutingPacket(0, ByteVector.fill(33)(0), ByteVector.fill(400)(0), randomBytes32())) val userCustomTlv = GenericTlv(UInt64(561), hex"deadbeef") - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount + 1000.msat, expiry, 1, None, routeParams = routeParams, additionalTlvs = Seq(trampolineTlv), userCustomTlvs = Seq(userCustomTlv)) + val finalPayload = FinalPayload.Standard.createMultiPartPayload(finalAmount + 1000.msat, expiry, randomBytes32(), None, additionalTlvs = Seq(trampolineTlv), userCustomTlvs = Seq(userCustomTlv)) + val payment = SendMultiPartPayment(sender.ref, e, finalPayload, finalAmount + 1000.msat, expiry, 1, routeParams = routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] router.send(payFsm, RouteResponse(Seq(Route(500000 msat, hop_ab_1 :: hop_be :: Nil), Route(501000 msat, hop_ac_1 :: hop_ce :: Nil)))) @@ -178,7 +183,8 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS test("successful retry") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 3, None, routeParams = routeParams) + val finalPayload = FinalPayload.Standard.createMultiPartPayload(finalAmount, expiry, randomBytes32(), None) + val payment = SendMultiPartPayment(sender.ref, e, finalPayload, finalAmount, expiry, 3, routeParams = routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] val failingRoute = Route(finalAmount, hop_ab_1 :: hop_be :: Nil) @@ -211,7 +217,8 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS test("retry failures while waiting for routes") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 3, None, routeParams = routeParams) + val finalPayload = FinalPayload.Standard.createMultiPartPayload(finalAmount, expiry, randomBytes32(), None) + val payment = SendMultiPartPayment(sender.ref, e, finalPayload, finalAmount, expiry, 3, routeParams = routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] router.send(payFsm, RouteResponse(Seq(Route(400000 msat, hop_ab_1 :: hop_be :: Nil), Route(600000 msat, hop_ab_2 :: hop_be :: Nil)))) @@ -253,7 +260,8 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS test("retry local channel failures") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 3, None, routeParams = routeParams) + val finalPayload = FinalPayload.Standard.createMultiPartPayload(finalAmount, expiry, randomBytes32(), None) + val payment = SendMultiPartPayment(sender.ref, e, finalPayload, finalAmount, expiry, 3, routeParams = routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] router.send(payFsm, RouteResponse(Seq(Route(finalAmount, hop_ab_1 :: hop_be :: Nil)))) @@ -278,7 +286,8 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS test("retry without ignoring channels") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 3, None, routeParams = routeParams) + val finalPayload = FinalPayload.Standard.createMultiPartPayload(finalAmount, expiry, randomBytes32(), None) + val payment = SendMultiPartPayment(sender.ref, e, finalPayload, finalAmount, expiry, 3, routeParams = routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] router.send(payFsm, RouteResponse(Seq(Route(500000 msat, hop_ab_1 :: hop_be :: Nil), Route(500000 msat, hop_ab_1 :: hop_be :: Nil)))) @@ -320,9 +329,10 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS test("retry with updated routing hints") { f => import f._ + val finalPayload = FinalPayload.Standard.createMultiPartPayload(finalAmount, expiry, randomBytes32(), None) // The B -> E channel is private and provided in the invoice routing hints. val extraEdge = Invoice.BasicEdge(b, e, hop_be.shortChannelId, hop_be.params.relayFees.feeBase, hop_be.params.relayFees.feeProportionalMillionths, hop_be.params.cltvExpiryDelta) - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 3, None, routeParams = routeParams, extraEdges = List(extraEdge)) + val payment = SendMultiPartPayment(sender.ref, e, finalPayload, finalAmount, expiry, 3, routeParams = routeParams, extraEdges = List(extraEdge)) sender.send(payFsm, payment) assert(router.expectMsgType[RouteRequest].extraEdges.head == extraEdge) val route = Route(finalAmount, hop_ab_1 :: hop_be :: Nil) @@ -342,9 +352,10 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS test("retry with ignored routing hints (temporary channel failure)") { f => import f._ + val finalPayload = FinalPayload.Standard.createMultiPartPayload(finalAmount, expiry, randomBytes32(), None) // The B -> E channel is private and provided in the invoice routing hints. val extraEdge = Invoice.BasicEdge(b, e, hop_be.shortChannelId, hop_be.params.relayFees.feeBase, hop_be.params.relayFees.feeProportionalMillionths, hop_be.params.cltvExpiryDelta) - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 3, None, routeParams = routeParams, extraEdges = List(extraEdge)) + val payment = SendMultiPartPayment(sender.ref, e, finalPayload, finalAmount, expiry, 3, routeParams = routeParams, extraEdges = List(extraEdge)) sender.send(payFsm, payment) assert(router.expectMsgType[RouteRequest].extraEdges.head == extraEdge) val route = Route(finalAmount, hop_ab_1 :: hop_be :: Nil) @@ -404,7 +415,8 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS test("abort after too many failed attempts") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 2, None, routeParams = routeParams) + val finalPayload = FinalPayload.Standard.createMultiPartPayload(finalAmount, expiry, randomBytes32(), None) + val payment = SendMultiPartPayment(sender.ref, e, finalPayload, finalAmount, expiry, 2, routeParams = routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] router.send(payFsm, RouteResponse(Seq(Route(500000 msat, hop_ab_1 :: hop_be :: Nil), Route(500000 msat, hop_ac_1 :: hop_ce :: Nil)))) @@ -435,7 +447,8 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS import f._ sender.watch(payFsm) - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 5, None, routeParams = routeParams) + val finalPayload = FinalPayload.Standard.createMultiPartPayload(finalAmount, expiry, randomBytes32(), None) + val payment = SendMultiPartPayment(sender.ref, e, finalPayload, finalAmount, expiry, 5, routeParams = routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] router.send(payFsm, Status.Failure(RouteNotFound)) @@ -465,7 +478,8 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS test("abort if recipient sends error") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 5, None, routeParams = routeParams) + val finalPayload = FinalPayload.Standard.createMultiPartPayload(finalAmount, expiry, randomBytes32(), None) + val payment = SendMultiPartPayment(sender.ref, e, finalPayload, finalAmount, expiry, 5, routeParams = routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] router.send(payFsm, RouteResponse(Seq(Route(finalAmount, hop_ab_1 :: hop_be :: Nil)))) @@ -486,7 +500,8 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS test("abort if payment gets settled on chain") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 5, None, routeParams = routeParams) + val finalPayload = FinalPayload.Standard.createMultiPartPayload(finalAmount, expiry, randomBytes32(), None) + val payment = SendMultiPartPayment(sender.ref, e, finalPayload, finalAmount, expiry, 5, routeParams = routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] router.send(payFsm, RouteResponse(Seq(Route(finalAmount, hop_ab_1 :: hop_be :: Nil)))) @@ -500,7 +515,8 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS test("abort if recipient sends error during retry") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 5, None, routeParams = routeParams) + val finalPayload = FinalPayload.Standard.createMultiPartPayload(finalAmount, expiry, randomBytes32(), None) + val payment = SendMultiPartPayment(sender.ref, e, finalPayload, finalAmount, expiry, 5, routeParams = routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] router.send(payFsm, RouteResponse(Seq(Route(400000 msat, hop_ab_1 :: hop_be :: Nil), Route(600000 msat, hop_ac_1 :: hop_ce :: Nil)))) @@ -518,7 +534,8 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS test("receive partial success after retriable failure (recipient spec violation)") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 5, None, routeParams = routeParams) + val finalPayload = FinalPayload.Standard.createMultiPartPayload(finalAmount, expiry, randomBytes32(), None) + val payment = SendMultiPartPayment(sender.ref, e, finalPayload, finalAmount, expiry, 5, routeParams = routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] router.send(payFsm, RouteResponse(Seq(Route(400000 msat, hop_ab_1 :: hop_be :: Nil), Route(600000 msat, hop_ac_1 :: hop_ce :: Nil)))) @@ -538,7 +555,8 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS test("receive partial success after abort (recipient spec violation)") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 5, None, routeParams = routeParams) + val finalPayload = FinalPayload.Standard.createMultiPartPayload(finalAmount, expiry, randomBytes32(), None) + val payment = SendMultiPartPayment(sender.ref, e, finalPayload, finalAmount, expiry, 5, routeParams = routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] router.send(payFsm, RouteResponse(Seq(Route(400000 msat, hop_ab_1 :: hop_be :: Nil), Route(600000 msat, hop_ac_1 :: hop_ce :: Nil)))) @@ -571,7 +589,8 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS test("receive partial failure after success (recipient spec violation)") { f => import f._ - val payment = SendMultiPartPayment(sender.ref, randomBytes32(), e, finalAmount, expiry, 5, None, routeParams = routeParams) + val finalPayload = FinalPayload.Standard.createMultiPartPayload(finalAmount, expiry, randomBytes32(), None) + val payment = SendMultiPartPayment(sender.ref, e, finalPayload, finalAmount, expiry, 5, routeParams = routeParams) sender.send(payFsm, payment) router.expectMsgType[RouteRequest] router.send(payFsm, RouteResponse(Seq(Route(400000 msat, hop_ab_1 :: hop_be :: Nil), Route(600000 msat, hop_ac_1 :: hop_ce :: Nil)))) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala index d5fa5c2249..44a437a3aa 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala @@ -155,7 +155,7 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike sender.send(initiator, request) val payment = sender.expectMsgType[SendPaymentToRouteResponse] payFsm.expectMsg(SendPaymentConfig(payment.paymentId, payment.parentId, None, paymentHash, finalAmount, c, Upstream.Local(payment.paymentId), Some(invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = false, Nil)) - payFsm.expectMsg(PaymentLifecycle.SendPaymentToRoute(initiator, Left(route), FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight + 1), invoice.paymentSecret, invoice.paymentMetadata))) + payFsm.expectMsg(PaymentLifecycle.SendPaymentToRoute(initiator, Left(route), invoice.singlePartFinalPayload(finalAmount, finalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight + 1)), finalAmount, finalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight + 1))) sender.send(initiator, GetPayment(Left(payment.paymentId))) sender.expectMsg(PaymentIsPending(payment.paymentId, invoice.paymentHash, PendingPaymentToRoute(sender.ref, request))) @@ -180,7 +180,7 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike sender.send(initiator, req) val id = sender.expectMsgType[UUID] payFsm.expectMsg(SendPaymentConfig(id, id, None, paymentHash, finalAmount, c, Upstream.Local(id), Some(invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = true, Nil)) - payFsm.expectMsg(PaymentLifecycle.SendPaymentToNode(initiator, c, FinalPayload.Standard(TlvStream(OnionPaymentPayloadTlv.AmountToForward(finalAmount), OnionPaymentPayloadTlv.OutgoingCltv(req.finalExpiry(nodeParams.currentBlockHeight)), OnionPaymentPayloadTlv.PaymentData(invoice.paymentSecret, finalAmount))), 1, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams)) + payFsm.expectMsg(PaymentLifecycle.SendPaymentToNode(initiator, c, FinalPayload.Standard(TlvStream(OnionPaymentPayloadTlv.AmountToForward(finalAmount), OnionPaymentPayloadTlv.OutgoingCltv(req.finalExpiry(nodeParams.currentBlockHeight)), OnionPaymentPayloadTlv.PaymentData(invoice.paymentSecret, finalAmount))), finalAmount, req.finalExpiry(nodeParams.currentBlockHeight), 1, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams)) sender.send(initiator, GetPayment(Left(id))) sender.expectMsg(PaymentIsPending(id, invoice.paymentHash, PendingPaymentToNode(sender.ref, req))) @@ -203,7 +203,8 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike sender.send(initiator, req) val id = sender.expectMsgType[UUID] multiPartPayFsm.expectMsg(SendPaymentConfig(id, id, None, paymentHash, finalAmount + 100.msat, c, Upstream.Local(id), Some(invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = true, Nil)) - multiPartPayFsm.expectMsg(SendMultiPartPayment(initiator, invoice.paymentSecret, c, finalAmount + 100.msat, req.finalExpiry(nodeParams.currentBlockHeight), 1, invoice.paymentMetadata, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams)) + val finalPayload = invoice.multiPartFinalPayload(finalAmount + 100.msat, req.finalExpiry(nodeParams.currentBlockHeight)) + multiPartPayFsm.expectMsg(SendMultiPartPayment(initiator, c, finalPayload, finalAmount + 100.msat, req.finalExpiry(nodeParams.currentBlockHeight), 1, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams)) sender.send(initiator, GetPayment(Left(id))) sender.expectMsg(PaymentIsPending(id, invoice.paymentHash, PendingPaymentToNode(sender.ref, req))) @@ -231,10 +232,10 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(msg.replyTo == initiator) assert(msg.route == Left(route)) assert(msg.finalPayload.isInstanceOf[FinalPayload.Standard]) - assert(msg.finalPayload.amount == finalAmount / 2) - assert(msg.finalPayload.expiry == req.finalExpiry(nodeParams.currentBlockHeight)) + assert(msg.finalPayload.asInstanceOf[FinalPayload.Standard].amount == finalAmount / 2) + assert(msg.finalPayload.asInstanceOf[FinalPayload.Standard].expiry == req.finalExpiry(nodeParams.currentBlockHeight)) assert(msg.finalPayload.asInstanceOf[FinalPayload.Standard].paymentSecret == invoice.paymentSecret) - assert(msg.finalPayload.totalAmount == finalAmount) + assert(msg.finalPayload.asInstanceOf[FinalPayload.Standard].totalAmount == finalAmount) sender.send(initiator, GetPayment(Left(payment.paymentId))) sender.expectMsg(PaymentIsPending(payment.paymentId, invoice.paymentHash, PendingPaymentToRoute(sender.ref, req))) @@ -266,15 +267,15 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike sender.expectMsg(PaymentIsPending(id, invoice.paymentHash, PendingTrampolinePayment(sender.ref, Nil, req))) val msg = multiPartPayFsm.expectMsgType[SendMultiPartPayment] - assert(msg.paymentSecret !== invoice.paymentSecret) // we should not leak the invoice secret to the trampoline node + assert(msg.finalPayload.records.get[OnionPaymentPayloadTlv.PaymentData].get.secret !== invoice.paymentSecret) // we should not leak the invoice secret to the trampoline node assert(msg.targetNodeId == b) assert(msg.targetExpiry.toLong == currentBlockCount + 9 + 12 + 1) assert(msg.totalAmount == finalAmount + trampolineFees) - assert(msg.additionalTlvs.head.isInstanceOf[OnionPaymentPayloadTlv.TrampolineOnion]) + assert(msg.finalPayload.records.get[OnionPaymentPayloadTlv.TrampolineOnion].nonEmpty) assert(msg.maxAttempts == nodeParams.maxPaymentAttempts) // Verify that the trampoline node can correctly peel the trampoline onion. - val trampolineOnion = msg.additionalTlvs.head.asInstanceOf[OnionPaymentPayloadTlv.TrampolineOnion].packet + val trampolineOnion = msg.finalPayload.records.get[OnionPaymentPayloadTlv.TrampolineOnion].get.packet val Right(decrypted) = Sphinx.peel(priv_b.privateKey, Some(invoice.paymentHash), trampolineOnion) assert(!decrypted.isLastPacket) val Right(trampolinePayload) = IntermediatePayload.NodeRelay.Standard.validate(PaymentOnionCodecs.perHopPayloadCodec.decode(decrypted.payload.bits).require.value) @@ -306,14 +307,14 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike multiPartPayFsm.expectMsgType[SendPaymentConfig] val msg = multiPartPayFsm.expectMsgType[SendMultiPartPayment] - assert(msg.paymentSecret !== invoice.paymentSecret) // we should not leak the invoice secret to the trampoline node + assert(msg.finalPayload.records.get[OnionPaymentPayloadTlv.PaymentData].get.secret !== invoice.paymentSecret) // we should not leak the invoice secret to the trampoline node assert(msg.targetNodeId == b) assert(msg.targetExpiry.toLong == currentBlockCount + 9 + 12 + 1) assert(msg.totalAmount == finalAmount + trampolineFees) - assert(msg.additionalTlvs.head.isInstanceOf[OnionPaymentPayloadTlv.TrampolineOnion]) + assert(msg.finalPayload.records.get[OnionPaymentPayloadTlv.TrampolineOnion].nonEmpty) // Verify that the trampoline node can correctly peel the trampoline onion. - val trampolineOnion = msg.additionalTlvs.head.asInstanceOf[OnionPaymentPayloadTlv.TrampolineOnion].packet + val trampolineOnion = msg.finalPayload.records.get[OnionPaymentPayloadTlv.TrampolineOnion].get.packet val Right(decrypted) = Sphinx.peel(priv_b.privateKey, Some(invoice.paymentHash), trampolineOnion) assert(!decrypted.isLastPacket) val Right(trampolinePayload) = IntermediatePayload.NodeRelay.Standard.validate(PaymentOnionCodecs.perHopPayloadCodec.decode(decrypted.payload.bits).require.value) @@ -460,9 +461,9 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val msg = payFsm.expectMsgType[PaymentLifecycle.SendPaymentToRoute] assert(msg.route == Left(route)) assert(msg.finalPayload.isInstanceOf[FinalPayload.Standard]) - assert(msg.finalPayload.amount == finalAmount + trampolineFees) + assert(msg.finalPayload.asInstanceOf[FinalPayload.Standard].amount == finalAmount + trampolineFees) assert(msg.finalPayload.asInstanceOf[FinalPayload.Standard].paymentSecret == payment.trampolineSecret.get) - assert(msg.finalPayload.totalAmount == finalAmount + trampolineFees) + assert(msg.finalPayload.asInstanceOf[FinalPayload.Standard].totalAmount == finalAmount + trampolineFees) val trampolineOnion = msg.finalPayload.records.get[OnionPaymentPayloadTlv.TrampolineOnion] assert(trampolineOnion.nonEmpty) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index 3bbcd723cb..b3261fd7b5 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -44,6 +44,7 @@ import fr.acinq.eclair.router.Graph.WeightRatios import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.router._ import fr.acinq.eclair.transactions.Scripts +import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload import fr.acinq.eclair.wire.protocol._ import scodec.bits.ByteVector @@ -105,7 +106,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // pre-computed route going from A to D val route = Route(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: Nil) - val request = SendPaymentToRoute(sender.ref, Right(route), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata)) + val request = SendPaymentToRoute(sender.ref, Right(route), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry) sender.send(paymentFSM, request) routerForwarder.expectNoMessage(100 millis) // we don't need the router, we have the pre-computed route val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) @@ -133,7 +134,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // pre-computed route going from A to D val route = PredefinedNodeRoute(Seq(a, b, c, d)) - val request = SendPaymentToRoute(sender.ref, Left(route), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata)) + val request = SendPaymentToRoute(sender.ref, Left(route), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry) sender.send(paymentFSM, request) routerForwarder.expectMsg(FinalizeRoute(defaultAmountMsat, route, paymentContext = Some(cfg.paymentContext))) @@ -159,7 +160,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val payFixture = createPaymentLifecycle(recordMetrics = false) import payFixture._ - val brokenRoute = SendPaymentToRoute(sender.ref, Left(PredefinedNodeRoute(Seq(randomKey().publicKey, randomKey().publicKey, randomKey().publicKey))), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata)) + val brokenRoute = SendPaymentToRoute(sender.ref, Left(PredefinedNodeRoute(Seq(randomKey().publicKey, randomKey().publicKey, randomKey().publicKey))), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry) sender.send(paymentFSM, brokenRoute) routerForwarder.expectMsgType[FinalizeRoute] routerForwarder.forward(routerFixture.router) @@ -176,7 +177,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val payFixture = createPaymentLifecycle(recordMetrics = false) import payFixture._ - val brokenRoute = SendPaymentToRoute(sender.ref, Left(PredefinedChannelRoute(randomKey().publicKey, Seq(ShortChannelId(1), ShortChannelId(2)))), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata)) + val brokenRoute = SendPaymentToRoute(sender.ref, Left(PredefinedChannelRoute(randomKey().publicKey, Seq(ShortChannelId(1), ShortChannelId(2)))), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry) sender.send(paymentFSM, brokenRoute) routerForwarder.expectMsgType[FinalizeRoute] routerForwarder.forward(routerFixture.router) @@ -197,7 +198,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val recipient = randomKey().publicKey val route = PredefinedNodeRoute(Seq(a, b, c, recipient)) val extraEdges = Seq(BasicEdge(c, recipient, ShortChannelId(561), 1 msat, 100, CltvExpiryDelta(144))) - val request = SendPaymentToRoute(sender.ref, Left(route), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), extraEdges) + val request = SendPaymentToRoute(sender.ref, Left(route), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry, extraEdges) sender.send(paymentFSM, request) routerForwarder.expectMsg(FinalizeRoute(defaultAmountMsat, route, extraEdges, paymentContext = Some(cfg.paymentContext))) @@ -223,7 +224,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, f, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 5, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, f, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry, 5, routeParams = defaultRouteParams) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val routeRequest = routerForwarder.expectMsgType[RouteRequest] @@ -256,7 +257,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { "my-test-experiment", experimentPercentage = 100 ).getDefaultRouteParams - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 5, routeParams = routeParams) + val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry, 5, routeParams = routeParams) sender.send(paymentFSM, request) val routeRequest = routerForwarder.expectMsgType[RouteRequest] val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) @@ -281,7 +282,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import cfg._ val paymentMetadataTooBig = ByteVector.fromValidHex("01" * 1300) - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, Some(paymentMetadataTooBig)), 5, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, Some(paymentMetadataTooBig)), defaultAmountMsat, defaultExpiry, 5, routeParams = defaultRouteParams) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val routeRequest = routerForwarder.expectMsgType[RouteRequest] @@ -300,7 +301,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 2, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry, 2, routeParams = defaultRouteParams) sender.send(paymentFSM, request) routerForwarder.expectMsg(defaultRouteRequest(a, d, cfg)) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) @@ -345,7 +346,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 2, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry, 2, routeParams = defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) routerForwarder.expectMsgType[RouteRequest] @@ -366,7 +367,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 2, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry, 2, routeParams = defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) routerForwarder.expectMsgType[RouteRequest] @@ -386,7 +387,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 2, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry, 2, routeParams = defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) @@ -409,7 +410,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 2, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry, 2, routeParams = defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) @@ -432,7 +433,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 2, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry, 2, routeParams = defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) @@ -455,7 +456,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val payFixture = createPaymentLifecycle() import payFixture._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 2, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry, 2, routeParams = defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) val WaitingForRoute(_, Nil, _) = paymentFSM.stateData @@ -486,7 +487,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 5, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry, 5, routeParams = defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) @@ -540,7 +541,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val payFixture = createPaymentLifecycle() import payFixture._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 1, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry, 1, routeParams = defaultRouteParams) sender.send(paymentFSM, request) routerForwarder.expectMsg(defaultRouteRequest(nodeParams.nodeId, d, cfg)) routerForwarder.forward(routerFixture.router) @@ -571,7 +572,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { BasicEdge(c, d, scid_cd, update_cd.feeBaseMsat, update_cd.feeProportionalMillionths, update_cd.cltvExpiryDelta) ) - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 5, extraEdges = extraEdges, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry, 5, extraEdges = extraEdges, routeParams = defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) @@ -612,7 +613,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // we build an assisted route for channel cd val extraEdges = Seq(BasicEdge(c, d, scid_cd, update_cd.feeBaseMsat, update_cd.feeProportionalMillionths, update_cd.cltvExpiryDelta)) - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 1, extraEdges = extraEdges, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry, 1, extraEdges = extraEdges, routeParams = defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) @@ -628,7 +629,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val failureOnion = Sphinx.FailurePacket.wrap(Sphinx.FailurePacket.create(sharedSecrets1(1)._1, failure), sharedSecrets1.head._1) sender.send(paymentFSM, addCompleted(HtlcResult.RemoteFail(UpdateFailHtlc(ByteVector32.Zeroes, 0, failureOnion)))) - assert(routerForwarder.expectMsgType[RouteCouldRelay].route.hops.map(_.shortChannelId) == Seq(update_ab, update_bc).map(_.shortChannelId)) + assert(routerForwarder.expectMsgType[RouteCouldRelay].route.hops.map(_.asInstanceOf[ChannelHop].shortChannelId) == Seq(update_ab, update_bc).map(_.shortChannelId)) routerForwarder.expectMsg(ExcludeChannel(ChannelDesc(update_cd.shortChannelId, c, d), Some(nodeParams.routerConf.channelExcludeDuration))) routerForwarder.expectMsg(channelUpdate_cd_disabled) } @@ -638,7 +639,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 2, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry, 2, routeParams = defaultRouteParams) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && nodeParams.db.payments.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) @@ -676,7 +677,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 5, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry, 5, routeParams = defaultRouteParams) sender.send(paymentFSM, request) routerForwarder.expectMsgType[RouteRequest] val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) @@ -703,7 +704,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { assert(metrics.fees == 730.msat) metricsListener.expectNoMessage() - assert(routerForwarder.expectMsgType[RouteDidRelay].route.hops.map(_.shortChannelId) == Seq(update_ab, update_bc, update_cd).map(_.shortChannelId)) + assert(routerForwarder.expectMsgType[RouteDidRelay].route.hops.map(_.asInstanceOf[ChannelHop].shortChannelId) == Seq(update_ab, update_bc, update_cd).map(_.shortChannelId)) } test("payment succeeded to a channel with fees=0") { routerFixture => @@ -732,7 +733,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ // we send a payment to H - val request = SendPaymentToNode(sender.ref, h, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 5, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, h, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry, 5, routeParams = defaultRouteParams) sender.send(paymentFSM, request) routerForwarder.expectMsgType[RouteRequest] @@ -744,13 +745,13 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, addCompleted(HtlcResult.OnChainFulfill(defaultPaymentPreimage))) val paymentOK = sender.expectMsgType[PaymentSent] val PaymentSent(_, _, paymentOK.paymentPreimage, finalAmount, _, PartialPayment(_, partAmount, fee, ByteVector32.Zeroes, _, _) :: Nil) = eventListener.expectMsgType[PaymentSent] - assert(partAmount == request.finalPayload.amount) + assert(partAmount == request.finalPayload.asInstanceOf[FinalPayload.Standard].amount) assert(finalAmount == defaultAmountMsat) // NB: A -> B doesn't pay fees because it's our direct neighbor // NB: B -> H doesn't asks for fees at all assert(fee == 0.msat) - assert(paymentOK.recipientAmount == request.finalPayload.amount) + assert(paymentOK.recipientAmount == request.finalPayload.asInstanceOf[FinalPayload.Standard].amount) val metrics = metricsListener.expectMsgType[PathFindingExperimentMetrics] assert(metrics.status == "SUCCESS") @@ -759,7 +760,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { assert(metrics.fees == 0.msat) metricsListener.expectNoMessage() - assert(routerForwarder.expectMsgType[RouteDidRelay].route.hops.map(_.shortChannelId) == Seq(update_ab, channelUpdate_bh).map(_.shortChannelId)) + assert(routerForwarder.expectMsgType[RouteDidRelay].route.hops.map(_.asInstanceOf[ChannelHop].shortChannelId) == Seq(update_ab, channelUpdate_bh).map(_.shortChannelId)) } test("filter errors properly") { () => @@ -795,7 +796,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { (RemoteFailure(defaultAmountMsat, route_abcd, Sphinx.DecryptedFailurePacket(b, FeeInsufficient(100 msat, update_bc))), Set.empty, Set.empty), // unreadable remote failures -> blacklist all nodes except our direct peer and the final recipient (UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: Nil), Set.empty, Set.empty), - (UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: ChannelHop(ShortChannelId(5656986L), d, e, null) :: Nil), Set(c, d), Set.empty) + (UnreadableRemoteFailure(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: ChannelHop(ShortChannelId(5656986L), d, e, ChannelRelayParams.FromHint(Invoice.BasicEdge(d, e, ShortChannelId(5656986L), 0 msat, 111, CltvExpiryDelta(55)))) :: Nil), Set(c, d), Set.empty) ) for ((failure, expectedNodes, expectedChannels) <- testCases) { @@ -819,7 +820,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import payFixture._ import cfg._ - val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), 3, routeParams = defaultRouteParams) + val request = SendPaymentToNode(sender.ref, d, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry, 3, routeParams = defaultRouteParams) sender.send(paymentFSM, request) routerForwarder.expectMsgType[RouteRequest] val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) @@ -842,7 +843,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // pre-computed route going from A to D val route = Route(defaultAmountMsat, channelHopFromUpdate(a, b, update_ab) :: channelHopFromUpdate(b, c, update_bc) :: channelHopFromUpdate(c, d, update_cd) :: Nil) - val request = SendPaymentToRoute(sender.ref, Right(route), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata)) + val request = SendPaymentToRoute(sender.ref, Right(route), PaymentOnion.FinalPayload.Standard.createSinglePartPayload(defaultAmountMsat, defaultExpiry, defaultInvoice.paymentSecret, defaultInvoice.paymentMetadata), defaultAmountMsat, defaultExpiry) sender.send(paymentFSM, request) routerForwarder.expectNoMessage(100 millis) // we don't need the router, we have the pre-computed route val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala index 6798da1c10..4737d3c91c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala @@ -64,7 +64,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { def testBuildOnion(): Unit = { val Right(finalPayload) = FinalPayload.Standard.validate(TlvStream(AmountToForward(finalAmount), OutgoingCltv(finalExpiry), PaymentData(paymentSecret, 0 msat))) - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, hops, finalPayload) + val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, hops, finalPayload, finalAmount, finalExpiry) assert(firstAmount == amount_ab) assert(firstExpiry == expiry_ab) assert(onion.packet.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) @@ -110,7 +110,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(payload_e.isInstanceOf[FinalPayload.Standard]) assert(payload_e.amount == finalAmount) assert(payload_e.totalAmount == finalAmount) - assert(payload_e.expiry == finalExpiry) + assert(payload_e.asInstanceOf[FinalPayload.Standard].expiry == finalExpiry) assert(payload_e.asInstanceOf[FinalPayload.Standard].paymentSecret == paymentSecret) } @@ -119,7 +119,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { } test("build a command including the onion") { - val Success((add, _)) = buildCommand(ActorRef.noSender, Upstream.Local(UUID.randomUUID), paymentHash, hops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) + val Success((_, add, _)) = buildCommand(ActorRef.noSender, randomKey(), Upstream.Local(UUID.randomUUID), paymentHash, hops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None), finalAmount, finalExpiry) assert(add.amount > finalAmount) assert(add.cltvExpiry == finalExpiry + channelUpdate_de.cltvExpiryDelta + channelUpdate_cd.cltvExpiryDelta + channelUpdate_bc.cltvExpiryDelta) assert(add.paymentHash == paymentHash) @@ -130,7 +130,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { } test("build a command with no hops") { - val Success((add, _)) = buildCommand(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, hops.take(1), FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, Some(paymentMetadata))) + val Success((_, add, _)) = buildCommand(ActorRef.noSender, randomKey(), Upstream.Local(UUID.randomUUID()), paymentHash, hops.take(1), FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, Some(paymentMetadata)), finalAmount, finalExpiry) assert(add.amount == finalAmount) assert(add.cltvExpiry == finalExpiry) assert(add.paymentHash == paymentHash) @@ -143,7 +143,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(payload_b.isInstanceOf[FinalPayload.Standard]) assert(payload_b.amount == finalAmount) assert(payload_b.totalAmount == finalAmount) - assert(payload_b.expiry == finalExpiry) + assert(payload_b.asInstanceOf[FinalPayload.Standard].expiry == finalExpiry) assert(payload_b.asInstanceOf[FinalPayload.Standard].paymentSecret == paymentSecret) assert(payload_b.asInstanceOf[FinalPayload.Standard].paymentMetadata.contains(paymentMetadata)) } @@ -154,11 +154,11 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // / \ / \ // a -> b -> c d e - val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount * 3, finalExpiry, paymentSecret, Some(hex"010203"))) + val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount * 3, finalExpiry, paymentSecret, Some(hex"010203")), finalAmount, finalExpiry) assert(amount_ac == amount_bc) assert(expiry_ac == expiry_bc) - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet)) + val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet), amount_ac, expiry_ac) assert(firstAmount == amount_ab) assert(firstExpiry == expiry_ab) @@ -182,7 +182,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(inner_c.paymentMetadata.isEmpty) // c forwards the trampoline payment to d. - val Success((amount_d, expiry_d, onion_d)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(c, d, channelUpdate_cd) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_cd, amount_cd, expiry_cd, randomBytes32(), packet_d)) + val Success((amount_d, expiry_d, onion_d)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(c, d, channelUpdate_cd) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_cd, amount_cd, expiry_cd, randomBytes32(), packet_d), amount_cd, expiry_cd) assert(amount_d == amount_cd) assert(expiry_d == expiry_cd) val add_d = UpdateAddHtlc(randomBytes32(), 3, amount_d, paymentHash, expiry_d, onion_d.packet, None) @@ -200,7 +200,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(inner_d.paymentMetadata.isEmpty) // d forwards the trampoline payment to e. - val Success((amount_e, expiry_e, onion_e)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(d, e, channelUpdate_de) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_de, amount_de, expiry_de, randomBytes32(), packet_e)) + val Success((amount_e, expiry_e, onion_e)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(d, e, channelUpdate_de) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_de, amount_de, expiry_de, randomBytes32(), packet_e), amount_de, expiry_de) assert(amount_e == amount_de) assert(expiry_e == expiry_de) val add_e = UpdateAddHtlc(randomBytes32(), 4, amount_e, paymentHash, expiry_e, onion_e.packet, None) @@ -218,11 +218,11 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val routingHints = List(List(Bolt11Invoice.ExtraHop(randomKey().publicKey, ShortChannelId(42), 10 msat, 100, CltvExpiryDelta(144)))) val invoiceFeatures = Features[InvoiceFeature](VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional) val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, Some(finalAmount), paymentHash, priv_a.privateKey, Left("#reckless"), CltvExpiryDelta(18), None, None, routingHints, features = invoiceFeatures, paymentMetadata = Some(hex"010203")) - val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolineToLegacyPacket(invoice, trampolineHops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, invoice.paymentSecret, None)) + val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolineToLegacyPacket(invoice, trampolineHops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, invoice.paymentSecret, None), finalAmount, finalExpiry) assert(amount_ac == amount_bc) assert(expiry_ac == expiry_bc) - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet)) + val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet), amount_ac, expiry_ac) assert(firstAmount == amount_ab) assert(firstExpiry == expiry_ab) @@ -243,7 +243,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(inner_c.paymentSecret.isEmpty) // c forwards the trampoline payment to d. - val Success((amount_d, expiry_d, onion_d)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(c, d, channelUpdate_cd) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_cd, amount_cd, expiry_cd, randomBytes32(), packet_d)) + val Success((amount_d, expiry_d, onion_d)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(c, d, channelUpdate_cd) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_cd, amount_cd, expiry_cd, randomBytes32(), packet_d), amount_cd, expiry_cd) assert(amount_d == amount_cd) assert(expiry_d == expiry_cd) val add_d = UpdateAddHtlc(randomBytes32(), 3, amount_d, paymentHash, expiry_d, onion_d.packet, None) @@ -265,19 +265,19 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { test("fail to build a trampoline payment when too much invoice data is provided") { val routingHintOverflow = List(List.fill(7)(Bolt11Invoice.ExtraHop(randomKey().publicKey, ShortChannelId(1), 10 msat, 100, CltvExpiryDelta(12)))) val invoice = Bolt11Invoice(Block.RegtestGenesisBlock.hash, Some(finalAmount), paymentHash, priv_a.privateKey, Left("#reckless"), CltvExpiryDelta(18), None, None, routingHintOverflow) - assert(buildTrampolineToLegacyPacket(invoice, trampolineHops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, invoice.paymentSecret, invoice.paymentMetadata)).isFailure) + assert(buildTrampolineToLegacyPacket(invoice, trampolineHops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, invoice.paymentSecret, invoice.paymentMetadata), finalAmount, finalExpiry).isFailure) } test("fail to decrypt when the onion is invalid") { - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, hops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) + val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, hops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None), finalAmount, finalExpiry) val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet.copy(payload = onion.packet.payload.reverse), None) val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) assert(failure.isInstanceOf[InvalidOnionHmac]) } test("fail to decrypt when the trampoline onion is invalid") { - val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount * 2, finalExpiry, paymentSecret, None)) - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet.copy(payload = trampolineOnion.packet.payload.reverse))) + val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount * 2, finalExpiry, paymentSecret, None), finalAmount, finalExpiry) + val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet.copy(payload = trampolineOnion.packet.payload.reverse)), amount_ac, expiry_ac) val add_b = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None) val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(add_b, priv_b.privateKey, Features.empty) val add_c = UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc, packet_c, None) @@ -286,59 +286,59 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { } test("fail to decrypt when payment hash doesn't match associated data") { - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash.reverse, hops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) + val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash.reverse, hops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None), finalAmount, finalExpiry) val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None) val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) assert(failure.isInstanceOf[InvalidOnionHmac]) } test("fail to decrypt at the final node when amount has been modified by next-to-last node") { - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, hops.take(1), FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) + val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, hops.take(1), FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None), finalAmount, finalExpiry) val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount - 100.msat, paymentHash, firstExpiry, onion.packet, None) val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) assert(failure == FinalIncorrectHtlcAmount(firstAmount - 100.msat)) } test("fail to decrypt at the final node when expiry has been modified by next-to-last node") { - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, hops.take(1), FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) + val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, hops.take(1), FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None), finalAmount, finalExpiry) val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry - CltvExpiryDelta(12), onion.packet, None) val Left(failure) = decrypt(add, priv_b.privateKey, Features.empty) assert(failure == FinalIncorrectCltvExpiry(firstExpiry - CltvExpiryDelta(12))) } test("fail to decrypt at the final trampoline node when amount has been modified by next-to-last trampoline") { - val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret, None)) - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet)) + val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret, None), finalAmount, finalExpiry) + val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet), amount_ac, expiry_ac) val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None), priv_b.privateKey, Features.empty) val Right(NodeRelayPacket(_, _, _, packet_d)) = decrypt(UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc, packet_c, None), priv_c.privateKey, Features.empty) // c forwards the trampoline payment to d. - val Success((amount_d, expiry_d, onion_d)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(c, d, channelUpdate_cd) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_cd, amount_cd, expiry_cd, randomBytes32(), packet_d)) + val Success((amount_d, expiry_d, onion_d)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(c, d, channelUpdate_cd) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_cd, amount_cd, expiry_cd, randomBytes32(), packet_d), amount_cd, expiry_cd) val Right(NodeRelayPacket(_, _, _, packet_e)) = decrypt(UpdateAddHtlc(randomBytes32(), 3, amount_d, paymentHash, expiry_d, onion_d.packet, None), priv_d.privateKey, Features.empty) // d forwards an invalid amount to e (the outer total amount doesn't match the inner amount). val invalidTotalAmount = amount_de + 100.msat - val Success((amount_e, expiry_e, onion_e)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(d, e, channelUpdate_de) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_de, invalidTotalAmount, expiry_de, randomBytes32(), packet_e)) + val Success((amount_e, expiry_e, onion_e)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(d, e, channelUpdate_de) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_de, invalidTotalAmount, expiry_de, randomBytes32(), packet_e), amount_de, expiry_de) val Left(failure) = decrypt(UpdateAddHtlc(randomBytes32(), 4, amount_e, paymentHash, expiry_e, onion_e.packet, None), priv_e.privateKey, Features.empty) assert(failure == FinalIncorrectHtlcAmount(invalidTotalAmount)) } test("fail to decrypt at the final trampoline node when expiry has been modified by next-to-last trampoline") { - val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret, None)) - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet)) + val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret, None), finalAmount, finalExpiry) + val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet), amount_ac, expiry_ac) val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None), priv_b.privateKey, Features.empty) val Right(NodeRelayPacket(_, _, _, packet_d)) = decrypt(UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc, packet_c, None), priv_c.privateKey, Features.empty) // c forwards the trampoline payment to d. - val Success((amount_d, expiry_d, onion_d)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(c, d, channelUpdate_cd) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_cd, amount_cd, expiry_cd, randomBytes32(), packet_d)) + val Success((amount_d, expiry_d, onion_d)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(c, d, channelUpdate_cd) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_cd, amount_cd, expiry_cd, randomBytes32(), packet_d), amount_cd, expiry_cd) val Right(NodeRelayPacket(_, _, _, packet_e)) = decrypt(UpdateAddHtlc(randomBytes32(), 3, amount_d, paymentHash, expiry_d, onion_d.packet, None), priv_d.privateKey, Features.empty) // d forwards an invalid expiry to e (the outer expiry doesn't match the inner expiry). val invalidExpiry = expiry_de - CltvExpiryDelta(12) - val Success((amount_e, expiry_e, onion_e)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(d, e, channelUpdate_de) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_de, amount_de, invalidExpiry, randomBytes32(), packet_e)) + val Success((amount_e, expiry_e, onion_e)) = buildPaymentPacket(paymentHash, channelHopFromUpdate(d, e, channelUpdate_de) :: Nil, FinalPayload.Standard.createTrampolinePayload(amount_de, amount_de, invalidExpiry, randomBytes32(), packet_e), amount_de, invalidExpiry) val Left(failure) = decrypt(UpdateAddHtlc(randomBytes32(), 4, amount_e, paymentHash, expiry_e, onion_e.packet, None), priv_e.privateKey, Features.empty) assert(failure == FinalIncorrectCltvExpiry(invalidExpiry)) } test("fail to decrypt at intermediate trampoline node when amount is invalid") { - val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet)) + val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None), finalAmount, finalExpiry) + val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet), amount_ac, expiry_ac) val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None), priv_b.privateKey, Features.empty) // A trampoline relay is very similar to a final node: it can validate that the HTLC amount matches the onion outer amount. val Left(failure) = decrypt(UpdateAddHtlc(randomBytes32(), 2, amount_bc - 100.msat, paymentHash, expiry_bc, packet_c, None), priv_c.privateKey, Features.empty) @@ -346,8 +346,8 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { } test("fail to decrypt at intermediate trampoline node when expiry is invalid") { - val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) - val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet)) + val Success((amount_ac, expiry_ac, trampolineOnion)) = buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None), finalAmount, finalExpiry) + val Success((firstAmount, firstExpiry, onion)) = buildPaymentPacket(paymentHash, trampolineChannelHops, FinalPayload.Standard.createTrampolinePayload(amount_ac, amount_ac, expiry_ac, randomBytes32(), trampolineOnion.packet), amount_ac, expiry_ac) val Right(ChannelRelayPacket(_, _, packet_c)) = decrypt(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet, None), priv_b.privateKey, Features.empty) // A trampoline relay is very similar to a final node: it can validate that the HTLC expiry matches the onion outer expiry. val Left(failure) = decrypt(UpdateAddHtlc(randomBytes32(), 2, amount_bc, paymentHash, expiry_bc - CltvExpiryDelta(12), packet_c, None), priv_c.privateKey, Features.empty) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala index 92b304967b..c67ac93e9d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala @@ -720,7 +720,7 @@ object PostRestartHtlcCleanerSpec { val (paymentHash1, paymentHash2, paymentHash3) = (Crypto.sha256(preimage1), Crypto.sha256(preimage2), Crypto.sha256(preimage3)) def buildHtlc(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32): UpdateAddHtlc = { - val Success((cmd, _)) = buildCommand(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, hops, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), None)) + val Success((_, cmd, _)) = buildCommand(ActorRef.noSender, randomKey(), Upstream.Local(UUID.randomUUID()), paymentHash, hops, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), None), finalAmount, finalExpiry) UpdateAddHtlc(channelId, htlcId, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion, None) } @@ -729,7 +729,7 @@ object PostRestartHtlcCleanerSpec { def buildHtlcOut(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32): DirectedHtlc = OutgoingHtlc(buildHtlc(htlcId, channelId, paymentHash)) def buildFinalHtlc(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32): DirectedHtlc = { - val Success((cmd, _)) = buildCommand(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, channelHopFromUpdate(a, TestConstants.Bob.nodeParams.nodeId, channelUpdate_ab) :: Nil, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), None)) + val Success((_, cmd, _)) = buildCommand(ActorRef.noSender, randomKey(), Upstream.Local(UUID.randomUUID()), paymentHash, channelHopFromUpdate(a, TestConstants.Bob.nodeParams.nodeId, channelUpdate_ab) :: Nil, PaymentOnion.FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), None), finalAmount, finalExpiry) IncomingHtlc(UpdateAddHtlc(channelId, htlcId, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion, None)) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala index 97dd0140bc..f401b88508 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala @@ -726,12 +726,12 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] validateOutgoingCfg(outgoingCfg, Upstream.Trampoline(incomingMultiPart.map(_.add))) val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] - assert(outgoingPayment.paymentSecret == invoice.paymentSecret) // we should use the provided secret - assert(outgoingPayment.paymentMetadata == invoice.paymentMetadata) // we should use the provided metadata + assert(outgoingPayment.finalPayload.records.get[OnionPaymentPayloadTlv.PaymentData].get.secret == invoice.paymentSecret) // we should use the provided secret + assert(outgoingPayment.finalPayload.records.get[OnionPaymentPayloadTlv.PaymentMetadata].map(_.data) == invoice.paymentMetadata) // we should use the provided metadata assert(outgoingPayment.totalAmount == outgoingAmount) assert(outgoingPayment.targetExpiry == outgoingExpiry) assert(outgoingPayment.targetNodeId == outgoingNodeId) - assert(outgoingPayment.additionalTlvs == Nil) + assert(outgoingPayment.finalPayload.records.get[OnionPaymentPayloadTlv.TrampolineOnion].isEmpty) assert(outgoingPayment.extraEdges.head.asInstanceOf[BasicEdge].shortChannelId == ShortChannelId(42)) assert(outgoingPayment.extraEdges.head.asInstanceOf[BasicEdge].sourceNodeId == hints.head.nodeId) assert(outgoingPayment.extraEdges.head.asInstanceOf[BasicEdge].targetNodeId == outgoingNodeId) @@ -774,8 +774,8 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl validateOutgoingCfg(outgoingCfg, Upstream.Trampoline(incomingMultiPart.map(_.add))) val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode] assert(outgoingPayment.finalPayload.isInstanceOf[FinalPayload.Standard]) - assert(outgoingPayment.finalPayload.amount == outgoingAmount) - assert(outgoingPayment.finalPayload.expiry == outgoingExpiry) + assert(outgoingPayment.finalPayload.asInstanceOf[FinalPayload.Standard].amount == outgoingAmount) + assert(outgoingPayment.finalPayload.asInstanceOf[FinalPayload.Standard].expiry == outgoingExpiry) assert(outgoingPayment.finalPayload.asInstanceOf[FinalPayload.Standard].paymentMetadata == invoice.paymentMetadata) // we should use the provided metadata assert(outgoingPayment.targetNodeId == outgoingNodeId) assert(outgoingPayment.extraEdges.head.asInstanceOf[BasicEdge].shortChannelId == ShortChannelId(42)) @@ -835,11 +835,11 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl } def validateOutgoingPayment(outgoingPayment: SendMultiPartPayment): Unit = { - assert(outgoingPayment.paymentSecret !== incomingSecret) // we should generate a new outgoing secret + assert(outgoingPayment.finalPayload.records.get[OnionPaymentPayloadTlv.PaymentData].get.secret !== incomingSecret) // we should generate a new outgoing secret assert(outgoingPayment.totalAmount == outgoingAmount) assert(outgoingPayment.targetExpiry == outgoingExpiry) assert(outgoingPayment.targetNodeId == outgoingNodeId) - assert(outgoingPayment.additionalTlvs == Seq(OnionPaymentPayloadTlv.TrampolineOnion(nextTrampolinePacket))) + assert(outgoingPayment.finalPayload.records.get[OnionPaymentPayloadTlv.TrampolineOnion].get.packet == nextTrampolinePacket) assert(outgoingPayment.extraEdges == Nil) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala index d06d3f55fd..def22a89f1 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala @@ -88,7 +88,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat } // we use this to build a valid onion - val Success((cmd, _)) = buildCommand(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, hops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) + val Success((_, cmd, _)) = buildCommand(ActorRef.noSender, randomKey(), Upstream.Local(UUID.randomUUID()), paymentHash, hops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None), finalAmount, finalExpiry) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = randomBytes32(), id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion, None) relayer ! RelayForward(add_ab) @@ -98,7 +98,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat test("relay an htlc-add at the final node to the payment handler") { f => import f._ - val Success((cmd, _)) = buildCommand(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, hops.take(1), FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) + val Success((_, cmd, _)) = buildCommand(ActorRef.noSender, randomKey(), Upstream.Local(UUID.randomUUID()), paymentHash, hops.take(1), FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None), finalAmount, finalExpiry) val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion, None) relayer ! RelayForward(add_ab) @@ -118,10 +118,10 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat // We simulate a payment split between multiple trampoline routes. val totalAmount = finalAmount * 3 val trampolineHops = NodeHop(a, b, channelUpdate_ab.cltvExpiryDelta, 0 msat) :: Nil - val Success((trampolineAmount, trampolineExpiry, trampolineOnion)) = OutgoingPaymentPacket.buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createMultiPartPayload(finalAmount, totalAmount, finalExpiry, paymentSecret, None)) + val Success((trampolineAmount, trampolineExpiry, trampolineOnion)) = OutgoingPaymentPacket.buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createMultiPartPayload(finalAmount, totalAmount, finalExpiry, paymentSecret, None), finalAmount, finalExpiry) assert(trampolineAmount == finalAmount) assert(trampolineExpiry == finalExpiry) - val Success((cmd, _)) = buildCommand(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, channelHopFromUpdate(a, b, channelUpdate_ab) :: Nil, FinalPayload.Standard.createTrampolinePayload(trampolineAmount, trampolineAmount, trampolineExpiry, randomBytes32(), trampolineOnion.packet)) + val Success((_, cmd, _)) = buildCommand(ActorRef.noSender, randomKey(), Upstream.Local(UUID.randomUUID()), paymentHash, channelHopFromUpdate(a, b, channelUpdate_ab) :: Nil, FinalPayload.Standard.createTrampolinePayload(trampolineAmount, trampolineAmount, trampolineExpiry, randomBytes32(), trampolineOnion.packet), trampolineAmount, trampolineExpiry) assert(cmd.amount == finalAmount) assert(cmd.cltvExpiry == finalExpiry) val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion, None) @@ -133,7 +133,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat assert(fp.payload.isInstanceOf[FinalPayload.Standard]) assert(fp.payload.amount == finalAmount) assert(fp.payload.totalAmount == totalAmount) - assert(fp.payload.expiry == finalExpiry) + assert(fp.payload.asInstanceOf[FinalPayload.Standard].expiry == finalExpiry) assert(fp.payload.asInstanceOf[FinalPayload.Standard].paymentSecret == paymentSecret) register.expectNoMessage(50 millis) @@ -143,7 +143,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat import f._ // we use this to build a valid onion - val Success((cmd, _)) = buildCommand(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, hops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) + val Success((_, cmd, _)) = buildCommand(ActorRef.noSender, randomKey(), Upstream.Local(UUID.randomUUID()), paymentHash, hops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None), finalAmount, finalExpiry) // and then manually build an htlc with an invalid onion (hmac) val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion.copy(hmac = cmd.onion.hmac.reverse), None) @@ -164,8 +164,8 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat // we use this to build a valid trampoline onion inside a normal onion val trampolineHops = NodeHop(a, b, channelUpdate_ab.cltvExpiryDelta, 0 msat) :: NodeHop(b, c, channelUpdate_bc.cltvExpiryDelta, fee_b) :: Nil - val Success((trampolineAmount, trampolineExpiry, trampolineOnion)) = OutgoingPaymentPacket.buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None)) - val Success((cmd, _)) = buildCommand(ActorRef.noSender, Upstream.Local(UUID.randomUUID()), paymentHash, channelHopFromUpdate(a, b, channelUpdate_ab) :: Nil, FinalPayload.Standard.createTrampolinePayload(trampolineAmount, trampolineAmount, trampolineExpiry, randomBytes32(), trampolineOnion.packet)) + val Success((trampolineAmount, trampolineExpiry, trampolineOnion)) = OutgoingPaymentPacket.buildTrampolinePacket(paymentHash, trampolineHops, FinalPayload.Standard.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, None), finalAmount, finalExpiry) + val Success((_, cmd, _)) = buildCommand(ActorRef.noSender, randomKey(), Upstream.Local(UUID.randomUUID()), paymentHash, channelHopFromUpdate(a, b, channelUpdate_ab) :: Nil, FinalPayload.Standard.createTrampolinePayload(trampolineAmount, trampolineAmount, trampolineExpiry, randomBytes32(), trampolineOnion.packet), trampolineAmount, trampolineExpiry) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion, None) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala index 3783965545..c16ea57f3f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala @@ -505,12 +505,12 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { val Success(route1 :: Nil) = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = BlockHeight(400000)) assert(route2Ids(route1) == 1 :: 2 :: 3 :: 4 :: Nil) - assert(route1.hops(1).params.relayFees.feeBase == 10.msat) + assert(route1.hops(1).asInstanceOf[ChannelHop].params.relayFees.feeBase == 10.msat) val extraGraphEdges = Set(makeEdge(2L, b, c, 5 msat, 5)) val Success(route2 :: Nil) = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, numRoutes = 1, extraEdges = extraGraphEdges, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = BlockHeight(400000)) assert(route2Ids(route2) == 1 :: 2 :: 3 :: 4 :: Nil) - assert(route2.hops(1).params.relayFees.feeBase == 5.msat) + assert(route2.hops(1).asInstanceOf[ChannelHop].params.relayFees.feeBase == 5.msat) } test("compute ignored channels") { @@ -1959,17 +1959,17 @@ object RouteCalculationSpec { def hops2Ids(hops: Seq[ChannelHop]): Seq[Long] = hops.map(hop => hop.shortChannelId.toLong) - def route2Ids(route: Route): Seq[Long] = hops2Ids(route.hops) + def route2Ids(route: Route): Seq[Long] = hops2Ids(route.hops.map(_.asInstanceOf[ChannelHop])) def routes2Ids(routes: Seq[Route]): Set[Seq[Long]] = routes.map(route2Ids).toSet - def route2Edges(route: Route): Seq[GraphEdge] = route.hops.map(hop => GraphEdge(ChannelDesc(hop.shortChannelId, hop.nodeId, hop.nextNodeId), hop.params, 0 sat, None)) + def route2Edges(route: Route): Seq[GraphEdge] = route.hops.map(hop => GraphEdge(ChannelDesc(hop.asInstanceOf[ChannelHop].shortChannelId, hop.nodeId, hop.nextNodeId), hop.asInstanceOf[ChannelHop].params, 0 sat, None)) def route2Nodes(route: Route): Seq[(PublicKey, PublicKey)] = route.hops.map(hop => (hop.nodeId, hop.nextNodeId)) def checkIgnoredChannels(routes: Seq[Route], shortChannelIds: Long*): Unit = { shortChannelIds.foreach(shortChannelId => routes.foreach(route => { - assert(route.hops.forall(_.shortChannelId.toLong != shortChannelId), route) + assert(route.hops.forall(_.asInstanceOf[ChannelHop].shortChannelId.toLong != shortChannelId), route) })) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala index 0ae271a686..6369e05992 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala @@ -526,7 +526,7 @@ class RouterSpec extends BaseRouterSpec { // the route hasn't changed (nodes are the same) assert(response.routes.head.hops.map(_.nodeId) == preComputedRoute.nodes.dropRight(1)) assert(response.routes.head.hops.map(_.nextNodeId) == preComputedRoute.nodes.drop(1)) - assert(response.routes.head.hops.map(_.params) == Seq(ChannelRelayParams.FromAnnouncement(update_ab), ChannelRelayParams.FromAnnouncement(update_bc), ChannelRelayParams.FromAnnouncement(update_cd))) + assert(response.routes.head.hops.map(_.asInstanceOf[ChannelHop].params) == Seq(ChannelRelayParams.FromAnnouncement(update_ab), ChannelRelayParams.FromAnnouncement(update_bc), ChannelRelayParams.FromAnnouncement(update_cd))) } test("given a pre-defined channels route add the proper channel updates") { fixture => @@ -540,7 +540,7 @@ class RouterSpec extends BaseRouterSpec { // the route hasn't changed (nodes are the same) assert(response.routes.head.hops.map(_.nodeId) == Seq(a, b, c)) assert(response.routes.head.hops.map(_.nextNodeId) == Seq(b, c, d)) - assert(response.routes.head.hops.map(_.params) == Seq(ChannelRelayParams.FromAnnouncement(update_ab), ChannelRelayParams.FromAnnouncement(update_bc), ChannelRelayParams.FromAnnouncement(update_cd))) + assert(response.routes.head.hops.map(_.asInstanceOf[ChannelHop].params) == Seq(ChannelRelayParams.FromAnnouncement(update_ab), ChannelRelayParams.FromAnnouncement(update_bc), ChannelRelayParams.FromAnnouncement(update_cd))) } test("given a pre-defined private channels route add the proper channel updates") { fixture => @@ -554,7 +554,7 @@ class RouterSpec extends BaseRouterSpec { val response = sender.expectMsgType[RouteResponse] assert(response.routes.length == 1) val route = response.routes.head - assert(route.hops.map(_.params) == Seq(ChannelRelayParams.FromAnnouncement(update_ag_private))) + assert(route.hops.map(_.asInstanceOf[ChannelHop].params) == Seq(ChannelRelayParams.FromAnnouncement(update_ag_private))) assert(route.hops.head.nodeId == a) assert(route.hops.head.nextNodeId == g) } @@ -565,7 +565,7 @@ class RouterSpec extends BaseRouterSpec { val response = sender.expectMsgType[RouteResponse] assert(response.routes.length == 1) val route = response.routes.head - assert(route.hops.map(_.params) == Seq(ChannelRelayParams.FromAnnouncement(update_ag_private))) + assert(route.hops.map(_.asInstanceOf[ChannelHop].params) == Seq(ChannelRelayParams.FromAnnouncement(update_ag_private))) assert(route.hops.head.nodeId == a) assert(route.hops.head.nextNodeId == g) } @@ -577,7 +577,7 @@ class RouterSpec extends BaseRouterSpec { val route = response.routes.head assert(route.hops.map(_.nodeId) == Seq(a, g)) assert(route.hops.map(_.nextNodeId) == Seq(g, h)) - assert(route.hops.map(_.params) == Seq(ChannelRelayParams.FromAnnouncement(update_ag_private), ChannelRelayParams.FromAnnouncement(update_gh))) + assert(route.hops.map(_.asInstanceOf[ChannelHop].params) == Seq(ChannelRelayParams.FromAnnouncement(update_ag_private), ChannelRelayParams.FromAnnouncement(update_gh))) } } @@ -598,8 +598,8 @@ class RouterSpec extends BaseRouterSpec { val route = response.routes.head assert(route.hops.map(_.nodeId) == Seq(a, b)) assert(route.hops.map(_.nextNodeId) == Seq(b, targetNodeId)) - assert(route.hops.head.params == ChannelRelayParams.FromAnnouncement(update_ab)) - assert(route.hops.last.params == ChannelRelayParams.FromHint(invoiceRoutingHint)) + assert(route.hops.head.asInstanceOf[ChannelHop].params == ChannelRelayParams.FromAnnouncement(update_ab)) + assert(route.hops.last.asInstanceOf[ChannelHop].params == ChannelRelayParams.FromHint(invoiceRoutingHint)) } { val invoiceRoutingHint = Invoice.BasicEdge(h, targetNodeId, RealShortChannelId(BlockHeight(420000), 516, 1105), 10 msat, 150, CltvExpiryDelta(96)) @@ -613,8 +613,8 @@ class RouterSpec extends BaseRouterSpec { val route = response.routes.head assert(route.hops.map(_.nodeId) == Seq(a, g, h)) assert(route.hops.map(_.nextNodeId) == Seq(g, h, targetNodeId)) - assert(route.hops.map(_.params).dropRight(1) == Seq(ChannelRelayParams.FromAnnouncement(update_ag_private), ChannelRelayParams.FromAnnouncement(update_gh))) - assert(route.hops.last.params == ChannelRelayParams.FromHint(invoiceRoutingHint)) + assert(route.hops.map(_.asInstanceOf[ChannelHop].params).dropRight(1) == Seq(ChannelRelayParams.FromAnnouncement(update_ag_private), ChannelRelayParams.FromAnnouncement(update_gh))) + assert(route.hops.last.asInstanceOf[ChannelHop].params == ChannelRelayParams.FromHint(invoiceRoutingHint)) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala index d025c876b6..a5e69602ad 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala @@ -201,9 +201,9 @@ class PaymentOnionSpec extends AnyFunSuite { RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat), ) val testCases = Map( - TlvStream[OnionPaymentPayloadTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), EncryptedRecipientData(hex"deadbeef")) -> hex"0d 02020231 04012a 0a04deadbeef", - TlvStream[OnionPaymentPayloadTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), EncryptedRecipientData(hex"deadbeef"), BlindingPoint(PublicKey(hex"036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2"))) -> hex"30 02020231 04012a 0a04deadbeef 0c21036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2", - TlvStream[OnionPaymentPayloadTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), EncryptedRecipientData(hex"deadbeef"), BlindingPoint(PublicKey(hex"036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2")), TotalAmount(1105 msat)) -> hex"34 02020231 04012a 0a04deadbeef 0c21036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2 12020451", + TlvStream[OnionPaymentPayloadTlv](AmountToForward(561 msat), EncryptedRecipientData(hex"deadbeef")) -> hex"0a 02020231 0a04deadbeef", + TlvStream[OnionPaymentPayloadTlv](AmountToForward(561 msat), EncryptedRecipientData(hex"deadbeef"), BlindingPoint(PublicKey(hex"036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2"))) -> hex"2d 02020231 0a04deadbeef 0c21036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2", + TlvStream[OnionPaymentPayloadTlv](AmountToForward(561 msat), EncryptedRecipientData(hex"deadbeef"), BlindingPoint(PublicKey(hex"036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2")), TotalAmount(1105 msat)) -> hex"31 02020231 0a04deadbeef 0c21036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2 12020451", ) for ((expected, bin) <- testCases) { @@ -211,7 +211,6 @@ class PaymentOnionSpec extends AnyFunSuite { assert(decoded == expected) val Right(payload) = FinalPayload.Blinded.validate(decoded, blindedTlvs) assert(payload.amount == 561.msat) - assert(payload.expiry == CltvExpiry(42)) assert(payload.pathId == hex"2a2a2a2a") val encoded = perHopPayloadCodec.encode(expected).require.bytes assert(encoded == bin) @@ -312,13 +311,13 @@ class PaymentOnionSpec extends AnyFunSuite { RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat), ) val testCases = Seq( - (MissingRequiredTlv(UInt64(2)), hex"0d 04012a 0a080123456789abcdef"), // missing amount - (MissingRequiredTlv(UInt64(4)), hex"0e 02020231 0a080123456789abcdef"), // missing expiry - (MissingRequiredTlv(UInt64(10)), hex"07 02020231 04012a"), // missing encrypted data - (ForbiddenTlv(UInt64(0)), hex"1b 02020231 04012a 06080000000000000451 0a080123456789abcdef"), // forbidden outgoing_channel_id - (ForbiddenTlv(UInt64(0)), hex"35 02020231 04012a 0822eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f2836866190451 0a080123456789abcdef"), // forbidden payment_data - (ForbiddenTlv(UInt64(0)), hex"17 02020231 04012a 0a080123456789abcdef 1004deadbeef"), // forbidden payment_metadata - (ForbiddenTlv(UInt64(65535)), hex"17 02020231 04012a 0a080123456789abcdef fdffff0206c1"), // forbidden unknown tlv + (MissingRequiredTlv(UInt64(2)), hex"0a 0a080123456789abcdef"), // missing amount + (MissingRequiredTlv(UInt64(10)), hex"04 02020231"), // missing encrypted data + (ForbiddenTlv(UInt64(0)), hex"11 02020231 04012a 0a080123456789abcdef"), // forbidden expiry + (ForbiddenTlv(UInt64(0)), hex"18 02020231 06080000000000000451 0a080123456789abcdef"), // forbidden outgoing_channel_id + (ForbiddenTlv(UInt64(0)), hex"32 02020231 0822eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f2836866190451 0a080123456789abcdef"), // forbidden payment_data + (ForbiddenTlv(UInt64(0)), hex"14 02020231 0a080123456789abcdef 1004deadbeef"), // forbidden payment_metadata + (ForbiddenTlv(UInt64(65535)), hex"14 02020231 0a080123456789abcdef fdffff0206c1"), // forbidden unknown tlv ) for ((expectedErr, bin) <- testCases) { From d2ed2b0334826453beddddddabe5fe543ae00701 Mon Sep 17 00:00:00 2001 From: Thomas HUET Date: Wed, 26 Oct 2022 15:48:25 +0200 Subject: [PATCH 2/3] Fix payments to blind --- .../acinq/eclair/payment/send/PaymentLifecycle.scala | 4 ++++ .../main/scala/fr/acinq/eclair/router/Graph.scala | 11 +++++++++-- .../fr/acinq/eclair/router/RouteCalculation.scala | 12 ++++++++++-- .../main/scala/fr/acinq/eclair/router/Router.scala | 7 +++++++ .../acinq/eclair/payment/MultiPartHandlerSpec.scala | 2 ++ .../acinq/eclair/router/RouteCalculationSpec.scala | 4 ++-- 6 files changed, 34 insertions(+), 6 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala index 6f08fb84bc..235112a70d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala @@ -298,6 +298,10 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A // we remove this edge for our next payment attempt data.c.extraEdges.filterNot(edge => edge.sourceNodeId == nodeId && edge.targetNodeId == hop.nextNodeId) } + case _: ChannelRelayParams.FromBlindedPath => + // Should not be reachable + log.error("received an update for a blinded route (shortChannelId={} nodeId={} enabled={} update={})", failure.update.shortChannelId, nodeId, failure.update.channelFlags.isEnabled, failure.update) + data.c.extraEdges } case Some(_: BlindedHop) => log.error(s"received update for blinded route, this should never happen") diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala index d03cba33db..a6f0145bda 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala @@ -463,7 +463,10 @@ object Graph { val maxBtc = 21e6.btc GraphEdge( desc = ChannelDesc(e.shortChannelId, e.sourceNodeId, e.targetNodeId), - params = ChannelRelayParams.FromHint(e), + params = e match { + case e: Invoice.BasicEdge => ChannelRelayParams.FromHint(e) + case e: Invoice.BlindedEdge => ChannelRelayParams.FromBlindedPath(e.path, e.payInfo) + }, // Bolt 11 routing hints don't include the channel's capacity, so we assume it's big enough capacity = maxBtc.toSatoshi, // we assume channels provided as hints have enough balance to handle the payment @@ -640,7 +643,11 @@ object Graph { new DirectedGraph(mutableMap.toMap) } - def graphEdgeToHop(graphEdge: GraphEdge): ChannelHop = ChannelHop(graphEdge.desc.shortChannelId, graphEdge.desc.a, graphEdge.desc.b, graphEdge.params) + def graphEdgeToHop(graphEdge: GraphEdge): ConnectedHop = + graphEdge.params match { + case blind: ChannelRelayParams.FromBlindedPath => BlindedHop(blind.path, blind.paymentInfo, blind.path.blindedNodeIds.last) + case params => ChannelHop(graphEdge.desc.shortChannelId, graphEdge.desc.a, graphEdge.desc.b, params) + } } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala index 6e358f7daa..b8100b679a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala @@ -22,6 +22,7 @@ import com.softwaremill.quicklens.ModifyPimp import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair._ +import fr.acinq.eclair.payment.Invoice import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph.graphEdgeToHop import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} import fr.acinq.eclair.router.Graph.{InfiniteLoop, NegativeProbability, RichWeight} @@ -48,7 +49,11 @@ object RouteCalculation { fr.route match { case PredefinedNodeRoute(hops) => // split into sublists [(a,b),(b,c), ...] then get the edges between each of those pairs - hops.sliding(2).map { case List(v1, v2) => g.getEdgesBetween(v1, v2) }.toList match { + hops.sliding(2).map { + case List(v1, v2) if v1 == localNodeId && v2 == localNodeId => + Seq(GraphEdge(Invoice.BasicEdge(localNodeId, localNodeId, ShortChannelId.toSelf, 0 msat, 0, CltvExpiryDelta(0)))) + case List(v1, v2) => g.getEdgesBetween(v1, v2) + }.toList match { case edges if edges.nonEmpty && edges.forall(_.nonEmpty) => // select the largest edge (using balance when available, otherwise capacity). val selectedEdges = edges.map(es => es.maxBy(e => e.balance_opt.getOrElse(e.capacity.toMilliSatoshi))) @@ -98,7 +103,10 @@ object RouteCalculation { paymentHash_opt = r.paymentContext.map(_.paymentHash))) { implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors - val extraEdges = r.extraEdges.map(GraphEdge(_)).filterNot(_.desc.a == r.source).toSet // we ignore routing hints for our own channels, we have more accurate information + val extraEdges = r.extraEdges.filter { + case edge: Invoice.BasicEdge => edge.sourceNodeId != r.source // we ignore routing hints for our own channels, we have more accurate information + case _: Invoice.BlindedEdge => true + }.map(GraphEdge(_)).toSet val ignoredEdges = r.ignore.channels ++ d.excludedChannels val params = r.routeParams val routesToFind = if (params.randomize) DEFAULT_ROUTES_COUNT else 1 diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index 314d65cc19..3f11acf821 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -453,6 +453,13 @@ object Router { override def htlcMinimum: MilliSatoshi = extraHop.htlcMinimum override def htlcMaximum_opt: Option[MilliSatoshi] = extraHop.htlcMaximum_opt } + /** It's a blinded route we learnt about from an invoice */ + case class FromBlindedPath(path: BlindedRoute, paymentInfo: PaymentInfo) extends ChannelRelayParams { + override def cltvExpiryDelta: CltvExpiryDelta = paymentInfo.cltvExpiryDelta + override def relayFees: Relayer.RelayFees = Relayer.RelayFees(paymentInfo.feeBase, paymentInfo.feeProportionalMillionths) + override def htlcMinimum: MilliSatoshi = paymentInfo.minHtlc + override def htlcMaximum_opt: Option[MilliSatoshi] = Some(paymentInfo.maxHtlc) + } def areSame(a: ChannelRelayParams, b: ChannelRelayParams, ignoreHtlcSize: Boolean = false): Boolean = a.cltvExpiryDelta == b.cltvExpiryDelta && diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala index e8318b6b29..ea09b51b4a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala @@ -30,6 +30,8 @@ import fr.acinq.eclair.payment.PaymentReceived.PartialPayment import fr.acinq.eclair.payment.receive.MultiPartHandler._ import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM.HtlcPart import fr.acinq.eclair.payment.receive.{MultiPartPaymentFSM, PaymentHandler} +import fr.acinq.eclair.router.Router +import fr.acinq.eclair.router.Router.RouteResponse import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer} import fr.acinq.eclair.wire.protocol.OnionPaymentPayloadTlv.{AmountToForward, EncryptedRecipientData, OutgoingCltv} import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala index c16ea57f3f..54bb5ba930 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala @@ -1957,9 +1957,9 @@ object RouteCalculationSpec { htlcMaximumMsat = maxHtlc.getOrElse(500_000_000 msat) ) - def hops2Ids(hops: Seq[ChannelHop]): Seq[Long] = hops.map(hop => hop.shortChannelId.toLong) + def hops2Ids(hops: Seq[ConnectedHop]): Seq[Long] = hops.map(hop => hop.asInstanceOf[ChannelHop].shortChannelId.toLong) - def route2Ids(route: Route): Seq[Long] = hops2Ids(route.hops.map(_.asInstanceOf[ChannelHop])) + def route2Ids(route: Route): Seq[Long] = hops2Ids(route.hops) def routes2Ids(routes: Seq[Route]): Set[Seq[Long]] = routes.map(route2Ids).toSet From 4c2e4156e65228bdef1b58d99f5cf66c8216986f Mon Sep 17 00:00:00 2001 From: Thomas HUET Date: Thu, 3 Nov 2022 12:18:35 +0100 Subject: [PATCH 3/3] cosmetic changes --- .../acinq/eclair/payment/PaymentPacket.scala | 8 ++++++- .../send/MultiPartPaymentLifecycle.scala | 23 ++++++++----------- .../payment/send/PaymentLifecycle.scala | 4 ++++ .../scala/fr/acinq/eclair/router/Router.scala | 2 ++ .../eclair/wire/protocol/PaymentOnion.scala | 1 + 5 files changed, 24 insertions(+), 14 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala index c6ab330b29..9b3541a32e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala @@ -252,6 +252,8 @@ object OutgoingPaymentPacket { * * @param hops the hops as computed by the router + extra routes from the invoice * @param finalPayload payload data for the final node (amount, expiry, etc) + * @param lastAmount amount for the final node + * @param lastExpiry expiry for the final node * @return a (firstAmount, firstExpiry, payloads) tuple where: * - firstAmount is the amount for the first htlc in the route * - firstExpiry is the cltv expiry for the first htlc in the route @@ -281,6 +283,8 @@ object OutgoingPaymentPacket { * * @param hops the hops as computed by the router + extra routes from the invoice, including ourselves in the first hop * @param finalPayload payload data for the final node (amount, expiry, etc) + * @param amount amount for the final node + * @param expiry expiry for the final node * @return a (firstAmount, firstExpiry, onion) tuple where: * - firstAmount is the amount for the first htlc in the route * - firstExpiry is the cltv expiry for the first htlc in the route @@ -310,6 +314,8 @@ object OutgoingPaymentPacket { * @param invoice Bolt 11 invoice (features and routing hints will be provided to the next-to-last node). * @param hops the trampoline hops (including ourselves in the first hop, and the non-trampoline final recipient in the last hop). * @param finalPayload payload data for the final node (amount, expiry, etc) + * @param amount amount for the final node + * @param expiry expiry for the final node * @return a (firstAmount, firstExpiry, onion) tuple where: * - firstAmount is the amount for the trampoline node in the route * - firstExpiry is the cltv expiry for the first trampoline node in the route @@ -347,7 +353,7 @@ object OutgoingPaymentPacket { /** * Build the command to add an HTLC with the given final payload and using the provided hops. * - * @return the command and the onion shared secrets (used to decrypt the error in case of payment failure) + * @return the channel id to send the command to, the command and the onion shared secrets (used to decrypt the error in case of payment failure) */ def buildCommand(replyTo: ActorRef, privateKey: PrivateKey, diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala index 6d758c1b11..927f1b0fa5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala @@ -30,7 +30,7 @@ import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToRoute import fr.acinq.eclair.router.Router._ -import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload.Partial +import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload import fr.acinq.eclair.{CltvExpiry, FSMDiagnosticActorLogging, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, TimestampMilli} import java.util.UUID @@ -302,21 +302,18 @@ object MultiPartPaymentLifecycle { * Send a payment to a given node. The payment may be split into multiple child payments, for which a path-finding * algorithm will run to find suitable payment routes. * - * @param paymentSecret payment secret to protect against probing (usually from a Bolt 11 invoice). - * @param targetNodeId target node (may be the final recipient when using source-routing, or the first trampoline - * node when using trampoline). - * @param totalAmount total amount to send to the target node. - * @param targetExpiry expiry at the target node (CLTV for the target node's received HTLCs). - * @param maxAttempts maximum number of retries. - * @param paymentMetadata payment metadata (usually from the Bolt 11 invoice). - * @param extraEdges routing hints (usually from a Bolt 11 invoice). - * @param routeParams parameters to fine-tune the routing algorithm. - * @param additionalTlvs when provided, additional tlvs that will be added to the onion sent to the target node. - * @param userCustomTlvs when provided, additional user-defined custom tlvs that will be added to the onion sent to the target node. + * @param targetNodeId target node (may be the final recipient when using source-routing, or the first trampoline + * node when using trampoline). + * @param finalPayload payload for the recipient (with the amount missing, it will be added later) + * @param totalAmount total amount to send to the target node. + * @param targetExpiry expiry at the target node (CLTV for the target node's received HTLCs). + * @param maxAttempts maximum number of retries. + * @param extraEdges routing hints (usually from a Bolt 11 invoice). + * @param routeParams parameters to fine-tune the routing algorithm. */ case class SendMultiPartPayment(replyTo: ActorRef, targetNodeId: PublicKey, - finalPayload: Partial, + finalPayload: FinalPayload.Partial, totalAmount: MilliSatoshi, targetExpiry: CltvExpiry, maxAttempts: Int, diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala index 235112a70d..ee078862e8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala @@ -411,6 +411,8 @@ object PaymentLifecycle { * * @param route payment route to use. * @param finalPayload onion payload for the target node. + * @param amount amount for the target node. + * @param expiry expiry for the target node. */ case class SendPaymentToRoute(replyTo: ActorRef, route: Either[PredefinedRoute, Route], @@ -437,6 +439,8 @@ object PaymentLifecycle { * @param targetNodeId target node (may be the final recipient when using source-routing, or the first trampoline * node when using trampoline). * @param finalPayload onion payload for the target node. + * @param amount amount for the target node. + * @param expiry expiry for the target node. * @param maxAttempts maximum number of retries. * @param extraEdges routing hints (usually from a Bolt 11 invoice). * @param routeParams parameters to fine-tune the routing algorithm. diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index 3f11acf821..f4a681044e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -407,6 +407,7 @@ object Router { } } } + // @formatter:on sealed trait Hop { /** @return the id of the start node. */ @@ -429,6 +430,7 @@ object Router { def length: Int } + // @formatter:off /** Channel routing parameters */ sealed trait ChannelRelayParams { def cltvExpiryDelta: CltvExpiryDelta diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala index 8eb6a29de7..3caf82cbfb 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala @@ -480,6 +480,7 @@ object PaymentOnion { } } } + } object PaymentOnionCodecs {