diff --git a/eclair-core/src/main/resources/reference.conf b/eclair-core/src/main/resources/reference.conf index 13fddf73c8..60334ef6aa 100644 --- a/eclair-core/src/main/resources/reference.conf +++ b/eclair-core/src/main/resources/reference.conf @@ -35,7 +35,7 @@ eclair { node-alias = "eclair" node-color = "49daaa" - global-features = "" + global-features = "0200" // variable_length_onion local-features = "8a" // initial_routing_sync + option_data_loss_protect + option_channel_range_queries override-features = [ // optional per-node features # { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index 50dde47091..7b485e0461 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -29,7 +29,8 @@ import fr.acinq.eclair.channel._ import fr.acinq.eclair.db.{IncomingPayment, NetworkFee, OutgoingPayment, Stats} import fr.acinq.eclair.io.Peer.{GetPeerInfo, PeerInfo} import fr.acinq.eclair.io.{NodeURI, Peer} -import fr.acinq.eclair.payment.PaymentLifecycle._ +import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest +import fr.acinq.eclair.payment.PaymentLifecycle.ReceivePayment import fr.acinq.eclair.payment._ import fr.acinq.eclair.router.{ChannelDesc, RouteRequest, RouteResponse, Router} import fr.acinq.eclair.wire.{ChannelAnnouncement, ChannelUpdate, NodeAddress, NodeAnnouncement} @@ -186,7 +187,7 @@ class EclairImpl(appKit: Kit) extends Eclair { } override def sendToRoute(route: Seq[PublicKey], amount: MilliSatoshi, paymentHash: ByteVector32, finalCltvExpiryDelta: CltvExpiryDelta)(implicit timeout: Timeout): Future[UUID] = { - (appKit.paymentInitiator ? SendPaymentToRoute(amount, paymentHash, route, finalCltvExpiryDelta)).mapTo[UUID] + (appKit.paymentInitiator ? SendPaymentRequest(amount, paymentHash, route.last, 1, finalCltvExpiryDelta, route)).mapTo[UUID] } override def send(recipientNodeId: PublicKey, amount: MilliSatoshi, paymentHash: ByteVector32, invoice_opt: Option[PaymentRequest], maxAttempts_opt: Option[Int], feeThreshold_opt: Option[Satoshi], maxFeePct_opt: Option[Double])(implicit timeout: Timeout): Future[UUID] = { @@ -202,12 +203,12 @@ class EclairImpl(appKit: Kit) extends Eclair { case Some(invoice) if invoice.isExpired => Future.failed(new IllegalArgumentException("invoice has expired")) case Some(invoice) => val sendPayment = invoice.minFinalCltvExpiryDelta match { - case Some(minFinalCltvExpiryDelta) => SendPayment(amount, paymentHash, recipientNodeId, invoice.routingInfo, minFinalCltvExpiryDelta, maxAttempts = maxAttempts, routeParams = Some(routeParams)) - case None => SendPayment(amount, paymentHash, recipientNodeId, invoice.routingInfo, maxAttempts = maxAttempts, routeParams = Some(routeParams)) + case Some(minFinalCltvExpiryDelta) => SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts, minFinalCltvExpiryDelta, assistedRoutes = invoice.routingInfo, routeParams = Some(routeParams)) + case None => SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts, assistedRoutes = invoice.routingInfo, routeParams = Some(routeParams)) } (appKit.paymentInitiator ? sendPayment).mapTo[UUID] case None => - val sendPayment = SendPayment(amount, paymentHash, recipientNodeId, maxAttempts = maxAttempts, routeParams = Some(routeParams)) + val sendPayment = SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts = maxAttempts, routeParams = Some(routeParams)) (appKit.paymentInitiator ? sendPayment).mapTo[UUID] } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala index 41f8e358ff..145961614c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala @@ -16,10 +16,7 @@ package fr.acinq.eclair - -import java.util.BitSet - -import scodec.bits.ByteVector +import scodec.bits.{BitVector, ByteVector} /** * Created by PM on 13/02/2017. @@ -38,18 +35,29 @@ object Features { val VARIABLE_LENGTH_ONION_MANDATORY = 8 val VARIABLE_LENGTH_ONION_OPTIONAL = 9 - def hasFeature(features: BitSet, bit: Int): Boolean = features.get(bit) + // Note that BitVector indexes from left to right whereas the specification indexes from right to left. + // This is why we have to reverse the bits to check if a feature is set. + + def hasFeature(features: BitVector, bit: Int): Boolean = if (features.sizeLessThanOrEqual(bit)) false else features.reverse.get(bit) - def hasFeature(features: ByteVector, bit: Int): Boolean = hasFeature(BitSet.valueOf(features.reverse.toArray), bit) + def hasFeature(features: ByteVector, bit: Int): Boolean = hasFeature(features.bits, bit) + + /** + * We currently don't distinguish mandatory and optional. Interpreting VARIABLE_LENGTH_ONION_MANDATORY strictly would + * be very restrictive and probably fork us out of the network. + * We may implement this distinction later, but for now both flags are interpreted as an optional support. + */ + def hasVariableLengthOnion(features: ByteVector): Boolean = hasFeature(features, VARIABLE_LENGTH_ONION_MANDATORY) || hasFeature(features, VARIABLE_LENGTH_ONION_OPTIONAL) /** * Check that the features that we understand are correctly specified, and that there are no mandatory features that - * we don't understand (even bits) + * we don't understand (even bits). */ - def areSupported(bitset: BitSet): Boolean = { - val supportedMandatoryFeatures = Set(OPTION_DATA_LOSS_PROTECT_MANDATORY) - for (i <- 0 until bitset.length() by 2) { - if (bitset.get(i) && !supportedMandatoryFeatures.contains(i)) return false + def areSupported(features: BitVector): Boolean = { + val supportedMandatoryFeatures = Set[Long](OPTION_DATA_LOSS_PROTECT_MANDATORY, VARIABLE_LENGTH_ONION_MANDATORY) + val reversed = features.reverse + for (i <- 0L until reversed.length by 2) { + if (reversed.get(i) && !supportedMandatoryFeatures.contains(i)) return false } true @@ -59,5 +67,6 @@ object Features { * A feature set is supported if all even bits are supported. * We just ignore unknown odd bits. */ - def areSupported(features: ByteVector): Boolean = areSupported(BitSet.valueOf(features.reverse.toArray)) + def areSupported(features: ByteVector): Boolean = areSupported(features.bits) + } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala index 7d29142d2e..2ee641c62f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala @@ -32,6 +32,7 @@ import fr.acinq.eclair.router.Rebroadcast import fr.acinq.eclair.transactions.{IN, OUT} import fr.acinq.eclair.wire.{TemporaryNodeFailure, UpdateAddHtlc} import grizzled.slf4j.Logging +import scodec.bits.ByteVector import scala.util.Success @@ -57,7 +58,7 @@ class Switchboard(nodeParams: NodeParams, authenticator: ActorRef, watcher: Acto }) val peers = nodeParams.db.peers.listPeers() - checkBrokenHtlcsLink(channels, nodeParams.privateKey) match { + checkBrokenHtlcsLink(channels, nodeParams.privateKey, nodeParams.globalFeatures) match { case Nil => () case brokenHtlcs => val brokenHtlcKiller = context.system.actorOf(Props[HtlcReaper], name = "htlc-reaper") @@ -165,7 +166,7 @@ object Switchboard extends Logging { * * This check will detect this and will allow us to fast-fail HTLCs and thus preserve channels. */ - def checkBrokenHtlcsLink(channels: Seq[HasCommitments], privateKey: PrivateKey): Seq[UpdateAddHtlc] = { + def checkBrokenHtlcsLink(channels: Seq[HasCommitments], privateKey: PrivateKey, features: ByteVector): Seq[UpdateAddHtlc] = { // We are interested in incoming HTLCs, that have been *cross-signed* (otherwise they wouldn't have been relayed). // They signed it first, so the HTLC will first appear in our commitment tx, and later on in their commitment when @@ -174,7 +175,7 @@ object Switchboard extends Logging { .flatMap(_.commitments.remoteCommit.spec.htlcs) .filter(_.direction == OUT) .map(_.add) - .map(Relayer.decryptPacket(_, privateKey)) + .map(Relayer.decryptPacket(_, privateKey, features)) .collect { case Right(RelayPayload(add, _, _)) => add } // we only consider htlcs that are relayed, not the ones for which we are the final node // Here we do it differently because we need the origin information. diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala index 2494ad3ffc..40ca336c44 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala @@ -19,7 +19,8 @@ package fr.acinq.eclair.payment import akka.actor.{Actor, ActorLogging, ActorRef, Props} import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.crypto.Sphinx.DecryptedFailurePacket -import fr.acinq.eclair.payment.PaymentLifecycle.{PaymentFailed, PaymentResult, RemoteFailure, SendPayment} +import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest +import fr.acinq.eclair.payment.PaymentLifecycle.{PaymentFailed, PaymentResult, RemoteFailure} import fr.acinq.eclair.router.{Announcements, Data, PublicChannel} import fr.acinq.eclair.wire.IncorrectOrUnknownPaymentDetails import fr.acinq.eclair.{LongToBtcAmount, NodeParams, randomBytes32, secureRandom} @@ -27,9 +28,9 @@ import fr.acinq.eclair.{LongToBtcAmount, NodeParams, randomBytes32, secureRandom import scala.concurrent.duration._ /** - * This actor periodically probes the network by sending payments to random nodes. The payments will eventually fail - * because the recipient doesn't know the preimage, but it allows us to test channels and improve routing for real payments. - */ + * This actor periodically probes the network by sending payments to random nodes. The payments will eventually fail + * because the recipient doesn't know the preimage, but it allows us to test channels and improve routing for real payments. + */ class Autoprobe(nodeParams: NodeParams, router: ActorRef, paymentInitiator: ActorRef) extends Actor with ActorLogging { import Autoprobe._ @@ -54,7 +55,7 @@ class Autoprobe(nodeParams: NodeParams, router: ActorRef, paymentInitiator: Acto case Some(targetNodeId) => val paymentHash = randomBytes32 // we don't even know the preimage (this needs to be a secure random!) log.info(s"sending payment probe to node=$targetNodeId payment_hash=$paymentHash") - paymentInitiator ! SendPayment(PAYMENT_AMOUNT_MSAT, paymentHash, targetNodeId, maxAttempts = 1) + paymentInitiator ! SendPaymentRequest(PAYMENT_AMOUNT_MSAT, paymentHash, targetNodeId, maxAttempts = 1) case None => log.info(s"could not find a destination, re-scheduling") scheduleProbe() diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala index de9bef237a..a3db9e6dc8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala @@ -17,25 +17,49 @@ package fr.acinq.eclair.payment import java.util.UUID + import akka.actor.{Actor, ActorLogging, ActorRef, Props} -import fr.acinq.eclair.NodeParams -import fr.acinq.eclair.payment.PaymentLifecycle.GenericSendPayment +import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.bitcoin.Crypto.PublicKey +import fr.acinq.eclair.channel.Channel +import fr.acinq.eclair.payment.PaymentLifecycle.{SendPayment, SendPaymentToRoute} +import fr.acinq.eclair.payment.PaymentRequest.ExtraHop +import fr.acinq.eclair.router.RouteParams +import fr.acinq.eclair.wire.Onion.FinalLegacyPayload +import fr.acinq.eclair.{CltvExpiryDelta, MilliSatoshi, NodeParams} /** - * Created by PM on 29/08/2016. - */ + * Created by PM on 29/08/2016. + */ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: ActorRef) extends Actor with ActorLogging { override def receive: Receive = { - case c: GenericSendPayment => + case p: PaymentInitiator.SendPaymentRequest => val paymentId = UUID.randomUUID() + // We add one block in order to not have our htlc fail when a new block has just been found. + val finalExpiry = (p.finalExpiryDelta + 1).toCltvExpiry val payFsm = context.actorOf(PaymentLifecycle.props(nodeParams, paymentId, router, register)) - payFsm forward c + // NB: we only generate legacy payment onions for now for maximum compatibility. + p.predefinedRoute match { + case Nil => payFsm forward SendPayment(p.paymentHash, p.targetNodeId, FinalLegacyPayload(p.amount, finalExpiry), p.maxAttempts, p.assistedRoutes, p.routeParams) + case hops => payFsm forward SendPaymentToRoute(p.paymentHash, hops, FinalLegacyPayload(p.amount, finalExpiry)) + } sender ! paymentId } } object PaymentInitiator { + def props(nodeParams: NodeParams, router: ActorRef, register: ActorRef) = Props(classOf[PaymentInitiator], nodeParams, router, register) + + case class SendPaymentRequest(amount: MilliSatoshi, + paymentHash: ByteVector32, + targetNodeId: PublicKey, + maxAttempts: Int, + finalExpiryDelta: CltvExpiryDelta = Channel.MIN_CLTV_EXPIRY_DELTA, + predefinedRoute: Seq[PublicKey] = Nil, + assistedRoutes: Seq[Seq[ExtraHop]] = Nil, + routeParams: Option[RouteParams] = None) + } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala index 839859bf2b..d06566ca7e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala @@ -22,12 +22,13 @@ import akka.actor.{ActorRef, FSM, Props, Status} import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair._ -import fr.acinq.eclair.channel.{AddHtlcFailed, CMD_ADD_HTLC, Channel, Register} +import fr.acinq.eclair.channel.{AddHtlcFailed, CMD_ADD_HTLC, Register} import fr.acinq.eclair.crypto.{Sphinx, TransportHandler} import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus} import fr.acinq.eclair.payment.PaymentLifecycle._ import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.router._ +import fr.acinq.eclair.wire.Onion._ import fr.acinq.eclair.wire._ import scodec.Attempt import scodec.bits.ByteVector @@ -46,14 +47,14 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis when(WAITING_FOR_REQUEST) { case Event(c: SendPaymentToRoute, WaitingForRequest) => - val send = SendPayment(c.amount, c.paymentHash, c.hops.last, finalCltvExpiryDelta = c.finalCltvExpiryDelta, maxAttempts = 1) - paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.amount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING)) + val send = SendPayment(c.paymentHash, c.hops.last, c.finalPayload, maxAttempts = 1) + paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.finalPayload.amount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING)) router ! FinalizeRoute(c.hops) goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, send, failures = Nil) case Event(c: SendPayment, WaitingForRequest) => - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.amount, c.assistedRoutes, routeParams = c.routeParams) - paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.amount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING)) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, routeParams = c.routeParams) + paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.finalPayload.amount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING)) goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, c, failures = Nil) } @@ -61,10 +62,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis case Event(RouteResponse(hops, ignoreNodes, ignoreChannels), WaitingForRoute(s, c, failures)) => log.info(s"route found: attempt=${failures.size + 1}/${c.maxAttempts} route=${hops.map(_.nextNodeId).mkString("->")} channels=${hops.map(_.lastUpdate.shortChannelId).mkString("->")}") val firstHop = hops.head - // we add one block in order to not have our htlc fail when a new block has just been found - val finalExpiry = (c.finalCltvExpiryDelta + 1).toCltvExpiry - - val (cmd, sharedSecrets) = buildCommand(id, c.amount, finalExpiry, c.paymentHash, hops) + val (cmd, sharedSecrets) = buildCommand(id, c.paymentHash, hops, c.finalPayload) register ! Register.ForwardShortId(firstHop.lastUpdate.shortChannelId, cmd) goto(WAITING_FOR_PAYMENT_COMPLETE) using WaitingForComplete(s, c, cmd, failures, sharedSecrets, ignoreNodes, ignoreChannels, hops) @@ -80,7 +78,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis case Event(fulfill: UpdateFulfillHtlc, WaitingForComplete(s, c, cmd, _, _, _, _, hops)) => paymentsDb.updateOutgoingPayment(id, OutgoingPaymentStatus.SUCCEEDED, preimage = Some(fulfill.paymentPreimage)) reply(s, PaymentSucceeded(id, cmd.amount, c.paymentHash, fulfill.paymentPreimage, hops)) - context.system.eventStream.publish(PaymentSent(id, c.amount, cmd.amount - c.amount, cmd.paymentHash, fulfill.paymentPreimage, fulfill.channelId)) + context.system.eventStream.publish(PaymentSent(id, c.finalPayload.amount, cmd.amount - c.finalPayload.amount, cmd.paymentHash, fulfill.paymentPreimage, fulfill.channelId)) stop(FSM.Normal) case Event(fail: UpdateFailHtlc, WaitingForComplete(s, c, _, failures, sharedSecrets, ignoreNodes, ignoreChannels, hops)) => @@ -110,12 +108,12 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis // in that case we don't know which node is sending garbage, let's try to blacklist all nodes except the one we are directly connected to and the destination node val blacklist = hops.map(_.nextNodeId).drop(1).dropRight(1) log.warning(s"blacklisting intermediate nodes=${blacklist.mkString(",")}") - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.amount, c.assistedRoutes, ignoreNodes ++ blacklist, ignoreChannels, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes ++ blacklist, ignoreChannels, c.routeParams) goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ UnreadableRemoteFailure(hops)) case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Node)) => log.info(s"received 'Node' type error message from nodeId=$nodeId, trying to route around it (failure=$failureMessage)") // let's try to route around this node - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.amount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams) goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(hops, e)) case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Update)) => log.info(s"received 'Update' type error message from nodeId=$nodeId, retrying payment (failure=$failureMessage)") @@ -143,18 +141,18 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis // in any case, we forward the update to the router router ! failureMessage.update // let's try again, router will have updated its state - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.amount, c.assistedRoutes, ignoreNodes, ignoreChannels, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes, ignoreChannels, c.routeParams) } else { // this node is fishy, it gave us a bad sig!! let's filter it out log.warning(s"got bad signature from node=$nodeId update=${failureMessage.update}") - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.amount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams) } goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(hops, e)) case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) => log.info(s"received an error message from nodeId=$nodeId, trying to use a different channel (failure=$failureMessage)") // let's try again without the channel outgoing from nodeId val faultyChannel = hops.find(_.nodeId == nodeId).map(hop => ChannelDesc(hop.lastUpdate.shortChannelId, hop.nodeId, hop.nextNodeId)) - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.amount, c.assistedRoutes, ignoreNodes, ignoreChannels ++ faultyChannel.toSet, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes, ignoreChannels ++ faultyChannel.toSet, c.routeParams) goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(hops, e)) } @@ -174,7 +172,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis } else { log.info(s"received an error message from local, trying to use a different channel (failure=${t.getMessage})") val faultyChannel = ChannelDesc(hops.head.lastUpdate.shortChannelId, hops.head.nodeId, hops.head.nextNodeId) - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.amount, c.assistedRoutes, ignoreNodes, ignoreChannels + faultyChannel, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes, ignoreChannels + faultyChannel, c.routeParams) goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ LocalFailure(t)) } @@ -198,16 +196,14 @@ object PaymentLifecycle { // @formatter:off case class ReceivePayment(amount_opt: Option[MilliSatoshi], description: String, expirySeconds_opt: Option[Long] = None, extraHops: List[List[ExtraHop]] = Nil, fallbackAddress: Option[String] = None, paymentPreimage: Option[ByteVector32] = None) - sealed trait GenericSendPayment - case class SendPaymentToRoute(amount: MilliSatoshi, paymentHash: ByteVector32, hops: Seq[PublicKey], finalCltvExpiryDelta: CltvExpiryDelta = Channel.MIN_CLTV_EXPIRY_DELTA) extends GenericSendPayment - case class SendPayment(amount: MilliSatoshi, - paymentHash: ByteVector32, + case class SendPaymentToRoute(paymentHash: ByteVector32, hops: Seq[PublicKey], finalPayload: FinalPayload) + case class SendPayment(paymentHash: ByteVector32, targetNodeId: PublicKey, - assistedRoutes: Seq[Seq[ExtraHop]] = Nil, - finalCltvExpiryDelta: CltvExpiryDelta = Channel.MIN_CLTV_EXPIRY_DELTA, + finalPayload: FinalPayload, maxAttempts: Int, - routeParams: Option[RouteParams] = None) extends GenericSendPayment { - require(amount > 0.msat, s"amountMsat must be > 0") + assistedRoutes: Seq[Seq[ExtraHop]] = Nil, + routeParams: Option[RouteParams] = None) { + require(finalPayload.amount > 0.msat, s"amount must be > 0") } sealed trait PaymentResult @@ -227,41 +223,45 @@ object PaymentLifecycle { case object WAITING_FOR_REQUEST extends State case object WAITING_FOR_ROUTE extends State case object WAITING_FOR_PAYMENT_COMPLETE extends State - // @formatter:on - def buildOnion(nodes: Seq[PublicKey], payloads: Seq[PerHopPayload], associatedData: ByteVector32): Sphinx.PacketAndSecrets = { require(nodes.size == payloads.size) val sessionKey = randomKey - val payloadsbin: Seq[ByteVector] = payloads - .map(OnionCodecs.perHopPayloadCodec.encode) + val payloadsBin: Seq[ByteVector] = payloads + .map { + case p: FinalPayload => OnionCodecs.finalPerHopPayloadCodec.encode(p) + case p: RelayPayload => OnionCodecs.relayPerHopPayloadCodec.encode(p) + } .map { case Attempt.Successful(bitVector) => bitVector.toByteVector case Attempt.Failure(cause) => throw new RuntimeException(s"serialization error: $cause") } - Sphinx.PaymentPacket.create(sessionKey, nodes, payloadsbin, associatedData) + Sphinx.PaymentPacket.create(sessionKey, nodes, payloadsBin, associatedData) } /** + * Build the onion payloads for each hop. * - * @param finalAmount the final htlc amount in millisatoshis - * @param finalExpiry the final htlc expiry in number of blocks - * @param hops the hops as computed by the router + extra routes from payment request - * @return a (firstAmountMsat, firstExpiry, payloads) tuple where: - * - firstAmountMsat is the amount for the first htlc in the route + * @param hops the hops as computed by the router + extra routes from payment request + * @param finalPayload payload data for the final node (amount, expiry, additional tlv records, etc) + * @return a (firstAmount, firstExpiry, payloads) tuple where: + * - firstAmount is the amount for the first htlc in the route * - firstExpiry is the cltv expiry for the first htlc in the route * - a sequence of payloads that will be used to build the onion */ - def buildPayloads(finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, hops: Seq[Hop]): (MilliSatoshi, CltvExpiry, Seq[PerHopPayload]) = - hops.reverse.foldLeft((finalAmount, finalExpiry, PerHopPayload(ShortChannelId(0L), finalAmount, finalExpiry) :: Nil)) { - case ((msat, expiry, payloads), hop) => - val nextFee = nodeFee(hop.lastUpdate.feeBaseMsat, hop.lastUpdate.feeProportionalMillionths, msat) - (msat + nextFee, expiry + hop.lastUpdate.cltvExpiryDelta, PerHopPayload(hop.lastUpdate.shortChannelId, msat, expiry) +: payloads) + def buildPayloads(hops: Seq[Hop], finalPayload: FinalPayload): (MilliSatoshi, CltvExpiry, Seq[PerHopPayload]) = { + hops.reverse.foldLeft((finalPayload.amount, finalPayload.expiry, Seq[PerHopPayload](finalPayload))) { + case ((amount, expiry, payloads), hop) => + val nextFee = nodeFee(hop.lastUpdate.feeBaseMsat, hop.lastUpdate.feeProportionalMillionths, amount) + // Since we don't have any scenario where we add tlv data for intermediate hops, we use legacy payloads. + val payload = RelayLegacyPayload(hop.lastUpdate.shortChannelId, amount, expiry) + (amount + nextFee, expiry + hop.lastUpdate.cltvExpiryDelta, payload +: payloads) } + } - def buildCommand(id: UUID, finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, paymentHash: ByteVector32, hops: Seq[Hop]): (CMD_ADD_HTLC, Seq[(ByteVector32, PublicKey)]) = { - val (firstAmount, firstExpiry, payloads) = buildPayloads(finalAmount, finalExpiry, hops.drop(1)) + def buildCommand(id: UUID, paymentHash: ByteVector32, hops: Seq[Hop], finalPayload: FinalPayload): (CMD_ADD_HTLC, Seq[(ByteVector32, PublicKey)]) = { + val (firstAmount, firstExpiry, payloads) = buildPayloads(hops.drop(1), finalPayload) val nodes = hops.map(_.nextNodeId) // BOLT 2 requires that associatedData == paymentHash val onion = buildOnion(nodes, payloads, paymentHash) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala index 673da84e34..1ee391d05e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala @@ -28,14 +28,14 @@ import fr.acinq.eclair.db.OutgoingPaymentStatus import fr.acinq.eclair.payment.PaymentLifecycle.{PaymentFailed, PaymentSucceeded} import fr.acinq.eclair.router.Announcements import fr.acinq.eclair.wire._ -import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, MilliSatoshi, NodeParams, ShortChannelId, nodeFee} +import fr.acinq.eclair.{CltvExpiryDelta, Features, LongToBtcAmount, MilliSatoshi, NodeParams, ShortChannelId, UInt64, nodeFee} import grizzled.slf4j.Logging +import scodec.bits.ByteVector import scodec.{Attempt, DecodeResult} import scala.collection.mutable // @formatter:off - sealed trait Origin case class Local(id: UUID, sender: Option[ActorRef]) extends Origin // we don't persist reference to local actors case class Relayed(originChannelId: ByteVector32, originHtlcId: Long, amountIn: MilliSatoshi, amountOut: MilliSatoshi) extends Origin @@ -48,10 +48,8 @@ case class ForwardFailMalformed(fail: UpdateFailMalformedHtlc, to: Origin, htlc: case object GetUsableBalances case class UsableBalances(remoteNodeId: PublicKey, shortChannelId: ShortChannelId, canSend: MilliSatoshi, canReceive: MilliSatoshi, isPublic: Boolean) - // @formatter:on - /** * Created by PM on 01/02/2017. */ @@ -99,7 +97,7 @@ class Relayer(nodeParams: NodeParams, register: ActorRef, paymentHandler: ActorR case ForwardAdd(add, previousFailures) => log.debug(s"received forwarding request for htlc #${add.id} paymentHash=${add.paymentHash} from channelId=${add.channelId}") - decryptPacket(add, nodeParams.privateKey) match { + decryptPacket(add, nodeParams.privateKey, nodeParams.globalFeatures) match { case Right(p: FinalPayload) => handleFinal(p) match { case Left(cmdFail) => @@ -112,17 +110,21 @@ class Relayer(nodeParams: NodeParams, register: ActorRef, paymentHandler: ActorR case Right(r: RelayPayload) => handleRelay(r, channelUpdates, node2channels, previousFailures, nodeParams.chainHash) match { case RelayFailure(cmdFail) => - log.info(s"rejecting htlc #${add.id} paymentHash=${add.paymentHash} from channelId=${add.channelId} to shortChannelId=${r.payload.shortChannelId} reason=${cmdFail.reason}") + log.info(s"rejecting htlc #${add.id} paymentHash=${add.paymentHash} from channelId=${add.channelId} to shortChannelId=${r.payload.outgoingChannelId} reason=${cmdFail.reason}") commandBuffer ! CommandBuffer.CommandSend(add.channelId, add.id, cmdFail) case RelaySuccess(selectedShortChannelId, cmdAdd) => log.info(s"forwarding htlc #${add.id} paymentHash=${add.paymentHash} from channelId=${add.channelId} to shortChannelId=$selectedShortChannelId") register ! Register.ForwardShortId(selectedShortChannelId, cmdAdd) } - case Left(badOnion) => + case Left(badOnion: BadOnion) => log.warning(s"couldn't parse onion: reason=${badOnion.message}") val cmdFail = CMD_FAIL_MALFORMED_HTLC(add.id, badOnion.onionHash, badOnion.code, commit = true) log.info(s"rejecting htlc #${add.id} paymentHash=${add.paymentHash} from channelId=${add.channelId} reason=malformed onionHash=${cmdFail.onionHash} failureCode=${cmdFail.failureCode}") commandBuffer ! CommandBuffer.CommandSend(add.channelId, add.id, cmdFail) + case Left(failure) => + log.warning(s"couldn't process onion: reason=${failure.message}") + val cmdFail = CMD_FAIL_HTLC(add.id, Right(failure), commit = true) + commandBuffer ! CommandBuffer.CommandSend(add.channelId, add.id, cmdFail) } case Status.Failure(Register.ForwardShortIdFailure(Register.ForwardShortId(shortChannelId, CMD_ADD_HTLC(_, _, _, _, Right(add), _, _)))) => @@ -213,37 +215,39 @@ object Relayer extends Logging { // @formatter:off sealed trait NextPayload - case class FinalPayload(add: UpdateAddHtlc, payload: PerHopPayload) extends NextPayload - case class RelayPayload(add: UpdateAddHtlc, payload: PerHopPayload, nextPacket: OnionRoutingPacket) extends NextPayload { - val relayFeeMsat: MilliSatoshi = add.amountMsat - payload.amtToForward - val expiryDelta: CltvExpiryDelta = add.cltvExpiry - payload.outgoingCltvValue + case class FinalPayload(add: UpdateAddHtlc, payload: Onion.FinalPayload) extends NextPayload + case class RelayPayload(add: UpdateAddHtlc, payload: Onion.RelayPayload, nextPacket: OnionRoutingPacket) extends NextPayload { + val relayFeeMsat: MilliSatoshi = add.amountMsat - payload.amountToForward + val expiryDelta: CltvExpiryDelta = add.cltvExpiry - payload.outgoingCltv } // @formatter:on /** - * Decrypt the onion of a received htlc, and find out if the payment is to be relayed, - * or if our node is the last one in the route + * Decrypt the onion of a received htlc, and find out if the payment is to be relayed, or if our node is the last one + * in the route. * * @param add incoming htlc * @param privateKey this node's private key * @return the payload for the next hop or an error. */ - def decryptPacket(add: UpdateAddHtlc, privateKey: PrivateKey): Either[BadOnion, NextPayload] = + def decryptPacket(add: UpdateAddHtlc, privateKey: PrivateKey, features: ByteVector): Either[FailureMessage, NextPayload] = Sphinx.PaymentPacket.peel(privateKey, add.paymentHash, add.onionRoutingPacket) match { case Right(p@Sphinx.DecryptedPacket(payload, nextPacket, _)) => - OnionCodecs.perHopPayloadCodec.decode(payload.bits) match { + val codec = if (p.isLastPacket) OnionCodecs.finalPerHopPayloadCodec else OnionCodecs.relayPerHopPayloadCodec + codec.decode(payload.bits) match { + case Attempt.Successful(DecodeResult(_: Onion.TlvFormat, _)) if !Features.hasVariableLengthOnion(features) => Left(InvalidRealm) case Attempt.Successful(DecodeResult(perHopPayload, remainder)) => if (remainder.nonEmpty) { logger.warn(s"${remainder.length} bits remaining after per-hop payload decoding: there might be an issue with the onion codec") } - if (p.isLastPacket) { - Right(FinalPayload(add, perHopPayload)) - } else { - Right(RelayPayload(add, perHopPayload, nextPacket)) + perHopPayload match { + case finalPayload: Onion.FinalPayload => Right(FinalPayload(add, finalPayload)) + case relayPayload: Onion.RelayPayload => Right(RelayPayload(add, relayPayload, nextPacket)) } - case Attempt.Failure(_) => - // Onion is correctly encrypted but the content of the per-hop payload couldn't be parsed. - Left(InvalidOnionPayload(Sphinx.PaymentPacket.hash(add.onionRoutingPacket))) + case Attempt.Failure(e: OnionCodecs.MissingRequiredTlv) => Left(e.failureMessage) + // Onion is correctly encrypted but the content of the per-hop payload couldn't be parsed. + // It's hard to provide tag and offset information from scodec failures, so we currently don't do it. + case Attempt.Failure(_) => Left(InvalidOnionPayload(UInt64(0), 0)) } case Left(badOnion) => Left(badOnion) } @@ -251,20 +255,18 @@ object Relayer extends Logging { /** * Handle an incoming htlc when we are the last node * - * @param finalPayload payload + * @param p final payload * @return either: * - a CMD_FAIL_HTLC to be sent back upstream * - an UpdateAddHtlc to forward */ - def handleFinal(finalPayload: FinalPayload): Either[CMD_FAIL_HTLC, UpdateAddHtlc] = { - import finalPayload.add - finalPayload.payload match { - case PerHopPayload(_, finalAmountToForward, _) if finalAmountToForward > add.amountMsat => - Left(CMD_FAIL_HTLC(add.id, Right(FinalIncorrectHtlcAmount(add.amountMsat)), commit = true)) - case PerHopPayload(_, _, finalOutgoingCltvValue) if finalOutgoingCltvValue != add.cltvExpiry => - Left(CMD_FAIL_HTLC(add.id, Right(FinalIncorrectCltvExpiry(add.cltvExpiry)), commit = true)) - case _ => - Right(add) + def handleFinal(p: FinalPayload): Either[CMD_FAIL_HTLC, UpdateAddHtlc] = { + if (p.add.amountMsat < p.payload.amount) { + Left(CMD_FAIL_HTLC(p.add.id, Right(FinalIncorrectHtlcAmount(p.add.amountMsat)), commit = true)) + } else if (p.add.cltvExpiry != p.payload.expiry) { + Left(CMD_FAIL_HTLC(p.add.id, Right(FinalIncorrectCltvExpiry(p.add.cltvExpiry)), commit = true)) + } else { + Right(p.add) } } @@ -284,7 +286,7 @@ object Relayer extends Logging { */ def handleRelay(relayPayload: RelayPayload, channelUpdates: Map[ShortChannelId, OutgoingChannel], node2channels: mutable.Map[PublicKey, mutable.Set[ShortChannelId]] with mutable.MultiMap[PublicKey, ShortChannelId], previousFailures: Seq[AddHtlcFailed], chainHash: ByteVector32)(implicit log: LoggingAdapter): RelayResult = { import relayPayload._ - log.info(s"relaying htlc #${add.id} paymentHash={} from channelId={} to requestedShortChannelId={} previousAttempts={}", add.paymentHash, add.channelId, relayPayload.payload.shortChannelId, previousFailures.size) + log.info(s"relaying htlc #${add.id} paymentHash={} from channelId={} to requestedShortChannelId={} previousAttempts={}", add.paymentHash, add.channelId, relayPayload.payload.outgoingChannelId, previousFailures.size) val alreadyTried = previousFailures.flatMap(_.channelUpdate).map(_.shortChannelId) selectPreferredChannel(relayPayload, channelUpdates, node2channels, alreadyTried) .flatMap(selectedShortChannelId => channelUpdates.get(selectedShortChannelId).map(_.channelUpdate)) match { @@ -292,7 +294,7 @@ object Relayer extends Logging { // no more channels to try val error = previousFailures // we return the error for the initially requested channel if it exists - .find(_.channelUpdate.map(_.shortChannelId).contains(relayPayload.payload.shortChannelId)) + .find(_.channelUpdate.map(_.shortChannelId).contains(relayPayload.payload.outgoingChannelId)) // otherwise we return the error for the first channel tried .getOrElse(previousFailures.head) RelayFailure(CMD_FAIL_HTLC(add.id, Right(translateError(error)), commit = true)) @@ -309,7 +311,7 @@ object Relayer extends Logging { */ def selectPreferredChannel(relayPayload: RelayPayload, channelUpdates: Map[ShortChannelId, OutgoingChannel], node2channels: mutable.Map[PublicKey, mutable.Set[ShortChannelId]] with mutable.MultiMap[PublicKey, ShortChannelId], alreadyTried: Seq[ShortChannelId])(implicit log: LoggingAdapter): Option[ShortChannelId] = { import relayPayload.add - val requestedShortChannelId = relayPayload.payload.shortChannelId + val requestedShortChannelId = relayPayload.payload.outgoingChannelId log.debug(s"selecting next channel for htlc #${add.id} paymentHash={} from channelId={} to requestedShortChannelId={} previousAttempts={}", add.paymentHash, add.channelId, requestedShortChannelId, alreadyTried.size) // first we find out what is the next node val nextNodeId_opt = channelUpdates.get(requestedShortChannelId) match { @@ -334,7 +336,7 @@ object Relayer extends Logging { (shortChannelId, channelInfo_opt, relayResult) } .collect { case (shortChannelId, Some(channelInfo), _: RelaySuccess) => (shortChannelId, channelInfo.commitments.availableBalanceForSend) } - .filter(_._2 > relayPayload.payload.amtToForward) // we only keep channels that have enough balance to handle this payment + .filter(_._2 > relayPayload.payload.amountToForward) // we only keep channels that have enough balance to handle this payment .toList // needed for ordering .sortBy(_._2) // we want to use the channel with the lowest available balance that can process the payment .headOption match { @@ -367,14 +369,14 @@ object Relayer extends Logging { RelayFailure(CMD_FAIL_HTLC(add.id, Right(UnknownNextPeer), commit = true)) case Some(channelUpdate) if !Announcements.isEnabled(channelUpdate.channelFlags) => RelayFailure(CMD_FAIL_HTLC(add.id, Right(ChannelDisabled(channelUpdate.messageFlags, channelUpdate.channelFlags, channelUpdate)), commit = true)) - case Some(channelUpdate) if payload.amtToForward < channelUpdate.htlcMinimumMsat => - RelayFailure(CMD_FAIL_HTLC(add.id, Right(AmountBelowMinimum(payload.amtToForward, channelUpdate)), commit = true)) + case Some(channelUpdate) if payload.amountToForward < channelUpdate.htlcMinimumMsat => + RelayFailure(CMD_FAIL_HTLC(add.id, Right(AmountBelowMinimum(payload.amountToForward, channelUpdate)), commit = true)) case Some(channelUpdate) if relayPayload.expiryDelta != channelUpdate.cltvExpiryDelta => - RelayFailure(CMD_FAIL_HTLC(add.id, Right(IncorrectCltvExpiry(payload.outgoingCltvValue, channelUpdate)), commit = true)) - case Some(channelUpdate) if relayPayload.relayFeeMsat < nodeFee(channelUpdate.feeBaseMsat, channelUpdate.feeProportionalMillionths, payload.amtToForward) => + RelayFailure(CMD_FAIL_HTLC(add.id, Right(IncorrectCltvExpiry(payload.outgoingCltv, channelUpdate)), commit = true)) + case Some(channelUpdate) if relayPayload.relayFeeMsat < nodeFee(channelUpdate.feeBaseMsat, channelUpdate.feeProportionalMillionths, payload.amountToForward) => RelayFailure(CMD_FAIL_HTLC(add.id, Right(FeeInsufficient(add.amountMsat, channelUpdate)), commit = true)) case Some(channelUpdate) => - RelaySuccess(channelUpdate.shortChannelId, CMD_ADD_HTLC(payload.amtToForward, add.paymentHash, payload.outgoingCltvValue, nextPacket, upstream = Right(add), commit = true, previousFailures = previousFailures)) + RelaySuccess(channelUpdate.shortChannelId, CMD_ADD_HTLC(payload.amountToForward, add.paymentHash, payload.outgoingCltv, nextPacket, upstream = Right(add), commit = true, previousFailures = previousFailures)) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Announcements.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Announcements.scala index 2baa73e7a9..b1544ee311 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Announcements.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Announcements.scala @@ -19,7 +19,7 @@ package fr.acinq.eclair.router import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey, sha256, verifySignature} import fr.acinq.bitcoin.{ByteVector32, ByteVector64, Crypto, LexicographicalOrdering} import fr.acinq.eclair.wire._ -import fr.acinq.eclair.{CltvExpiryDelta, Features, MilliSatoshi, ShortChannelId, serializationResult} +import fr.acinq.eclair.{CltvExpiryDelta, MilliSatoshi, ShortChannelId, serializationResult} import scodec.bits.{BitVector, ByteVector} import shapeless.HNil @@ -27,8 +27,8 @@ import scala.compat.Platform import scala.concurrent.duration._ /** - * Created by PM on 03/02/2017. - */ + * Created by PM on 03/02/2017. + */ object Announcements { def channelAnnouncementWitnessEncode(chainHash: ByteVector32, shortChannelId: ShortChannelId, nodeId1: PublicKey, nodeId2: PublicKey, bitcoinKey1: PublicKey, bitcoinKey2: PublicKey, features: ByteVector, unknownFields: ByteVector): ByteVector = @@ -73,9 +73,8 @@ object Announcements { ) } - def makeNodeAnnouncement(nodeSecret: PrivateKey, alias: String, color: Color, nodeAddresses: List[NodeAddress], timestamp: Long = Platform.currentTime.milliseconds.toSeconds): NodeAnnouncement = { + def makeNodeAnnouncement(nodeSecret: PrivateKey, alias: String, color: Color, nodeAddresses: List[NodeAddress], features: ByteVector, timestamp: Long = Platform.currentTime.milliseconds.toSeconds): NodeAnnouncement = { require(alias.length <= 32) - val features = BitVector.fromLong(1 << Features.VARIABLE_LENGTH_ONION_OPTIONAL).bytes val witness = nodeAnnouncementWitnessEncode(timestamp, nodeSecret.publicKey, color, alias, features, nodeAddresses, unknownFields = ByteVector.empty) val sig = Crypto.sign(witness, nodeSecret) NodeAnnouncement( @@ -90,37 +89,37 @@ object Announcements { } /** - * BOLT 7: - * The creating node MUST set node-id-1 and node-id-2 to the public keys of the - * two nodes who are operating the channel, such that node-id-1 is the numerically-lesser - * of the two DER encoded keys sorted in ascending numerical order, - * - * @return true if localNodeId is node1 - */ + * BOLT 7: + * The creating node MUST set node-id-1 and node-id-2 to the public keys of the + * two nodes who are operating the channel, such that node-id-1 is the numerically-lesser + * of the two DER encoded keys sorted in ascending numerical order, + * + * @return true if localNodeId is node1 + */ def isNode1(localNodeId: PublicKey, remoteNodeId: PublicKey) = LexicographicalOrdering.isLessThan(localNodeId.value, remoteNodeId.value) /** - * BOLT 7: - * The creating node [...] MUST set the direction bit of flags to 0 if - * the creating node is node-id-1 in that message, otherwise 1. - * - * @return true if the node who sent these flags is node1 - */ + * BOLT 7: + * The creating node [...] MUST set the direction bit of flags to 0 if + * the creating node is node-id-1 in that message, otherwise 1. + * + * @return true if the node who sent these flags is node1 + */ def isNode1(channelFlags: Byte): Boolean = (channelFlags & 1) == 0 /** - * A node MAY create and send a channel_update with the disable bit set to - * signal the temporary unavailability of a channel - * - * @return - */ + * A node MAY create and send a channel_update with the disable bit set to + * signal the temporary unavailability of a channel + * + * @return + */ def isEnabled(channelFlags: Byte): Boolean = (channelFlags & 2) == 0 /** - * This method compares channel updates, ignoring fields that don't matter, like signature or timestamp - * - * @return true if channel updates are "equal" - */ + * This method compares channel updates, ignoring fields that don't matter, like signature or timestamp + * + * @return true if channel updates are "equal" + */ def areSame(u1: ChannelUpdate, u2: ChannelUpdate): Boolean = u1.copy(signature = ByteVector64.Zeroes, timestamp = 0) == u2.copy(signature = ByteVector64.Zeroes, timestamp = 0) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index a42fecb688..7e0bf08f8d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -194,7 +194,7 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ // on restart we update our node announcement // note that if we don't currently have public channels, this will be ignored - val nodeAnn = Announcements.makeNodeAnnouncement(nodeParams.privateKey, nodeParams.alias, nodeParams.color, nodeParams.publicAddresses) + val nodeAnn = Announcements.makeNodeAnnouncement(nodeParams.privateKey, nodeParams.alias, nodeParams.color, nodeParams.publicAddresses, nodeParams.globalFeatures) self ! nodeAnn log.info(s"initialization completed, ready to process messages") @@ -289,7 +289,7 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ // in case we just validated our first local channel, we announce the local node if (!d0.nodes.contains(nodeParams.nodeId) && isRelatedTo(c, nodeParams.nodeId)) { log.info("first local channel validated, announcing local node") - val nodeAnn = Announcements.makeNodeAnnouncement(nodeParams.privateKey, nodeParams.alias, nodeParams.color, nodeParams.publicAddresses) + val nodeAnn = Announcements.makeNodeAnnouncement(nodeParams.privateKey, nodeParams.alias, nodeParams.color, nodeParams.publicAddresses, nodeParams.globalFeatures) self ! nodeAnn } Some(PublicChannel(c, tx.txid, capacity, None, None)) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala index b3afca1f24..442545bfa3 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala @@ -18,10 +18,10 @@ package fr.acinq.eclair.wire import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.crypto.Mac32 -import fr.acinq.eclair.wire.CommonCodecs.{cltvExpiry, discriminatorWithDefault, millisatoshi, sha256} +import fr.acinq.eclair.wire.CommonCodecs._ import fr.acinq.eclair.wire.FailureMessageCodecs.failureMessageCodec import fr.acinq.eclair.wire.LightningMessageCodecs.{channelUpdateCodec, lightningMessageCodec} -import fr.acinq.eclair.{CltvExpiry, LongToBtcAmount, MilliSatoshi} +import fr.acinq.eclair.{CltvExpiry, LongToBtcAmount, MilliSatoshi, UInt64} import scodec.codecs._ import scodec.{Attempt, Codec} @@ -49,7 +49,6 @@ case object RequiredNodeFeatureMissing extends Perm with Node { def message = "p case class InvalidOnionVersion(onionHash: ByteVector32) extends BadOnion with Perm { def message = "onion version was not understood by the processing node" } case class InvalidOnionHmac(onionHash: ByteVector32) extends BadOnion with Perm { def message = "onion HMAC was incorrect when it reached the processing node" } case class InvalidOnionKey(onionHash: ByteVector32) extends BadOnion with Perm { def message = "ephemeral key was unparsable by the processing node" } -case class InvalidOnionPayload(onionHash: ByteVector32) extends BadOnion with Perm { def message = "onion per-hop payload could not be parsed" } case class TemporaryChannelFailure(update: ChannelUpdate) extends Update { def message = s"channel ${update.shortChannelId} is currently unavailable" } case object PermanentChannelFailure extends Perm { def message = "channel is permanently unavailable" } case object RequiredChannelFeatureMissing extends Perm { def message = "channel requires features not present in the onion" } @@ -63,6 +62,7 @@ case class ExpiryTooSoon(update: ChannelUpdate) extends Update { def message = " case class FinalIncorrectCltvExpiry(expiry: CltvExpiry) extends FailureMessage { def message = "payment expiry doesn't match the value in the onion" } case class FinalIncorrectHtlcAmount(amount: MilliSatoshi) extends FailureMessage { def message = "payment amount is incorrect in the final htlc" } case object ExpiryTooFar extends FailureMessage { def message = "payment expiry is too far in the future" } +case class InvalidOnionPayload(tag: UInt64, offset: Int) extends Perm { def message = "onion per-hop payload is invalid" } /** * We allow remote nodes to send us unknown failure codes (e.g. deprecated failure codes). @@ -97,7 +97,6 @@ object FailureMessageCodecs { .typecase(NODE | 2, provide(TemporaryNodeFailure)) .typecase(PERM | NODE | 2, provide(PermanentNodeFailure)) .typecase(PERM | NODE | 3, provide(RequiredNodeFeatureMissing)) - .typecase(BADONION | PERM, sha256.as[InvalidOnionPayload]) .typecase(BADONION | PERM | 4, sha256.as[InvalidOnionVersion]) .typecase(BADONION | PERM | 5, sha256.as[InvalidOnionHmac]) .typecase(BADONION | PERM | 6, sha256.as[InvalidOnionKey]) @@ -115,7 +114,8 @@ object FailureMessageCodecs { // PERM | 17 (final_expiry_too_soon) has been deprecated because it allowed probing attacks: IncorrectOrUnknownPaymentDetails should be used instead. .typecase(18, ("expiry" | cltvExpiry).as[FinalIncorrectCltvExpiry]) .typecase(19, ("amountMsat" | millisatoshi).as[FinalIncorrectHtlcAmount]) - .typecase(21, provide(ExpiryTooFar)), + .typecase(21, provide(ExpiryTooFar)) + .typecase(PERM | 22, (("tag" | varint) :: ("offset" | uint16)).as[InvalidOnionPayload]), uint16.xmap(code => { val failureMessage = code match { // @formatter:off diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala index f1737eb606..d1ff115951 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala @@ -18,41 +18,98 @@ package fr.acinq.eclair.wire import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.{CltvExpiry, MilliSatoshi, ShortChannelId} -import scodec.bits.{BitVector, ByteVector} -import scodec.codecs._ -import scodec.{Codec, DecodeResult, Decoder} +import fr.acinq.eclair.wire.CommonCodecs._ +import fr.acinq.eclair.wire.TlvCodecs._ +import fr.acinq.eclair.{CltvExpiry, MilliSatoshi, ShortChannelId, UInt64} +import scodec.bits.{BitVector, ByteVector, HexStringSyntax} /** * Created by t-bast on 05/07/2019. */ -case class OnionRoutingPacket(version: Int, - publicKey: ByteVector, - payload: ByteVector, - hmac: ByteVector32) +case class OnionRoutingPacket(version: Int, publicKey: ByteVector, payload: ByteVector, hmac: ByteVector32) -case class PerHopPayload(shortChannelId: ShortChannelId, - amtToForward: MilliSatoshi, - outgoingCltvValue: CltvExpiry) +/** Tlv types used inside onion messages. */ +sealed trait OnionTlv extends Tlv + +object OnionTlv { + + /** Amount to forward to the next node. */ + case class AmountToForward(amount: MilliSatoshi) extends OnionTlv + + /** CLTV value to use for the HTLC offered to the next node. */ + case class OutgoingCltv(cltv: CltvExpiry) extends OnionTlv + + /** Id of the channel to use to forward a payment to the next node. */ + case class OutgoingChannelId(shortChannelId: ShortChannelId) extends OnionTlv + +} + +object Onion { + + import OnionTlv._ + + sealed trait PerHopPayloadFormat + + /** Legacy fixed-size 65-bytes onion payload. */ + sealed trait LegacyFormat extends PerHopPayloadFormat + + /** Variable-length onion payload with optional additional tlv records. */ + sealed trait TlvFormat extends PerHopPayloadFormat { + def records: TlvStream[OnionTlv] + } + + /** Per-hop payload from an HTLC's payment onion (after decryption and decoding). */ + sealed trait PerHopPayload + + /** Per-hop payload for an intermediate node. */ + sealed trait RelayPayload extends PerHopPayload with PerHopPayloadFormat { + /** Amount to forward to the next node. */ + val amountToForward: MilliSatoshi + /** CLTV value to use for the HTLC offered to the next node. */ + val outgoingCltv: CltvExpiry + /** Id of the channel to use to forward a payment to the next node. */ + val outgoingChannelId: ShortChannelId + } + + /** Per-hop payload for a final node. */ + sealed trait FinalPayload extends PerHopPayload with PerHopPayloadFormat { + val amount: MilliSatoshi + val expiry: CltvExpiry + } + + case class RelayLegacyPayload(outgoingChannelId: ShortChannelId, amountToForward: MilliSatoshi, outgoingCltv: CltvExpiry) extends RelayPayload with LegacyFormat + + case class FinalLegacyPayload(amount: MilliSatoshi, expiry: CltvExpiry) extends FinalPayload with LegacyFormat + + case class RelayTlvPayload(records: TlvStream[OnionTlv]) extends RelayPayload with TlvFormat { + override val amountToForward = records.get[AmountToForward].get.amount + override val outgoingCltv = records.get[OutgoingCltv].get.cltv + override val outgoingChannelId = records.get[OutgoingChannelId].get.shortChannelId + } + + case class FinalTlvPayload(records: TlvStream[OnionTlv]) extends FinalPayload with TlvFormat { + override val amount = records.get[AmountToForward].get.amount + override val expiry = records.get[OutgoingCltv].get.cltv + } + +} object OnionCodecs { + import Onion._ + import OnionTlv._ + import scodec.codecs._ + import scodec.{Attempt, Codec, DecodeResult, Decoder, Err} + def onionRoutingPacketCodec(payloadLength: Int): Codec[OnionRoutingPacket] = ( ("version" | uint8) :: ("publicKey" | bytes(33)) :: ("onionPayload" | bytes(payloadLength)) :: - ("hmac" | CommonCodecs.bytes32)).as[OnionRoutingPacket] + ("hmac" | bytes32)).as[OnionRoutingPacket] val paymentOnionPacketCodec: Codec[OnionRoutingPacket] = onionRoutingPacketCodec(Sphinx.PaymentPacket.PayloadLength) - val perHopPayloadCodec: Codec[PerHopPayload] = ( - ("realm" | constant(ByteVector.fromByte(0))) :: - ("short_channel_id" | CommonCodecs.shortchannelid) :: - ("amt_to_forward" | CommonCodecs.millisatoshi) :: - ("outgoing_cltv_value" | CommonCodecs.cltvExpiry) :: - ("unused_with_v0_version_on_header" | ignore(8 * 12))).as[PerHopPayload] - /** * The 1.1 BOLT spec changed the onion frame format to use variable-length per-hop payloads. * The first bytes contain a varint encoding the length of the payload data (not including the trailing mac). @@ -60,6 +117,63 @@ object OnionCodecs { * the varint prefix. */ val payloadLengthDecoder = Decoder[Long]((bits: BitVector) => - CommonCodecs.varintoverflow.decode(bits).map(d => DecodeResult(d.value + (bits.length - d.remainder.length) / 8, d.remainder))) + varintoverflow.decode(bits).map(d => DecodeResult(d.value + (bits.length - d.remainder.length) / 8, d.remainder))) + + private val amountToForward: Codec[AmountToForward] = ("amount_msat" | tu64overflow).xmap(amountMsat => AmountToForward(MilliSatoshi(amountMsat)), (a: AmountToForward) => a.amount.toLong) + + private val outgoingCltv: Codec[OutgoingCltv] = ("cltv" | tu32).xmap(cltv => OutgoingCltv(CltvExpiry(cltv)), (c: OutgoingCltv) => c.cltv.toLong) + + private val outgoingChannelId: Codec[OutgoingChannelId] = (("length" | constant(hex"08")) :: ("short_channel_id" | shortchannelid)).as[OutgoingChannelId] + + private val onionTlvCodec = discriminated[OnionTlv].by(varint) + .typecase(UInt64(2), amountToForward) + .typecase(UInt64(4), outgoingCltv) + .typecase(UInt64(6), outgoingChannelId) + + val tlvPerHopPayloadCodec: Codec[TlvStream[OnionTlv]] = TlvCodecs.lengthPrefixedTlvStream[OnionTlv](onionTlvCodec).complete + + private val legacyRelayPerHopPayloadCodec: Codec[RelayLegacyPayload] = ( + ("realm" | constant(ByteVector.fromByte(0))) :: + ("short_channel_id" | shortchannelid) :: + ("amt_to_forward" | millisatoshi) :: + ("outgoing_cltv_value" | cltvExpiry) :: + ("unused_with_v0_version_on_header" | ignore(8 * 12))).as[RelayLegacyPayload] + + private val legacyFinalPerHopPayloadCodec: Codec[FinalLegacyPayload] = ( + ("realm" | constant(ByteVector.fromByte(0))) :: + ("short_channel_id" | ignore(8 * 8)) :: + ("amount" | millisatoshi) :: + ("expiry" | cltvExpiry) :: + ("unused_with_v0_version_on_header" | ignore(8 * 12))).as[FinalLegacyPayload] + + case class MissingRequiredTlv(tag: UInt64) extends Err { + // @formatter:off + val failureMessage: FailureMessage = InvalidOnionPayload(tag, 0) + override def message = failureMessage.message + override def context: List[String] = Nil + override def pushContext(ctx: String): Err = this + // @formatter:on + } + + val relayPerHopPayloadCodec: Codec[RelayPayload] = fallback(tlvPerHopPayloadCodec, legacyRelayPerHopPayloadCodec).narrow({ + case Left(tlvs) if tlvs.get[AmountToForward].isEmpty => Attempt.failure(MissingRequiredTlv(UInt64(2))) + case Left(tlvs) if tlvs.get[OutgoingCltv].isEmpty => Attempt.failure(MissingRequiredTlv(UInt64(4))) + case Left(tlvs) if tlvs.get[OutgoingChannelId].isEmpty => Attempt.failure(MissingRequiredTlv(UInt64(6))) + case Left(tlvs) => Attempt.successful(RelayTlvPayload(tlvs)) + case Right(legacy) => Attempt.successful(legacy) + }, { + case legacy: RelayLegacyPayload => Right(legacy) + case RelayTlvPayload(tlvs) => Left(tlvs) + }) + + val finalPerHopPayloadCodec: Codec[FinalPayload] = fallback(tlvPerHopPayloadCodec, legacyFinalPerHopPayloadCodec).narrow({ + case Left(tlvs) if tlvs.get[AmountToForward].isEmpty => Attempt.failure(MissingRequiredTlv(UInt64(2))) + case Left(tlvs) if tlvs.get[OutgoingCltv].isEmpty => Attempt.failure(MissingRequiredTlv(UInt64(4))) + case Left(tlvs) => Attempt.successful(FinalTlvPayload(tlvs)) + case Right(legacy) => Attempt.successful(legacy) + }, { + case legacy: FinalLegacyPayload => Right(legacy) + case FinalTlvPayload(tlvs) => Left(tlvs) + }) } \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala index 664f7182ca..0f56ba666d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala @@ -44,6 +44,14 @@ object TlvCodecs { .\(0x07) { case i if i < 0x0100000000000000L => i }(variableSizeUInt64(7, 0x01000000000000L)) .\(0x08) { case i if i <= UInt64.MaxValue => i }(variableSizeUInt64(8, 0x0100000000000000L)) + /** + * Length-prefixed truncated long (1 to 9 bytes unsigned integer). + * This codec can be safely used for values < `2^63` and will fail otherwise. + */ + val tu64overflow: Codec[Long] = tu64.exmap( + u => if (u <= Long.MaxValue) Attempt.Successful(u.toBigInt.toLong) else Attempt.Failure(Err(s"overflow for value $u")), + l => if (l >= 0) Attempt.Successful(UInt64(l)) else Attempt.Failure(Err(s"uint64 must be positive (actual=$l)"))) + /** * Length-prefixed truncated uint32 (1 to 5 bytes unsigned integer). */ diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala index 54cb65c948..3ea3d85c21 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala @@ -25,10 +25,7 @@ import scala.reflect.ClassTag * Created by t-bast on 20/06/2019. */ -// @formatter:off trait Tlv -sealed trait OnionTlv extends Tlv -// @formatter:on /** * Generic tlv type we fallback to if we don't understand the incoming tlv. diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala index 1802fc5b6e..3e5f5a8dab 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala @@ -26,7 +26,8 @@ import fr.acinq.eclair.blockchain.TestWallet import fr.acinq.eclair.channel.{CMD_FORCECLOSE, Register, _} import fr.acinq.eclair.db._ import fr.acinq.eclair.io.Peer.OpenChannel -import fr.acinq.eclair.payment.PaymentLifecycle.{ReceivePayment, SendPayment, SendPaymentToRoute} +import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest +import fr.acinq.eclair.payment.PaymentLifecycle.ReceivePayment import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.payment.{LocalPaymentHandler, PaymentRequest} import fr.acinq.eclair.router.RouteCalculationSpec.makeUpdate @@ -94,7 +95,7 @@ class EclairImplSpec extends TestKit(ActorSystem("mySystem")) with fixture.FunSu val nodeId = PublicKey(hex"030bb6a5e0c6b203c7e2180fb78c7ba4bdce46126761d8201b91ddac089cdecc87") eclair.send(nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = None) - val send = paymentInitiator.expectMsgType[SendPayment] + val send = paymentInitiator.expectMsgType[SendPaymentRequest] assert(send.targetNodeId == nodeId) assert(send.amount == 123.msat) assert(send.paymentHash == ByteVector32.Zeroes) @@ -104,7 +105,7 @@ class EclairImplSpec extends TestKit(ActorSystem("mySystem")) with fixture.FunSu val hints = List(List(ExtraHop(Bob.nodeParams.nodeId, ShortChannelId("569178x2331x1"), feeBase = 10 msat, feeProportionalMillionths = 1, cltvExpiryDelta = CltvExpiryDelta(12)))) val invoice1 = PaymentRequest(Block.RegtestGenesisBlock.hash, Some(123 msat), ByteVector32.Zeroes, randomKey, "description", None, None, hints) eclair.send(nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(invoice1)) - val send1 = paymentInitiator.expectMsgType[SendPayment] + val send1 = paymentInitiator.expectMsgType[SendPaymentRequest] assert(send1.targetNodeId == nodeId) assert(send1.amount == 123.msat) assert(send1.paymentHash == ByteVector32.Zeroes) @@ -113,15 +114,15 @@ class EclairImplSpec extends TestKit(ActorSystem("mySystem")) with fixture.FunSu // with finalCltvExpiry val invoice2 = PaymentRequest("lntb", Some(123 msat), System.currentTimeMillis() / 1000L, nodeId, List(PaymentRequest.MinFinalCltvExpiry(96), PaymentRequest.PaymentHash(ByteVector32.Zeroes), PaymentRequest.Description("description")), ByteVector.empty) eclair.send(nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(invoice2)) - val send2 = paymentInitiator.expectMsgType[SendPayment] + val send2 = paymentInitiator.expectMsgType[SendPaymentRequest] assert(send2.targetNodeId == nodeId) assert(send2.amount == 123.msat) assert(send2.paymentHash == ByteVector32.Zeroes) - assert(send2.finalCltvExpiryDelta == CltvExpiryDelta(96)) + assert(send2.finalExpiryDelta == CltvExpiryDelta(96)) // with custom route fees parameters eclair.send(nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = None, feeThreshold_opt = Some(123 sat), maxFeePct_opt = Some(4.20)) - val send3 = paymentInitiator.expectMsgType[SendPayment] + val send3 = paymentInitiator.expectMsgType[SendPaymentRequest] assert(send3.targetNodeId == nodeId) assert(send3.amount == 123.msat) assert(send3.paymentHash == ByteVector32.Zeroes) @@ -252,12 +253,11 @@ class EclairImplSpec extends TestKit(ActorSystem("mySystem")) with fixture.FunSu val eclair = new EclairImpl(kit) eclair.sendToRoute(route, 1234 msat, ByteVector32.One, CltvExpiryDelta(123)) - val send = paymentInitiator.expectMsgType[SendPaymentToRoute] - - assert(send.hops === route) + val send = paymentInitiator.expectMsgType[SendPaymentRequest] + assert(send.predefinedRoute == route) assert(send.amount === 1234.msat) - assert(send.finalCltvExpiryDelta === CltvExpiryDelta(123)) - assert(send.paymentHash === ByteVector32.One) + assert(send.finalExpiryDelta === CltvExpiryDelta(123)) + assert(send.paymentHash == ByteVector32.One) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala index 102e50e532..30f06c385d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala @@ -16,9 +16,6 @@ package fr.acinq.eclair -import java.nio.ByteOrder - -import fr.acinq.bitcoin.Protocol import fr.acinq.eclair.Features._ import org.scalatest.FunSuite import scodec.bits._ @@ -38,21 +35,25 @@ class FeaturesSpec extends FunSuite { assert(hasFeature(hex"02", Features.OPTION_DATA_LOSS_PROTECT_OPTIONAL)) } - test("'initial_routing_sync' and 'data_loss_protect' feature") { - val features = hex"0a" - assert(areSupported(features) && hasFeature(features, OPTION_DATA_LOSS_PROTECT_OPTIONAL) && hasFeature(features, INITIAL_ROUTING_SYNC_BIT_OPTIONAL)) + test("'initial_routing_sync', 'data_loss_protect' and 'variable_length_onion' features") { + val features = hex"010a" + assert(areSupported(features) && hasFeature(features, OPTION_DATA_LOSS_PROTECT_OPTIONAL) && hasFeature(features, INITIAL_ROUTING_SYNC_BIT_OPTIONAL) && hasFeature(features, VARIABLE_LENGTH_ONION_MANDATORY)) } test("'variable_length_onion' feature") { assert(hasFeature(hex"0100", Features.VARIABLE_LENGTH_ONION_MANDATORY)) + assert(hasVariableLengthOnion(hex"0100")) assert(hasFeature(hex"0200", Features.VARIABLE_LENGTH_ONION_OPTIONAL)) + assert(hasVariableLengthOnion(hex"0200")) } test("features compatibility") { - assert(areSupported(Protocol.writeUInt64(1l << INITIAL_ROUTING_SYNC_BIT_OPTIONAL, ByteOrder.BIG_ENDIAN))) - assert(areSupported(Protocol.writeUInt64(1L << OPTION_DATA_LOSS_PROTECT_MANDATORY, ByteOrder.BIG_ENDIAN))) - assert(areSupported(Protocol.writeUInt64(1l << OPTION_DATA_LOSS_PROTECT_OPTIONAL, ByteOrder.BIG_ENDIAN))) - assert(areSupported(Protocol.writeUInt64(1l << VARIABLE_LENGTH_ONION_OPTIONAL, ByteOrder.BIG_ENDIAN))) + assert(areSupported(ByteVector.fromLong(1L << INITIAL_ROUTING_SYNC_BIT_OPTIONAL))) + assert(areSupported(ByteVector.fromLong(1L << OPTION_DATA_LOSS_PROTECT_MANDATORY))) + assert(areSupported(ByteVector.fromLong(1L << OPTION_DATA_LOSS_PROTECT_OPTIONAL))) + assert(areSupported(ByteVector.fromLong(1L << VARIABLE_LENGTH_ONION_OPTIONAL))) + assert(areSupported(ByteVector.fromLong(1L << VARIABLE_LENGTH_ONION_MANDATORY))) + assert(areSupported(hex"0b")) assert(!areSupported(hex"14")) assert(!areSupported(hex"0141")) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala index 5d6a9b11b5..fb4838d4fa 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala @@ -27,7 +27,7 @@ import fr.acinq.eclair.db._ import fr.acinq.eclair.io.Peer import fr.acinq.eclair.router.RouterConf import fr.acinq.eclair.wire.{Color, EncodingType, NodeAddress} -import scodec.bits.ByteVector +import scodec.bits.{ByteVector, HexStringSyntax} import scala.concurrent.duration._ @@ -36,6 +36,7 @@ import scala.concurrent.duration._ */ object TestConstants { + val globalFeatures = hex"0200" // variable_length_onion val fundingSatoshis = 1000000L sat val pushMsat = 200000000L msat val feeratePerKw = 10000L @@ -67,7 +68,7 @@ object TestConstants { alias = "alice", color = Color(1, 2, 3), publicAddresses = NodeAddress.fromParts("localhost", 9731).get :: Nil, - globalFeatures = ByteVector.empty, + globalFeatures = globalFeatures, localFeatures = ByteVector(0), overrideFeatures = Map.empty, syncWhitelist = Set.empty, @@ -142,7 +143,7 @@ object TestConstants { alias = "bob", color = Color(4, 5, 6), publicAddresses = NodeAddress.fromParts("localhost", 9732).get :: Nil, - globalFeatures = ByteVector.empty, + globalFeatures = globalFeatures, localFeatures = ByteVector.empty, // no announcement overrideFeatures = Map.empty, syncWhitelist = Set.empty, diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala index 53a0ab5914..f1f464ee1c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala @@ -30,6 +30,7 @@ import fr.acinq.eclair.channel.states.StateTestsHelperMethods import fr.acinq.eclair.payment.PaymentLifecycle.ReceivePayment import fr.acinq.eclair.payment._ import fr.acinq.eclair.router.Hop +import fr.acinq.eclair.wire.Onion.FinalLegacyPayload import fr.acinq.eclair.wire._ import grizzled.slf4j.Logging import org.scalatest.{Outcome, Tag} @@ -39,8 +40,8 @@ import scala.concurrent.duration._ import scala.util.Random /** - * Created by PM on 05/07/2016. - */ + * Created by PM on 05/07/2016. + */ class FuzzySpec extends TestkitBaseClass with StateTestsHelperMethods with Logging { @@ -95,10 +96,10 @@ class FuzzySpec extends TestkitBaseClass with StateTestsHelperMethods with Loggi // allow overpaying (no more than 2 times the required amount) val amount = MilliSatoshi(requiredAmount + Random.nextInt(requiredAmount)) val expiry = (Channel.MIN_CLTV_EXPIRY_DELTA + 1).toCltvExpiry - PaymentLifecycle.buildCommand(UUID.randomUUID(), amount, expiry, paymentHash, Hop(null, dest, null) :: Nil)._1 + PaymentLifecycle.buildCommand(UUID.randomUUID(), paymentHash, Hop(null, dest, null) :: Nil, FinalLegacyPayload(amount, expiry))._1 } - def initiatePayment(stopping: Boolean) = + def initiatePayment(stopping: Boolean): Unit = if (stopping) { context stop self } else { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/StateTestsHelperMethods.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/StateTestsHelperMethods.scala index cdcbab0d8b..9ffb4c8634 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/StateTestsHelperMethods.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/StateTestsHelperMethods.scala @@ -28,6 +28,7 @@ import fr.acinq.eclair.channel._ import fr.acinq.eclair.io.Peer import fr.acinq.eclair.payment.PaymentLifecycle import fr.acinq.eclair.router.Hop +import fr.acinq.eclair.wire.Onion.FinalLegacyPayload import fr.acinq.eclair.wire._ import fr.acinq.eclair.{NodeParams, TestConstants, randomBytes32, _} @@ -109,7 +110,7 @@ trait StateTestsHelperMethods extends TestKitBase { val payment_preimage: ByteVector32 = randomBytes32 val payment_hash: ByteVector32 = Crypto.sha256(payment_preimage) val expiry = CltvExpiryDelta(144).toCltvExpiry - val cmd = PaymentLifecycle.buildCommand(UUID.randomUUID, amount, expiry, payment_hash, Hop(null, destination, null) :: Nil)._1.copy(commit = false) + val cmd = PaymentLifecycle.buildCommand(UUID.randomUUID, payment_hash, Hop(null, destination, null) :: Nil, FinalLegacyPayload(amount, expiry))._1.copy(commit = false) (payment_preimage, cmd) } @@ -124,7 +125,7 @@ trait StateTestsHelperMethods extends TestKitBase { (payment_preimage, htlc) } - def fulfillHtlc(id: Long, R: ByteVector32, s: TestFSMRef[State, Data, Channel], r: TestFSMRef[State, Data, Channel], s2r: TestProbe, r2s: TestProbe) = { + def fulfillHtlc(id: Long, R: ByteVector32, s: TestFSMRef[State, Data, Channel], r: TestFSMRef[State, Data, Channel], s2r: TestProbe, r2s: TestProbe): Unit = { val sender = TestProbe() sender.send(s, CMD_FULFILL_HTLC(id, R)) sender.expectMsg("ok") @@ -133,7 +134,7 @@ trait StateTestsHelperMethods extends TestKitBase { awaitCond(r.stateData.asInstanceOf[HasCommitments].commitments.remoteChanges.proposed.contains(fulfill)) } - def crossSign(s: TestFSMRef[State, Data, Channel], r: TestFSMRef[State, Data, Channel], s2r: TestProbe, r2s: TestProbe) = { + def crossSign(s: TestFSMRef[State, Data, Channel], r: TestFSMRef[State, Data, Channel], s2r: TestProbe, r2s: TestProbe): Unit = { val sender = TestProbe() val sCommitIndex = s.stateData.asInstanceOf[HasCommitments].commitments.localCommit.index val rCommitIndex = r.stateData.asInstanceOf[HasCommitments].commitments.localCommit.index @@ -170,12 +171,14 @@ trait StateTestsHelperMethods extends TestKitBase { implicit class ChannelWithTestFeeConf(a: TestFSMRef[State, Data, Channel]) { def feeEstimator: TestFeeEstimator = a.underlyingActor.nodeParams.onChainFeeConf.feeEstimator.asInstanceOf[TestFeeEstimator] + def feeTargets: FeeTargets = a.underlyingActor.nodeParams.onChainFeeConf.feeTargets } implicit class PeerWithTestFeeConf(a: TestFSMRef[Peer.State, Peer.Data, Peer]) { def feeEstimator: TestFeeEstimator = a.underlyingActor.nodeParams.onChainFeeConf.feeEstimator.asInstanceOf[TestFeeEstimator] + def feeTargets: FeeTargets = a.underlyingActor.nodeParams.onChainFeeConf.feeTargets } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala index 5b425eac08..0757f56338 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala @@ -28,6 +28,7 @@ import fr.acinq.eclair.channel._ import fr.acinq.eclair.channel.states.StateTestsHelperMethods import fr.acinq.eclair.payment._ import fr.acinq.eclair.router.Hop +import fr.acinq.eclair.wire.Onion.FinalLegacyPayload import fr.acinq.eclair.wire.{CommitSig, Error, FailureMessageCodecs, PermanentChannelFailure, RevokeAndAck, Shutdown, UpdateAddHtlc, UpdateFailHtlc, UpdateFailMalformedHtlc, UpdateFee, UpdateFulfillHtlc} import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, TestConstants, TestkitBaseClass, randomBytes32} import org.scalatest.Outcome @@ -56,7 +57,7 @@ class ShutdownStateSpec extends TestkitBaseClass with StateTestsHelperMethods { val h1 = Crypto.sha256(r1) val amount1 = 300000000 msat val expiry1 = CltvExpiryDelta(144).toCltvExpiry - val cmd1 = PaymentLifecycle.buildCommand(UUID.randomUUID, amount1, expiry1, h1, Hop(null, TestConstants.Bob.nodeParams.nodeId, null) :: Nil)._1.copy(commit = false) + val cmd1 = PaymentLifecycle.buildCommand(UUID.randomUUID, h1, Hop(null, TestConstants.Bob.nodeParams.nodeId, null) :: Nil, FinalLegacyPayload(amount1, expiry1))._1.copy(commit = false) sender.send(alice, cmd1) sender.expectMsg("ok") val htlc1 = alice2bob.expectMsgType[UpdateAddHtlc] @@ -66,7 +67,7 @@ class ShutdownStateSpec extends TestkitBaseClass with StateTestsHelperMethods { val h2 = Crypto.sha256(r2) val amount2 = 200000000 msat val expiry2 = CltvExpiryDelta(144).toCltvExpiry - val cmd2 = PaymentLifecycle.buildCommand(UUID.randomUUID, amount2, expiry2, h2, Hop(null, TestConstants.Bob.nodeParams.nodeId, null) :: Nil)._1.copy(commit = false) + val cmd2 = PaymentLifecycle.buildCommand(UUID.randomUUID, h2, Hop(null, TestConstants.Bob.nodeParams.nodeId, null) :: Nil, FinalLegacyPayload(amount2, expiry2))._1.copy(commit = false) sender.send(alice, cmd2) sender.expectMsg("ok") val htlc2 = alice2bob.expectMsgType[UpdateAddHtlc] @@ -216,7 +217,6 @@ class ShutdownStateSpec extends TestkitBaseClass with StateTestsHelperMethods { test("recv CMD_FAIL_HTLC (acknowledge in case of failure)") { f => import f._ val sender = TestProbe() - val r = randomBytes32 val initialState = bob.stateData.asInstanceOf[DATA_SHUTDOWN] sender.send(bob, CMD_FAIL_HTLC(42, Right(PermanentChannelFailure))) // this will fail sender.expectMsg(Failure(UnknownHtlcId(channelId(bob), 42))) @@ -239,7 +239,7 @@ class ShutdownStateSpec extends TestkitBaseClass with StateTestsHelperMethods { import f._ val sender = TestProbe() val initialState = bob.stateData.asInstanceOf[DATA_SHUTDOWN] - sender.send(bob, CMD_FAIL_MALFORMED_HTLC(42, ByteVector32.Zeroes, FailureMessageCodecs.BADONION)) + sender.send(bob, CMD_FAIL_MALFORMED_HTLC(42, randomBytes32, FailureMessageCodecs.BADONION)) sender.expectMsg(Failure(UnknownHtlcId(channelId(bob), 42))) assert(initialState == bob.stateData) } @@ -248,7 +248,7 @@ class ShutdownStateSpec extends TestkitBaseClass with StateTestsHelperMethods { import f._ val sender = TestProbe() val initialState = bob.stateData.asInstanceOf[DATA_SHUTDOWN] - sender.send(bob, CMD_FAIL_MALFORMED_HTLC(42, ByteVector32.Zeroes, 42)) + sender.send(bob, CMD_FAIL_MALFORMED_HTLC(42, randomBytes32, 42)) sender.expectMsg(Failure(InvalidFailureCode(channelId(bob)))) assert(initialState == bob.stateData) } @@ -256,10 +256,8 @@ class ShutdownStateSpec extends TestkitBaseClass with StateTestsHelperMethods { test("recv CMD_FAIL_MALFORMED_HTLC (acknowledge in case of failure)") { f => import f._ val sender = TestProbe() - val r = randomBytes32 val initialState = bob.stateData.asInstanceOf[DATA_SHUTDOWN] - - sender.send(bob, CMD_FAIL_MALFORMED_HTLC(42, ByteVector32.Zeroes, FailureMessageCodecs.BADONION)) // this will fail + sender.send(bob, CMD_FAIL_MALFORMED_HTLC(42, randomBytes32, FailureMessageCodecs.BADONION)) // this will fail sender.expectMsg(Failure(UnknownHtlcId(channelId(bob), 42))) relayerB.expectMsg(CommandBuffer.CommandAck(initialState.channelId, 42)) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala index 59a622b48b..1dbdecffbe 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala @@ -18,7 +18,7 @@ package fr.acinq.eclair.crypto import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} -import fr.acinq.eclair.wire +import fr.acinq.eclair.{UInt64, wire} import fr.acinq.eclair.wire._ import org.scalatest.FunSuite import scodec.bits._ @@ -249,7 +249,7 @@ class SphinxSpec extends FunSuite { val packet = FailurePacket.wrap( FailurePacket.wrap( - FailurePacket.create(sharedSecrets.head, InvalidOnionPayload(ByteVector32.Zeroes)), + FailurePacket.create(sharedSecrets.head, InvalidOnionPayload(UInt64(0), 0)), sharedSecrets(1)), sharedSecrets(2)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala index d2347b2769..faedbb1499 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala @@ -16,16 +16,17 @@ package fr.acinq.eclair.db -import java.sql.{Connection, DriverManager} +import java.sql.Connection import fr.acinq.bitcoin.Crypto.PrivateKey -import fr.acinq.bitcoin.{Block, Crypto, Satoshi} +import fr.acinq.bitcoin.{Block, Crypto} import fr.acinq.eclair.db.sqlite.SqliteNetworkDb import fr.acinq.eclair.db.sqlite.SqliteUtils._ import fr.acinq.eclair.router.{Announcements, PublicChannel} import fr.acinq.eclair.wire.{Color, NodeAddress, Tor2} import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, ShortChannelId, TestConstants, randomBytes32, randomKey} import org.scalatest.FunSuite +import scodec.bits.HexStringSyntax import scala.collection.SortedMap @@ -82,15 +83,15 @@ class SqliteNetworkDbSpec extends FunSuite { val sqlite = TestConstants.sqliteInMemory() val db = new SqliteNetworkDb(sqlite) - val node_1 = Announcements.makeNodeAnnouncement(randomKey, "node-alice", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil) - val node_2 = Announcements.makeNodeAnnouncement(randomKey, "node-bob", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil) - val node_3 = Announcements.makeNodeAnnouncement(randomKey, "node-charlie", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil) - val node_4 = Announcements.makeNodeAnnouncement(randomKey, "node-charlie", Color(100.toByte, 200.toByte, 300.toByte), Tor2("aaaqeayeaudaocaj", 42000) :: Nil) + val node_1 = Announcements.makeNodeAnnouncement(randomKey, "node-alice", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, hex"") + val node_2 = Announcements.makeNodeAnnouncement(randomKey, "node-bob", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, hex"0200") + val node_3 = Announcements.makeNodeAnnouncement(randomKey, "node-charlie", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, hex"0200") + val node_4 = Announcements.makeNodeAnnouncement(randomKey, "node-charlie", Color(100.toByte, 200.toByte, 300.toByte), Tor2("aaaqeayeaudaocaj", 42000) :: Nil, hex"00") assert(db.listNodes().toSet === Set.empty) db.addNode(node_1) db.addNode(node_1) // duplicate is ignored - assert(db.getNode(node_1.nodeId) == Some(node_1)) + assert(db.getNode(node_1.nodeId) === Some(node_1)) assert(db.listNodes().size === 1) db.addNode(node_2) db.addNode(node_3) @@ -110,7 +111,7 @@ class SqliteNetworkDbSpec extends FunSuite { def generatePubkeyHigherThan(priv: PrivateKey) = { var res = priv - while(!Announcements.isNode1(priv.publicKey, res.publicKey)) res = randomKey + while (!Announcements.isNode1(priv.publicKey, res.publicKey)) res = randomKey res } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala index 2a6db6c66e..dd01ce607e 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala @@ -34,6 +34,7 @@ import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx.DecryptedFailurePacket import fr.acinq.eclair.io.Peer import fr.acinq.eclair.io.Peer.{Disconnect, PeerRoutingMessage} +import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest import fr.acinq.eclair.payment.PaymentLifecycle.{State => _, _} import fr.acinq.eclair.payment.{LocalPaymentHandler, PaymentRequest} import fr.acinq.eclair.router.Graph.WeightRatios @@ -262,8 +263,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.send(nodes("D").paymentHandler, ReceivePayment(Some(amountMsat), "1 coffee")) val pr = sender.expectMsgType[PaymentRequest] // then we make the actual payment - sender.send(nodes("A").paymentInitiator, - SendPayment(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 1)) + sender.send(nodes("A").paymentInitiator, SendPaymentRequest(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 1)) val paymentId = sender.expectMsgType[UUID](5 seconds) val ps = sender.expectMsgType[PaymentSucceeded](5 seconds) assert(ps.id == paymentId) @@ -287,7 +287,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.send(nodes("D").paymentHandler, ReceivePayment(Some(amountMsat), "1 coffee")) val pr = sender.expectMsgType[PaymentRequest] // then we make the actual payment, do not randomize the route to make sure we route through node B - val sendReq = SendPayment(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) + val sendReq = SendPaymentRequest(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) sender.send(nodes("A").paymentInitiator, sendReq) // A will receive an error from B that include the updated channel update, then will retry the payment val paymentId = sender.expectMsgType[UUID](5 seconds) @@ -327,7 +327,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.send(nodes("D").paymentHandler, ReceivePayment(Some(amountMsat), "1 coffee")) val pr = sender.expectMsgType[PaymentRequest] // then we make the payment (B-C has a smaller capacity than A-B and C-D) - val sendReq = SendPayment(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) + val sendReq = SendPaymentRequest(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) sender.send(nodes("A").paymentInitiator, sendReq) // A will first receive an error from C, then retry and route around C: A->B->E->C->D sender.expectMsgType[UUID](5 seconds) @@ -336,7 +336,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService test("send an HTLC A->D with an unknown payment hash") { val sender = TestProbe() - val pr = SendPayment(100000000 msat, randomBytes32, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) + val pr = SendPaymentRequest(100000000 msat, randomBytes32, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) sender.send(nodes("A").paymentInitiator, pr) // A will receive an error from D and won't retry @@ -356,7 +356,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val pr = sender.expectMsgType[PaymentRequest] // A send payment of only 1 mBTC - val sendReq = SendPayment(100000000 msat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) + val sendReq = SendPaymentRequest(100000000 msat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) sender.send(nodes("A").paymentInitiator, sendReq) // A will first receive an IncorrectPaymentAmount error from D @@ -376,7 +376,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val pr = sender.expectMsgType[PaymentRequest] // A send payment of 6 mBTC - val sendReq = SendPayment(600000000 msat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) + val sendReq = SendPaymentRequest(600000000 msat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) sender.send(nodes("A").paymentInitiator, sendReq) // A will first receive an IncorrectPaymentAmount error from D @@ -396,7 +396,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val pr = sender.expectMsgType[PaymentRequest] // A send payment of 3 mBTC, more than asked but it should still be accepted - val sendReq = SendPayment(300000000 msat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) + val sendReq = SendPaymentRequest(300000000 msat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) sender.send(nodes("A").paymentInitiator, sendReq) sender.expectMsgType[UUID] } @@ -409,7 +409,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.send(nodes("D").paymentHandler, ReceivePayment(Some(amountMsat), "1 payment")) val pr = sender.expectMsgType[PaymentRequest] - val sendReq = SendPayment(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) + val sendReq = SendPaymentRequest(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) sender.send(nodes("A").paymentInitiator, sendReq) sender.expectMsgType[UUID] sender.expectMsgType[PaymentSucceeded] // the payment FSM will also reply to the sender after the payment is completed @@ -424,8 +424,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val pr = sender.expectMsgType[PaymentRequest](30 seconds) // the payment is requesting to use a capacity-optimized route which will select node G even though it's a bit more expensive - sender.send(nodes("A").paymentInitiator, - SendPayment(amountMsat, pr.paymentHash, nodes("C").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams.map(_.copy(ratios = Some(WeightRatios(0, 0, 1)))))) + sender.send(nodes("A").paymentInitiator, SendPaymentRequest(amountMsat, pr.paymentHash, nodes("C").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams.map(_.copy(ratios = Some(WeightRatios(0, 0, 1)))))) sender.expectMsgType[UUID](max = 60 seconds) awaitCond({ @@ -469,7 +468,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val preimage = randomBytes32 val paymentHash = Crypto.sha256(preimage) // A sends a payment to F - val paymentReq = SendPayment(100000000 msat, paymentHash, nodes("F1").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams) + val paymentReq = SendPaymentRequest(100000000 msat, paymentHash, nodes("F1").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams) val paymentSender = TestProbe() paymentSender.send(nodes("A").paymentInitiator, paymentReq) paymentSender.expectMsgType[UUID](30 seconds) @@ -549,7 +548,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val preimage = randomBytes32 val paymentHash = Crypto.sha256(preimage) // A sends a payment to F - val paymentReq = SendPayment(100000000 msat, paymentHash, nodes("F2").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams) + val paymentReq = SendPaymentRequest(100000000 msat, paymentHash, nodes("F2").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams) val paymentSender = TestProbe() paymentSender.send(nodes("A").paymentInitiator, paymentReq) paymentSender.expectMsgType[UUID](30 seconds) @@ -626,7 +625,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val preimage: ByteVector = randomBytes32 val paymentHash = Crypto.sha256(preimage) // A sends a payment to F - val paymentReq = SendPayment(100000000 msat, paymentHash, nodes("F3").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams) + val paymentReq = SendPaymentRequest(100000000 msat, paymentHash, nodes("F3").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams) val paymentSender = TestProbe() paymentSender.send(nodes("A").paymentInitiator, paymentReq) val paymentId = paymentSender.expectMsgType[UUID] @@ -688,7 +687,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val preimage: ByteVector = randomBytes32 val paymentHash = Crypto.sha256(preimage) // A sends a payment to F - val paymentReq = SendPayment(100000000 msat, paymentHash, nodes("F4").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams) + val paymentReq = SendPaymentRequest(100000000 msat, paymentHash, nodes("F4").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams) val paymentSender = TestProbe() paymentSender.send(nodes("A").paymentInitiator, paymentReq) val paymentId = paymentSender.expectMsgType[UUID](30 seconds) @@ -762,7 +761,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val amountMsat = 300000000.msat sender.send(paymentHandlerF, ReceivePayment(Some(amountMsat), "1 coffee")) val pr = sender.expectMsgType[PaymentRequest] - val sendReq = SendPayment(300000000 msat, pr.paymentHash, pr.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 1) + val sendReq = SendPaymentRequest(300000000 msat, pr.paymentHash, pr.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 1) sender.send(nodes("A").paymentInitiator, sendReq) val paymentId = sender.expectMsgType[UUID] // we forward the htlc to the payment handler @@ -776,7 +775,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService def send(amountMsat: MilliSatoshi, paymentHandler: ActorRef, paymentInitiator: ActorRef) = { sender.send(paymentHandler, ReceivePayment(Some(amountMsat), "1 coffee")) val pr = sender.expectMsgType[PaymentRequest] - val sendReq = SendPayment(amountMsat, pr.paymentHash, pr.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 1) + val sendReq = SendPaymentRequest(amountMsat, pr.paymentHash, pr.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 1) sender.send(paymentInitiator, sendReq) sender.expectMsgType[UUID] } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/ChannelSelectionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/ChannelSelectionSpec.scala index 0599412083..e2bbb33201 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/ChannelSelectionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/ChannelSelectionSpec.scala @@ -22,6 +22,7 @@ import fr.acinq.eclair.channel.{CMD_ADD_HTLC, CMD_FAIL_HTLC} import fr.acinq.eclair.payment.HtlcGenerationSpec.makeCommitments import fr.acinq.eclair.payment.Relayer.{OutgoingChannel, RelayFailure, RelayPayload, RelaySuccess} import fr.acinq.eclair.router.Announcements +import fr.acinq.eclair.wire.Onion.RelayLegacyPayload import fr.acinq.eclair.wire._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, MilliSatoshi, ShortChannelId, TestConstants, randomBytes32, randomKey} import org.scalatest.FunSuite @@ -30,6 +31,8 @@ import scala.collection.mutable class ChannelSelectionSpec extends FunSuite { + implicit val log: akka.event.LoggingAdapter = akka.event.NoLogging + /** * This is just a simplified helper function with random values for fields we are not using here */ @@ -37,42 +40,41 @@ class ChannelSelectionSpec extends FunSuite { Announcements.makeChannelUpdate(Block.RegtestGenesisBlock.hash, randomKey, randomKey.publicKey, shortChannelId, cltvExpiryDelta, htlcMinimumMsat, feeBaseMsat, feeProportionalMillionths, htlcMaximumMsat, enable) test("convert to CMD_FAIL_HTLC/CMD_ADD_HTLC") { + val onionPayload = RelayLegacyPayload(ShortChannelId(12345), 998900 msat, CltvExpiry(60)) val relayPayload = RelayPayload( add = UpdateAddHtlc(randomBytes32, 42, 1000000 msat, randomBytes32, CltvExpiry(70), TestConstants.emptyOnionPacket), - payload = PerHopPayload(ShortChannelId(12345), amtToForward = 998900 msat, outgoingCltvValue = CltvExpiry(60)), + payload = onionPayload, nextPacket = TestConstants.emptyOnionPacket // just a placeholder ) val channelUpdate = dummyUpdate(ShortChannelId(12345), CltvExpiryDelta(10), 100 msat, 1000 msat, 100, 10000000 msat, true) - implicit val log = akka.event.NoLogging - // nominal case - assert(Relayer.relayOrFail(relayPayload, Some(channelUpdate)) === RelaySuccess(ShortChannelId(12345), CMD_ADD_HTLC(relayPayload.payload.amtToForward, relayPayload.add.paymentHash, relayPayload.payload.outgoingCltvValue, relayPayload.nextPacket, upstream = Right(relayPayload.add), commit = true))) + assert(Relayer.relayOrFail(relayPayload, Some(channelUpdate)) === RelaySuccess(ShortChannelId(12345), CMD_ADD_HTLC(relayPayload.payload.amountToForward, relayPayload.add.paymentHash, relayPayload.payload.outgoingCltv, relayPayload.nextPacket, upstream = Right(relayPayload.add), commit = true))) // no channel_update assert(Relayer.relayOrFail(relayPayload, channelUpdate_opt = None) === RelayFailure(CMD_FAIL_HTLC(relayPayload.add.id, Right(UnknownNextPeer), commit = true))) // channel disabled - val channelUpdate_disabled = channelUpdate.copy(channelFlags = Announcements.makeChannelFlags(true, enable = false)) + val channelUpdate_disabled = channelUpdate.copy(channelFlags = Announcements.makeChannelFlags(isNode1 = true, enable = false)) assert(Relayer.relayOrFail(relayPayload, Some(channelUpdate_disabled)) === RelayFailure(CMD_FAIL_HTLC(relayPayload.add.id, Right(ChannelDisabled(channelUpdate_disabled.messageFlags, channelUpdate_disabled.channelFlags, channelUpdate_disabled)), commit = true))) // amount too low - val relayPayload_toolow = relayPayload.copy(payload = relayPayload.payload.copy(amtToForward = 99 msat)) - assert(Relayer.relayOrFail(relayPayload_toolow, Some(channelUpdate)) === RelayFailure(CMD_FAIL_HTLC(relayPayload.add.id, Right(AmountBelowMinimum(relayPayload_toolow.payload.amtToForward, channelUpdate)), commit = true))) + val relayPayload_toolow = relayPayload.copy(payload = onionPayload.copy(amountToForward = 99 msat)) + assert(Relayer.relayOrFail(relayPayload_toolow, Some(channelUpdate)) === RelayFailure(CMD_FAIL_HTLC(relayPayload.add.id, Right(AmountBelowMinimum(relayPayload_toolow.payload.amountToForward, channelUpdate)), commit = true))) // incorrect cltv expiry - val relayPayload_incorrectcltv = relayPayload.copy(payload = relayPayload.payload.copy(outgoingCltvValue = CltvExpiry(42))) - assert(Relayer.relayOrFail(relayPayload_incorrectcltv, Some(channelUpdate)) === RelayFailure(CMD_FAIL_HTLC(relayPayload.add.id, Right(IncorrectCltvExpiry(relayPayload_incorrectcltv.payload.outgoingCltvValue, channelUpdate)), commit = true))) + val relayPayload_incorrectcltv = relayPayload.copy(payload = onionPayload.copy(outgoingCltv = CltvExpiry(42))) + assert(Relayer.relayOrFail(relayPayload_incorrectcltv, Some(channelUpdate)) === RelayFailure(CMD_FAIL_HTLC(relayPayload.add.id, Right(IncorrectCltvExpiry(relayPayload_incorrectcltv.payload.outgoingCltv, channelUpdate)), commit = true))) // insufficient fee - val relayPayload_insufficientfee = relayPayload.copy(payload = relayPayload.payload.copy(amtToForward = 998910 msat)) + val relayPayload_insufficientfee = relayPayload.copy(payload = onionPayload.copy(amountToForward = 998910 msat)) assert(Relayer.relayOrFail(relayPayload_insufficientfee, Some(channelUpdate)) === RelayFailure(CMD_FAIL_HTLC(relayPayload.add.id, Right(FeeInsufficient(relayPayload_insufficientfee.add.amountMsat, channelUpdate)), commit = true))) // note that a generous fee is ok! - val relayPayload_highfee = relayPayload.copy(payload = relayPayload.payload.copy(amtToForward = 900000 msat)) - assert(Relayer.relayOrFail(relayPayload_highfee, Some(channelUpdate)) === RelaySuccess(ShortChannelId(12345), CMD_ADD_HTLC(relayPayload_highfee.payload.amtToForward, relayPayload_highfee.add.paymentHash, relayPayload_highfee.payload.outgoingCltvValue, relayPayload_highfee.nextPacket, upstream = Right(relayPayload.add), commit = true))) + val relayPayload_highfee = relayPayload.copy(payload = onionPayload.copy(amountToForward = 900000 msat)) + assert(Relayer.relayOrFail(relayPayload_highfee, Some(channelUpdate)) === RelaySuccess(ShortChannelId(12345), CMD_ADD_HTLC(relayPayload_highfee.payload.amountToForward, relayPayload_highfee.add.paymentHash, relayPayload_highfee.payload.outgoingCltv, relayPayload_highfee.nextPacket, upstream = Right(relayPayload.add), commit = true))) } test("channel selection") { - + val onionPayload = RelayLegacyPayload(ShortChannelId(12345), 998900 msat, CltvExpiry(60)) val relayPayload = RelayPayload( add = UpdateAddHtlc(randomBytes32, 42, 1000000 msat, randomBytes32, CltvExpiry(70), TestConstants.emptyOnionPacket), - payload = PerHopPayload(ShortChannelId(12345), amtToForward = 998900 msat, outgoingCltvValue = CltvExpiry(60)), + payload = onionPayload, nextPacket = TestConstants.emptyOnionPacket // just a placeholder ) @@ -91,10 +93,6 @@ class ChannelSelectionSpec extends FunSuite { node2channels.put(a, mutable.Set(ShortChannelId(12345), ShortChannelId(11111), ShortChannelId(22222), ShortChannelId(33333))) node2channels.put(b, mutable.Set(ShortChannelId(44444))) - implicit val log = akka.event.NoLogging - - import com.softwaremill.quicklens._ - // select the channel to the same node, with the lowest balance but still high enough to handle the payment assert(Relayer.selectPreferredChannel(relayPayload, channelUpdates, node2channels, Seq.empty) === Some(ShortChannelId(22222))) // select 2nd-to-best channel @@ -104,13 +102,13 @@ class ChannelSelectionSpec extends FunSuite { // all the suitable channels have been tried assert(Relayer.selectPreferredChannel(relayPayload, channelUpdates, node2channels, Seq(ShortChannelId(22222), ShortChannelId(12345), ShortChannelId(11111))) === None) // higher amount payment (have to increased incoming htlc amount for fees to be sufficient) - assert(Relayer.selectPreferredChannel(relayPayload.modify(_.add.amountMsat).setTo(60000000 msat).modify(_.payload.amtToForward).setTo(50000000 msat), channelUpdates, node2channels, Seq.empty) === Some(ShortChannelId(11111))) + assert(Relayer.selectPreferredChannel(relayPayload.copy(add = relayPayload.add.copy(amountMsat = 60000000 msat), payload = onionPayload.copy(amountToForward = 50000000 msat)), channelUpdates, node2channels, Seq.empty) === Some(ShortChannelId(11111))) // lower amount payment - assert(Relayer.selectPreferredChannel(relayPayload.modify(_.payload.amtToForward).setTo(1000 msat), channelUpdates, node2channels, Seq.empty) === Some(ShortChannelId(33333))) + assert(Relayer.selectPreferredChannel(relayPayload.copy(payload = onionPayload.copy(amountToForward = 1000 msat)), channelUpdates, node2channels, Seq.empty) === Some(ShortChannelId(33333))) // payment too high, no suitable channel found - assert(Relayer.selectPreferredChannel(relayPayload.modify(_.payload.amtToForward).setTo(1000000000 msat), channelUpdates, node2channels, Seq.empty) === Some(ShortChannelId(12345))) + assert(Relayer.selectPreferredChannel(relayPayload.copy(payload = onionPayload.copy(amountToForward = 1000000000 msat)), channelUpdates, node2channels, Seq.empty) === Some(ShortChannelId(12345))) // invalid cltv expiry, no suitable channel, we keep the requested one - assert(Relayer.selectPreferredChannel(relayPayload.modify(_.payload.outgoingCltvValue).setTo(CltvExpiry(40)), channelUpdates, node2channels, Seq.empty) === Some(ShortChannelId(12345))) + assert(Relayer.selectPreferredChannel(relayPayload.copy(payload = onionPayload.copy(outgoingCltv = CltvExpiry(40))), channelUpdates, node2channels, Seq.empty) === Some(ShortChannelId(12345))) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala index 0a82253054..6361c3813c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala @@ -25,16 +25,22 @@ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.crypto.Sphinx.{DecryptedPacket, PacketAndSecrets} import fr.acinq.eclair.payment.PaymentLifecycle._ import fr.acinq.eclair.router.Hop -import fr.acinq.eclair.wire.{ChannelUpdate, OnionCodecs, PerHopPayload} -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, MilliSatoshi, ShortChannelId, TestConstants, nodeFee, randomBytes32} -import org.scalatest.FunSuite +import fr.acinq.eclair.wire.Onion.{FinalLegacyPayload, FinalTlvPayload, PerHopPayload, RelayLegacyPayload} +import fr.acinq.eclair.wire.OnionTlv.{AmountToForward, OutgoingCltv} +import fr.acinq.eclair.wire._ +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Globals, LongToBtcAmount, MilliSatoshi, ShortChannelId, TestConstants, nodeFee, randomBytes32} +import org.scalatest.{BeforeAndAfterAll, FunSuite} import scodec.bits.ByteVector /** * Created by PM on 31/05/2016. */ -class HtlcGenerationSpec extends FunSuite { +class HtlcGenerationSpec extends FunSuite with BeforeAndAfterAll { + + override def beforeAll { + Globals.blockCount.set(HtlcGenerationSpec.currentBlockCount) + } test("compute fees") { val feeBaseMsat = 150000 msat @@ -49,89 +55,80 @@ class HtlcGenerationSpec extends FunSuite { import HtlcGenerationSpec._ test("compute payloads with fees and expiry delta") { - - val (firstAmountMsat, firstExpiry, payloads) = buildPayloads(finalAmountMsat, finalExpiry, hops.drop(1)) + val (firstAmountMsat, firstExpiry, payloads) = buildPayloads(hops.drop(1), FinalLegacyPayload(finalAmountMsat, finalExpiry)) + val expectedPayloads = Seq[PerHopPayload]( + RelayLegacyPayload(channelUpdate_bc.shortChannelId, amount_bc, expiry_bc), + RelayLegacyPayload(channelUpdate_cd.shortChannelId, amount_cd, expiry_cd), + RelayLegacyPayload(channelUpdate_de.shortChannelId, amount_de, expiry_de), + FinalLegacyPayload(finalAmountMsat, finalExpiry)) assert(firstAmountMsat === amount_ab) assert(firstExpiry === expiry_ab) - assert(payloads === - PerHopPayload(channelUpdate_bc.shortChannelId, amount_bc, expiry_bc) :: - PerHopPayload(channelUpdate_cd.shortChannelId, amount_cd, expiry_cd) :: - PerHopPayload(channelUpdate_de.shortChannelId, amount_de, expiry_de) :: - PerHopPayload(ShortChannelId(0L), finalAmountMsat, finalExpiry) :: Nil) + assert(payloads === expectedPayloads) } - test("build onion") { - - val (_, _, payloads) = buildPayloads(finalAmountMsat, finalExpiry, hops.drop(1)) + def testBuildOnion(legacy: Boolean): Unit = { + val finalPayload = if (legacy) { + FinalLegacyPayload(finalAmountMsat, finalExpiry) + } else { + FinalTlvPayload(TlvStream[OnionTlv](AmountToForward(finalAmountMsat), OutgoingCltv(finalExpiry))) + } + val (_, _, payloads) = buildPayloads(hops.drop(1), finalPayload) val nodes = hops.map(_.nextNodeId) val PacketAndSecrets(packet_b, _) = buildOnion(nodes, payloads, paymentHash) assert(packet_b.payload.length === Sphinx.PaymentPacket.PayloadLength) // let's peel the onion + testPeelOnion(packet_b) + } + + def testPeelOnion(packet_b: OnionRoutingPacket): Unit = { val Right(DecryptedPacket(bin_b, packet_c, _)) = Sphinx.PaymentPacket.peel(priv_b.privateKey, paymentHash, packet_b) - val payload_b = OnionCodecs.perHopPayloadCodec.decode(bin_b.toBitVector).require.value + val payload_b = OnionCodecs.relayPerHopPayloadCodec.decode(bin_b.toBitVector).require.value assert(packet_c.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_b.amtToForward === amount_bc) - assert(payload_b.outgoingCltvValue === expiry_bc) + assert(payload_b.amountToForward === amount_bc) + assert(payload_b.outgoingCltv === expiry_bc) val Right(DecryptedPacket(bin_c, packet_d, _)) = Sphinx.PaymentPacket.peel(priv_c.privateKey, paymentHash, packet_c) - val payload_c = OnionCodecs.perHopPayloadCodec.decode(bin_c.toBitVector).require.value + val payload_c = OnionCodecs.relayPerHopPayloadCodec.decode(bin_c.toBitVector).require.value assert(packet_d.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_c.amtToForward === amount_cd) - assert(payload_c.outgoingCltvValue === expiry_cd) + assert(payload_c.amountToForward === amount_cd) + assert(payload_c.outgoingCltv === expiry_cd) val Right(DecryptedPacket(bin_d, packet_e, _)) = Sphinx.PaymentPacket.peel(priv_d.privateKey, paymentHash, packet_d) - val payload_d = OnionCodecs.perHopPayloadCodec.decode(bin_d.toBitVector).require.value + val payload_d = OnionCodecs.relayPerHopPayloadCodec.decode(bin_d.toBitVector).require.value assert(packet_e.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_d.amtToForward === amount_de) - assert(payload_d.outgoingCltvValue === expiry_de) + assert(payload_d.amountToForward === amount_de) + assert(payload_d.outgoingCltv === expiry_de) val Right(DecryptedPacket(bin_e, packet_random, _)) = Sphinx.PaymentPacket.peel(priv_e.privateKey, paymentHash, packet_e) - val payload_e = OnionCodecs.perHopPayloadCodec.decode(bin_e.toBitVector).require.value + val payload_e = OnionCodecs.finalPerHopPayloadCodec.decode(bin_e.toBitVector).require.value assert(packet_random.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_e.amtToForward === finalAmountMsat) - assert(payload_e.outgoingCltvValue === finalExpiry) + assert(payload_e.amount === finalAmountMsat) + assert(payload_e.expiry === finalExpiry) } - test("build a command including the onion") { + test("build onion with final legacy payload") { + testBuildOnion(legacy = true) + } - val (add, _) = buildCommand(UUID.randomUUID, finalAmountMsat, finalExpiry, paymentHash, hops) + test("build onion with final tlv payload") { + testBuildOnion(legacy = false) + } + test("build a command including the onion") { + val (add, _) = buildCommand(UUID.randomUUID, paymentHash, hops, FinalLegacyPayload(finalAmountMsat, finalExpiry)) assert(add.amount > finalAmountMsat) assert(add.cltvExpiry === finalExpiry + channelUpdate_de.cltvExpiryDelta + channelUpdate_cd.cltvExpiryDelta + channelUpdate_bc.cltvExpiryDelta) assert(add.paymentHash === paymentHash) assert(add.onion.payload.length === Sphinx.PaymentPacket.PayloadLength) // let's peel the onion - val Right(DecryptedPacket(bin_b, packet_c, _)) = Sphinx.PaymentPacket.peel(priv_b.privateKey, paymentHash, add.onion) - val payload_b = OnionCodecs.perHopPayloadCodec.decode(bin_b.toBitVector).require.value - assert(packet_c.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_b.amtToForward === amount_bc) - assert(payload_b.outgoingCltvValue === expiry_bc) - - val Right(DecryptedPacket(bin_c, packet_d, _)) = Sphinx.PaymentPacket.peel(priv_c.privateKey, paymentHash, packet_c) - val payload_c = OnionCodecs.perHopPayloadCodec.decode(bin_c.toBitVector).require.value - assert(packet_d.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_c.amtToForward === amount_cd) - assert(payload_c.outgoingCltvValue === expiry_cd) - - val Right(DecryptedPacket(bin_d, packet_e, _)) = Sphinx.PaymentPacket.peel(priv_d.privateKey, paymentHash, packet_d) - val payload_d = OnionCodecs.perHopPayloadCodec.decode(bin_d.toBitVector).require.value - assert(packet_e.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_d.amtToForward === amount_de) - assert(payload_d.outgoingCltvValue === expiry_de) - - val Right(DecryptedPacket(bin_e, packet_random, _)) = Sphinx.PaymentPacket.peel(priv_e.privateKey, paymentHash, packet_e) - val payload_e = OnionCodecs.perHopPayloadCodec.decode(bin_e.toBitVector).require.value - assert(packet_random.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_e.amtToForward === finalAmountMsat) - assert(payload_e.outgoingCltvValue === finalExpiry) + testPeelOnion(add.onion) } test("build a command with no hops") { - val (add, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops.take(1)) - + val (add, _) = buildCommand(UUID.randomUUID(), paymentHash, hops.take(1), FinalLegacyPayload(finalAmountMsat, finalExpiry)) assert(add.amount === finalAmountMsat) assert(add.cltvExpiry === finalExpiry) assert(add.paymentHash === paymentHash) @@ -139,10 +136,10 @@ class HtlcGenerationSpec extends FunSuite { // let's peel the onion val Right(DecryptedPacket(bin_b, packet_random, _)) = Sphinx.PaymentPacket.peel(priv_b.privateKey, paymentHash, add.onion) - val payload_b = OnionCodecs.perHopPayloadCodec.decode(bin_b.toBitVector).require.value + val payload_b = OnionCodecs.relayPerHopPayloadCodec.decode(bin_b.toBitVector).require.value assert(packet_random.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_b.amtToForward === finalAmountMsat) - assert(payload_b.outgoingCltvValue === finalExpiry) + assert(payload_b.amountToForward === finalAmountMsat) + assert(payload_b.outgoingCltv === finalExpiry) } } @@ -175,12 +172,12 @@ object HtlcGenerationSpec { Hop(d, e, channelUpdate_de) :: Nil val finalAmountMsat = 42000000 msat - val currentBlockCount = 420000 + val currentBlockCount = 400000 val finalExpiry = CltvExpiry(currentBlockCount) + Channel.MIN_CLTV_EXPIRY_DELTA val paymentPreimage = randomBytes32 val paymentHash = Crypto.sha256(paymentPreimage) - val expiry_de = CltvExpiry(currentBlockCount) + Channel.MIN_CLTV_EXPIRY_DELTA + val expiry_de = finalExpiry val amount_de = finalAmountMsat val fee_d = nodeFee(channelUpdate_de.feeBaseMsat, channelUpdate_de.feeProportionalMillionths, amount_de) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala new file mode 100644 index 0000000000..a059dcc19c --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala @@ -0,0 +1,61 @@ +/* + * Copyright 2019 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.payment + +import java.util.UUID + +import akka.actor.ActorSystem +import akka.testkit.{TestKit, TestProbe} +import fr.acinq.eclair.payment.HtlcGenerationSpec._ +import fr.acinq.eclair.payment.PaymentRequest.ExtraHop +import fr.acinq.eclair.router.{FinalizeRoute, RouteParams, RouteRequest} +import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, TestConstants} +import org.scalatest.FunSuiteLike + +/** + * Created by t-bast on 25/07/2019. + */ + +class PaymentInitiatorSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { + + test("forward payment with pre-defined route") { + val sender = TestProbe() + val router = TestProbe() + val paymentInitiator = system.actorOf(PaymentInitiator.props(TestConstants.Alice.nodeParams, router.ref, TestProbe().ref)) + + sender.send(paymentInitiator, PaymentInitiator.SendPaymentRequest(finalAmountMsat, paymentHash, c, 1, predefinedRoute = Seq(a, b, c))) + sender.expectMsgType[UUID] + router.expectMsg(FinalizeRoute(Seq(a, b, c))) + } + + test("forward legacy payment") { + val sender = TestProbe() + val router = TestProbe() + val paymentInitiator = system.actorOf(PaymentInitiator.props(TestConstants.Alice.nodeParams, router.ref, TestProbe().ref)) + + val hints = Seq(Seq(ExtraHop(b, channelUpdate_bc.shortChannelId, feeBase = 10 msat, feeProportionalMillionths = 1, cltvExpiryDelta = CltvExpiryDelta(12)))) + val routeParams = RouteParams(randomize = true, 15 msat, 1.5, 5, CltvExpiryDelta(561), None) + sender.send(paymentInitiator, PaymentInitiator.SendPaymentRequest(finalAmountMsat, paymentHash, c, 1, CltvExpiryDelta(42), assistedRoutes = hints, routeParams = Some(routeParams))) + sender.expectMsgType[UUID] + router.expectMsg(RouteRequest(TestConstants.Alice.nodeParams.nodeId, c, finalAmountMsat, assistedRoutes = hints, routeParams = Some(routeParams))) + + sender.send(paymentInitiator, PaymentInitiator.SendPaymentRequest(finalAmountMsat, paymentHash, e, 3)) + sender.expectMsgType[UUID] + router.expectMsg(RouteRequest(TestConstants.Alice.nodeParams.nodeId, e, finalAmountMsat)) + } + +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index b36f29d777..957dbdef11 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -26,7 +26,7 @@ import fr.acinq.bitcoin.{Block, ByteVector32, Transaction, TxOut} import fr.acinq.eclair._ import fr.acinq.eclair.blockchain.{UtxoStatus, ValidateRequest, ValidateResult, WatchSpentBasic} import fr.acinq.eclair.channel.Register.ForwardShortId -import fr.acinq.eclair.channel.{AddHtlcFailed, ChannelUnavailable} +import fr.acinq.eclair.channel.{AddHtlcFailed, Channel, ChannelUnavailable} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.db.OutgoingPaymentStatus import fr.acinq.eclair.io.Peer.PeerRoutingMessage @@ -34,7 +34,9 @@ import fr.acinq.eclair.payment.PaymentLifecycle._ import fr.acinq.eclair.router.Announcements.{makeChannelUpdate, makeNodeAnnouncement} import fr.acinq.eclair.router._ import fr.acinq.eclair.transactions.Scripts +import fr.acinq.eclair.wire.Onion.FinalLegacyPayload import fr.acinq.eclair.wire._ +import scodec.bits.HexStringSyntax /** * Created by PM on 29/08/2016. @@ -43,6 +45,7 @@ import fr.acinq.eclair.wire._ class PaymentLifecycleSpec extends BaseRouterSpec { val defaultAmountMsat = 142000000 msat + val defaultExpiry = Channel.MIN_CLTV_EXPIRY_DELTA.toCltvExpiry test("send to route") { fixture => import fixture._ @@ -60,7 +63,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) // pre-computed route going from A to D - val request = SendPaymentToRoute(defaultAmountMsat, defaultPaymentHash, Seq(a, b, c, d)) + val request = SendPaymentToRoute(defaultPaymentHash, Seq(a, b, c, d), FinalLegacyPayload(defaultAmountMsat, defaultExpiry)) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) @@ -86,7 +89,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultAmountMsat, defaultPaymentHash, f, maxAttempts = 5) + val request = SendPayment(defaultPaymentHash, f, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val routeRequest = routerForwarder.expectMsgType[RouteRequest] @@ -109,7 +112,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultAmountMsat, randomBytes32, d, routeParams = Some(RouteParams(randomize = false, maxFeeBase = 100 msat, maxFeePct = 0.0, routeMaxLength = 20, routeMaxCltv = CltvExpiryDelta(2016), ratios = None)), maxAttempts = 5) + val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5, routeParams = Some(RouteParams(randomize = false, maxFeeBase = 100 msat, maxFeePct = 0.0, routeMaxLength = 20, routeMaxCltv = CltvExpiryDelta(2016), ratios = None))) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) @@ -132,7 +135,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultAmountMsat, defaultPaymentHash, d, maxAttempts = 2) + val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) @@ -175,7 +178,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultAmountMsat, randomBytes32, d, maxAttempts = 2) + val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) @@ -208,7 +211,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultAmountMsat, defaultPaymentHash, d, maxAttempts = 2) + val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) @@ -239,7 +242,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultAmountMsat, randomBytes32, d, maxAttempts = 2) + val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData @@ -279,7 +282,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultAmountMsat, randomBytes32, d, maxAttempts = 5) + val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) @@ -341,7 +344,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultAmountMsat, randomBytes32, d, maxAttempts = 2) + val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) @@ -389,7 +392,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultAmountMsat, defaultPaymentHash, d, maxAttempts = 5) + val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val Transition(_, WAITING_FOR_ROUTE, WAITING_FOR_PAYMENT_COMPLETE) = monitor.expectMsgClass(classOf[Transition[_]]) @@ -397,9 +400,9 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) val paymentOK = sender.expectMsgType[PaymentSucceeded] - val PaymentSent(_, request.amount, fee, request.paymentHash, paymentOK.paymentPreimage, _, _) = eventListener.expectMsgType[PaymentSent] + val PaymentSent(_, request.finalPayload.amount, fee, request.paymentHash, paymentOK.paymentPreimage, _, _) = eventListener.expectMsgType[PaymentSent] assert(fee > 0.msat) - assert(fee === paymentOK.amount - request.amount) + assert(fee === paymentOK.amount - request.finalPayload.amount) awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.SUCCEEDED)) } @@ -414,7 +417,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val (priv_g, priv_funding_g) = (randomKey, randomKey) val (g, funding_g) = (priv_g.publicKey, priv_funding_g.publicKey) - val ann_g = makeNodeAnnouncement(priv_g, "node-G", Color(-30, 10, -50), Nil) + val ann_g = makeNodeAnnouncement(priv_g, "node-G", Color(-30, 10, -50), Nil, hex"0200") val channelId_bg = ShortChannelId(420000, 5, 0) val chan_bg = channelAnnouncement(channelId_bg, priv_b, priv_g, priv_funding_b, priv_funding_g) val channelUpdate_bg = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, g, channelId_bg, CltvExpiryDelta(9), htlcMinimumMsat = 0 msat, feeBaseMsat = 0 msat, feeProportionalMillionths = 0, htlcMaximumMsat = 500000000 msat) @@ -439,7 +442,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) // we send a payment to G which is just after the - val request = SendPayment(defaultAmountMsat, defaultPaymentHash, g, maxAttempts = 5) + val request = SendPayment(defaultPaymentHash, g, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5) sender.send(paymentFSM, request) // the route will be A -> B -> G where B -> G has a channel_update with fees=0 @@ -449,13 +452,13 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) val paymentOK = sender.expectMsgType[PaymentSucceeded] - val PaymentSent(_, request.amount, fee, request.paymentHash, paymentOK.paymentPreimage, _, _) = eventListener.expectMsgType[PaymentSent] + val PaymentSent(_, request.finalPayload.amount, fee, request.paymentHash, paymentOK.paymentPreimage, _, _) = eventListener.expectMsgType[PaymentSent] // during the route computation the fees were treated as if they were 1msat but when sending the onion we actually put zero // NB: A -> B doesn't pay fees because it's our direct neighbor // NB: B -> G doesn't asks for fees at all assert(fee === 0.msat) - assert(fee === paymentOK.amount - request.amount) + assert(fee === paymentOK.amount - request.finalPayload.amount) } test("filter errors properly") { _ => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala index 534a6101b2..299ad3058f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala @@ -21,13 +21,16 @@ import java.util.UUID import akka.actor.{ActorRef, Status} import akka.testkit.TestProbe import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.payment.PaymentLifecycle.buildCommand +import fr.acinq.eclair.payment.PaymentLifecycle.{buildCommand, buildOnion} import fr.acinq.eclair.router.Announcements +import fr.acinq.eclair.wire.Onion.{FinalLegacyPayload, FinalTlvPayload, PerHopPayload, RelayTlvPayload} import fr.acinq.eclair.wire._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, ShortChannelId, TestConstants, TestkitBaseClass, UInt64, randomBytes32} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, ShortChannelId, TestConstants, TestkitBaseClass, UInt64, nodeFee, randomBytes32, randomKey} import org.scalatest.Outcome +import scodec.Attempt import scodec.bits.ByteVector import scala.concurrent.duration._ @@ -41,6 +44,7 @@ class RelayerSpec extends TestkitBaseClass { // let's reuse the existing test data import HtlcGenerationSpec._ + import RelayerSpec._ case class FixtureParam(relayer: ActorRef, register: TestProbe, paymentHandler: TestProbe) @@ -62,7 +66,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -71,6 +75,37 @@ class RelayerSpec extends TestkitBaseClass { val fwd = register.expectMsgType[Register.ForwardShortId[CMD_ADD_HTLC]] assert(fwd.shortChannelId === channelUpdate_bc.shortChannelId) + assert(fwd.message.amount === amount_bc) + assert(fwd.message.cltvExpiry === expiry_bc) + assert(fwd.message.upstream === Right(add_ab)) + + sender.expectNoMsg(100 millis) + paymentHandler.expectNoMsg(100 millis) + } + + test("relay an htlc-add with onion tlv payload") { f => + import f._ + import fr.acinq.eclair.wire.OnionTlv._ + val sender = TestProbe() + + // Use tlv payloads for all hops (final and intermediate) + val finalPayload: Seq[PerHopPayload] = FinalTlvPayload(TlvStream[OnionTlv](AmountToForward(finalAmountMsat), OutgoingCltv(finalExpiry))) :: Nil + val (firstAmountMsat, firstExpiry, payloads) = hops.drop(1).reverse.foldLeft((finalAmountMsat, finalExpiry, finalPayload)) { + case ((amountMsat, expiry, currentPayloads), hop) => + val nextFee = nodeFee(hop.lastUpdate.feeBaseMsat, hop.lastUpdate.feeProportionalMillionths, amountMsat) + val payload = RelayTlvPayload(TlvStream[OnionTlv](AmountToForward(amountMsat), OutgoingCltv(expiry), OutgoingChannelId(hop.lastUpdate.shortChannelId))) + (amountMsat + nextFee, expiry + hop.lastUpdate.cltvExpiryDelta, payload +: currentPayloads) + } + val Sphinx.PacketAndSecrets(onion, _) = buildOnion(hops.map(_.nextNodeId), payloads, paymentHash) + val add_ab = UpdateAddHtlc(channelId_ab, 123456, firstAmountMsat, paymentHash, firstExpiry, onion) + relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) + + sender.send(relayer, ForwardAdd(add_ab)) + + val fwd = register.expectMsgType[Register.ForwardShortId[CMD_ADD_HTLC]] + assert(fwd.shortChannelId === channelUpdate_bc.shortChannelId) + assert(fwd.message.amount === amount_bc) + assert(fwd.message.cltvExpiry === expiry_bc) assert(fwd.message.upstream === Right(add_ab)) sender.expectNoMsg(100 millis) @@ -82,7 +117,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) @@ -122,12 +157,27 @@ class RelayerSpec extends TestkitBaseClass { paymentHandler.expectNoMsg(100 millis) } + test("relay an htlc-add at the final node to the payment handler") { f => + import f._ + val sender = TestProbe() + + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops.take(1), FinalLegacyPayload(finalAmountMsat, finalExpiry)) + val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) + sender.send(relayer, ForwardAdd(add_ab)) + + val htlc = paymentHandler.expectMsgType[UpdateAddHtlc] + assert(htlc === add_ab) + + sender.expectNoMsg(100 millis) + register.expectNoMsg(100 millis) + } + test("fail to relay an htlc-add when we have no channel_update for the next channel") { f => import f._ val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) @@ -147,7 +197,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -174,7 +224,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // check that payments are sent properly - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, FinalLegacyPayload(finalAmountMsat, finalExpiry)) val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -190,7 +240,7 @@ class RelayerSpec extends TestkitBaseClass { // now tell the relayer that the channel is down and try again relayer ! LocalChannelDown(sender.ref, channelId = channelId_bc, shortChannelId = channelUpdate_bc.shortChannelId, remoteNodeId = TestConstants.Bob.nodeParams.nodeId) - val (cmd1, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, randomBytes32, hops) + val (cmd1, _) = buildCommand(UUID.randomUUID(), randomBytes32, hops, FinalLegacyPayload(finalAmountMsat, finalExpiry)) val add_ab1 = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd1.amount, cmd1.paymentHash, cmd1.cltvExpiry, cmd1.onion) sender.send(relayer, ForwardAdd(add_ab1)) @@ -207,7 +257,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) val channelUpdate_bc_disabled = channelUpdate_bc.copy(channelFlags = Announcements.makeChannelFlags(Announcements.isNode1(channelUpdate_bc.channelFlags), enable = false)) @@ -228,7 +278,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc with an invalid onion (hmac) val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion.copy(hmac = cmd.onion.hmac.reverse)) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -244,12 +294,59 @@ class RelayerSpec extends TestkitBaseClass { paymentHandler.expectNoMsg(100 millis) } + test("fail to relay an htlc-add when the onion payload is missing data") { f => + import f._ + import fr.acinq.eclair.wire.OnionTlv._ + + // B is not the last hop and receives an onion missing some routing information. + val invalidPayloads_bc = Seq( + (InvalidOnionPayload(UInt64(2), 0), TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_bc.shortChannelId), OutgoingCltv(expiry_bc))), // Missing forwarding amount. + (InvalidOnionPayload(UInt64(4), 0), TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_bc.shortChannelId), AmountToForward(amount_bc))), // Missing cltv expiry. + (InvalidOnionPayload(UInt64(6), 0), TlvStream[OnionTlv](AmountToForward(amount_bc), OutgoingCltv(expiry_bc)))) // Missing channel id. + val payload_cd = TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_cd.shortChannelId), AmountToForward(amount_cd), OutgoingCltv(expiry_cd)) + + val sender = TestProbe() + relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) + + for ((expectedErr, invalidPayload_bc) <- invalidPayloads_bc) { + val onion = buildTlvOnion(Seq(b, c), Seq(invalidPayload_bc, payload_cd), paymentHash) + val add_ab = UpdateAddHtlc(channelId_ab, 123456, amount_ab, paymentHash, expiry_ab, onion) + sender.send(relayer, ForwardAdd(add_ab)) + + register.expectMsg(Register.Forward(channelId_ab, CMD_FAIL_HTLC(add_ab.id, Right(expectedErr), commit = true))) + register.expectNoMsg(100 millis) + paymentHandler.expectNoMsg(100 millis) + } + } + + test("fail to relay an htlc-add when variable length onion is disabled") { f => + import f._ + import fr.acinq.eclair.wire.OnionTlv._ + + val relayer = system.actorOf(Relayer.props(TestConstants.Bob.nodeParams.copy(globalFeatures = ByteVector.empty), register.ref, paymentHandler.ref)) + val sender = TestProbe() + relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) + + val payload_bc = TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_bc.shortChannelId), AmountToForward(amount_bc), OutgoingCltv(expiry_bc)) + val payload_cd = TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_cd.shortChannelId), AmountToForward(amount_cd), OutgoingCltv(expiry_cd)) + + val onion = buildTlvOnion(Seq(b, c), Seq(payload_bc, payload_cd), paymentHash) + val add_ab = UpdateAddHtlc(channelId_ab, 123456, amount_ab, paymentHash, expiry_ab, onion) + sender.send(relayer, ForwardAdd(add_ab)) + + register.expectMsg(Register.Forward(channelId_ab, CMD_FAIL_HTLC(add_ab.id, Right(InvalidRealm), commit = true))) + register.expectNoMsg(100 millis) + paymentHandler.expectNoMsg(100 millis) + } + test("fail to relay an htlc-add when amount is below the next hop's requirements") { f => import f._ val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), channelUpdate_bc.htlcMinimumMsat - (1 msat), finalExpiry, paymentHash, hops.map(hop => hop.copy(lastUpdate = hop.lastUpdate.copy(feeBaseMsat = 0 msat, feeProportionalMillionths = 0)))) + val finalPayload = FinalLegacyPayload(channelUpdate_bc.htlcMinimumMsat - (1 msat), finalExpiry) + val zeroFeeHops = hops.map(hop => hop.copy(lastUpdate = hop.lastUpdate.copy(feeBaseMsat = 0 msat, feeProportionalMillionths = 0))) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, zeroFeeHops, finalPayload) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -269,7 +366,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() val hops1 = hops.updated(1, hops(1).copy(lastUpdate = hops(1).lastUpdate.copy(cltvExpiryDelta = CltvExpiryDelta(0)))) - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops1) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops1, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -289,7 +386,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() val hops1 = hops.updated(1, hops(1).copy(lastUpdate = hops(1).lastUpdate.copy(feeBaseMsat = hops(1).lastUpdate.feeBaseMsat / 2))) - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops1) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops1, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -310,7 +407,7 @@ class RelayerSpec extends TestkitBaseClass { // to simulate this we use a zero-hop route A->B where A is the 'attacker' val hops1 = hops.head :: Nil - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops1) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops1, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc with a wrong expiry val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount - (1 msat), cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -331,7 +428,7 @@ class RelayerSpec extends TestkitBaseClass { // to simulate this we use a zero-hop route A->B where A is the 'attacker' val hops1 = hops.head :: Nil - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops1) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops1, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc with a wrong expiry val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry - CltvExpiryDelta(1), cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -346,6 +443,28 @@ class RelayerSpec extends TestkitBaseClass { paymentHandler.expectNoMsg(100 millis) } + test("fail an htlc-add at the final node when the onion payload is missing data") { f => + import f._ + import fr.acinq.eclair.wire.OnionTlv._ + + // B is the last hop and receives an onion missing some payment information. + val invalidFinalPayloads = Seq( + (InvalidOnionPayload(UInt64(2), 0), TlvStream[OnionTlv](OutgoingCltv(expiry_bc))), // Missing forwarding amount. + (InvalidOnionPayload(UInt64(4), 0), TlvStream[OnionTlv](AmountToForward(amount_bc)))) // Missing cltv expiry. + + val sender = TestProbe() + + for ((expectedErr, invalidFinalPayload) <- invalidFinalPayloads) { + val onion = buildTlvOnion(Seq(b), Seq(invalidFinalPayload), paymentHash) + val add_ab = UpdateAddHtlc(channelId_ab, 123456, amount_ab, paymentHash, expiry_ab, onion) + sender.send(relayer, ForwardAdd(add_ab)) + + register.expectMsg(Register.Forward(channelId_ab, CMD_FAIL_HTLC(add_ab.id, Right(expectedErr), commit = true))) + register.expectNoMsg(100 millis) + paymentHandler.expectNoMsg(100 millis) + } + } + test("correctly translates errors returned by channel when attempting to add an htlc") { f => import f._ val sender = TestProbe() @@ -445,3 +564,20 @@ class RelayerSpec extends TestkitBaseClass { assert(usableBalances5.size === 1) } } + +object RelayerSpec { + + /** Build onion from arbitrary tlv stream (potentially invalid). */ + def buildTlvOnion(nodes: Seq[PublicKey], payloads: Seq[TlvStream[OnionTlv]], associatedData: ByteVector32): OnionRoutingPacket = { + require(nodes.size == payloads.size) + val sessionKey = randomKey + val payloadsBin: Seq[ByteVector] = payloads + .map(OnionCodecs.tlvPerHopPayloadCodec.encode) + .map { + case Attempt.Successful(bitVector) => bitVector.toByteVector + case Attempt.Failure(cause) => throw new RuntimeException(s"serialization error: $cause") + } + Sphinx.PaymentPacket.create(sessionKey, nodes, payloadsBin, associatedData).packet + } + +} \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/AnnouncementsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/AnnouncementsSpec.scala index 2e01c4d516..4f628ede35 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/AnnouncementsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/AnnouncementsSpec.scala @@ -48,7 +48,7 @@ class AnnouncementsSpec extends FunSuite { } test("create valid signed node announcement") { - val ann = makeNodeAnnouncement(Alice.nodeParams.privateKey, Alice.nodeParams.alias, Alice.nodeParams.color, Alice.nodeParams.publicAddresses) + val ann = makeNodeAnnouncement(Alice.nodeParams.privateKey, Alice.nodeParams.alias, Alice.nodeParams.color, Alice.nodeParams.publicAddresses, Alice.nodeParams.globalFeatures) assert(Features.hasFeature(ann.features, Features.VARIABLE_LENGTH_ONION_OPTIONAL)) assert(checkSig(ann)) assert(checkSig(ann.copy(timestamp = 153)) === false) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala index 628cc489a3..c358490b47 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala @@ -30,7 +30,7 @@ import fr.acinq.eclair.transactions.Scripts import fr.acinq.eclair.wire._ import fr.acinq.eclair.{TestkitBaseClass, randomKey, _} import org.scalatest.Outcome -import scodec.bits.ByteVector +import scodec.bits.{ByteVector, HexStringSyntax} import scala.concurrent.duration._ @@ -55,12 +55,12 @@ abstract class BaseRouterSpec extends TestkitBaseClass { val (priv_funding_a, priv_funding_b, priv_funding_c, priv_funding_d, priv_funding_e, priv_funding_f) = (randomKey, randomKey, randomKey, randomKey, randomKey, randomKey) val (funding_a, funding_b, funding_c, funding_d, funding_e, funding_f) = (priv_funding_a.publicKey, priv_funding_b.publicKey, priv_funding_c.publicKey, priv_funding_d.publicKey, priv_funding_e.publicKey, priv_funding_f.publicKey) - val ann_a = makeNodeAnnouncement(priv_a, "node-A", Color(15, 10, -70), Nil) - val ann_b = makeNodeAnnouncement(priv_b, "node-B", Color(50, 99, -80), Nil) - val ann_c = makeNodeAnnouncement(priv_c, "node-C", Color(123, 100, -40), Nil) - val ann_d = makeNodeAnnouncement(priv_d, "node-D", Color(-120, -20, 60), Nil) - val ann_e = makeNodeAnnouncement(priv_e, "node-E", Color(-50, 0, 10), Nil) - val ann_f = makeNodeAnnouncement(priv_f, "node-F", Color(30, 10, -50), Nil) + val ann_a = makeNodeAnnouncement(priv_a, "node-A", Color(15, 10, -70), Nil, hex"0200") + val ann_b = makeNodeAnnouncement(priv_b, "node-B", Color(50, 99, -80), Nil, hex"") + val ann_c = makeNodeAnnouncement(priv_c, "node-C", Color(123, 100, -40), Nil, hex"0200") + val ann_d = makeNodeAnnouncement(priv_d, "node-D", Color(-120, -20, 60), Nil, hex"00") + val ann_e = makeNodeAnnouncement(priv_e, "node-E", Color(-50, 0, 10), Nil, hex"00") + val ann_f = makeNodeAnnouncement(priv_f, "node-F", Color(30, 10, -50), Nil, hex"00") val channelId_ab = ShortChannelId(420000, 1, 0) val channelId_bc = ShortChannelId(420000, 2, 0) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala index fefa68b5e4..9b1f6f24b4 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala @@ -30,13 +30,13 @@ import fr.acinq.eclair.router.BaseRouterSpec.channelAnnouncement import fr.acinq.eclair.transactions.Scripts import fr.acinq.eclair.wire._ import org.scalatest.FunSuiteLike +import scodec.bits.HexStringSyntax import scala.collection.immutable.TreeMap import scala.collection.{SortedSet, immutable, mutable} import scala.compat.Platform import scala.concurrent.duration._ - class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { import RoutingSyncSpec._ @@ -120,7 +120,7 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { } def countUpdates(channels: Map[ShortChannelId, PublicChannel]) = channels.values.foldLeft(0) { - case (count, pc) => count + pc.update_1_opt.map(_ => 1).getOrElse(0) + pc.update_2_opt.map(_ => 1).getOrElse(0) + case (count, pc) => count + pc.update_1_opt.map(_ => 1).getOrElse(0) + pc.update_2_opt.map(_ => 1).getOrElse(0) } test("sync with standard channel queries") { @@ -151,7 +151,7 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { // add some updates to bob and resync fakeRoutingInfo.take(10).values.foreach { - case (pc, na1, na2) => + case (pc, _, _) => sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_2_opt.get)) } awaitCond(bob.stateData.channels.size === 10 && countUpdates(bob.stateData.channels) === 10 * 2) @@ -172,7 +172,7 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { awaitCond(alice.stateData.channels === bob.stateData.channels, max = 60 seconds) } - def syncWithExtendedQueries(requestNodeAnnouncements: Boolean) = { + def syncWithExtendedQueries(requestNodeAnnouncements: Boolean): Unit = { val watcher = system.actorOf(Props(new YesWatcher())) val alice = TestFSMRef(new Router(Alice.nodeParams.copy(routerConf = Alice.nodeParams.routerConf.copy(requestNodeAnnouncements = requestNodeAnnouncements)), watcher)) val bob = TestFSMRef(new Router(Bob.nodeParams, watcher)) @@ -200,7 +200,7 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { // add some updates to bob and resync fakeRoutingInfo.take(10).values.foreach { - case (pc, na1, na2) => + case (pc, _, _) => sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_2_opt.get)) } awaitCond(bob.stateData.channels.size === 10 && countUpdates(bob.stateData.channels) === 10 * 2) @@ -319,8 +319,8 @@ object RoutingSyncSpec { val channelAnn_12 = channelAnnouncement(shortChannelId, priv1, priv2, priv_funding1, priv_funding2) val channelUpdate_12 = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv1, priv2.publicKey, shortChannelId, cltvExpiryDelta = CltvExpiryDelta(7), 0 msat, feeBaseMsat = 766000 msat, feeProportionalMillionths = 10, 500000000L msat, timestamp = timestamp) val channelUpdate_21 = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv2, priv1.publicKey, shortChannelId, cltvExpiryDelta = CltvExpiryDelta(7), 0 msat, feeBaseMsat = 766000 msat, feeProportionalMillionths = 10, 500000000L msat, timestamp = timestamp) - val nodeAnnouncement_1 = makeNodeAnnouncement(priv1, "", Color(0, 0, 0), List()) - val nodeAnnouncement_2 = makeNodeAnnouncement(priv2, "", Color(0, 0, 0), List()) + val nodeAnnouncement_1 = makeNodeAnnouncement(priv1, "a", Color(0, 0, 0), List(), hex"0200") + val nodeAnnouncement_2 = makeNodeAnnouncement(priv2, "b", Color(0, 0, 0), List(), hex"00") val publicChannel = PublicChannel(channelAnn_12, ByteVector32.Zeroes, Satoshi(0), Some(channelUpdate_12), Some(channelUpdate_21)) (publicChannel, nodeAnnouncement_1, nodeAnnouncement_2) } @@ -336,7 +336,7 @@ object RoutingSyncSpec { def makeFakeNodeAnnouncement(nodeId: PublicKey): NodeAnnouncement = { val priv = pub2priv(nodeId) - makeNodeAnnouncement(priv, "", Color(0, 0, 0), List()) + makeNodeAnnouncement(priv, "", Color(0, 0, 0), List(), hex"00") } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/FailureMessageCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/FailureMessageCodecsSpec.scala index 30ea734322..3925b55044 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/FailureMessageCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/FailureMessageCodecsSpec.scala @@ -19,7 +19,7 @@ package fr.acinq.eclair.wire import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64} import fr.acinq.eclair.crypto.Hmac256 import fr.acinq.eclair.wire.FailureMessageCodecs._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, MilliSatoshi, ShortChannelId, randomBytes32, randomBytes64} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, MilliSatoshi, ShortChannelId, UInt64, randomBytes32, randomBytes64} import org.scalatest.FunSuite import scodec.bits._ @@ -44,10 +44,10 @@ class FailureMessageCodecsSpec extends FunSuite { test("encode/decode all failure messages") { val msgs: List[FailureMessage] = InvalidRealm :: TemporaryNodeFailure :: PermanentNodeFailure :: RequiredNodeFeatureMissing :: - InvalidOnionVersion(randomBytes32) :: InvalidOnionHmac(randomBytes32) :: InvalidOnionKey(randomBytes32) :: InvalidOnionPayload(randomBytes32) :: + InvalidOnionVersion(randomBytes32) :: InvalidOnionHmac(randomBytes32) :: InvalidOnionKey(randomBytes32) :: TemporaryChannelFailure(channelUpdate) :: PermanentChannelFailure :: RequiredChannelFeatureMissing :: UnknownNextPeer :: AmountBelowMinimum(123456 msat, channelUpdate) :: FeeInsufficient(546463 msat, channelUpdate) :: IncorrectCltvExpiry(CltvExpiry(1211), channelUpdate) :: ExpiryTooSoon(channelUpdate) :: - IncorrectOrUnknownPaymentDetails(123456 msat, 1105) :: FinalIncorrectCltvExpiry(CltvExpiry(1234)) :: ChannelDisabled(0, 1, channelUpdate) :: ExpiryTooFar :: Nil + IncorrectOrUnknownPaymentDetails(123456 msat, 1105) :: FinalIncorrectCltvExpiry(CltvExpiry(1234)) :: ChannelDisabled(0, 1, channelUpdate) :: ExpiryTooFar :: InvalidOnionPayload(UInt64(561), 1105) :: Nil msgs.foreach { msg => { @@ -83,8 +83,7 @@ class FailureMessageCodecsSpec extends FunSuite { val msgs = Map( (BADONION | PERM | 4) -> InvalidOnionVersion(randomBytes32), (BADONION | PERM | 5) -> InvalidOnionHmac(randomBytes32), - (BADONION | PERM | 6) -> InvalidOnionKey(randomBytes32), - (BADONION | PERM) -> InvalidOnionPayload(randomBytes32) + (BADONION | PERM | 6) -> InvalidOnionKey(randomBytes32) ) for ((code, message) <- msgs) { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala index 272b85fc26..c2077d0ebb 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala @@ -17,9 +17,13 @@ package fr.acinq.eclair.wire import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.eclair.UInt64.Conversions._ +import fr.acinq.eclair.wire.Onion.{FinalLegacyPayload, FinalTlvPayload, RelayLegacyPayload, RelayTlvPayload} import fr.acinq.eclair.wire.OnionCodecs._ -import fr.acinq.eclair.{CltvExpiry, LongToBtcAmount, ShortChannelId} +import fr.acinq.eclair.wire.OnionTlv._ +import fr.acinq.eclair.{CltvExpiry, LongToBtcAmount, ShortChannelId, UInt64} import org.scalatest.FunSuite +import scodec.Attempt import scodec.bits.HexStringSyntax /** @@ -39,18 +43,35 @@ class OnionCodecsSpec extends FunSuite { assert(encoded.toByteVector === bin) } - test("encode/decode per-hop payload") { - val payload = PerHopPayload(shortChannelId = ShortChannelId(42), amtToForward = 142000 msat, outgoingCltvValue = CltvExpiry(500000)) - val bin = perHopPayloadCodec.encode(payload).require - assert(bin.toByteVector.size === 33) - val payload1 = perHopPayloadCodec.decode(bin).require.value - assert(payload === payload1) - - // realm (the first byte) should be 0 - val bin1 = bin.toByteVector.update(0, 1) - intercept[IllegalArgumentException] { - val payload2 = perHopPayloadCodec.decode(bin1.bits).require.value - assert(payload2 === payload1) + test("encode/decode fixed-size (legacy) relay per-hop payload") { + val testCases = Map( + RelayLegacyPayload(ShortChannelId(0), 0 msat, CltvExpiry(0)) -> hex"00 0000000000000000 0000000000000000 00000000 000000000000000000000000", + RelayLegacyPayload(ShortChannelId(42), 142000 msat, CltvExpiry(500000)) -> hex"00 000000000000002a 0000000000022ab0 0007a120 000000000000000000000000", + RelayLegacyPayload(ShortChannelId(561), 1105 msat, CltvExpiry(1729)) -> hex"00 0000000000000231 0000000000000451 000006c1 000000000000000000000000" + ) + + for ((expected, bin) <- testCases) { + val decoded = relayPerHopPayloadCodec.decode(bin.bits).require.value + assert(decoded === expected) + + val encoded = relayPerHopPayloadCodec.encode(expected).require.bytes + assert(encoded === bin) + } + } + + test("encode/decode fixed-size (legacy) final per-hop payload") { + val testCases = Map( + FinalLegacyPayload(0 msat, CltvExpiry(0)) -> hex"00 0000000000000000 0000000000000000 00000000 000000000000000000000000", + FinalLegacyPayload(142000 msat, CltvExpiry(500000)) -> hex"00 0000000000000000 0000000000022ab0 0007a120 000000000000000000000000", + FinalLegacyPayload(1105 msat, CltvExpiry(1729)) -> hex"00 0000000000000000 0000000000000451 000006c1 000000000000000000000000" + ) + + for ((expected, bin) <- testCases) { + val decoded = finalPerHopPayloadCodec.decode(bin.bits).require.value + assert(decoded === expected) + + val encoded = finalPerHopPayloadCodec.encode(expected).require.bytes + assert(encoded === bin) } } @@ -71,4 +92,88 @@ class OnionCodecsSpec extends FunSuite { } } + test("encode/decode variable-length (tlv) relay per-hop payload") { + val testCases = Map( + TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), OutgoingChannelId(ShortChannelId(1105))) -> hex"11 02020231 04012a 06080000000000000451", + TlvStream[OnionTlv](Seq(AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), OutgoingChannelId(ShortChannelId(1105))), Seq(GenericTlv(65535, hex"06c1"))) -> hex"17 02020231 04012a 06080000000000000451 fdffff0206c1" + ) + + for ((expected, bin) <- testCases) { + val decoded = relayPerHopPayloadCodec.decode(bin.bits).require.value + assert(decoded === RelayTlvPayload(expected)) + assert(decoded.amountToForward === 561.msat) + assert(decoded.outgoingCltv === CltvExpiry(42)) + assert(decoded.outgoingChannelId === ShortChannelId(1105)) + + val encoded = relayPerHopPayloadCodec.encode(RelayTlvPayload(expected)).require.bytes + assert(encoded === bin) + } + } + + test("encode/decode variable-length (tlv) final per-hop payload") { + val testCases = Map( + TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42))) -> hex"07 02020231 04012a", + TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), OutgoingChannelId(ShortChannelId(1105))) -> hex"11 02020231 04012a 06080000000000000451", + TlvStream[OnionTlv](Seq(AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42))), Seq(GenericTlv(65535, hex"06c1"))) -> hex"0d 02020231 04012a fdffff0206c1" + ) + + for ((expected, bin) <- testCases) { + val decoded = finalPerHopPayloadCodec.decode(bin.bits).require.value + assert(decoded === FinalTlvPayload(expected)) + assert(decoded.amount === 561.msat) + assert(decoded.expiry === CltvExpiry(42)) + + val encoded = finalPerHopPayloadCodec.encode(FinalTlvPayload(expected)).require.bytes + assert(encoded === bin) + } + } + + test("decode variable-length (tlv) relay per-hop payload missing information") { + val testCases = Seq( + (InvalidOnionPayload(UInt64(2), 0), hex"0d 04012a 06080000000000000451"), // missing amount + (InvalidOnionPayload(UInt64(4), 0), hex"0e 02020231 06080000000000000451"), // missing cltv + (InvalidOnionPayload(UInt64(6), 0), hex"07 02020231 04012a") // missing channel id + ) + + for ((expectedErr, bin) <- testCases) { + val decoded = relayPerHopPayloadCodec.decode(bin.bits) + assert(decoded.isFailure) + val Attempt.Failure(err: MissingRequiredTlv) = decoded + assert(err.failureMessage === expectedErr) + } + } + + test("decode variable-length (tlv) final per-hop payload missing information") { + val testCases = Seq( + (InvalidOnionPayload(UInt64(2), 0), hex"03 04012a"), // missing amount + (InvalidOnionPayload(UInt64(4), 0), hex"04 02020231") // missing cltv + ) + + for ((expectedErr, bin) <- testCases) { + val decoded = finalPerHopPayloadCodec.decode(bin.bits) + assert(decoded.isFailure) + val Attempt.Failure(err: MissingRequiredTlv) = decoded + assert(err.failureMessage === expectedErr) + } + } + + test("decode invalid per-hop payload") { + val testCases = Seq( + // Invalid fixed-size (legacy) payload. + hex"00 000000000000002a 000000000000002a", // invalid length + // Invalid variable-length (tlv) payload. + hex"00", // empty payload is missing required information + hex"01", // invalid length + hex"01 0000", // invalid length + hex"04 0000 2a00", // unknown even types + hex"04 0000 0000", // duplicate types + hex"04 0100 0000" // unordered types + ) + + for (testCase <- testCases) { + assert(relayPerHopPayloadCodec.decode(testCase.bits).isFailure) + assert(finalPerHopPayloadCodec.decode(testCase.bits).isFailure) + } + } + } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala index c21033085e..aaf5e4d1aa 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala @@ -116,6 +116,17 @@ class TlvCodecsSpec extends FunSuite { } } + test("encode/decode truncated uint64 overflow") { + assert(tu64overflow.encode(Long.MaxValue).require.toByteVector === hex"087fffffffffffffff") + assert(tu64overflow.decode(hex"087fffffffffffffff".bits).require.value === Long.MaxValue) + + assert(tu64overflow.encode(42L).require.toByteVector === hex"012a") + assert(tu64overflow.decode(hex"012a".bits).require.value === 42L) + + assert(tu64overflow.encode(-1L).isFailure) + assert(tu64overflow.decode(hex"088000000000000000".bits).isFailure) + } + test("decode invalid truncated integers") { val testCases = Seq( (tu16, hex"01 00"), // not minimal diff --git a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/Handlers.scala b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/Handlers.scala index e0bb1daf59..b11aeb5888 100644 --- a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/Handlers.scala +++ b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/Handlers.scala @@ -20,11 +20,11 @@ import java.util.UUID import akka.pattern.{AskTimeoutException, ask} import akka.util.Timeout -import fr.acinq.eclair.MilliSatoshi -import fr.acinq.eclair._ +import fr.acinq.eclair.{MilliSatoshi, _} import fr.acinq.eclair.gui.controllers._ import fr.acinq.eclair.io.{NodeURI, Peer} -import fr.acinq.eclair.payment.PaymentLifecycle.{PaymentResult, ReceivePayment, SendPayment} +import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest +import fr.acinq.eclair.payment.PaymentLifecycle.ReceivePayment import fr.acinq.eclair.payment._ import grizzled.slf4j.Logging @@ -33,26 +33,23 @@ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} /** - * Created by PM on 16/08/2016. - */ + * Created by PM on 16/08/2016. + */ class Handlers(fKit: Future[Kit])(implicit ec: ExecutionContext = ExecutionContext.Implicits.global) extends Logging { implicit val timeout = Timeout(60 seconds) private var notifsController: Option[NotificationsController] = None - def initNotifications(controller: NotificationsController) = { + def initNotifications(controller: NotificationsController): Unit = { notifsController = Option(controller) } /** - * Opens a connection to a node. If the channel option exists this will also open a channel with the node, with a - * `fundingSatoshis` capacity and `pushMsat` amount. - * - * @param nodeUri - * @param channel - */ - def open(nodeUri: NodeURI, channel: Option[Peer.OpenChannel]) = { + * Opens a connection to a node. If the channel option exists this will also open a channel with the node, with a + * `fundingSatoshis` capacity and `pushMsat` amount. + */ + def open(nodeUri: NodeURI, channel: Option[Peer.OpenChannel]): Unit = { logger.info(s"opening a connection to nodeUri=$nodeUri") (for { kit <- fKit @@ -88,8 +85,8 @@ class Handlers(fKit: Future[Kit])(implicit ec: ExecutionContext = ExecutionConte (for { kit <- fKit sendPayment = req.minFinalCltvExpiryDelta match { - case None => SendPayment(MilliSatoshi(amountMsat), req.paymentHash, req.nodeId, req.routingInfo, maxAttempts = kit.nodeParams.maxPaymentAttempts) - case Some(minFinalCltvExpiry) => SendPayment(MilliSatoshi(amountMsat), req.paymentHash, req.nodeId, req.routingInfo, finalCltvExpiryDelta = minFinalCltvExpiry, maxAttempts = kit.nodeParams.maxPaymentAttempts) + case None => SendPaymentRequest(MilliSatoshi(amountMsat), req.paymentHash, req.nodeId, kit.nodeParams.maxPaymentAttempts, assistedRoutes = req.routingInfo) + case Some(minFinalCltvExpiry) => SendPaymentRequest(MilliSatoshi(amountMsat), req.paymentHash, req.nodeId, kit.nodeParams.maxPaymentAttempts, assistedRoutes = req.routingInfo, finalExpiryDelta = minFinalCltvExpiry) } res <- (kit.paymentInitiator ? sendPayment).mapTo[UUID] } yield res).recover { @@ -108,14 +105,14 @@ class Handlers(fKit: Future[Kit])(implicit ec: ExecutionContext = ExecutionConte } yield res /** - * Displays a system notification if the system supports it. - * - * @param title Title of the notification - * @param message main message of the notification, will not wrap - * @param notificationType type of the message, default to NONE - * @param showAppName true if you want the notification title to be preceded by "Eclair - ". True by default - */ - def notification(title: String, message: String, notificationType: NotificationType = NOTIFICATION_NONE, showAppName: Boolean = true) = { + * Displays a system notification if the system supports it. + * + * @param title Title of the notification + * @param message main message of the notification, will not wrap + * @param notificationType type of the message, default to NONE + * @param showAppName true if you want the notification title to be preceded by "Eclair - ". True by default + */ + def notification(title: String, message: String, notificationType: NotificationType = NOTIFICATION_NONE, showAppName: Boolean = true): Unit = { notifsController.foreach(_.addNotification(if (showAppName) s"Eclair - $title" else title, message, notificationType)) } }