From 96ca8b9b62a062645b7cd8eb510b71100360db46 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Fri, 6 Sep 2019 13:43:07 +0200 Subject: [PATCH 01/14] Refactor PaymentLifecycle. Unify payment events. Factorize DB and eventStream interactions: this paves the way for sub-payments that shouldn't be stored in the DB nor emit events. --- .../fr/acinq/eclair/payment/Autoprobe.scala | 5 +- .../eclair/payment/LocalPaymentHandler.scala | 3 +- .../acinq/eclair/payment/PaymentEvents.scala | 64 +++++++++- .../eclair/payment/PaymentInitiator.scala | 4 +- .../eclair/payment/PaymentLifecycle.scala | 113 ++++++++---------- .../fr/acinq/eclair/payment/Relayer.scala | 8 +- .../eclair/integration/IntegrationSpec.scala | 35 +++--- .../eclair/payment/PaymentHandlerSpec.scala | 4 +- .../eclair/payment/PaymentLifecycleSpec.scala | 69 ++++++----- .../scala/fr/acinq/eclair/gui/FxApp.scala | 16 ++- .../fr/acinq/eclair/gui/GUIUpdater.scala | 85 +++++++------ .../scala/fr/acinq/eclair/api/Service.scala | 3 +- .../fr/acinq/eclair/api/ApiServiceSpec.scala | 14 +-- 13 files changed, 232 insertions(+), 191 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala index 40ca336c44..fc738ac934 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala @@ -20,7 +20,6 @@ import akka.actor.{Actor, ActorLogging, ActorRef, Props} import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.crypto.Sphinx.DecryptedFailurePacket import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest -import fr.acinq.eclair.payment.PaymentLifecycle.{PaymentFailed, PaymentResult, RemoteFailure} import fr.acinq.eclair.router.{Announcements, Data, PublicChannel} import fr.acinq.eclair.wire.IncorrectOrUnknownPaymentDetails import fr.acinq.eclair.{LongToBtcAmount, NodeParams, randomBytes32, secureRandom} @@ -61,9 +60,9 @@ class Autoprobe(nodeParams: NodeParams, router: ActorRef, paymentInitiator: Acto scheduleProbe() } - case paymentResult: PaymentResult => + case paymentResult: PaymentEvent => paymentResult match { - case PaymentFailed(_, _, _ :+ RemoteFailure(_, DecryptedFailurePacket(targetNodeId, IncorrectOrUnknownPaymentDetails(_, _)))) => + case PaymentFailed(_, _, _ :+ RemoteFailure(_, DecryptedFailurePacket(targetNodeId, IncorrectOrUnknownPaymentDetails(_, _))), _) => log.info(s"payment probe successful to node=$targetNodeId") case _ => log.info(s"payment probe failed with paymentResult=$paymentResult") diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/LocalPaymentHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/LocalPaymentHandler.scala index 9d3cea6cbc..fab34c5a21 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/LocalPaymentHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/LocalPaymentHandler.scala @@ -79,8 +79,7 @@ class LocalPaymentHandler(nodeParams: NodeParams) extends Actor with ActorLoggin // amount is correct or was not specified in the payment request nodeParams.db.payments.addIncomingPayment(IncomingPayment(htlc.paymentHash, htlc.amountMsat, Platform.currentTime)) sender ! CMD_FULFILL_HTLC(htlc.id, paymentPreimage, commit = true) - context.system.eventStream.publish(PaymentReceived(htlc.amountMsat, htlc.paymentHash, htlc.channelId)) - + context.system.eventStream.publish(PaymentReceived(htlc.amountMsat, htlc.paymentHash)) } case None => sender ! CMD_FAIL_HTLC(htlc.id, Right(IncorrectOrUnknownPaymentDetails(htlc.amountMsat, nodeParams.currentBlockHeight)), commit = true) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala index 7b0f1b55c2..2bad5b064f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala @@ -20,20 +20,76 @@ import java.util.UUID import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.MilliSatoshi +import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.router.Hop import scala.compat.Platform /** - * Created by PM on 01/02/2017. - */ + * Created by PM on 01/02/2017. + */ + sealed trait PaymentEvent { val paymentHash: ByteVector32 + val timestamp: Long } -case class PaymentSent(id: UUID, amount: MilliSatoshi, feesPaid: MilliSatoshi, paymentHash: ByteVector32, paymentPreimage: ByteVector32, toChannelId: ByteVector32, timestamp: Long = Platform.currentTime) extends PaymentEvent +case class PaymentSent(id: UUID, amount: MilliSatoshi, feesPaid: MilliSatoshi, paymentHash: ByteVector32, paymentPreimage: ByteVector32, route: Seq[Hop], timestamp: Long = Platform.currentTime) extends PaymentEvent + +case class PaymentFailed(id: UUID, paymentHash: ByteVector32, failures: Seq[PaymentFailure], timestamp: Long = Platform.currentTime) extends PaymentEvent case class PaymentRelayed(amountIn: MilliSatoshi, amountOut: MilliSatoshi, paymentHash: ByteVector32, fromChannelId: ByteVector32, toChannelId: ByteVector32, timestamp: Long = Platform.currentTime) extends PaymentEvent -case class PaymentReceived(amount: MilliSatoshi, paymentHash: ByteVector32, fromChannelId: ByteVector32, timestamp: Long = Platform.currentTime) extends PaymentEvent +case class PaymentReceived(amount: MilliSatoshi, paymentHash: ByteVector32, timestamp: Long = Platform.currentTime) extends PaymentEvent case class PaymentSettlingOnChain(id: UUID, amount: MilliSatoshi, paymentHash: ByteVector32, timestamp: Long = Platform.currentTime) extends PaymentEvent + +sealed trait PaymentFailure + +/** A failure happened locally, preventing the payment from being sent (e.g. no route found). */ +case class LocalFailure(t: Throwable) extends PaymentFailure + +/** A remote node failed the payment and we were able to decrypt the onion failure packet. */ +case class RemoteFailure(route: Seq[Hop], e: Sphinx.DecryptedFailurePacket) extends PaymentFailure + +/** A remote node failed the payment but we couldn't decrypt the failure (e.g. a malicious node tampered with the message). */ +case class UnreadableRemoteFailure(route: Seq[Hop]) extends PaymentFailure + +object PaymentFailure { + + import fr.acinq.bitcoin.Crypto.PublicKey + import fr.acinq.eclair.channel.AddHtlcFailed + import fr.acinq.eclair.router.RouteNotFound + import fr.acinq.eclair.wire.Update + + /** + * Rewrites a list of failures to retrieve the meaningful part. + * + * If a list of failures with many elements ends up with a LocalFailure RouteNotFound, this RouteNotFound failure + * should be removed. This last failure is irrelevant information. In such a case only the n-1 attempts were rejected + * with a **significant reason**; the final RouteNotFound error provides no meaningful insight. + * + * This method should be used by the user interface to provide a non-exhaustive but more useful feedback. + * + * @param failures a list of payment failures for a payment + */ + def transformForUser(failures: Seq[PaymentFailure]): Seq[PaymentFailure] = { + failures.map { + case LocalFailure(AddHtlcFailed(_, _, t, _, _, _)) => LocalFailure(t) // we're interested in the error which caused the add-htlc to fail + case other => other + } match { + case previousFailures :+ LocalFailure(RouteNotFound) if previousFailures.nonEmpty => previousFailures + case other => other + } + } + + /** + * This allows us to detect if a bad node always answers with a new update (e.g. with a slightly different expiry or fee) + * in order to mess with us. + */ + def hasAlreadyFailedOnce(nodeId: PublicKey, failures: Seq[PaymentFailure]): Boolean = + failures + .collectFirst { case RemoteFailure(_, Sphinx.DecryptedFailurePacket(origin, u: Update)) if origin == nodeId => u.update } + .isDefined + +} \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala index f1b6acf1df..fed4e1db4e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala @@ -22,7 +22,7 @@ import akka.actor.{Actor, ActorLogging, ActorRef, Props} import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.channel.Channel -import fr.acinq.eclair.payment.PaymentLifecycle.{SendPayment, SendPaymentToRoute} +import fr.acinq.eclair.payment.PaymentLifecycle.{DefaultPaymentProgressHandler, SendPayment, SendPaymentToRoute} import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.router.RouteParams import fr.acinq.eclair.wire.Onion.FinalLegacyPayload @@ -38,7 +38,7 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: Actor val paymentId = UUID.randomUUID() // We add one block in order to not have our htlc fail when a new block has just been found. val finalExpiry = p.finalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight + 1) - val payFsm = context.actorOf(PaymentLifecycle.props(nodeParams, paymentId, router, register)) + val payFsm = context.actorOf(PaymentLifecycle.props(nodeParams, DefaultPaymentProgressHandler(paymentId, nodeParams.db.payments), router, register)) // NB: we only generate legacy payment onions for now for maximum compatibility. p.predefinedRoute match { case Nil => payFsm forward SendPayment(p.paymentHash, p.targetNodeId, FinalLegacyPayload(p.amount, finalExpiry), p.maxAttempts, p.assistedRoutes, p.routeParams) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala index d06566ca7e..5cc7e79051 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala @@ -18,13 +18,13 @@ package fr.acinq.eclair.payment import java.util.UUID -import akka.actor.{ActorRef, FSM, Props, Status} +import akka.actor.{ActorContext, ActorRef, FSM, Props, Status} import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair._ -import fr.acinq.eclair.channel.{AddHtlcFailed, CMD_ADD_HTLC, Register} +import fr.acinq.eclair.channel.{CMD_ADD_HTLC, Register} import fr.acinq.eclair.crypto.{Sphinx, TransportHandler} -import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus} +import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus, PaymentsDb} import fr.acinq.eclair.payment.PaymentLifecycle._ import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.router._ @@ -39,22 +39,23 @@ import scala.util.{Failure, Success} /** * Created by PM on 26/08/2016. */ -class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, register: ActorRef) extends FSM[PaymentLifecycle.State, PaymentLifecycle.Data] { - val paymentsDb = nodeParams.db.payments +class PaymentLifecycle(nodeParams: NodeParams, progressHandler: PaymentProgressHandler, router: ActorRef, register: ActorRef) extends FSM[PaymentLifecycle.State, PaymentLifecycle.Data] { + + val id = progressHandler.id startWith(WAITING_FOR_REQUEST, WaitingForRequest) when(WAITING_FOR_REQUEST) { case Event(c: SendPaymentToRoute, WaitingForRequest) => val send = SendPayment(c.paymentHash, c.hops.last, c.finalPayload, maxAttempts = 1) - paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.finalPayload.amount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING)) router ! FinalizeRoute(c.hops) + progressHandler.onSend(c.paymentHash, c.finalPayload.amount) goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, send, failures = Nil) case Event(c: SendPayment, WaitingForRequest) => router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, routeParams = c.routeParams) - paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.finalPayload.amount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING)) + progressHandler.onSend(c.paymentHash, c.finalPayload.amount) goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, c, failures = Nil) } @@ -67,18 +68,15 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis goto(WAITING_FOR_PAYMENT_COMPLETE) using WaitingForComplete(s, c, cmd, failures, sharedSecrets, ignoreNodes, ignoreChannels, hops) case Event(Status.Failure(t), WaitingForRoute(s, c, failures)) => - reply(s, PaymentFailed(id, c.paymentHash, failures = failures :+ LocalFailure(t))) - paymentsDb.updateOutgoingPayment(id, OutgoingPaymentStatus.FAILED) + progressHandler.onFail(s, PaymentFailed(id, c.paymentHash, failures :+ LocalFailure(t)))(context) stop(FSM.Normal) } when(WAITING_FOR_PAYMENT_COMPLETE) { - case Event("ok", _) => stay() + case Event("ok", _) => stay - case Event(fulfill: UpdateFulfillHtlc, WaitingForComplete(s, c, cmd, _, _, _, _, hops)) => - paymentsDb.updateOutgoingPayment(id, OutgoingPaymentStatus.SUCCEEDED, preimage = Some(fulfill.paymentPreimage)) - reply(s, PaymentSucceeded(id, cmd.amount, c.paymentHash, fulfill.paymentPreimage, hops)) - context.system.eventStream.publish(PaymentSent(id, c.finalPayload.amount, cmd.amount - c.finalPayload.amount, cmd.paymentHash, fulfill.paymentPreimage, fulfill.channelId)) + case Event(fulfill: UpdateFulfillHtlc, WaitingForComplete(s, c, cmd, _, _, _, _, route)) => + progressHandler.onSucceed(s, PaymentSent(id, c.finalPayload.amount, cmd.amount - c.finalPayload.amount, c.paymentHash, fulfill.paymentPreimage, route))(context) stop(FSM.Normal) case Event(fail: UpdateFailHtlc, WaitingForComplete(s, c, _, failures, sharedSecrets, ignoreNodes, ignoreChannels, hops)) => @@ -86,8 +84,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) if nodeId == c.targetNodeId => // if destination node returns an error, we fail the payment immediately log.warning(s"received an error message from target nodeId=$nodeId, failing the payment (failure=$failureMessage)") - reply(s, PaymentFailed(id, c.paymentHash, failures = failures :+ RemoteFailure(hops, e))) - paymentsDb.updateOutgoingPayment(id, OutgoingPaymentStatus.FAILED) + progressHandler.onFail(s, PaymentFailed(id, c.paymentHash, failures :+ RemoteFailure(hops, e)))(context) stop(FSM.Normal) case res if failures.size + 1 >= c.maxAttempts => // otherwise we never try more than maxAttempts, no matter the kind of error returned @@ -100,8 +97,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis UnreadableRemoteFailure(hops) } log.warning(s"too many failed attempts, failing the payment") - reply(s, PaymentFailed(id, c.paymentHash, failures = failures :+ failure)) - paymentsDb.updateOutgoingPayment(id, OutgoingPaymentStatus.FAILED) + progressHandler.onFail(s, PaymentFailed(id, c.paymentHash, failures :+ failure))(context) stop(FSM.Normal) case Failure(t) => log.warning(s"cannot parse returned error: ${t.getMessage}") @@ -128,7 +124,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis log.info(s"received exact same update from nodeId=$nodeId, excluding the channel from futures routes") val nextNodeId = hops.find(_.nodeId == nodeId).get.nextNodeId router ! ExcludeChannel(ChannelDesc(u.shortChannelId, nodeId, nextNodeId)) - case Some(u) if hasAlreadyFailedOnce(nodeId, failures) => + case Some(u) if PaymentFailure.hasAlreadyFailedOnce(nodeId, failures) => // this node had already given us a new channel update and is still unhappy, it is probably messing with us, let's exclude it log.warning(s"it is the second time nodeId=$nodeId answers with a new update, excluding it: old=$u new=${failureMessage.update}") val nextNodeId = hops.find(_.nodeId == nodeId).get.nextNodeId @@ -166,8 +162,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis case Event(Status.Failure(t), WaitingForComplete(s, c, _, failures, _, ignoreNodes, ignoreChannels, hops)) => if (failures.size + 1 >= c.maxAttempts) { - paymentsDb.updateOutgoingPayment(id, OutgoingPaymentStatus.FAILED) - reply(s, PaymentFailed(id, c.paymentHash, failures :+ LocalFailure(t))) + progressHandler.onFail(s, PaymentFailed(id, c.paymentHash, failures :+ LocalFailure(t)))(context) stop(FSM.Normal) } else { log.info(s"received an error message from local, trying to use a different channel (failure=${t.getMessage})") @@ -182,17 +177,44 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis case Event(_: TransportHandler.ReadAck, _) => stay // ignored, router replies with this when we forward a channel_update } - def reply(to: ActorRef, e: PaymentResult): Unit = { - to ! e - context.system.eventStream.publish(e) - } - initialize() } object PaymentLifecycle { - def props(nodeParams: NodeParams, id: UUID, router: ActorRef, register: ActorRef) = Props(classOf[PaymentLifecycle], nodeParams, id, router, register) + def props(nodeParams: NodeParams, progressHandler: PaymentProgressHandler, router: ActorRef, register: ActorRef) = Props(classOf[PaymentLifecycle], nodeParams, progressHandler, router, register) + + /** This handler notifies other components of payment progress. */ + trait PaymentProgressHandler { + val id: UUID + + // @formatter:off + def onSend(paymentHash: ByteVector32, finalAmount: MilliSatoshi): Unit + def onSucceed(sender: ActorRef, result: PaymentSent)(ctx: ActorContext): Unit + def onFail(sender: ActorRef, result: PaymentFailed)(ctx: ActorContext): Unit + // @formatter:on + } + + /** Normal payments are stored in the payments DB and emit payment events. */ + case class DefaultPaymentProgressHandler(id: UUID, db: PaymentsDb) extends PaymentProgressHandler { + + override def onSend(paymentHash: ByteVector32, finalAmount: MilliSatoshi): Unit = { + db.addOutgoingPayment(OutgoingPayment(id, paymentHash, None, finalAmount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING)) + } + + override def onSucceed(sender: ActorRef, result: PaymentSent)(ctx: ActorContext): Unit = { + db.updateOutgoingPayment(result.id, OutgoingPaymentStatus.SUCCEEDED, preimage = Some(result.paymentPreimage)) + sender ! result + ctx.system.eventStream.publish(result) + } + + override def onFail(sender: ActorRef, result: PaymentFailed)(ctx: ActorContext): Unit = { + db.updateOutgoingPayment(result.id, OutgoingPaymentStatus.FAILED) + sender ! result + ctx.system.eventStream.publish(result) + } + + } // @formatter:off case class ReceivePayment(amount_opt: Option[MilliSatoshi], description: String, expirySeconds_opt: Option[Long] = None, extraHops: List[List[ExtraHop]] = Nil, fallbackAddress: Option[String] = None, paymentPreimage: Option[ByteVector32] = None) @@ -206,14 +228,6 @@ object PaymentLifecycle { require(finalPayload.amount > 0.msat, s"amount must be > 0") } - sealed trait PaymentResult - case class PaymentSucceeded(id: UUID, amount: MilliSatoshi, paymentHash: ByteVector32, paymentPreimage: ByteVector32, route: Seq[Hop]) extends PaymentResult // note: the amount includes fees - sealed trait PaymentFailure - case class LocalFailure(t: Throwable) extends PaymentFailure - case class RemoteFailure(route: Seq[Hop], e: Sphinx.DecryptedFailurePacket) extends PaymentFailure - case class UnreadableRemoteFailure(route: Seq[Hop]) extends PaymentFailure - case class PaymentFailed(id: UUID, paymentHash: ByteVector32, failures: Seq[PaymentFailure]) extends PaymentResult - sealed trait Data case object WaitingForRequest extends Data case class WaitingForRoute(sender: ActorRef, c: SendPayment, failures: Seq[PaymentFailure]) extends Data @@ -268,27 +282,6 @@ object PaymentLifecycle { CMD_ADD_HTLC(firstAmount, paymentHash, firstExpiry, onion.packet, upstream = Left(id), commit = true) -> onion.sharedSecrets } - /** - * Rewrites a list of failures to retrieve the meaningful part. - *

- * If a list of failures with many elements ends up with a LocalFailure RouteNotFound, this RouteNotFound failure - * should be removed. This last failure is irrelevant information. In such a case only the n-1 attempts were rejected - * with a **significant reason** ; the final RouteNotFound error provides no meaningful insight. - *

- * This method should be used by the user interface to provide a non-exhaustive but more useful feedback. - * - * @param failures a list of payment failures for a payment - */ - def transformForUser(failures: Seq[PaymentFailure]): Seq[PaymentFailure] = { - failures.map { - case LocalFailure(AddHtlcFailed(_, _, t, _, _, _)) => LocalFailure(t) // we're interested in the error which caused the add-htlc to fail - case other => other - } match { - case previousFailures :+ LocalFailure(RouteNotFound) if previousFailures.nonEmpty => previousFailures - case other => other - } - } - /** * This method retrieves the channel update that we used when we built a route. * It just iterates over the hops, but there are at most 20 of them. @@ -297,12 +290,4 @@ object PaymentLifecycle { */ def getChannelUpdateForNode(nodeId: PublicKey, hops: Seq[Hop]): Option[ChannelUpdate] = hops.find(_.nodeId == nodeId).map(_.lastUpdate) - /** - * This allows us to detect if a bad node always answers with a new update (e.g. with a slightly different expiry or fee) - * in order to mess with us. - */ - def hasAlreadyFailedOnce(nodeId: PublicKey, failures: Seq[PaymentFailure]): Boolean = - failures - .collectFirst { case RemoteFailure(_, Sphinx.DecryptedFailurePacket(origin, u: Update)) if origin == nodeId => u.update } - .isDefined } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala index 1ee391d05e..a009f310ce 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala @@ -25,7 +25,6 @@ import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.db.OutgoingPaymentStatus -import fr.acinq.eclair.payment.PaymentLifecycle.{PaymentFailed, PaymentSucceeded} import fr.acinq.eclair.router.Announcements import fr.acinq.eclair.wire._ import fr.acinq.eclair.{CltvExpiryDelta, Features, LongToBtcAmount, MilliSatoshi, NodeParams, ShortChannelId, UInt64, nodeFee} @@ -159,12 +158,11 @@ class Relayer(nodeParams: NodeParams, register: ActorRef, paymentHandler: ActorR case ForwardFulfill(fulfill, to, add) => to match { case Local(id, None) => - val feesPaid = 0.msat - context.system.eventStream.publish(PaymentSent(id, add.amountMsat, feesPaid, add.paymentHash, fulfill.paymentPreimage, fulfill.channelId)) // we sent the payment, but we probably restarted and the reference to the original sender was lost, - // we publish the failure on the event stream and update the status in paymentDb + // we publish the success on the event stream and update the status in paymentDb + val feesPaid = 0.msat // fees are unknown since we lost the reference to the payment nodeParams.db.payments.updateOutgoingPayment(id, OutgoingPaymentStatus.SUCCEEDED, Some(fulfill.paymentPreimage)) - context.system.eventStream.publish(PaymentSucceeded(id, add.amountMsat, add.paymentHash, fulfill.paymentPreimage, Nil)) // + context.system.eventStream.publish(PaymentSent(id, add.amountMsat, feesPaid, add.paymentHash, fulfill.paymentPreimage, Nil)) case Local(_, Some(sender)) => sender ! fulfill case Relayed(originChannelId, originHtlcId, amountIn, amountOut) => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala index 92dc0bae4f..a703f2bf43 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala @@ -36,7 +36,7 @@ import fr.acinq.eclair.io.Peer import fr.acinq.eclair.io.Peer.{Disconnect, PeerRoutingMessage} import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest import fr.acinq.eclair.payment.PaymentLifecycle.{State => _, _} -import fr.acinq.eclair.payment.{LocalPaymentHandler, PaymentRequest} +import fr.acinq.eclair.payment._ import fr.acinq.eclair.router.Graph.WeightRatios import fr.acinq.eclair.router.Router.ROUTE_MAX_LENGTH import fr.acinq.eclair.router.{Announcements, AnnouncementsBatchValidationSpec, PublicChannel, RouteParams} @@ -267,7 +267,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService // then we make the actual payment sender.send(nodes("A").paymentInitiator, SendPaymentRequest(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 1)) val paymentId = sender.expectMsgType[UUID](5 seconds) - val ps = sender.expectMsgType[PaymentSucceeded](5 seconds) + val ps = sender.expectMsgType[PaymentSent](5 seconds) assert(ps.id == paymentId) } @@ -293,7 +293,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.send(nodes("A").paymentInitiator, sendReq) // A will receive an error from B that include the updated channel update, then will retry the payment val paymentId = sender.expectMsgType[UUID](5 seconds) - val ps = sender.expectMsgType[PaymentSucceeded](5 seconds) + val ps = sender.expectMsgType[PaymentSent](5 seconds) assert(ps.id == paymentId) def updateFor(n: PublicKey, pc: PublicChannel): Option[ChannelUpdate] = if (n == pc.ann.nodeId1) pc.update_1_opt else if (n == pc.ann.nodeId2) pc.update_2_opt else throw new IllegalArgumentException("this node is unrelated to this channel") @@ -333,7 +333,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.send(nodes("A").paymentInitiator, sendReq) // A will first receive an error from C, then retry and route around C: A->B->E->C->D sender.expectMsgType[UUID](5 seconds) - sender.expectMsgType[PaymentSucceeded] // the payment FSM will also reply to the sender after the payment is completed + sender.expectMsgType[PaymentSent] // the payment FSM will also reply to the sender after the payment is completed } test("send an HTLC A->D with an unknown payment hash") { @@ -414,7 +414,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val sendReq = SendPaymentRequest(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) sender.send(nodes("A").paymentInitiator, sendReq) sender.expectMsgType[UUID] - sender.expectMsgType[PaymentSucceeded] // the payment FSM will also reply to the sender after the payment is completed + sender.expectMsgType[PaymentSent] // the payment FSM will also reply to the sender after the payment is completed } } @@ -430,19 +430,16 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.expectMsgType[UUID](max = 60 seconds) awaitCond({ - sender.expectMsgType[PaymentResult](10 seconds) match { - case PaymentFailed(_, _, failures) => failures == Seq.empty // if something went wrong fail with a hint - case PaymentSucceeded(_, _, _, _, route) => route.exists(_.nodeId == nodes("G").nodeParams.nodeId) + sender.expectMsgType[PaymentEvent](10 seconds) match { + case PaymentFailed(_, _, failures, _) => failures == Seq.empty // if something went wrong fail with a hint + case PaymentSent(_, _, _, _, _, route, _) => route.exists(_.nodeId == nodes("G").nodeParams.nodeId) + case _ => false } }, max = 30 seconds, interval = 10 seconds) } - /** * We currently use p2pkh script Helpers.getFinalScriptPubKey - * - * @param scriptPubKey - * @return */ def scriptPubKeyToAddress(scriptPubKey: ByteVector) = Script.parse(scriptPubKey) match { case OP_DUP :: OP_HASH160 :: OP_PUSHDATA(pubKeyHash, _) :: OP_EQUALVERIFY :: OP_CHECKSIG :: Nil => @@ -508,7 +505,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.send(bitcoincli, BitcoinReq("generate", 1)) sender.expectMsgType[JValue](10 seconds) // C will extract the preimage from the blockchain and fulfill the payment upstream - paymentSender.expectMsgType[PaymentSucceeded](30 seconds) + paymentSender.expectMsgType[PaymentSent](30 seconds) // at this point F should have 1 recv transactions: the redeemed htlc awaitCond({ sender.send(bitcoincli, BitcoinReq("listreceivedbyaddress", 0)) @@ -590,7 +587,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.send(bitcoincli, BitcoinReq("generate", 1)) sender.expectMsgType[JValue](10 seconds) // C will extract the preimage from the blockchain and fulfill the payment upstream - paymentSender.expectMsgType[PaymentSucceeded](30 seconds) + paymentSender.expectMsgType[PaymentSent](30 seconds) // at this point F should have 1 recv transactions: the redeemed htlc // we then generate enough blocks so that F gets its htlc-success delayed output sender.send(bitcoincli, BitcoinReq("generate", 145)) @@ -773,7 +770,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService forwardHandlerF.forward(paymentHandlerF) sigListener.expectMsgType[ChannelSignatureReceived] sigListener.expectMsgType[ChannelSignatureReceived] - sender.expectMsgType[PaymentSucceeded].id === paymentId + sender.expectMsgType[PaymentSent].id === paymentId // we now send a few htlcs C->F and F->C in order to obtain a commitments with multiple htlcs def send(amountMsat: MilliSatoshi, paymentHandler: ActorRef, paymentInitiator: ActorRef) = { @@ -813,19 +810,19 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService buffer.expectMsgType[UpdateAddHtlc] buffer.forward(paymentHandlerF) sigListener.expectMsgType[ChannelSignatureReceived] - val preimage1 = sender.expectMsgType[PaymentSucceeded].paymentPreimage + val preimage1 = sender.expectMsgType[PaymentSent].paymentPreimage buffer.expectMsgType[UpdateAddHtlc] buffer.forward(paymentHandlerF) sigListener.expectMsgType[ChannelSignatureReceived] - sender.expectMsgType[PaymentSucceeded].paymentPreimage + sender.expectMsgType[PaymentSent].paymentPreimage buffer.expectMsgType[UpdateAddHtlc] buffer.forward(paymentHandlerC) sigListener.expectMsgType[ChannelSignatureReceived] - sender.expectMsgType[PaymentSucceeded].paymentPreimage + sender.expectMsgType[PaymentSent].paymentPreimage buffer.expectMsgType[UpdateAddHtlc] buffer.forward(paymentHandlerC) sigListener.expectMsgType[ChannelSignatureReceived] - sender.expectMsgType[PaymentSucceeded].paymentPreimage + sender.expectMsgType[PaymentSent].paymentPreimage // this also allows us to get the channel id val channelId = commitmentsF.channelId // we also retrieve C's default final address diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentHandlerSpec.scala index b142e521f5..6e4ff42d31 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentHandlerSpec.scala @@ -59,7 +59,7 @@ class PaymentHandlerSpec extends TestKit(ActorSystem("test")) with FunSuiteLike sender.expectMsgType[CMD_FULFILL_HTLC] val paymentRelayed = eventListener.expectMsgType[PaymentReceived] - assert(paymentRelayed.copy(timestamp = 0) === PaymentReceived(amountMsat, add.paymentHash, add.channelId, timestamp = 0)) + assert(paymentRelayed.copy(timestamp = 0) === PaymentReceived(amountMsat, add.paymentHash, timestamp = 0)) assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).exists(_.paymentHash == pr.paymentHash)) } @@ -72,7 +72,7 @@ class PaymentHandlerSpec extends TestKit(ActorSystem("test")) with FunSuiteLike sender.send(handler, add) sender.expectMsgType[CMD_FULFILL_HTLC] val paymentRelayed = eventListener.expectMsgType[PaymentReceived] - assert(paymentRelayed.copy(timestamp = 0) === PaymentReceived(amountMsat, add.paymentHash, add.channelId, timestamp = 0)) + assert(paymentRelayed.copy(timestamp = 0) === PaymentReceived(amountMsat, add.paymentHash, timestamp = 0)) assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).exists(_.paymentHash == pr.paymentHash)) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index 7dcd6417e0..4e7c846457 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -22,7 +22,7 @@ import akka.actor.FSM.{CurrentState, SubscribeTransitionCallBack, Transition} import akka.actor.Status import akka.testkit.{TestFSMRef, TestProbe} import fr.acinq.bitcoin.Script.{pay2wsh, write} -import fr.acinq.bitcoin.{Block, ByteVector32, Transaction, TxOut} +import fr.acinq.bitcoin.{Block, ByteVector32, Crypto, Transaction, TxOut} import fr.acinq.eclair._ import fr.acinq.eclair.blockchain.{UtxoStatus, ValidateRequest, ValidateResult, WatchSpentBasic} import fr.acinq.eclair.channel.Register.ForwardShortId @@ -51,9 +51,10 @@ class PaymentLifecycleSpec extends BaseRouterSpec { import fixture._ val defaultPaymentHash = randomBytes32 val nodeParams = TestConstants.Alice.nodeParams.copy(keyManager = testKeyManager) - val paymentDb = nodeParams.db.payments val id = UUID.randomUUID() - val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, id, router, TestProbe().ref)) + val paymentDb = nodeParams.db.payments + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, paymentDb) + val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, progressHandler, router, TestProbe().ref)) val monitor = TestProbe() val sender = TestProbe() val eventListener = TestProbe() @@ -71,7 +72,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) - sender.expectMsgType[PaymentSucceeded] + sender.expectMsgType[PaymentSent] awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.SUCCEEDED)) } @@ -81,8 +82,9 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val nodeParams = TestConstants.Alice.nodeParams.copy(keyManager = testKeyManager) val paymentDb = nodeParams.db.payments val id = UUID.randomUUID() + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, paymentDb) val routerForwarder = TestProbe() - val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, id, routerForwarder.ref, TestProbe().ref)) + val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, progressHandler, routerForwarder.ref, TestProbe().ref)) val monitor = TestProbe() val sender = TestProbe() @@ -96,7 +98,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) routerForwarder.forward(router, routeRequest) - sender.expectMsg(PaymentFailed(id, request.paymentHash, LocalFailure(RouteNotFound) :: Nil)) + assert(sender.expectMsgType[PaymentFailed].failures === LocalFailure(RouteNotFound) :: Nil) awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.FAILED)) } @@ -105,7 +107,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val nodeParams = TestConstants.Alice.nodeParams.copy(keyManager = testKeyManager) val paymentDb = nodeParams.db.payments val id = UUID.randomUUID() - val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, id, router, TestProbe().ref)) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, paymentDb) + val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, progressHandler, router, TestProbe().ref)) val monitor = TestProbe() val sender = TestProbe() @@ -128,7 +131,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val relayer = TestProbe() val routerForwarder = TestProbe() val id = UUID.randomUUID() - val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, id, routerForwarder.ref, relayer.ref)) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, paymentDb) + val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, progressHandler, routerForwarder.ref, relayer.ref)) val monitor = TestProbe() val sender = TestProbe() @@ -160,7 +164,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, UpdateFailHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) // unparsable message // we allow 2 tries, so we send a 2nd request to the router - sender.expectMsg(PaymentFailed(id, request.paymentHash, UnreadableRemoteFailure(hops) :: UnreadableRemoteFailure(hops) :: Nil)) + assert(sender.expectMsgType[PaymentFailed].failures === UnreadableRemoteFailure(hops) :: UnreadableRemoteFailure(hops) :: Nil) awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.FAILED)) // after last attempt the payment is failed } @@ -171,7 +175,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val relayer = TestProbe() val routerForwarder = TestProbe() val id = UUID.randomUUID() - val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, id, routerForwarder.ref, relayer.ref)) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, paymentDb) + val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, progressHandler, routerForwarder.ref, relayer.ref)) val monitor = TestProbe() val sender = TestProbe() @@ -204,7 +209,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val relayer = TestProbe() val routerForwarder = TestProbe() val id = UUID.randomUUID() - val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, id, routerForwarder.ref, relayer.ref)) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, paymentDb) + val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, progressHandler, routerForwarder.ref, relayer.ref)) val monitor = TestProbe() val sender = TestProbe() @@ -235,7 +241,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val relayer = TestProbe() val routerForwarder = TestProbe() val id = UUID.randomUUID() - val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, id, routerForwarder.ref, relayer.ref)) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, nodeParams.db.payments) + val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, progressHandler, routerForwarder.ref, relayer.ref)) val monitor = TestProbe() val sender = TestProbe() @@ -265,7 +272,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) routerForwarder.forward(router) // we allow 2 tries, so we send a 2nd request to the router - sender.expectMsg(PaymentFailed(id, request.paymentHash, RemoteFailure(hops, Sphinx.DecryptedFailurePacket(b, failure)) :: LocalFailure(RouteNotFound) :: Nil)) + assert(sender.expectMsgType[PaymentFailed].failures === RemoteFailure(hops, Sphinx.DecryptedFailurePacket(b, failure)) :: LocalFailure(RouteNotFound) :: Nil) } test("payment failed (Update)") { fixture => @@ -275,7 +282,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val relayer = TestProbe() val routerForwarder = TestProbe() val id = UUID.randomUUID() - val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, id, routerForwarder.ref, relayer.ref)) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, paymentDb) + val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, progressHandler, routerForwarder.ref, relayer.ref)) val monitor = TestProbe() val sender = TestProbe() @@ -326,7 +334,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { routerForwarder.forward(router) // this time the router can't find a route: game over - sender.expectMsg(PaymentFailed(id, request.paymentHash, RemoteFailure(hops, Sphinx.DecryptedFailurePacket(b, failure)) :: RemoteFailure(hops2, Sphinx.DecryptedFailurePacket(b, failure2)) :: LocalFailure(RouteNotFound) :: Nil)) + assert(sender.expectMsgType[PaymentFailed].failures === RemoteFailure(hops, Sphinx.DecryptedFailurePacket(b, failure)) :: RemoteFailure(hops2, Sphinx.DecryptedFailurePacket(b, failure2)) :: LocalFailure(RouteNotFound) :: Nil) awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.FAILED)) } @@ -337,7 +345,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val relayer = TestProbe() val routerForwarder = TestProbe() val id = UUID.randomUUID() - val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, id, routerForwarder.ref, relayer.ref)) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, paymentDb) + val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, progressHandler, routerForwarder.ref, relayer.ref)) val monitor = TestProbe() val sender = TestProbe() @@ -363,7 +372,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { routerForwarder.forward(router) // we allow 2 tries, so we send a 2nd request to the router, which won't find another route - sender.expectMsg(PaymentFailed(id, request.paymentHash, RemoteFailure(hops, Sphinx.DecryptedFailurePacket(b, failure)) :: LocalFailure(RouteNotFound) :: Nil)) + assert(sender.expectMsgType[PaymentFailed].failures === RemoteFailure(hops, Sphinx.DecryptedFailurePacket(b, failure)) :: LocalFailure(RouteNotFound) :: Nil) awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.FAILED)) } @@ -379,11 +388,13 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("payment succeeded") { fixture => import fixture._ - val defaultPaymentHash = randomBytes32 + val paymentPreimage = randomBytes32 + val paymentHash = Crypto.sha256(paymentPreimage) val nodeParams = TestConstants.Alice.nodeParams.copy(keyManager = testKeyManager) val paymentDb = nodeParams.db.payments val id = UUID.randomUUID() - val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, id, router, TestProbe().ref)) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, paymentDb) + val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, progressHandler, router, TestProbe().ref)) val monitor = TestProbe() val sender = TestProbe() val eventListener = TestProbe() @@ -392,17 +403,18 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 5) + val request = SendPayment(paymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 5) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val Transition(_, WAITING_FOR_ROUTE, WAITING_FOR_PAYMENT_COMPLETE) = monitor.expectMsgClass(classOf[Transition[_]]) awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) - sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) + sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, paymentPreimage)) - val paymentOK = sender.expectMsgType[PaymentSucceeded] - val PaymentSent(_, request.finalPayload.amount, fee, request.paymentHash, paymentOK.paymentPreimage, _, _) = eventListener.expectMsgType[PaymentSent] - assert(fee > 0.msat) - assert(fee === paymentOK.amount - request.finalPayload.amount) + val ps = eventListener.expectMsgType[PaymentSent] + assert(ps.feesPaid > 0.msat) + assert(ps.amount === defaultAmountMsat) + assert(ps.paymentHash === paymentHash) + assert(ps.paymentPreimage === paymentPreimage) awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.SUCCEEDED)) } @@ -432,7 +444,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { watcher.expectMsgType[WatchSpentBasic] // actual test begins - val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, UUID.randomUUID(), router, TestProbe().ref)) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(UUID.randomUUID(), nodeParams.db.payments) + val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, progressHandler, router, TestProbe().ref)) val monitor = TestProbe() val sender = TestProbe() val eventListener = TestProbe() @@ -451,7 +464,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) - val paymentOK = sender.expectMsgType[PaymentSucceeded] + val paymentOK = sender.expectMsgType[PaymentSent] val PaymentSent(_, request.finalPayload.amount, fee, request.paymentHash, paymentOK.paymentPreimage, _, _) = eventListener.expectMsgType[PaymentSent] // during the route computation the fees were treated as if they were 1msat but when sending the onion we actually put zero @@ -463,7 +476,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("filter errors properly") { _ => val failures = LocalFailure(RouteNotFound) :: RemoteFailure(Hop(a, b, channelUpdate_ab) :: Nil, Sphinx.DecryptedFailurePacket(a, TemporaryNodeFailure)) :: LocalFailure(AddHtlcFailed(ByteVector32.Zeroes, ByteVector32.Zeroes, ChannelUnavailable(ByteVector32.Zeroes), Local(UUID.randomUUID(), None), None, None)) :: LocalFailure(RouteNotFound) :: Nil - val filtered = PaymentLifecycle.transformForUser(failures) + val filtered = PaymentFailure.transformForUser(failures) assert(filtered == LocalFailure(RouteNotFound) :: RemoteFailure(Hop(a, b, channelUpdate_ab) :: Nil, Sphinx.DecryptedFailurePacket(a, TemporaryNodeFailure)) :: LocalFailure(ChannelUnavailable(ByteVector32.Zeroes)) :: Nil) } } diff --git a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/FxApp.scala b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/FxApp.scala index 163d8c14a9..4958bbc634 100644 --- a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/FxApp.scala +++ b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/FxApp.scala @@ -25,7 +25,6 @@ import fr.acinq.eclair.blockchain.electrum.ElectrumClient.ElectrumEvent import fr.acinq.eclair.channel.ChannelEvent import fr.acinq.eclair.gui.controllers.{MainController, NotificationsController} import fr.acinq.eclair.payment.PaymentEvent -import fr.acinq.eclair.payment.PaymentLifecycle.PaymentResult import fr.acinq.eclair.router.NetworkEvent import grizzled.slf4j.Logging import javafx.application.Preloader.ErrorNotification @@ -42,8 +41,8 @@ import scala.util.{Failure, Success, Try} /** - * Created by PM on 16/08/2016. - */ + * Created by PM on 16/08/2016. + */ class FxApp extends Application with Logging { override def init = { @@ -99,7 +98,6 @@ class FxApp extends Application with Logging { system.eventStream.subscribe(guiUpdater, classOf[ChannelEvent]) system.eventStream.subscribe(guiUpdater, classOf[NetworkEvent]) system.eventStream.subscribe(guiUpdater, classOf[PaymentEvent]) - system.eventStream.subscribe(guiUpdater, classOf[PaymentResult]) system.eventStream.subscribe(guiUpdater, classOf[ZMQEvent]) system.eventStream.subscribe(guiUpdater, classOf[ElectrumEvent]) pKit.completeWith(setup.bootstrap) @@ -137,11 +135,11 @@ class FxApp extends Application with Logging { } /** - * Initialize the notification stage and assign it to the handler class. - * - * @param owner stage owning the notification stage - * @param notifhandlers Handles the notifications - */ + * Initialize the notification stage and assign it to the handler class. + * + * @param owner stage owning the notification stage + * @param notifhandlers Handles the notifications + */ private def initNotificationStage(owner: Stage, notifhandlers: Handlers) = { // get fxml/controller val notifFXML = new FXMLLoader(getClass.getResource("/gui/main/notifications.fxml")) diff --git a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/GUIUpdater.scala b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/GUIUpdater.scala index 2f996aa569..a295326330 100644 --- a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/GUIUpdater.scala +++ b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/GUIUpdater.scala @@ -21,12 +21,11 @@ import java.time.LocalDateTime import akka.actor.{Actor, ActorLogging, ActorRef, Terminated} import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin._ -import fr.acinq.eclair.{CoinUtils, MilliSatoshi} +import fr.acinq.eclair.CoinUtils import fr.acinq.eclair.blockchain.bitcoind.zmq.ZMQActor.{ZMQConnected, ZMQDisconnected} import fr.acinq.eclair.blockchain.electrum.ElectrumClient.{ElectrumDisconnected, ElectrumReady} import fr.acinq.eclair.channel._ import fr.acinq.eclair.gui.controllers._ -import fr.acinq.eclair.payment.PaymentLifecycle.{LocalFailure, PaymentFailed, PaymentSucceeded, RemoteFailure} import fr.acinq.eclair.payment._ import fr.acinq.eclair.router.{NORMAL => _, _} import javafx.application.Platform @@ -35,21 +34,21 @@ import javafx.scene.layout.VBox import scala.collection.JavaConversions._ - /** - * Created by PM on 16/08/2016. - */ + * Created by PM on 16/08/2016. + */ + class GUIUpdater(mainController: MainController) extends Actor with ActorLogging { val STATE_MUTUAL_CLOSE = Set(WAIT_FOR_INIT_INTERNAL, WAIT_FOR_OPEN_CHANNEL, WAIT_FOR_ACCEPT_CHANNEL, WAIT_FOR_FUNDING_INTERNAL, WAIT_FOR_FUNDING_CREATED, WAIT_FOR_FUNDING_SIGNED, NORMAL) val STATE_FORCE_CLOSE = Set(WAIT_FOR_FUNDING_CONFIRMED, WAIT_FOR_FUNDING_LOCKED, NORMAL, SHUTDOWN, NEGOTIATING, OFFLINE, SYNCING) /** - * Needed to stop JavaFX complaining about updates from non GUI thread - */ + * Needed to stop JavaFX complaining about updates from non GUI thread + */ private def runInGuiThread(f: () => Unit): Unit = { Platform.runLater(new Runnable() { - @Override def run() = f() + @Override def run(): Unit = f() }) } @@ -95,7 +94,7 @@ class GUIUpdater(mainController: MainController) extends Actor with ActorLogging }) context.become(main(m1)) - case ShortChannelIdAssigned(channel, channelId, shortChannelId) if m.contains(channel) => + case ShortChannelIdAssigned(channel, _, shortChannelId) if m.contains(channel) => val channelPaneController = m(channel) runInGuiThread(() => channelPaneController.shortChannelId.setText(shortChannelId.toString)) @@ -103,12 +102,12 @@ class GUIUpdater(mainController: MainController) extends Actor with ActorLogging val channelPaneController = m(channel) runInGuiThread(() => channelPaneController.channelId.setText(channelId.toHex)) - case ChannelStateChanged(channel, _, remoteNodeId, _, currentState, currentData) if m.contains(channel) => + case ChannelStateChanged(channel, _, _, _, currentState, currentData) if m.contains(channel) => val channelPaneController = m(channel) runInGuiThread { () => (currentState, currentData) match { case (WAIT_FOR_FUNDING_CONFIRMED, d: HasCommitments) => channelPaneController.txId.setText(d.commitments.commitInput.outPoint.txid.toHex) - case _ => {} + case _ => } channelPaneController.close.setVisible(STATE_MUTUAL_CLOSE.contains(currentState)) channelPaneController.forceclose.setVisible(STATE_FORCE_CLOSE.contains(currentState)) @@ -137,8 +136,8 @@ class GUIUpdater(mainController: MainController) extends Actor with ActorLogging case NodesDiscovered(nodeAnnouncements) => runInGuiThread { () => - nodeAnnouncements.foreach { nodeAnnouncement => - log.debug(s"peer node discovered with node id={}", nodeAnnouncement.nodeId) + nodeAnnouncements.foreach { nodeAnnouncement => + log.debug(s"peer node discovered with node id={}", nodeAnnouncement.nodeId) if (!mainController.networkNodesMap.containsKey(nodeAnnouncement.nodeId)) { mainController.networkNodesMap.put(nodeAnnouncement.nodeId, nodeAnnouncement) m.foreach(f => if (nodeAnnouncement.nodeId.toString.equals(f._2.peerNodeId)) { @@ -149,7 +148,7 @@ class GUIUpdater(mainController: MainController) extends Actor with ActorLogging } case NodeLost(nodeId) => - log.debug(s"peer node lost with node id=${nodeId}") + log.debug(s"peer node lost with node id=$nodeId") runInGuiThread { () => mainController.networkNodesMap.remove(nodeId) } @@ -165,46 +164,42 @@ class GUIUpdater(mainController: MainController) extends Actor with ActorLogging case ChannelsDiscovered(channelsDiscovered) => runInGuiThread { () => - channelsDiscovered.foreach { case SingleChannelDiscovered(channelAnnouncement, capacity) => - log.debug(s"peer channel discovered with channel id={}", channelAnnouncement.shortChannelId) - if (!mainController.networkChannelsMap.containsKey(channelAnnouncement.shortChannelId)) { - mainController.networkChannelsMap.put(channelAnnouncement.shortChannelId, new ChannelInfo(channelAnnouncement, None, None, None, None, capacity, None, None)) - } - } + channelsDiscovered.foreach { case SingleChannelDiscovered(channelAnnouncement, capacity) => + log.debug(s"peer channel discovered with channel id={}", channelAnnouncement.shortChannelId) + if (!mainController.networkChannelsMap.containsKey(channelAnnouncement.shortChannelId)) { + mainController.networkChannelsMap.put(channelAnnouncement.shortChannelId, ChannelInfo(channelAnnouncement, None, None, None, None, capacity, None, None)) } + } + } case ChannelLost(shortChannelId) => - log.debug(s"peer channel lost with channel id=${shortChannelId}") + log.debug(s"peer channel lost with channel id=$shortChannelId") runInGuiThread { () => mainController.networkChannelsMap.remove(shortChannelId) } case ChannelUpdatesReceived(channelUpdates) => runInGuiThread { () => - channelUpdates.foreach { case channelUpdate => - log.debug(s"peer channel with id={} has been updated - flags: {} fees: {} {}", channelUpdate.shortChannelId, channelUpdate.channelFlags, channelUpdate.feeBaseMsat, channelUpdate.feeProportionalMillionths) - if (mainController.networkChannelsMap.containsKey(channelUpdate.shortChannelId)) { - val c = mainController.networkChannelsMap.get(channelUpdate.shortChannelId) - if (Announcements.isNode1(channelUpdate.channelFlags)) { - c.isNode1Enabled = Some(Announcements.isEnabled(channelUpdate.channelFlags)) - c.feeBaseMsatNode1_opt = Some(channelUpdate.feeBaseMsat.toLong) - c.feeProportionalMillionthsNode1_opt = Some(channelUpdate.feeProportionalMillionths) - } else { - c.isNode2Enabled = Some(Announcements.isEnabled(channelUpdate.channelFlags)) - c.feeBaseMsatNode2_opt = Some(channelUpdate.feeBaseMsat.toLong) - c.feeProportionalMillionthsNode2_opt = Some(channelUpdate.feeProportionalMillionths) - } - mainController.networkChannelsMap.put(channelUpdate.shortChannelId, c) - } + channelUpdates.foreach { channelUpdate => + log.debug(s"peer channel with id={} has been updated - flags: {} fees: {} {}", channelUpdate.shortChannelId, channelUpdate.channelFlags, channelUpdate.feeBaseMsat, channelUpdate.feeProportionalMillionths) + if (mainController.networkChannelsMap.containsKey(channelUpdate.shortChannelId)) { + val c = mainController.networkChannelsMap.get(channelUpdate.shortChannelId) + if (Announcements.isNode1(channelUpdate.channelFlags)) { + c.isNode1Enabled = Some(Announcements.isEnabled(channelUpdate.channelFlags)) + c.feeBaseMsatNode1_opt = Some(channelUpdate.feeBaseMsat.toLong) + c.feeProportionalMillionthsNode1_opt = Some(channelUpdate.feeProportionalMillionths) + } else { + c.isNode2Enabled = Some(Announcements.isEnabled(channelUpdate.channelFlags)) + c.feeBaseMsatNode2_opt = Some(channelUpdate.feeBaseMsat.toLong) + c.feeProportionalMillionthsNode2_opt = Some(channelUpdate.feeProportionalMillionths) } + mainController.networkChannelsMap.put(channelUpdate.shortChannelId, c) } - - case p: PaymentSucceeded => - val message = CoinUtils.formatAmountInUnit(p.amount, FxApp.getUnit, withUnit = true) - mainController.handlers.notification("Payment Sent", message, NOTIFICATION_SUCCESS) + } + } case p: PaymentFailed => - val distilledFailures = PaymentLifecycle.transformForUser(p.failures) + val distilledFailures = PaymentFailure.transformForUser(p.failures) val message = s"${distilledFailures.size} attempts:\n${ distilledFailures.map { case LocalFailure(t) => s"- (local) ${t.getMessage}" @@ -216,15 +211,17 @@ class GUIUpdater(mainController: MainController) extends Actor with ActorLogging case p: PaymentSent => log.debug(s"payment sent with h=${p.paymentHash}, amount=${p.amount}, fees=${p.feesPaid}") - runInGuiThread(() => mainController.paymentSentList.prepend(new PaymentSentRecord(p, LocalDateTime.now()))) + val message = CoinUtils.formatAmountInUnit(p.amount + p.feesPaid, FxApp.getUnit, withUnit = true) + mainController.handlers.notification("Payment Sent", message, NOTIFICATION_SUCCESS) + runInGuiThread(() => mainController.paymentSentList.prepend(PaymentSentRecord(p, LocalDateTime.now()))) case p: PaymentReceived => log.debug(s"payment received with h=${p.paymentHash}, amount=${p.amount}") - runInGuiThread(() => mainController.paymentReceivedList.prepend(new PaymentReceivedRecord(p, LocalDateTime.now()))) + runInGuiThread(() => mainController.paymentReceivedList.prepend(PaymentReceivedRecord(p, LocalDateTime.now()))) case p: PaymentRelayed => log.debug(s"payment relayed with h=${p.paymentHash}, amount=${p.amountIn}, feesEarned=${p.amountOut}") - runInGuiThread(() => mainController.paymentRelayedList.prepend(new PaymentRelayedRecord(p, LocalDateTime.now()))) + runInGuiThread(() => mainController.paymentRelayedList.prepend(PaymentRelayedRecord(p, LocalDateTime.now()))) case ZMQConnected => log.debug("ZMQ connection UP") diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala index 20d9609419..e67797448f 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala @@ -36,8 +36,7 @@ import fr.acinq.bitcoin.{ByteVector32, Satoshi} import fr.acinq.eclair.api.FormParamExtractors._ import fr.acinq.eclair.api.JsonSupport.CustomTypeHints import fr.acinq.eclair.io.NodeURI -import fr.acinq.eclair.payment.PaymentLifecycle.PaymentFailed -import fr.acinq.eclair.payment.{PaymentReceived, PaymentRequest, _} +import fr.acinq.eclair.payment.{PaymentFailed, PaymentReceived, PaymentRequest, _} import fr.acinq.eclair.{CltvExpiryDelta, Eclair, MilliSatoshi} import grizzled.slf4j.Logging import scodec.bits.ByteVector diff --git a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala index 330b821fd0..6bcccf64d8 100644 --- a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala +++ b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala @@ -32,7 +32,7 @@ import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair._ import fr.acinq.eclair.io.NodeURI import fr.acinq.eclair.io.Peer.PeerInfo -import fr.acinq.eclair.payment.PaymentLifecycle.PaymentFailed +import fr.acinq.eclair.payment.PaymentFailed import fr.acinq.eclair.payment._ import fr.acinq.eclair.wire.NodeAddress import org.mockito.scalatest.IdiomaticMockito @@ -357,14 +357,14 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock mockService.route ~> check { - val pf = PaymentFailed(fixedUUID, ByteVector32.Zeroes, failures = Seq.empty) - val expectedSerializedPf = """{"type":"payment-failed","id":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","failures":[]}""" + val pf = PaymentFailed(fixedUUID, ByteVector32.Zeroes, failures = Seq.empty, timestamp = 1553784963659L) + val expectedSerializedPf = """{"type":"payment-failed","id":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","failures":[],"timestamp":1553784963659}""" serialization.write(pf)(mockService.formatsWithTypeHint) === expectedSerializedPf system.eventStream.publish(pf) wsClient.expectMessage(expectedSerializedPf) - val ps = PaymentSent(fixedUUID, amount = 21 msat, feesPaid = 1 msat, paymentHash = ByteVector32.Zeroes, paymentPreimage = ByteVector32.One, toChannelId = ByteVector32.Zeroes, timestamp = 1553784337711L) - val expectedSerializedPs = """{"type":"payment-sent","id":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","amount":21,"feesPaid":1,"paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","paymentPreimage":"0100000000000000000000000000000000000000000000000000000000000000","toChannelId":"0000000000000000000000000000000000000000000000000000000000000000","timestamp":1553784337711}""" + val ps = PaymentSent(fixedUUID, amount = 21 msat, feesPaid = 1 msat, paymentHash = ByteVector32.Zeroes, paymentPreimage = ByteVector32.One, route = Nil, timestamp = 1553784337711L) + val expectedSerializedPs = """{"type":"payment-sent","id":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","amount":21,"feesPaid":1,"paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","paymentPreimage":"0100000000000000000000000000000000000000000000000000000000000000","route":[],"timestamp":1553784337711}""" serialization.write(ps)(mockService.formatsWithTypeHint) === expectedSerializedPs system.eventStream.publish(ps) wsClient.expectMessage(expectedSerializedPs) @@ -375,8 +375,8 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock system.eventStream.publish(prel) wsClient.expectMessage(expectedSerializedPrel) - val precv = PaymentReceived(amount = 21 msat, paymentHash = ByteVector32.Zeroes, fromChannelId = ByteVector32.One, timestamp = 1553784963659L) - val expectedSerializedPrecv = """{"type":"payment-received","amount":21,"paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","fromChannelId":"0100000000000000000000000000000000000000000000000000000000000000","timestamp":1553784963659}""" + val precv = PaymentReceived(amount = 21 msat, paymentHash = ByteVector32.Zeroes, timestamp = 1553784963659L) + val expectedSerializedPrecv = """{"type":"payment-received","amount":21,"paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","timestamp":1553784963659}""" serialization.write(precv)(mockService.formatsWithTypeHint) === expectedSerializedPrecv system.eventStream.publish(precv) wsClient.expectMessage(expectedSerializedPrecv) From 8c4d7f46b0c8bdc1778e4d406f3fbe21db4246a1 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Mon, 16 Sep 2019 14:46:50 +0200 Subject: [PATCH 02/14] Update Audit DB for new payment events. ChannelId is removed (won't make sense for AMP). Fixed typo in ChannelErrorOccurred. --- .../fr/acinq/eclair/channel/Channel.scala | 10 +- .../acinq/eclair/channel/ChannelEvents.scala | 2 +- .../scala/fr/acinq/eclair/db/AuditDb.scala | 6 +- .../eclair/db/sqlite/SqliteAuditDb.scala | 61 ++++-- .../acinq/eclair/db/sqlite/SqliteUtils.scala | 76 +++----- .../fr/acinq/eclair/payment/Auditor.scala | 12 +- .../channel/states/e/NormalStateSpec.scala | 14 +- .../channel/states/e/OfflineStateSpec.scala | 4 +- .../acinq/eclair/db/SqliteAuditDbSpec.scala | 179 ++++++++++++++---- 9 files changed, 240 insertions(+), 124 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala index 616c3b2b39..7b726270f9 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala @@ -1732,7 +1732,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId case _: ChannelException => () case _ => log.error(cause, s"msg=$cmd stateData=$stateData ") } - context.system.eventStream.publish(ChannelErrorOccured(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, LocalError(cause), isFatal = false)) + context.system.eventStream.publish(ChannelErrorOccurred(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, LocalError(cause), isFatal = false)) stay replying Status.Failure(cause) } @@ -1785,7 +1785,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId val error = Error(d.channelId, exc.getMessage) // NB: we don't use the handleLocalError handler because it would result in the commit tx being published, which we don't want: // implementation *guarantees* that in case of BITCOIN_FUNDING_PUBLISH_FAILED, the funding tx hasn't and will never be published, so we can close the channel right away - context.system.eventStream.publish(ChannelErrorOccured(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, LocalError(exc), isFatal = true)) + context.system.eventStream.publish(ChannelErrorOccurred(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, LocalError(exc), isFatal = true)) goto(CLOSED) sending error } @@ -1793,7 +1793,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId log.warning(s"funding tx hasn't been confirmed in time, cancelling channel delay=$FUNDING_TIMEOUT_FUNDEE") val exc = FundingTxTimedout(d.channelId) val error = Error(d.channelId, exc.getMessage) - context.system.eventStream.publish(ChannelErrorOccured(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, LocalError(exc), isFatal = true)) + context.system.eventStream.publish(ChannelErrorOccurred(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, LocalError(exc), isFatal = true)) goto(CLOSED) sending error } @@ -1863,7 +1863,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId case _ => log.error(cause, s"msg=${msg.getOrElse("n/a")} stateData=$stateData ") } val error = Error(Helpers.getChannelId(d), cause.getMessage) - context.system.eventStream.publish(ChannelErrorOccured(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, LocalError(cause), isFatal = true)) + context.system.eventStream.publish(ChannelErrorOccurred(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, LocalError(cause), isFatal = true)) d match { case dd: HasCommitments if Closing.nothingAtStake(dd) => goto(CLOSED) @@ -1879,7 +1879,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId def handleRemoteError(e: Error, d: Data) = { // see BOLT 1: only print out data verbatim if is composed of printable ASCII characters log.error(s"peer sent error: ascii='${e.toAscii}' bin=${e.data.toHex}") - context.system.eventStream.publish(ChannelErrorOccured(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, RemoteError(e), isFatal = true)) + context.system.eventStream.publish(ChannelErrorOccurred(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, RemoteError(e), isFatal = true)) d match { case _: DATA_CLOSING => stay // nothing to do, there is already a spending tx published diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelEvents.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelEvents.scala index c3b3d79979..bf6441b0d8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelEvents.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelEvents.scala @@ -48,7 +48,7 @@ case class ChannelSignatureSent(channel: ActorRef, commitments: Commitments) ext case class ChannelSignatureReceived(channel: ActorRef, commitments: Commitments) extends ChannelEvent -case class ChannelErrorOccured(channel: ActorRef, channelId: ByteVector32, remoteNodeId: PublicKey, data: Data, error: ChannelError, isFatal: Boolean) extends ChannelEvent +case class ChannelErrorOccurred(channel: ActorRef, channelId: ByteVector32, remoteNodeId: PublicKey, data: Data, error: ChannelError, isFatal: Boolean) extends ChannelEvent case class NetworkFeePaid(channel: ActorRef, remoteNodeId: PublicKey, channelId: ByteVector32, tx: Transaction, fee: Satoshi, txType: String) extends ChannelEvent diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/AuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/AuditDb.scala index 5fccd441b1..c271604353 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/AuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/AuditDb.scala @@ -16,8 +16,8 @@ package fr.acinq.eclair.db -import fr.acinq.bitcoin.{ByteVector32, Satoshi} import fr.acinq.bitcoin.Crypto.PublicKey +import fr.acinq.bitcoin.{ByteVector32, Satoshi} import fr.acinq.eclair.channel._ import fr.acinq.eclair.payment.{PaymentReceived, PaymentRelayed, PaymentSent} @@ -35,7 +35,7 @@ trait AuditDb { def add(networkFeePaid: NetworkFeePaid) - def add(channelErrorOccured: ChannelErrorOccured) + def add(channelErrorOccurred: ChannelErrorOccurred) def listSent(from: Long, to: Long): Seq[PaymentSent] @@ -47,7 +47,7 @@ trait AuditDb { def stats: Seq[Stats] - def close: Unit + def close(): Unit } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala index c03bec2a84..f0342e953a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala @@ -18,17 +18,19 @@ package fr.acinq.eclair.db.sqlite import java.sql.{Connection, Statement} import java.util.UUID + import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.Satoshi import fr.acinq.eclair.MilliSatoshi -import fr.acinq.eclair.channel.{AvailableBalanceChanged, Channel, ChannelErrorOccured, NetworkFeePaid} -import fr.acinq.eclair.db.{AuditDb, ChannelLifecycleEvent, NetworkFee, Stats} -import fr.acinq.eclair.payment.{PaymentReceived, PaymentRelayed, PaymentSent} +import fr.acinq.eclair.channel.{AvailableBalanceChanged, Channel, ChannelErrorOccurred, NetworkFeePaid} +import fr.acinq.eclair.db._ +import fr.acinq.eclair.payment._ import fr.acinq.eclair.wire.ChannelCodecs import grizzled.slf4j.Logging + import scala.collection.immutable.Queue import scala.compat.Platform -import concurrent.duration._ +import scala.concurrent.duration._ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { @@ -36,7 +38,7 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { import ExtendedResultSet._ val DB_NAME = "audit" - val CURRENT_VERSION = 3 + val CURRENT_VERSION = 4 using(sqlite.createStatement()) { statement => @@ -49,20 +51,42 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)") } + def migration34(statement: Statement) = { + statement.executeUpdate("DROP index sent_timestamp_idx") + statement.executeUpdate("ALTER TABLE sent RENAME TO _sent_old") + statement.executeUpdate("CREATE TABLE sent (id BLOB NOT NULL, amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("INSERT INTO sent (id, amount_msat, fees_msat, payment_hash, payment_preimage, timestamp) SELECT id, amount_msat, fees_msat, payment_hash, payment_preimage, timestamp FROM _sent_old") + statement.executeUpdate("DROP table _sent_old") + statement.executeUpdate("CREATE INDEX sent_timestamp_idx ON sent(timestamp)") + + statement.executeUpdate("DROP index received_timestamp_idx") + statement.executeUpdate("ALTER TABLE received RENAME TO _received_old") + statement.executeUpdate("CREATE TABLE received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("INSERT INTO received (amount_msat, payment_hash, timestamp) SELECT amount_msat, payment_hash, timestamp FROM _received_old") + statement.executeUpdate("DROP table _received_old") + statement.executeUpdate("CREATE INDEX received_timestamp_idx ON received(timestamp)") + } + getVersion(statement, DB_NAME, CURRENT_VERSION) match { case 1 => // previous version let's migrate logger.warn(s"migrating db $DB_NAME, found version=1 current=$CURRENT_VERSION") migration12(statement) migration23(statement) + migration34(statement) setVersion(statement, DB_NAME, CURRENT_VERSION) case 2 => logger.warn(s"migrating db $DB_NAME, found version=2 current=$CURRENT_VERSION") migration23(statement) + migration34(statement) + setVersion(statement, DB_NAME, CURRENT_VERSION) + case 3 => + logger.warn(s"migrating db $DB_NAME, found version=3 current=$CURRENT_VERSION") + migration34(statement) setVersion(statement, DB_NAME, CURRENT_VERSION) case CURRENT_VERSION => statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL, id BLOB NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (id BLOB NOT NULL, amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, timestamp INTEGER NOT NULL)") statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (amount_in_msat INTEGER NOT NULL, amount_out_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)") statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event STRING NOT NULL, timestamp INTEGER NOT NULL)") @@ -104,24 +128,22 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { } override def add(e: PaymentSent): Unit = - using(sqlite.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement => - statement.setLong(1, e.amount.toLong) - statement.setLong(2, e.feesPaid.toLong) - statement.setBytes(3, e.paymentHash.toArray) - statement.setBytes(4, e.paymentPreimage.toArray) - statement.setBytes(5, e.toChannelId.toArray) + using(sqlite.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setBytes(1, e.id.toString.getBytes) + statement.setLong(2, e.amount.toLong) + statement.setLong(3, e.feesPaid.toLong) + statement.setBytes(4, e.paymentHash.toArray) + statement.setBytes(5, e.paymentPreimage.toArray) statement.setLong(6, e.timestamp) - statement.setBytes(7, e.id.toString.getBytes) statement.executeUpdate() } override def add(e: PaymentReceived): Unit = - using(sqlite.prepareStatement("INSERT INTO received VALUES (?, ?, ?, ?)")) { statement => + using(sqlite.prepareStatement("INSERT INTO received VALUES (?, ?, ?)")) { statement => statement.setLong(1, e.amount.toLong) statement.setBytes(2, e.paymentHash.toArray) - statement.setBytes(3, e.fromChannelId.toArray) - statement.setLong(4, e.timestamp) + statement.setLong(3, e.timestamp) statement.executeUpdate() } @@ -147,7 +169,7 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { statement.executeUpdate() } - override def add(e: ChannelErrorOccured): Unit = + override def add(e: ChannelErrorOccurred): Unit = using(sqlite.prepareStatement("INSERT INTO channel_errors VALUES (?, ?, ?, ?, ?, ?)")) { statement => val (errorName, errorMessage) = e.error match { case Channel.LocalError(t) => (t.getClass.getSimpleName, t.getMessage) @@ -175,7 +197,7 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { feesPaid = MilliSatoshi(rs.getLong("fees_msat")), paymentHash = rs.getByteVector32("payment_hash"), paymentPreimage = rs.getByteVector32("payment_preimage"), - toChannelId = rs.getByteVector32("to_channel_id"), + route = Nil, timestamp = rs.getLong("timestamp")) } q @@ -191,7 +213,6 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { q = q :+ PaymentReceived( amount = MilliSatoshi(rs.getLong("amount_msat")), paymentHash = rs.getByteVector32("payment_hash"), - fromChannelId = rs.getByteVector32("from_channel_id"), timestamp = rs.getLong("timestamp")) } q diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala index 35af8fc8dd..edef34db43 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala @@ -27,11 +27,8 @@ import scala.collection.immutable.Queue object SqliteUtils { /** - * Manages closing of statement - * - * @param statement - * @param block - */ + * Manages closing of statement + */ def using[T <: Statement, U](statement: T, disableAutoCommit: Boolean = false)(block: T => U): U = { try { if (disableAutoCommit) statement.getConnection.setAutoCommit(false) @@ -43,15 +40,10 @@ object SqliteUtils { } /** - * Several logical databases (channels, network, peers) may be stored in the same physical sqlite database. - * We keep track of their respective version using a dedicated table. The version entry will be created if - * there is none but will never be updated here (use setVersion to do that). - * - * @param statement - * @param db_name - * @param currentVersion - * @return - */ + * Several logical databases (channels, network, peers) may be stored in the same physical sqlite database. + * We keep track of their respective version using a dedicated table. The version entry will be created if + * there is none but will never be updated here (use setVersion to do that). + */ def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int = { statement.executeUpdate("CREATE TABLE IF NOT EXISTS versions (db_name TEXT NOT NULL PRIMARY KEY, version INTEGER NOT NULL)") // if there was no version for the current db, then insert the current version @@ -62,12 +54,8 @@ object SqliteUtils { } /** - * Updates the version for a particular logical database, it will overwrite the previous version. - * @param statement - * @param db_name - * @param newVersion - * @return - */ + * Updates the version for a particular logical database, it will overwrite the previous version. + */ def setVersion(statement: Statement, db_name: String, newVersion: Int) = { statement.executeUpdate("CREATE TABLE IF NOT EXISTS versions (db_name TEXT NOT NULL PRIMARY KEY, version INTEGER NOT NULL)") // overwrite the existing version @@ -75,15 +63,10 @@ object SqliteUtils { } /** - * This helper assumes that there is a "data" column available, decodable with the provided codec - * - * TODO: we should use an scala.Iterator instead - * - * @param rs - * @param codec - * @tparam T - * @return - */ + * This helper assumes that there is a "data" column available, decodable with the provided codec + * + * TODO: we should use an scala.Iterator instead + */ def codecSequence[T](rs: ResultSet, codec: Codec[T]): Seq[T] = { var q: Queue[T] = Queue() while (rs.next()) { @@ -93,27 +76,22 @@ object SqliteUtils { } /** - * This helper retrieves the value from a nullable integer column and interprets it as an option. This is needed - * because `rs.getLong` would return `0` for a null value. - * It is used on Android only - * - * @param label - * @return - */ - def getNullableLong(rs: ResultSet, label: String) : Option[Long] = { + * This helper retrieves the value from a nullable integer column and interprets it as an option. This is needed + * because `rs.getLong` would return `0` for a null value. + * It is used on Android only + */ + def getNullableLong(rs: ResultSet, label: String): Option[Long] = { val result = rs.getLong(label) if (rs.wasNull()) None else Some(result) } /** - * Obtain an exclusive lock on a sqlite database. This is useful when we want to make sure that only one process - * accesses the database file (see https://www.sqlite.org/pragma.html). - * - * The lock will be kept until the database is closed, or if the locking mode is explicitly reset. - * - * @param sqlite - */ - def obtainExclusiveLock(sqlite: Connection){ + * Obtain an exclusive lock on a sqlite database. This is useful when we want to make sure that only one process + * accesses the database file (see https://www.sqlite.org/pragma.html). + * + * The lock will be kept until the database is closed, or if the locking mode is explicitly reset. + */ + def obtainExclusiveLock(sqlite: Connection) { val statement = sqlite.createStatement() statement.execute("PRAGMA locking_mode = EXCLUSIVE") // we have to make a write to actually obtain the lock @@ -127,15 +105,21 @@ object SqliteUtils { def getByteVector(columnLabel: String): ByteVector = ByteVector(rs.getBytes(columnLabel)) + def getByteVectorNullable(columnLabel: String): ByteVector = { + val result = rs.getBytes(columnLabel) + if (rs.wasNull()) ByteVector.empty else ByteVector(result) + } + def getByteVector32(columnLabel: String): ByteVector32 = ByteVector32(ByteVector(rs.getBytes(columnLabel))) def getByteVector32Nullable(columnLabel: String): Option[ByteVector32] = { val bytes = rs.getBytes(columnLabel) - if(rs.wasNull()) None else Some(ByteVector32(ByteVector(bytes))) + if (rs.wasNull()) None else Some(ByteVector32(ByteVector(bytes))) } } object ExtendedResultSet { implicit def conv(rs: ResultSet): ExtendedResultSet = ExtendedResultSet(rs) } + } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Auditor.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Auditor.scala index 6a8eaa82bd..1036925221 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Auditor.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Auditor.scala @@ -35,7 +35,7 @@ class Auditor(nodeParams: NodeParams) extends Actor with ActorLogging { context.system.eventStream.subscribe(self, classOf[PaymentEvent]) context.system.eventStream.subscribe(self, classOf[NetworkFeePaid]) context.system.eventStream.subscribe(self, classOf[AvailableBalanceChanged]) - context.system.eventStream.subscribe(self, classOf[ChannelErrorOccured]) + context.system.eventStream.subscribe(self, classOf[ChannelErrorOccurred]) context.system.eventStream.subscribe(self, classOf[ChannelStateChanged]) context.system.eventStream.subscribe(self, classOf[ChannelClosed]) @@ -74,7 +74,7 @@ class Auditor(nodeParams: NodeParams) extends Actor with ActorLogging { case e: AvailableBalanceChanged => balanceEventThrottler ! e - case e: ChannelErrorOccured => + case e: ChannelErrorOccurred => val metric = Kamon.counter("channels.errors") e.error match { case LocalError(_) if e.isFatal => metric.withTag("origin", "local").withTag("fatal", "yes").increment() @@ -113,8 +113,8 @@ class Auditor(nodeParams: NodeParams) extends Actor with ActorLogging { } /** - * We don't want to log every tiny payment, and we don't want to log probing events. - */ + * We don't want to log every tiny payment, and we don't want to log probing events. + */ class BalanceEventThrottler(db: AuditDb) extends Actor with ActorLogging { import ExecutionContext.Implicits.global @@ -135,11 +135,11 @@ class BalanceEventThrottler(db: AuditDb) extends Actor with ActorLogging { // we delay the processing of the event in order to smooth variations log.info(s"will log balance event in $delay for channelId=${e.channelId}") context.system.scheduler.scheduleOnce(delay, self, ProcessEvent(e.channelId)) - context.become(run(pending + (e.channelId -> (BalanceUpdate(e, e))))) + context.become(run(pending + (e.channelId -> BalanceUpdate(e, e)))) case Some(BalanceUpdate(first, _)) => // we already are about to log a balance event, let's update the data we have log.info(s"updating balance data for channelId=${e.channelId}") - context.become(run(pending + (e.channelId -> (BalanceUpdate(first, e))))) + context.become(run(pending + (e.channelId -> BalanceUpdate(first, e)))) } case ProcessEvent(channelId) => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala index 10182a87db..581ffe4b14 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala @@ -29,7 +29,7 @@ import fr.acinq.eclair.blockchain._ import fr.acinq.eclair.blockchain.fee.FeeratesPerKw import fr.acinq.eclair.channel.Channel._ import fr.acinq.eclair.channel.states.StateTestsHelperMethods -import fr.acinq.eclair.channel.{ChannelErrorOccured, _} +import fr.acinq.eclair.channel.{ChannelErrorOccurred, _} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.io.Peer import fr.acinq.eclair.payment._ @@ -1710,7 +1710,7 @@ class NormalStateSpec extends TestkitBaseClass with StateTestsHelperMethods { crossSign(alice, bob, alice2bob, bob2alice) val listener = TestProbe() - system.eventStream.subscribe(listener.ref, classOf[ChannelErrorOccured]) + system.eventStream.subscribe(listener.ref, classOf[ChannelErrorOccurred]) // actual test begins: // * Bob receives the HTLC pre-image and wants to fulfill @@ -1726,7 +1726,7 @@ class NormalStateSpec extends TestkitBaseClass with StateTestsHelperMethods { bob2alice.expectMsgType[UpdateFulfillHtlc] sender.send(bob, CurrentBlockCount((htlc.cltvExpiry - Bob.nodeParams.fulfillSafetyBeforeTimeoutBlocks).toLong)) - val ChannelErrorOccured(_, _, _, _, LocalError(err), isFatal) = listener.expectMsgType[ChannelErrorOccured] + val ChannelErrorOccurred(_, _, _, _, LocalError(err), isFatal) = listener.expectMsgType[ChannelErrorOccurred] assert(isFatal) assert(err.isInstanceOf[HtlcWillTimeoutUpstream]) @@ -1745,7 +1745,7 @@ class NormalStateSpec extends TestkitBaseClass with StateTestsHelperMethods { crossSign(alice, bob, alice2bob, bob2alice) val listener = TestProbe() - system.eventStream.subscribe(listener.ref, classOf[ChannelErrorOccured]) + system.eventStream.subscribe(listener.ref, classOf[ChannelErrorOccurred]) // actual test begins: // * Bob receives the HTLC pre-image and wants to fulfill but doesn't sign @@ -1761,7 +1761,7 @@ class NormalStateSpec extends TestkitBaseClass with StateTestsHelperMethods { bob2alice.expectMsgType[UpdateFulfillHtlc] sender.send(bob, CurrentBlockCount((htlc.cltvExpiry - Bob.nodeParams.fulfillSafetyBeforeTimeoutBlocks).toLong)) - val ChannelErrorOccured(_, _, _, _, LocalError(err), isFatal) = listener.expectMsgType[ChannelErrorOccured] + val ChannelErrorOccurred(_, _, _, _, LocalError(err), isFatal) = listener.expectMsgType[ChannelErrorOccurred] assert(isFatal) assert(err.isInstanceOf[HtlcWillTimeoutUpstream]) @@ -1780,7 +1780,7 @@ class NormalStateSpec extends TestkitBaseClass with StateTestsHelperMethods { crossSign(alice, bob, alice2bob, bob2alice) val listener = TestProbe() - system.eventStream.subscribe(listener.ref, classOf[ChannelErrorOccured]) + system.eventStream.subscribe(listener.ref, classOf[ChannelErrorOccurred]) // actual test begins: // * Bob receives the HTLC pre-image and wants to fulfill @@ -1801,7 +1801,7 @@ class NormalStateSpec extends TestkitBaseClass with StateTestsHelperMethods { alice2bob.forward(bob) sender.send(bob, CurrentBlockCount((htlc.cltvExpiry - Bob.nodeParams.fulfillSafetyBeforeTimeoutBlocks).toLong)) - val ChannelErrorOccured(_, _, _, _, LocalError(err), isFatal) = listener.expectMsgType[ChannelErrorOccured] + val ChannelErrorOccurred(_, _, _, _, LocalError(err), isFatal) = listener.expectMsgType[ChannelErrorOccurred] assert(isFatal) assert(err.isInstanceOf[HtlcWillTimeoutUpstream]) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala index e73e50140e..ca9fe9e34c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala @@ -405,7 +405,7 @@ class OfflineStateSpec extends TestkitBaseClass with StateTestsHelperMethods { crossSign(alice, bob, alice2bob, bob2alice) val listener = TestProbe() - system.eventStream.subscribe(listener.ref, classOf[ChannelErrorOccured]) + system.eventStream.subscribe(listener.ref, classOf[ChannelErrorOccurred]) val initialState = bob.stateData.asInstanceOf[DATA_NORMAL] val initialCommitTx = initialState.commitments.localCommit.publishableTxs.commitTx.tx @@ -421,7 +421,7 @@ class OfflineStateSpec extends TestkitBaseClass with StateTestsHelperMethods { sender.send(commandBuffer, CommandSend(htlc.channelId, htlc.id, CMD_FULFILL_HTLC(htlc.id, r, commit = true))) sender.send(bob, CurrentBlockCount((htlc.cltvExpiry - bob.underlyingActor.nodeParams.fulfillSafetyBeforeTimeoutBlocks).toLong)) - val ChannelErrorOccured(_, _, _, _, LocalError(err), isFatal) = listener.expectMsgType[ChannelErrorOccured] + val ChannelErrorOccurred(_, _, _, _, LocalError(err), isFatal) = listener.expectMsgType[ChannelErrorOccurred] assert(isFatal) assert(err.isInstanceOf[HtlcWillTimeoutUpstream]) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala index 49dd366358..898783e0b1 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala @@ -21,11 +21,12 @@ import java.util.UUID import fr.acinq.bitcoin.Transaction import fr.acinq.eclair._ import fr.acinq.eclair.channel.Channel.{LocalError, RemoteError} -import fr.acinq.eclair.channel.{AvailableBalanceChanged, ChannelErrorOccured, NetworkFeePaid} +import fr.acinq.eclair.channel.{AvailableBalanceChanged, ChannelErrorOccurred, NetworkFeePaid} import fr.acinq.eclair.db.sqlite.SqliteAuditDb import fr.acinq.eclair.db.sqlite.SqliteUtils.{getVersion, using} -import fr.acinq.eclair.payment.{PaymentReceived, PaymentRelayed, PaymentSent} -import fr.acinq.eclair.wire.{ChannelCodecs, ChannelCodecsSpec} +import fr.acinq.eclair.payment._ +import fr.acinq.eclair.router.Hop +import fr.acinq.eclair.wire.{ChannelCodecs, ChannelCodecsSpec, ChannelUpdate} import org.scalatest.FunSuite import scala.compat.Platform @@ -34,6 +35,8 @@ import scala.concurrent.duration._ class SqliteAuditDbSpec extends FunSuite { + import SqliteAuditDbSpec._ + test("init sqlite 2 times in a row") { val sqlite = TestConstants.sqliteInMemory() val db1 = new SqliteAuditDb(sqlite) @@ -44,16 +47,16 @@ class SqliteAuditDbSpec extends FunSuite { val sqlite = TestConstants.sqliteInMemory() val db = new SqliteAuditDb(sqlite) - val e1 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, 42000 msat, 1000 msat, randomBytes32, randomBytes32, randomBytes32) - val e2 = PaymentReceived(42000 msat, randomBytes32, randomBytes32) + val e1 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, 42000 msat, 1000 msat, randomBytes32, randomBytes32, Seq(Hop(carol, dave, channelUpdate2))) + val e2 = PaymentReceived(42000 msat, randomBytes32) val e3 = PaymentRelayed(42000 msat, 1000 msat, randomBytes32, randomBytes32, randomBytes32) val e4 = NetworkFeePaid(null, randomKey.publicKey, randomBytes32, Transaction(0, Seq.empty, Seq.empty, 0), 42 sat, "mutual") - val e5 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, 42000 msat, 1000 msat, randomBytes32, randomBytes32, randomBytes32, timestamp = 0) - val e6 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, 42000 msat, 1000 msat, randomBytes32, randomBytes32, randomBytes32, timestamp = (Platform.currentTime.milliseconds + 10.minutes).toMillis) + val e5 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, 42000 msat, 1000 msat, randomBytes32, randomBytes32, Seq(Hop(alice, bob, channelUpdate1), Hop(bob, carol, channelUpdate2)), timestamp = 0) + val e6 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, 42000 msat, 1000 msat, randomBytes32, randomBytes32, Nil, timestamp = (Platform.currentTime.milliseconds + 10.minutes).toMillis) val e7 = AvailableBalanceChanged(null, randomBytes32, ShortChannelId(500000, 42, 1), 456123000 msat, ChannelCodecsSpec.commitments) - val e8 = ChannelLifecycleEvent(randomBytes32, randomKey.publicKey, 456123000 sat, true, false, "mutual") - val e9 = ChannelErrorOccured(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), true) - val e10 = ChannelErrorOccured(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), true) + val e8 = ChannelLifecycleEvent(randomBytes32, randomKey.publicKey, 456123000 sat, isFunder = true, isPrivate = false, "mutual") + val e9 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true) + val e10 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true) db.add(e1) db.add(e2) @@ -66,8 +69,8 @@ class SqliteAuditDbSpec extends FunSuite { db.add(e9) db.add(e10) - assert(db.listSent(from = 0L, to = (Platform.currentTime.milliseconds + 15.minute).toSeconds).toSet === Set(e1, e5, e6)) - assert(db.listSent(from = 100000L, to = (Platform.currentTime.milliseconds + 1.minute).toSeconds).toList === List(e1)) + assert(db.listSent(from = 0L, to = (Platform.currentTime.milliseconds + 15.minute).toSeconds).toSet === Set(e1.copy(route = Nil), e5.copy(route = Nil), e6)) + assert(db.listSent(from = 100000L, to = (Platform.currentTime.milliseconds + 1.minute).toSeconds).toList === List(e1.copy(route = Nil))) assert(db.listReceived(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toSeconds).toList === List(e2)) assert(db.listRelayed(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toSeconds).toList === List(e3)) assert(db.listNetworkFees(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toSeconds).size === 1) @@ -103,7 +106,7 @@ class SqliteAuditDbSpec extends FunSuite { )) } - test("handle migration version 1 -> 3") { + test("handle migration version 1 -> 4") { val connection = TestConstants.sqliteInMemory() @@ -126,51 +129,69 @@ class SqliteAuditDbSpec extends FunSuite { } using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit", 3) == 1) // we expect version 1 + assert(getVersion(statement, "audit", 4) == 1) // we expect version 1 } - val ps = PaymentSent(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, randomBytes32, randomBytes32) - val ps1 = PaymentSent(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, randomBytes32, randomBytes32) - val ps2 = PaymentSent(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, randomBytes32, randomBytes32) - val e1 = ChannelErrorOccured(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), true) - val e2 = ChannelErrorOccured(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), true) - - // add a row (no ID on sent) + val ps = PaymentSent(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, randomBytes32, Seq(Hop(alice, bob, channelUpdate1))) + val ps1 = PaymentSent(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, randomBytes32, Seq(Hop(alice, bob, channelUpdate1), Hop(bob, carol, channelUpdate2))) + val ps2 = PaymentSent(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, randomBytes32, Nil) + val pr = PaymentReceived(561 msat, randomBytes32) + val pr1 = PaymentReceived(1105 msat, randomBytes32) + val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true) + val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true) + + // Changes to the 'sent' table between versions 1 and 4: + // - the 'id' column was added + // - the 'toChannelId' column was removed using(connection.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?)")) { statement => statement.setLong(1, ps.amount.toLong) statement.setLong(2, ps.feesPaid.toLong) statement.setBytes(3, ps.paymentHash.toArray) statement.setBytes(4, ps.paymentPreimage.toArray) - statement.setBytes(5, ps.toChannelId.toArray) + statement.setBytes(5, randomBytes32.toArray) // toChannelId statement.setLong(6, ps.timestamp) statement.executeUpdate() } + // Changes to the 'received' table between versions 1 and 4: + // - the 'fromChannelId' column was removed + using(connection.prepareStatement("INSERT INTO received VALUES (?, ?, ?, ?)")) { statement => + statement.setLong(1, pr.amount.toLong) + statement.setBytes(2, pr.paymentHash.toArray) + statement.setBytes(3, randomBytes32.toArray) // fromChannelId + statement.setLong(4, pr.timestamp) + statement.executeUpdate() + } + val migratedDb = new SqliteAuditDb(connection) using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit", 3) == 3) // version changed from 1 -> 3 + assert(getVersion(statement, "audit", 4) == 4) // version changed from 1 -> 4 } - // existing rows will use 00000000-0000-0000-0000-000000000000 as default - assert(migratedDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) == Seq(ps.copy(id = ChannelCodecs.UNKNOWN_UUID))) + // existing rows in the 'sent' table will use id=00000000-0000-0000-0000-000000000000 as default + assert(migratedDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) === Seq(ps.copy(id = ChannelCodecs.UNKNOWN_UUID, route = Nil))) + // existing rows in the 'received' table will not contain a fromChannelId anymore + assert(migratedDb.listReceived(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) === Seq(pr)) val postMigrationDb = new SqliteAuditDb(connection) using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit", 3) == 3) // version 3 + assert(getVersion(statement, "audit", 4) == 4) // version 4 } postMigrationDb.add(ps1) postMigrationDb.add(ps2) postMigrationDb.add(e1) postMigrationDb.add(e2) + postMigrationDb.add(pr1) - // the old record will have the UNKNOWN_UUID but the new ones will have their actual id - assert(postMigrationDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) == Seq(ps.copy(id = ChannelCodecs.UNKNOWN_UUID), ps1, ps2)) + // the old 'sent' record will have the UNKNOWN_UUID and an empty route but the new ones will have their actual id + assert(postMigrationDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) === Seq(ps.copy(id = ChannelCodecs.UNKNOWN_UUID, route = Nil), ps1.copy(route = Nil), ps2.copy(route = Nil))) + assert(postMigrationDb.listReceived(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) === Seq(pr, pr1)) } - test("handle migration version 2 -> 3") { + test("handle migration version 2 -> 4") { val connection = TestConstants.sqliteInMemory() @@ -193,16 +214,16 @@ class SqliteAuditDbSpec extends FunSuite { } using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit", 3) == 2) // version 2 is deployed now + assert(getVersion(statement, "audit", 4) == 2) // version 2 is deployed now } - val e1 = ChannelErrorOccured(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), true) - val e2 = ChannelErrorOccured(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), true) + val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true) + val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true) val migratedDb = new SqliteAuditDb(connection) using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit", 3) == 3) // version changed from 2 -> 3 + assert(getVersion(statement, "audit", 4) == 4) // version changed from 2 -> 4 } migratedDb.add(e1) @@ -210,10 +231,100 @@ class SqliteAuditDbSpec extends FunSuite { val postMigrationDb = new SqliteAuditDb(connection) using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit", 3) == 3) // version 3 + assert(getVersion(statement, "audit", 4) == 4) // version 4 } postMigrationDb.add(e2) } + test("handle migration version 3 -> 4") { + val connection = TestConstants.sqliteInMemory() + + // simulate existing previous version db + using(connection.createStatement()) { statement => + getVersion(statement, "audit", 3) + statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL, id BLOB NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (amount_in_msat INTEGER NOT NULL, amount_out_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event STRING NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, error_name STRING NOT NULL, error_message STRING NOT NULL, is_fatal INTEGER NOT NULL, timestamp INTEGER NOT NULL)") + + statement.executeUpdate("CREATE INDEX IF NOT EXISTS balance_updated_idx ON balance_updated(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)") + } + + using(connection.createStatement()) { statement => + assert(getVersion(statement, "audit", 4) == 3) // version 3 is deployed now + } + + val ps = PaymentSent(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, randomBytes32, Seq(Hop(alice, bob, channelUpdate1))) + val ps1 = PaymentSent(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, randomBytes32, Seq(Hop(alice, bob, channelUpdate1), Hop(bob, carol, channelUpdate2))) + val ps2 = PaymentSent(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, randomBytes32, Nil) + val pr = PaymentReceived(561 msat, randomBytes32) + val pr1 = PaymentReceived(1105 msat, randomBytes32) + + // Changes to the 'sent' table between versions 3 and 4: + // - the 'toChannelId' column was removed + using(connection.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement => + statement.setLong(1, ps.amount.toLong) + statement.setLong(2, ps.feesPaid.toLong) + statement.setBytes(3, ps.paymentHash.toArray) + statement.setBytes(4, ps.paymentPreimage.toArray) + statement.setBytes(5, randomBytes32.toArray) // toChannelId + statement.setLong(6, ps.timestamp) + statement.setBytes(7, ps.id.toString.getBytes) + statement.executeUpdate() + } + + // Changes to the 'received' table between versions 3 and 4: + // - the 'fromChannelId' column was removed + using(connection.prepareStatement("INSERT INTO received VALUES (?, ?, ?, ?)")) { statement => + statement.setLong(1, pr.amount.toLong) + statement.setBytes(2, pr.paymentHash.toArray) + statement.setBytes(3, randomBytes32.toArray) // fromChannelId + statement.setLong(4, pr.timestamp) + statement.executeUpdate() + } + + val migratedDb = new SqliteAuditDb(connection) + + using(connection.createStatement()) { statement => + assert(getVersion(statement, "audit", 4) == 4) // version changed from 3 -> 4 + } + + // existing rows in the 'sent' table will use route=NULL as default + assert(migratedDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) === Seq(ps.copy(route = Nil))) + // existing rows in the 'received' table will not contain a fromChannelId anymore + assert(migratedDb.listReceived(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) === Seq(pr)) + + val postMigrationDb = new SqliteAuditDb(connection) + + using(connection.createStatement()) { statement => + assert(getVersion(statement, "audit", 4) == 4) // version 4 + } + + postMigrationDb.add(ps1) + postMigrationDb.add(ps2) + postMigrationDb.add(pr1) + + // the old 'sent' record will have the UNKNOWN_UUID and an empty route but the new ones will have their actual id + assert(postMigrationDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) === Seq(ps.copy(route = Nil), ps1.copy(route = Nil), ps2.copy(route = Nil))) + assert(postMigrationDb.listReceived(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) === Seq(pr, pr1)) + } + } + +object SqliteAuditDbSpec { + + val (alice, bob, carol, dave) = (randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey) + val channelUpdate1 = ChannelUpdate(randomBytes64, randomBytes32, ShortChannelId(561), 0, 0, 0, CltvExpiryDelta(144), 100 msat, 10 msat, 1000, None) + val channelUpdate2 = ChannelUpdate(randomBytes64, randomBytes32, ShortChannelId(1105), 0, 0, 0, CltvExpiryDelta(9), 1000 msat, 15 msat, 100, None) + +} \ No newline at end of file From 7f728e1f7dbab11bef62beb54b917db4c2477504 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Tue, 17 Sep 2019 14:30:46 +0200 Subject: [PATCH 03/14] Add more fields to the payments DB: * bolt 11 invoice * external id * parent id (AMP) * target node id * fees * route * failures --- .../main/scala/fr/acinq/eclair/Eclair.scala | 16 +- .../scala/fr/acinq/eclair/db/PaymentsDb.scala | 119 ++++++--- .../eclair/db/sqlite/SqlitePaymentsDb.scala | 212 ++++++++++------ .../acinq/eclair/db/sqlite/SqliteUtils.scala | 6 + .../eclair/payment/PaymentInitiator.scala | 4 +- .../eclair/payment/PaymentLifecycle.scala | 17 +- .../fr/acinq/eclair/payment/Relayer.scala | 21 +- .../fr/acinq/eclair/EclairImplSpec.scala | 62 +++-- .../eclair/db/SqlitePaymentsDbSpec.scala | 226 ++++++++++++++---- .../eclair/payment/PaymentLifecycleSpec.scala | 48 ++-- .../scala/fr/acinq/eclair/api/Service.scala | 24 +- .../fr/acinq/eclair/api/ApiServiceSpec.scala | 26 +- 12 files changed, 534 insertions(+), 247 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index fc8e7356c5..4c9c51a5c2 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -78,13 +78,13 @@ trait Eclair { def receivedInfo(paymentHash: ByteVector32)(implicit timeout: Timeout): Future[Option[IncomingPayment]] - def send(recipientNodeId: PublicKey, amount: MilliSatoshi, paymentHash: ByteVector32, invoice_opt: Option[PaymentRequest] = None, maxAttempts_opt: Option[Int] = None, feeThresholdSat_opt: Option[Satoshi] = None, maxFeePct_opt: Option[Double] = None)(implicit timeout: Timeout): Future[UUID] + def send(externalId: Option[UUID], recipientNodeId: PublicKey, amount: MilliSatoshi, paymentHash: ByteVector32, invoice_opt: Option[PaymentRequest] = None, maxAttempts_opt: Option[Int] = None, feeThresholdSat_opt: Option[Satoshi] = None, maxFeePct_opt: Option[Double] = None)(implicit timeout: Timeout): Future[UUID] def sentInfo(id: Either[UUID, ByteVector32])(implicit timeout: Timeout): Future[Seq[OutgoingPayment]] def findRoute(targetNodeId: PublicKey, amount: MilliSatoshi, assistedRoutes: Seq[Seq[PaymentRequest.ExtraHop]] = Seq.empty)(implicit timeout: Timeout): Future[RouteResponse] - def sendToRoute(route: Seq[PublicKey], amount: MilliSatoshi, paymentHash: ByteVector32, finalCltvExpiryDelta: CltvExpiryDelta)(implicit timeout: Timeout): Future[UUID] + def sendToRoute(externalId: Option[UUID], route: Seq[PublicKey], amount: MilliSatoshi, paymentHash: ByteVector32, finalCltvExpiryDelta: CltvExpiryDelta)(implicit timeout: Timeout): Future[UUID] def audit(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[AuditResponse] @@ -186,11 +186,11 @@ class EclairImpl(appKit: Kit) extends Eclair { (appKit.router ? RouteRequest(appKit.nodeParams.nodeId, targetNodeId, amount, assistedRoutes)).mapTo[RouteResponse] } - override def sendToRoute(route: Seq[PublicKey], amount: MilliSatoshi, paymentHash: ByteVector32, finalCltvExpiryDelta: CltvExpiryDelta)(implicit timeout: Timeout): Future[UUID] = { - (appKit.paymentInitiator ? SendPaymentRequest(amount, paymentHash, route.last, 1, finalCltvExpiryDelta, route)).mapTo[UUID] + override def sendToRoute(externalId: Option[UUID], route: Seq[PublicKey], amount: MilliSatoshi, paymentHash: ByteVector32, finalCltvExpiryDelta: CltvExpiryDelta)(implicit timeout: Timeout): Future[UUID] = { + (appKit.paymentInitiator ? SendPaymentRequest(amount, paymentHash, route.last, 1, finalCltvExpiryDelta, None, externalId, route)).mapTo[UUID] } - override def send(recipientNodeId: PublicKey, amount: MilliSatoshi, paymentHash: ByteVector32, invoice_opt: Option[PaymentRequest], maxAttempts_opt: Option[Int], feeThreshold_opt: Option[Satoshi], maxFeePct_opt: Option[Double])(implicit timeout: Timeout): Future[UUID] = { + override def send(externalId: Option[UUID], recipientNodeId: PublicKey, amount: MilliSatoshi, paymentHash: ByteVector32, invoice_opt: Option[PaymentRequest], maxAttempts_opt: Option[Int], feeThreshold_opt: Option[Satoshi], maxFeePct_opt: Option[Double])(implicit timeout: Timeout): Future[UUID] = { val maxAttempts = maxAttempts_opt.getOrElse(appKit.nodeParams.maxPaymentAttempts) val defaultRouteParams = Router.getDefaultRouteParams(appKit.nodeParams.routerConf) @@ -203,12 +203,12 @@ class EclairImpl(appKit: Kit) extends Eclair { case Some(invoice) if invoice.isExpired => Future.failed(new IllegalArgumentException("invoice has expired")) case Some(invoice) => val sendPayment = invoice.minFinalCltvExpiryDelta match { - case Some(minFinalCltvExpiryDelta) => SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts, minFinalCltvExpiryDelta, assistedRoutes = invoice.routingInfo, routeParams = Some(routeParams)) - case None => SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts, assistedRoutes = invoice.routingInfo, routeParams = Some(routeParams)) + case Some(minFinalCltvExpiryDelta) => SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts, minFinalCltvExpiryDelta, invoice_opt, externalId, assistedRoutes = invoice.routingInfo, routeParams = Some(routeParams)) + case None => SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts, paymentRequest = invoice_opt, externalId = externalId, assistedRoutes = invoice.routingInfo, routeParams = Some(routeParams)) } (appKit.paymentInitiator ? sendPayment).mapTo[UUID] case None => - val sendPayment = SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts = maxAttempts, routeParams = Some(routeParams)) + val sendPayment = SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts = maxAttempts, externalId = externalId, routeParams = Some(routeParams)) (appKit.paymentInitiator ? sendPayment).mapTo[UUID] } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala index b9f4aa838f..64de9c7ae5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala @@ -19,70 +19,131 @@ package fr.acinq.eclair.db import java.util.UUID import fr.acinq.bitcoin.ByteVector32 -import fr.acinq.eclair.MilliSatoshi -import fr.acinq.eclair.payment.PaymentRequest +import fr.acinq.bitcoin.Crypto.PublicKey +import fr.acinq.eclair.payment._ +import fr.acinq.eclair.router.Hop +import fr.acinq.eclair.{MilliSatoshi, ShortChannelId} trait PaymentsDb { - // creates a record for a non yet finalized outgoing payment + /** Create a record for a non yet finalized outgoing payment. */ def addOutgoingPayment(outgoingPayment: OutgoingPayment) - // updates the status of the payment, if the newStatus is SUCCEEDED you must supply a preimage - def updateOutgoingPayment(id: UUID, newStatus: OutgoingPaymentStatus.Value, preimage: Option[ByteVector32] = None) + /** Update the status of the payment in case of success. */ + def updateOutgoingPayment(paymentResult: PaymentSent) + /** Update the status of the payment in case of failure. */ + def updateOutgoingPayment(paymentResult: PaymentFailed) + + /** Get an outgoing payment attempt. */ def getOutgoingPayment(id: UUID): Option[OutgoingPayment] - // all the outgoing payment (attempts) to pay the given paymentHash + /** Get all the outgoing payment attempts to pay the given paymentHash. */ def getOutgoingPayments(paymentHash: ByteVector32): Seq[OutgoingPayment] - def listOutgoingPayments(): Seq[OutgoingPayment] + /** Get all the outgoing payment attempts in the given time range. */ + def listOutgoingPayments(from: Long, to: Long): Seq[OutgoingPayment] + /** Add a new payment request (to receive a payment). */ def addPaymentRequest(pr: PaymentRequest, preimage: ByteVector32) + /** Get the payment request for the given payment hash, if any. */ def getPaymentRequest(paymentHash: ByteVector32): Option[PaymentRequest] - // returns non paid payment request + /** Get the currently pending payment request for the given payment hash, if any. */ def getPendingPaymentRequestAndPreimage(paymentHash: ByteVector32): Option[(ByteVector32, PaymentRequest)] + /** Get all payment requests (pending, expired and fulfilled) in the given time range. */ def listPaymentRequests(from: Long, to: Long): Seq[PaymentRequest] - // returns non paid, non expired payment requests + /** Get pending, non expired payment requests in the given time range. */ def listPendingPaymentRequests(from: Long, to: Long): Seq[PaymentRequest] - // assumes there is already a payment request for it (the record for the given payment hash) + /** Add a received payment (assumes there is already a payment request for the given payment hash). */ def addIncomingPayment(payment: IncomingPayment) + /** Get the received payment associated with a given payment hash, if any. */ def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] - def listIncomingPayments(): Seq[IncomingPayment] + /** Get all payments received in the given time range. */ + def listIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] } /** - * Incoming payment object stored in DB. - * - * @param paymentHash identifier of the payment - * @param amount amount of the payment, in milli-satoshis - * @param receivedAt absolute time in seconds since UNIX epoch when the payment was received. - */ + * Incoming payment object stored in DB. + * + * @param paymentHash identifier of the payment. + * @param amount amount of the payment, in milli-satoshis. + * @param receivedAt absolute time in seconds since UNIX epoch when the payment was received. + */ case class IncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: Long) /** - * Sent payment is every payment that is sent by this node, they may not be finalized and - * when is final it can be failed or successful. - * - * @param id internal payment identifier - * @param paymentHash payment_hash - * @param preimage the preimage of the payment_hash, known if the outgoing payment was successful - * @param amount amount of the payment, in milli-satoshis - * @param createdAt absolute time in seconds since UNIX epoch when the payment was created. - * @param completedAt absolute time in seconds since UNIX epoch when the payment succeeded. - * @param status current status of the payment. - */ -case class OutgoingPayment(id: UUID, paymentHash: ByteVector32, preimage:Option[ByteVector32], amount: MilliSatoshi, createdAt: Long, completedAt: Option[Long], status: OutgoingPaymentStatus.Value) + * An outgoing payment sent by this node. + * At first it is in a pending state, then will become either a success or a failure. + * + * @param id internal payment identifier. + * @param parentId internal identifier of a parent payment, if any. + * @param externalId external payment identifier: lets lightning applications reconcile payments with their own db. + * @param paymentHash payment_hash. + * @param amount amount of the payment, in milli-satoshis. + * @param targetNodeId node ID of the payment recipient. + * @param createdAt absolute time in seconds since UNIX epoch when the payment was created. + * @param status current status of the payment. + * @param paymentRequest_opt Bolt 11 payment request (if paying from an invoice). + * @param completedAt absolute time in seconds since UNIX epoch when the payment completed (success of failure). + * @param successSummary summary of the payment success (if status == "SUCCEEDED"). + * @param failureSummary summary of the payment failure (if status == "FAILED"). + */ +case class OutgoingPayment(id: UUID, + parentId: Option[UUID], + externalId: Option[UUID], + paymentHash: ByteVector32, + amount: MilliSatoshi, + targetNodeId: PublicKey, + createdAt: Long, + status: OutgoingPaymentStatus.Value, + paymentRequest_opt: Option[PaymentRequest], + completedAt: Option[Long] = None, + successSummary: Option[PaymentSuccessSummary] = None, + failureSummary: Option[PaymentFailureSummary] = None) object OutgoingPaymentStatus extends Enumeration { val PENDING = Value(1, "PENDING") val SUCCEEDED = Value(2, "SUCCEEDED") val FAILED = Value(3, "FAILED") +} + +case class PaymentSuccessSummary(paymentPreimage: ByteVector32, feesPaid: MilliSatoshi, route: Seq[HopSummary]) + +case class PaymentFailureSummary(failures: Seq[FailureSummary]) + +/** A minimal representation of a hop in a payment route (suitable to store in a database). */ +case class HopSummary(nodeId: PublicKey, nextNodeId: PublicKey, shortChannelId: Option[ShortChannelId] = None) { + override def toString = shortChannelId match { + case Some(shortChannelId) => s"$nodeId->$nextNodeId ($shortChannelId)" + case None => s"$nodeId->$nextNodeId" + } +} + +object HopSummary { + def apply(h: Hop): HopSummary = HopSummary(h.nodeId, h.nextNodeId, Some(h.lastUpdate.shortChannelId)) +} + +/** A minimal representation of a payment failure (suitable to store in a database). */ +case class FailureSummary(failureType: FailureType.Value, failureMessage: String, failedRoute: List[HopSummary]) + +object FailureType extends Enumeration { + val LOCAL = Value(1, "Local") + val REMOTE = Value(2, "Remote") + val UNREADABLE_REMOTE = Value(3, "UnreadableRemote") +} + +object FailureSummary { + def apply(f: PaymentFailure): FailureSummary = f match { + case LocalFailure(t) => FailureSummary(FailureType.LOCAL, t.getMessage, Nil) + case RemoteFailure(route, e) => FailureSummary(FailureType.REMOTE, e.failureMessage.message, route.map(h => HopSummary(h)).toList) + case UnreadableRemoteFailure(route) => FailureSummary(FailureType.UNREADABLE_REMOTE, "could not decrypt failure onion", route.map(h => HopSummary(h)).toList) + } } \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala index 127ec6be2e..22aa6639ca 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala @@ -16,116 +16,187 @@ package fr.acinq.eclair.db.sqlite -import java.sql.Connection +import java.sql.{Connection, ResultSet, Statement} import java.util.UUID + import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} +import fr.acinq.eclair.MilliSatoshi +import fr.acinq.eclair.db._ import fr.acinq.eclair.db.sqlite.SqliteUtils._ -import fr.acinq.eclair.db.{IncomingPayment, OutgoingPayment, OutgoingPaymentStatus, PaymentsDb} -import fr.acinq.eclair.payment.PaymentRequest +import fr.acinq.eclair.payment.{PaymentFailed, PaymentRequest, PaymentSent} +import fr.acinq.eclair.wire.CommonCodecs import grizzled.slf4j.Logging +import scodec.Attempt +import scodec.codecs._ + import scala.collection.immutable.Queue -import OutgoingPaymentStatus._ -import fr.acinq.eclair.MilliSatoshi -import concurrent.duration._ import scala.compat.Platform +import scala.concurrent.duration._ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { import SqliteUtils.ExtendedResultSet._ val DB_NAME = "payments" - val CURRENT_VERSION = 2 + val CURRENT_VERSION = 3 + + private val hopSummaryCodec = (("node_id" | CommonCodecs.publicKey) :: ("next_node_id" | CommonCodecs.publicKey) :: ("short_channel_id" | optional(bool, CommonCodecs.shortchannelid))).as[HopSummary] + private val paymentRouteCodec = discriminated[List[HopSummary]].by(byte) + .typecase(0x01, listOfN(uint8, hopSummaryCodec)) + private val failureSummaryCodec = (("type" | enumerated(uint8, FailureType)) :: ("message" | ascii32) :: paymentRouteCodec).as[FailureSummary] + private val paymentFailuresCodec = discriminated[List[FailureSummary]].by(byte) + .typecase(0x01, listOfN(uint8, failureSummaryCodec)) using(sqlite.createStatement()) { statement => - require(getVersion(statement, DB_NAME, CURRENT_VERSION) <= CURRENT_VERSION, s"incompatible version of $DB_NAME DB found") // version 2 is "backward compatible" in the sense that it uses separate tables from version 1. There is no migration though - statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER, received_at INTEGER)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, payment_hash BLOB NOT NULL, preimage BLOB, amount_msat INTEGER NOT NULL, created_at INTEGER NOT NULL, completed_at INTEGER, status VARCHAR NOT NULL)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS payment_hash_idx ON sent_payments(payment_hash)") - setVersion(statement, DB_NAME, CURRENT_VERSION) + + def migration12(statement: Statement) = { + // Version 2 is "backwards compatible" in the sense that it uses separate tables from version 1 (which used a single "payments" table). + statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER, received_at INTEGER)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, payment_hash BLOB NOT NULL, preimage BLOB, amount_msat INTEGER NOT NULL, created_at INTEGER NOT NULL, completed_at INTEGER, status VARCHAR NOT NULL)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS payment_hash_idx ON sent_payments(payment_hash)") + } + + def migration23(statement: Statement) = { + // Nothing changes in the received_payments table, but the sent_payments table changes a lot. + statement.executeUpdate("DROP index payment_hash_idx") + statement.executeUpdate("ALTER TABLE sent_payments RENAME TO _sent_payments_old") + statement.executeUpdate("CREATE TABLE sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT, external_id TEXT, payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, target_node_id BLOB NOT NULL, created_at INTEGER NOT NULL, status VARCHAR NOT NULL, payment_request TEXT, completed_at INTEGER, payment_preimage BLOB, fees_msat INTEGER, payment_route BLOB, failures BLOB)") + // Old rows will be missing a target node id, so we use an easy-to-spot default value. + val defaultTargetNodeId = PrivateKey(ByteVector32.One).publicKey + statement.executeUpdate(s"INSERT INTO sent_payments (id, payment_hash, amount_msat, target_node_id, created_at, status, completed_at, payment_preimage) SELECT id, payment_hash, amount_msat, X'${defaultTargetNodeId.toString}', created_at, status, completed_at, preimage FROM _sent_payments_old") + statement.executeUpdate("DROP table _sent_payments_old") + + statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_parent_id_idx ON sent_payments(parent_id)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_payment_hash_idx ON sent_payments(payment_hash)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_created_idx ON sent_payments(created_at)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_created_idx ON received_payments(created_at)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received_payments(received_at)") + } + + getVersion(statement, DB_NAME, CURRENT_VERSION) match { + case 1 => + logger.warn(s"migrating db $DB_NAME, found version=1 current=$CURRENT_VERSION") + migration12(statement) + migration23(statement) + setVersion(statement, DB_NAME, CURRENT_VERSION) + case 2 => + logger.warn(s"migrating db $DB_NAME, found version=2 current=$CURRENT_VERSION") + migration23(statement) + setVersion(statement, DB_NAME, CURRENT_VERSION) + case CURRENT_VERSION => + statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER, received_at INTEGER)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT, external_id TEXT, payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, target_node_id BLOB NOT NULL, created_at INTEGER NOT NULL, status VARCHAR NOT NULL, payment_request TEXT, completed_at INTEGER, payment_preimage BLOB, fees_msat INTEGER, payment_route BLOB, failures BLOB)") + + statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_parent_id_idx ON sent_payments(parent_id)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_payment_hash_idx ON sent_payments(payment_hash)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_created_idx ON sent_payments(created_at)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_created_idx ON received_payments(created_at)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received_payments(received_at)") + case unknownVersion => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") + } + } - override def addOutgoingPayment(sent: OutgoingPayment): Unit = { - using(sqlite.prepareStatement("INSERT INTO sent_payments (id, payment_hash, amount_msat, created_at, status) VALUES (?, ?, ?, ?, ?)")) { statement => + override def addOutgoingPayment(sent: OutgoingPayment): Unit = + using(sqlite.prepareStatement("INSERT INTO sent_payments (id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, status, payment_request) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement => statement.setString(1, sent.id.toString) - statement.setBytes(2, sent.paymentHash.toArray) - statement.setLong(3, sent.amount.toLong) - statement.setLong(4, sent.createdAt) - statement.setString(5, sent.status.toString) - val res = statement.executeUpdate() - logger.debug(s"inserted $res payment=${sent.paymentHash} into payment DB") + statement.setString(2, sent.parentId.map(_.toString).orNull) + statement.setString(3, sent.externalId.map(_.toString).orNull) + statement.setBytes(4, sent.paymentHash.toArray) + statement.setLong(5, sent.amount.toLong) + statement.setBytes(6, sent.targetNodeId.value.toArray) + statement.setLong(7, sent.createdAt) + statement.setString(8, sent.status.toString) + statement.setString(9, sent.paymentRequest_opt.map(PaymentRequest.write).orNull) + statement.executeUpdate() + } + + override def updateOutgoingPayment(paymentResult: PaymentSent): Unit = + using(sqlite.prepareStatement("UPDATE sent_payments SET (completed_at, status, payment_preimage, fees_msat, payment_route) = (?, ?, ?, ?, ?) WHERE id = ? AND completed_at IS NULL")) { statement => + statement.setLong(1, paymentResult.timestamp) + statement.setString(2, OutgoingPaymentStatus.SUCCEEDED.toString) + statement.setBytes(3, paymentResult.paymentPreimage.toArray) + statement.setLong(4, paymentResult.feesPaid.toLong) + statement.setBytes(5, paymentRouteCodec.encode(paymentResult.route.map(h => HopSummary(h)).toList).require.toByteArray) + statement.setString(6, paymentResult.id.toString) + if (statement.executeUpdate() == 0) throw new IllegalArgumentException(s"Tried to mark an outgoing payment as succeeded but already in final status (id=${paymentResult.id})") } - } - override def updateOutgoingPayment(id: UUID, newStatus: OutgoingPaymentStatus.Value, preimage: Option[ByteVector32] = None) = { - require((newStatus == SUCCEEDED && preimage.isDefined) || (newStatus == FAILED && preimage.isEmpty), "Wrong combination of state/preimage") + override def updateOutgoingPayment(paymentResult: PaymentFailed): Unit = + using(sqlite.prepareStatement("UPDATE sent_payments SET (completed_at, status, failures) = (?, ?, ?) WHERE id = ? AND completed_at IS NULL")) { statement => + statement.setLong(1, paymentResult.timestamp) + statement.setString(2, OutgoingPaymentStatus.FAILED.toString) + statement.setBytes(3, paymentFailuresCodec.encode(paymentResult.failures.map(f => FailureSummary(f)).toList).require.toByteArray) + statement.setString(4, paymentResult.id.toString) + if (statement.executeUpdate() == 0) throw new IllegalArgumentException(s"Tried to mark an outgoing payment as failed but already in final status (id=${paymentResult.id})") + } - using(sqlite.prepareStatement("UPDATE sent_payments SET (completed_at, preimage, status) = (?, ?, ?) WHERE id = ? AND completed_at IS NULL")) { statement => - statement.setLong(1, Platform.currentTime) - statement.setBytes(2, if (preimage.isEmpty) null else preimage.get.toArray) - statement.setString(3, newStatus.toString) - statement.setString(4, id.toString) - if (statement.executeUpdate() == 0) throw new IllegalArgumentException(s"Tried to update an outgoing payment (id=$id) already in final status with=$newStatus") + private def parseOutgoingPayment(rs: ResultSet): OutgoingPayment = { + val result = OutgoingPayment( + UUID.fromString(rs.getString("id")), + rs.getStringNullable("parent_id").map(UUID.fromString), + rs.getStringNullable("external_id").map(UUID.fromString), + rs.getByteVector32("payment_hash"), + MilliSatoshi(rs.getLong("amount_msat")), + PublicKey(rs.getByteVector("target_node_id")), + rs.getLong("created_at"), + OutgoingPaymentStatus.withName(rs.getString("status")), + rs.getStringNullable("payment_request").map(PaymentRequest.read), + getNullableLong(rs, "completed_at") + ) + result.status match { + case OutgoingPaymentStatus.SUCCEEDED => result.copy(successSummary = Some(PaymentSuccessSummary( + rs.getByteVector32("payment_preimage"), + MilliSatoshi(rs.getLong("fees_msat")), + rs.getBitVectorOpt("payment_route").map(b => paymentRouteCodec.decode(b) match { + case Attempt.Successful(route) => route.value + case Attempt.Failure(_) => Nil + }).getOrElse(Nil) + ))) + case OutgoingPaymentStatus.FAILED => result.copy(failureSummary = Some(PaymentFailureSummary( + rs.getBitVectorOpt("failures").map(b => paymentFailuresCodec.decode(b) match { + case Attempt.Successful(failures) => failures.value + case Attempt.Failure(_) => Nil + }).getOrElse(Nil) + ))) + case _ => result } } - override def getOutgoingPayment(id: UUID): Option[OutgoingPayment] = { - using(sqlite.prepareStatement("SELECT id, payment_hash, preimage, amount_msat, created_at, completed_at, status FROM sent_payments WHERE id = ?")) { statement => + override def getOutgoingPayment(id: UUID): Option[OutgoingPayment] = + using(sqlite.prepareStatement("SELECT id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, status, payment_request, completed_at, payment_preimage, fees_msat, payment_route, failures FROM sent_payments WHERE id = ?")) { statement => statement.setString(1, id.toString) val rs = statement.executeQuery() if (rs.next()) { - Some(OutgoingPayment( - UUID.fromString(rs.getString("id")), - rs.getByteVector32("payment_hash"), - rs.getByteVector32Nullable("preimage"), - MilliSatoshi(rs.getLong("amount_msat")), - rs.getLong("created_at"), - getNullableLong(rs, "completed_at"), - OutgoingPaymentStatus.withName(rs.getString("status")) - )) + Some(parseOutgoingPayment(rs)) } else { None } } - } - override def getOutgoingPayments(paymentHash: ByteVector32): Seq[OutgoingPayment] = { - using(sqlite.prepareStatement("SELECT id, payment_hash, preimage, amount_msat, created_at, completed_at, status FROM sent_payments WHERE payment_hash = ?")) { statement => + override def getOutgoingPayments(paymentHash: ByteVector32): Seq[OutgoingPayment] = + using(sqlite.prepareStatement("SELECT id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, status, payment_request, completed_at, payment_preimage, fees_msat, payment_route, failures FROM sent_payments WHERE payment_hash = ?")) { statement => statement.setBytes(1, paymentHash.toArray) val rs = statement.executeQuery() var q: Queue[OutgoingPayment] = Queue() while (rs.next()) { - q = q :+ OutgoingPayment( - UUID.fromString(rs.getString("id")), - rs.getByteVector32("payment_hash"), - rs.getByteVector32Nullable("preimage"), - MilliSatoshi(rs.getLong("amount_msat")), - rs.getLong("created_at"), - getNullableLong(rs, "completed_at"), - OutgoingPaymentStatus.withName(rs.getString("status")) - ) + q = q :+ parseOutgoingPayment(rs) } q } - } - override def listOutgoingPayments(): Seq[OutgoingPayment] = { - using(sqlite.createStatement()) { statement => - val rs = statement.executeQuery("SELECT id, payment_hash, preimage, amount_msat, created_at, completed_at, status FROM sent_payments") + override def listOutgoingPayments(from: Long, to: Long): Seq[OutgoingPayment] = + using(sqlite.prepareStatement("SELECT id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, status, payment_request, completed_at, payment_preimage, fees_msat, payment_route, failures FROM sent_payments WHERE created_at >= ? AND created_at < ?")) { statement => + statement.setLong(1, from.seconds.toMillis) + statement.setLong(2, to.seconds.toMillis) + val rs = statement.executeQuery() var q: Queue[OutgoingPayment] = Queue() while (rs.next()) { - q = q :+ OutgoingPayment( - UUID.fromString(rs.getString("id")), - rs.getByteVector32("payment_hash"), - rs.getByteVector32Nullable("preimage"), - MilliSatoshi(rs.getLong("amount_msat")), - rs.getLong("created_at"), - getNullableLong(rs, "completed_at"), - OutgoingPaymentStatus.withName(rs.getString("status")) - ) + q = q :+ parseOutgoingPayment(rs) } q } - } override def addPaymentRequest(pr: PaymentRequest, preimage: ByteVector32): Unit = { val insertStmt = pr.expiry match { @@ -215,15 +286,16 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { } } - override def listIncomingPayments(): Seq[IncomingPayment] = { - using(sqlite.createStatement()) { statement => - val rs = statement.executeQuery("SELECT payment_hash, received_msat, received_at FROM received_payments WHERE received_msat > 0") + override def listIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = + using(sqlite.prepareStatement("SELECT payment_hash, received_msat, received_at FROM received_payments WHERE received_msat > 0 AND received_at >= ? AND received_at < ?")) { statement => + statement.setLong(1, from.seconds.toMillis) + statement.setLong(2, to.seconds.toMillis) + val rs = statement.executeQuery() var q: Queue[IncomingPayment] = Queue() while (rs.next()) { q = q :+ IncomingPayment(rs.getByteVector32("payment_hash"), MilliSatoshi(rs.getLong("received_msat")), rs.getLong("received_at")) } q } - } } \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala index edef34db43..d8d6875162 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala @@ -116,6 +116,12 @@ object SqliteUtils { val bytes = rs.getBytes(columnLabel) if (rs.wasNull()) None else Some(ByteVector32(ByteVector(bytes))) } + + def getStringNullable(columnLabel: String): Option[String] = { + val result = rs.getString(columnLabel) + if (rs.wasNull()) None else Some(result) + } + } object ExtendedResultSet { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala index fed4e1db4e..b86ee963e9 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala @@ -38,7 +38,7 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: Actor val paymentId = UUID.randomUUID() // We add one block in order to not have our htlc fail when a new block has just been found. val finalExpiry = p.finalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight + 1) - val payFsm = context.actorOf(PaymentLifecycle.props(nodeParams, DefaultPaymentProgressHandler(paymentId, nodeParams.db.payments), router, register)) + val payFsm = context.actorOf(PaymentLifecycle.props(nodeParams, DefaultPaymentProgressHandler(paymentId, p, nodeParams.db.payments), router, register)) // NB: we only generate legacy payment onions for now for maximum compatibility. p.predefinedRoute match { case Nil => payFsm forward SendPayment(p.paymentHash, p.targetNodeId, FinalLegacyPayload(p.amount, finalExpiry), p.maxAttempts, p.assistedRoutes, p.routeParams) @@ -58,6 +58,8 @@ object PaymentInitiator { targetNodeId: PublicKey, maxAttempts: Int, finalExpiryDelta: CltvExpiryDelta = Channel.MIN_CLTV_EXPIRY_DELTA, + paymentRequest: Option[PaymentRequest] = None, + externalId: Option[UUID] = None, predefinedRoute: Seq[PublicKey] = Nil, assistedRoutes: Seq[Seq[ExtraHop]] = Nil, routeParams: Option[RouteParams] = None) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala index 5cc7e79051..65e7905a1f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala @@ -25,6 +25,7 @@ import fr.acinq.eclair._ import fr.acinq.eclair.channel.{CMD_ADD_HTLC, Register} import fr.acinq.eclair.crypto.{Sphinx, TransportHandler} import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus, PaymentsDb} +import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest import fr.acinq.eclair.payment.PaymentLifecycle._ import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.router._ @@ -50,12 +51,12 @@ class PaymentLifecycle(nodeParams: NodeParams, progressHandler: PaymentProgressH case Event(c: SendPaymentToRoute, WaitingForRequest) => val send = SendPayment(c.paymentHash, c.hops.last, c.finalPayload, maxAttempts = 1) router ! FinalizeRoute(c.hops) - progressHandler.onSend(c.paymentHash, c.finalPayload.amount) + progressHandler.onSend() goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, send, failures = Nil) case Event(c: SendPayment, WaitingForRequest) => router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, routeParams = c.routeParams) - progressHandler.onSend(c.paymentHash, c.finalPayload.amount) + progressHandler.onSend() goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, c, failures = Nil) } @@ -189,27 +190,27 @@ object PaymentLifecycle { val id: UUID // @formatter:off - def onSend(paymentHash: ByteVector32, finalAmount: MilliSatoshi): Unit + def onSend(): Unit def onSucceed(sender: ActorRef, result: PaymentSent)(ctx: ActorContext): Unit def onFail(sender: ActorRef, result: PaymentFailed)(ctx: ActorContext): Unit // @formatter:on } /** Normal payments are stored in the payments DB and emit payment events. */ - case class DefaultPaymentProgressHandler(id: UUID, db: PaymentsDb) extends PaymentProgressHandler { + case class DefaultPaymentProgressHandler(id: UUID, r: SendPaymentRequest, db: PaymentsDb) extends PaymentProgressHandler { - override def onSend(paymentHash: ByteVector32, finalAmount: MilliSatoshi): Unit = { - db.addOutgoingPayment(OutgoingPayment(id, paymentHash, None, finalAmount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING)) + override def onSend(): Unit = { + db.addOutgoingPayment(OutgoingPayment(id, None, r.externalId, r.paymentHash, r.amount, r.targetNodeId, Platform.currentTime, OutgoingPaymentStatus.PENDING, r.paymentRequest)) } override def onSucceed(sender: ActorRef, result: PaymentSent)(ctx: ActorContext): Unit = { - db.updateOutgoingPayment(result.id, OutgoingPaymentStatus.SUCCEEDED, preimage = Some(result.paymentPreimage)) + db.updateOutgoingPayment(result) sender ! result ctx.system.eventStream.publish(result) } override def onFail(sender: ActorRef, result: PaymentFailed)(ctx: ActorContext): Unit = { - db.updateOutgoingPayment(result.id, OutgoingPaymentStatus.FAILED) + db.updateOutgoingPayment(result) sender ! result ctx.system.eventStream.publish(result) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala index a009f310ce..192d26bcbd 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala @@ -24,7 +24,6 @@ import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.db.OutgoingPaymentStatus import fr.acinq.eclair.router.Announcements import fr.acinq.eclair.wire._ import fr.acinq.eclair.{CltvExpiryDelta, Features, LongToBtcAmount, MilliSatoshi, NodeParams, ShortChannelId, UInt64, nodeFee} @@ -137,8 +136,9 @@ class Relayer(nodeParams: NodeParams, register: ActorRef, paymentHandler: ActorR case Local(id, None) => // we sent the payment, but we probably restarted and the reference to the original sender was lost, // we publish the failure on the event stream and update the status in paymentDb - nodeParams.db.payments.updateOutgoingPayment(id, OutgoingPaymentStatus.FAILED) - context.system.eventStream.publish(PaymentFailed(id, paymentHash, Nil)) + val result = PaymentFailed(id, paymentHash, Nil) + nodeParams.db.payments.updateOutgoingPayment(result) + context.system.eventStream.publish(result) case Local(_, Some(sender)) => sender ! Status.Failure(addFailed) case Relayed(originChannelId, originHtlcId, _, _) => @@ -161,8 +161,9 @@ class Relayer(nodeParams: NodeParams, register: ActorRef, paymentHandler: ActorR // we sent the payment, but we probably restarted and the reference to the original sender was lost, // we publish the success on the event stream and update the status in paymentDb val feesPaid = 0.msat // fees are unknown since we lost the reference to the payment - nodeParams.db.payments.updateOutgoingPayment(id, OutgoingPaymentStatus.SUCCEEDED, Some(fulfill.paymentPreimage)) - context.system.eventStream.publish(PaymentSent(id, add.amountMsat, feesPaid, add.paymentHash, fulfill.paymentPreimage, Nil)) + val result = PaymentSent(id, add.amountMsat, feesPaid, add.paymentHash, fulfill.paymentPreimage, Nil) + nodeParams.db.payments.updateOutgoingPayment(result) + context.system.eventStream.publish(result) case Local(_, Some(sender)) => sender ! fulfill case Relayed(originChannelId, originHtlcId, amountIn, amountOut) => @@ -176,8 +177,9 @@ class Relayer(nodeParams: NodeParams, register: ActorRef, paymentHandler: ActorR case Local(id, None) => // we sent the payment, but we probably restarted and the reference to the original sender was lost // we publish the failure on the event stream and update the status in paymentDb - nodeParams.db.payments.updateOutgoingPayment(id, OutgoingPaymentStatus.FAILED) - context.system.eventStream.publish(PaymentFailed(id, add.paymentHash, Nil)) + val result = PaymentFailed(id, add.paymentHash, Nil) + nodeParams.db.payments.updateOutgoingPayment(result) + context.system.eventStream.publish(result) case Local(_, Some(sender)) => sender ! fail case Relayed(originChannelId, originHtlcId, _, _) => @@ -190,8 +192,9 @@ class Relayer(nodeParams: NodeParams, register: ActorRef, paymentHandler: ActorR case Local(id, None) => // we sent the payment, but we probably restarted and the reference to the original sender was lost // we publish the failure on the event stream and update the status in paymentDb - nodeParams.db.payments.updateOutgoingPayment(id, OutgoingPaymentStatus.FAILED) - context.system.eventStream.publish(PaymentFailed(id, add.paymentHash, Nil)) + val result = PaymentFailed(id, add.paymentHash, Nil) + nodeParams.db.payments.updateOutgoingPayment(result) + context.system.eventStream.publish(result) case Local(_, Some(sender)) => sender ! fail case Relayed(originChannelId, originHtlcId, _, _) => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala index 32f4450896..44d9e7f748 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala @@ -16,6 +16,8 @@ package fr.acinq.eclair +import java.util.UUID + import akka.actor.ActorSystem import akka.testkit.{TestKit, TestProbe} import akka.util.Timeout @@ -94,43 +96,51 @@ class EclairImplSpec extends TestKit(ActorSystem("test")) with fixture.FunSuiteL val eclair = new EclairImpl(kit) val nodeId = PublicKey(hex"030bb6a5e0c6b203c7e2180fb78c7ba4bdce46126761d8201b91ddac089cdecc87") - eclair.send(nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = None) + eclair.send(None, nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = None) val send = paymentInitiator.expectMsgType[SendPaymentRequest] - assert(send.targetNodeId == nodeId) - assert(send.amount == 123.msat) - assert(send.paymentHash == ByteVector32.Zeroes) - assert(send.assistedRoutes == Seq.empty) + assert(send.externalId === None) + assert(send.targetNodeId === nodeId) + assert(send.amount === 123.msat) + assert(send.paymentHash === ByteVector32.Zeroes) + assert(send.paymentRequest === None) + assert(send.assistedRoutes === Seq.empty) // with assisted routes + val externalId1 = UUID.randomUUID() val hints = List(List(ExtraHop(Bob.nodeParams.nodeId, ShortChannelId("569178x2331x1"), feeBase = 10 msat, feeProportionalMillionths = 1, cltvExpiryDelta = CltvExpiryDelta(12)))) val invoice1 = PaymentRequest(Block.RegtestGenesisBlock.hash, Some(123 msat), ByteVector32.Zeroes, randomKey, "description", None, None, hints) - eclair.send(nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(invoice1)) + eclair.send(Some(externalId1), nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(invoice1)) val send1 = paymentInitiator.expectMsgType[SendPaymentRequest] - assert(send1.targetNodeId == nodeId) - assert(send1.amount == 123.msat) - assert(send1.paymentHash == ByteVector32.Zeroes) - assert(send1.assistedRoutes == hints) + assert(send1.externalId === Some(externalId1)) + assert(send1.targetNodeId === nodeId) + assert(send1.amount === 123.msat) + assert(send1.paymentHash === ByteVector32.Zeroes) + assert(send1.paymentRequest === Some(invoice1)) + assert(send1.assistedRoutes === hints) // with finalCltvExpiry val invoice2 = PaymentRequest("lntb", Some(123 msat), System.currentTimeMillis() / 1000L, nodeId, List(PaymentRequest.MinFinalCltvExpiry(96), PaymentRequest.PaymentHash(ByteVector32.Zeroes), PaymentRequest.Description("description")), ByteVector.empty) - eclair.send(nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(invoice2)) + eclair.send(Some(externalId1), nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(invoice2)) val send2 = paymentInitiator.expectMsgType[SendPaymentRequest] - assert(send2.targetNodeId == nodeId) - assert(send2.amount == 123.msat) - assert(send2.paymentHash == ByteVector32.Zeroes) - assert(send2.finalExpiryDelta == CltvExpiryDelta(96)) + assert(send2.externalId === Some(externalId1)) + assert(send2.targetNodeId === nodeId) + assert(send2.amount === 123.msat) + assert(send2.paymentHash === ByteVector32.Zeroes) + assert(send2.paymentRequest === Some(invoice2)) + assert(send2.finalExpiryDelta === CltvExpiryDelta(96)) // with custom route fees parameters - eclair.send(nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = None, feeThreshold_opt = Some(123 sat), maxFeePct_opt = Some(4.20)) + eclair.send(None, nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = None, feeThreshold_opt = Some(123 sat), maxFeePct_opt = Some(4.20)) val send3 = paymentInitiator.expectMsgType[SendPaymentRequest] - assert(send3.targetNodeId == nodeId) - assert(send3.amount == 123.msat) - assert(send3.paymentHash == ByteVector32.Zeroes) - assert(send3.routeParams.get.maxFeeBase == 123000.msat) // conversion sat -> msat - assert(send3.routeParams.get.maxFeePct == 4.20) + assert(send3.externalId === None) + assert(send3.targetNodeId === nodeId) + assert(send3.amount === 123.msat) + assert(send3.paymentHash === ByteVector32.Zeroes) + assert(send3.routeParams.get.maxFeeBase === 123000.msat) // conversion sat -> msat + assert(send3.routeParams.get.maxFeePct === 4.20) val expiredInvoice = invoice2.copy(timestamp = 0L) - assertThrows[IllegalArgumentException](Await.result(eclair.send(nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(expiredInvoice)), 50 millis)) + assertThrows[IllegalArgumentException](Await.result(eclair.send(None, nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(expiredInvoice)), 50 millis)) } test("allupdates can filter by nodeId") { f => @@ -249,15 +259,17 @@ class EclairImplSpec extends TestKit(ActorSystem("test")) with fixture.FunSuiteL test("sendtoroute should pass the parameters correctly") { f => import f._ + val externalId = UUID.randomUUID() val route = Seq(PublicKey(hex"030bb6a5e0c6b203c7e2180fb78c7ba4bdce46126761d8201b91ddac089cdecc87")) val eclair = new EclairImpl(kit) - eclair.sendToRoute(route, 1234 msat, ByteVector32.One, CltvExpiryDelta(123)) + eclair.sendToRoute(Some(externalId), route, 1234 msat, ByteVector32.One, CltvExpiryDelta(123)) val send = paymentInitiator.expectMsgType[SendPaymentRequest] - assert(send.predefinedRoute == route) + assert(send.externalId === Some(externalId)) + assert(send.predefinedRoute === route) assert(send.amount === 1234.msat) assert(send.finalExpiryDelta === CltvExpiryDelta(123)) - assert(send.paymentHash == ByteVector32.One) + assert(send.paymentHash === ByteVector32.One) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala index ed8993cc77..8a1d2ee211 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala @@ -18,29 +18,32 @@ package fr.acinq.eclair.db import java.util.UUID -import fr.acinq.bitcoin.{Block, ByteVector32} -import fr.acinq.eclair.TestConstants.Bob +import fr.acinq.bitcoin.Crypto.PrivateKey +import fr.acinq.bitcoin.{Block, ByteVector32, Crypto} +import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.db.OutgoingPaymentStatus._ import fr.acinq.eclair.db.sqlite.SqlitePaymentsDb import fr.acinq.eclair.db.sqlite.SqliteUtils._ -import fr.acinq.eclair.payment.PaymentRequest -import fr.acinq.eclair.{LongToBtcAmount, TestConstants, randomBytes32} +import fr.acinq.eclair.payment._ +import fr.acinq.eclair.router.Hop +import fr.acinq.eclair.wire.{ChannelUpdate, UnknownNextPeer} +import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, ShortChannelId, TestConstants, randomBytes32, randomBytes64, randomKey} import org.scalatest.FunSuite -import scodec.bits._ import scala.compat.Platform import scala.concurrent.duration._ class SqlitePaymentsDbSpec extends FunSuite { + import SqlitePaymentsDbSpec._ + test("init sqlite 2 times in a row") { val sqlite = TestConstants.sqliteInMemory() val db1 = new SqlitePaymentsDb(sqlite) val db2 = new SqlitePaymentsDb(sqlite) } - test("handle version migration 1->2") { - + test("handle version migration 1->3") { val connection = TestConstants.sqliteInMemory() using(connection.createStatement()) { statement => @@ -52,9 +55,11 @@ class SqlitePaymentsDbSpec extends FunSuite { assert(getVersion(statement, "payments", 1) == 1) // version 1 is deployed now } - val oldReceivedPayment = IncomingPayment(ByteVector32(hex"0f059ef9b55bb70cc09069ee4df854bf0fab650eee6f2b87ba26d1ad08ab114f"), 123 msat, 1233322) + val oldReceivedPayment = IncomingPayment(randomBytes32, 123 msat, 1233322) - // insert old type record + // Changes between version 1 and 2: + // - the monolithic payments table has been replaced by two tables, received_payments and sent_payments + // - old records from the payments table are ignored (not migrated to the new tables) using(connection.prepareStatement("INSERT INTO payments VALUES (?, ?, ?)")) { statement => statement.setBytes(1, oldReceivedPayment.paymentHash.toArray) statement.setLong(2, oldReceivedPayment.amount.toLong) @@ -65,14 +70,14 @@ class SqlitePaymentsDbSpec extends FunSuite { val preMigrationDb = new SqlitePaymentsDb(connection) using(connection.createStatement()) { statement => - assert(getVersion(statement, "payments", 1) == 2) // version has changed from 1 to 2! + assert(getVersion(statement, "payments", 1) == 3) // version changed from 1 -> 3 } // the existing received payment can NOT be queried anymore assert(preMigrationDb.getIncomingPayment(oldReceivedPayment.paymentHash).isEmpty) // add a few rows - val ps1 = OutgoingPayment(id = UUID.randomUUID(), paymentHash = ByteVector32(hex"0f059ef9b55bb70cc09069ee4df854bf0fab650eee6f2b87ba26d1ad08ab114f"), None, amount = 12345 msat, createdAt = 12345, None, PENDING) + val ps1 = OutgoingPayment(UUID.randomUUID(), None, None, oldReceivedPayment.paymentHash, 12345 msat, alice, 12345, PENDING, None) val i1 = PaymentRequest.read("lnbc10u1pw2t4phpp5ezwm2gdccydhnphfyepklc0wjkxhz0r4tctg9paunh2lxgeqhcmsdqlxycrqvpqwdshgueqvfjhggr0dcsry7qcqzpgfa4ecv7447p9t5hkujy9qgrxvkkf396p9zar9p87rv2htmeuunkhydl40r64n5s2k0u7uelzc8twxmp37nkcch6m0wg5tvvx69yjz8qpk94qf3") val pr1 = IncomingPayment(i1.paymentHash, 12345678 msat, 1513871928275L) @@ -80,19 +85,126 @@ class SqlitePaymentsDbSpec extends FunSuite { preMigrationDb.addIncomingPayment(pr1) preMigrationDb.addOutgoingPayment(ps1) - assert(preMigrationDb.listIncomingPayments() == Seq(pr1)) - assert(preMigrationDb.listOutgoingPayments() == Seq(ps1)) - assert(preMigrationDb.listPaymentRequests(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) == Seq(i1)) + val now = Platform.currentTime.milliseconds.toSeconds + assert(preMigrationDb.listIncomingPayments(0, now) == Seq(pr1)) + assert(preMigrationDb.listOutgoingPayments(0, now) == Seq(ps1)) + assert(preMigrationDb.listPaymentRequests(0, now) == Seq(i1)) + + val postMigrationDb = new SqlitePaymentsDb(connection) + + using(connection.createStatement()) { statement => + assert(getVersion(statement, "payments", 3) == 3) // version still to 3 + } + + assert(postMigrationDb.listIncomingPayments(0, now) == Seq(pr1)) + assert(postMigrationDb.listOutgoingPayments(0, now) == Seq(ps1)) + assert(preMigrationDb.listPaymentRequests(0, now) == Seq(i1)) + } + + test("handle version migration 2->3") { + val connection = TestConstants.sqliteInMemory() + + using(connection.createStatement()) { statement => + getVersion(statement, "payments", 2) + statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER, received_at INTEGER)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, payment_hash BLOB NOT NULL, preimage BLOB, amount_msat INTEGER NOT NULL, created_at INTEGER NOT NULL, completed_at INTEGER, status VARCHAR NOT NULL)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS payment_hash_idx ON sent_payments(payment_hash)") + } + + using(connection.createStatement()) { statement => + assert(getVersion(statement, "payments", 2) == 2) // version 2 is deployed now + } + + // Insert a bunch of old version 2 rows. + val ps1 = OutgoingPayment(UUID.randomUUID(), None, None, randomBytes32, 561 msat, PrivateKey(ByteVector32.One).publicKey, 0, PENDING, None) + val ps2 = OutgoingPayment(UUID.randomUUID(), None, None, randomBytes32, 1105 msat, PrivateKey(ByteVector32.One).publicKey, 1, FAILED, None, Some(2), None, Some(PaymentFailureSummary(Nil))) + val ps3 = OutgoingPayment(UUID.randomUUID(), None, None, defaultPaymentHash, 1729 msat, PrivateKey(ByteVector32.One).publicKey, 4, SUCCEEDED, None, Some(5), Some(PaymentSuccessSummary(defaultPreimage, 0 msat, Nil))) + val i1 = PaymentRequest.read("lnbc10u1pw2t4phpp5ezwm2gdccydhnphfyepklc0wjkxhz0r4tctg9paunh2lxgeqhcmsdqlxycrqvpqwdshgueqvfjhggr0dcsry7qcqzpgfa4ecv7447p9t5hkujy9qgrxvkkf396p9zar9p87rv2htmeuunkhydl40r64n5s2k0u7uelzc8twxmp37nkcch6m0wg5tvvx69yjz8qpk94qf3") + val pr1 = IncomingPayment(i1.paymentHash, 12345678 msat, 9) + val i2 = PaymentRequest.read("lnbc1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdpl2pkx2ctnv5sxxmmwwd5kgetjypeh2ursdae8g6twvus8g6rfwvs8qun0dfjkxaq8rkx3yf5tcsyz3d73gafnh3cax9rn449d9p5uxz9ezhhypd0elx87sjle52x86fux2ypatgddc6k63n7erqz25le42c4u4ecky03ylcqca784w") + + // Changes between version 2 and 3 to sent_payments: + // - added optional payment failures + // - added optional payment success details (fees paid and route) + // - added optional payment request + // - added target node ID + // - added externalID and parentID + + using(connection.prepareStatement("INSERT INTO sent_payments VALUES (?, ?, NULL, ?, ?, NULL, ?)")) { statement => + statement.setString(1, ps1.id.toString) + statement.setBytes(2, ps1.paymentHash.toArray) + statement.setLong(3, ps1.amount.toLong) + statement.setLong(4, ps1.createdAt) + statement.setString(5, ps1.status.toString) + statement.executeUpdate() + } + + for (ps <- Seq(ps2, ps3)) { + using(connection.prepareStatement("INSERT INTO sent_payments VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement => + statement.setString(1, ps.id.toString) + statement.setBytes(2, ps.paymentHash.toArray) + statement.setBytes(3, ps.successSummary.map(_.paymentPreimage.toArray).orNull) + statement.setLong(4, ps.amount.toLong) + statement.setLong(5, ps.createdAt) + statement.setLong(6, ps.completedAt.get) + statement.setString(7, ps.status.toString) + statement.executeUpdate() + } + } + + using(connection.prepareStatement("INSERT INTO received_payments VALUES (?, ?, ?, ?, ?, NULL, ?)")) { statement => + statement.setBytes(1, i1.paymentHash.toArray) + statement.setBytes(2, defaultPreimage.toArray) + statement.setString(3, PaymentRequest.write(i1)) + statement.setLong(4, pr1.amount.toLong) + statement.setLong(5, 0) // created_at + statement.setLong(6, pr1.receivedAt) + statement.executeUpdate() + } + + using(connection.prepareStatement("INSERT INTO received_payments VALUES (?, ?, ?, NULL, ?, NULL, NULL)")) { statement => + statement.setBytes(1, i2.paymentHash.toArray) + statement.setBytes(2, defaultPreimage.toArray) + statement.setString(3, PaymentRequest.write(i2)) + statement.setLong(4, 0) // created_at + statement.executeUpdate() + } + + val preMigrationDb = new SqlitePaymentsDb(connection) + + using(connection.createStatement()) { statement => + assert(getVersion(statement, "payments", 2) == 3) // version changed from 2 -> 3 + } + + assert(preMigrationDb.getPaymentRequest(i1.paymentHash) === Some(i1)) + assert(preMigrationDb.getPaymentRequest(i2.paymentHash) === Some(i2)) + assert(preMigrationDb.getIncomingPayment(i1.paymentHash) === Some(pr1)) + assert(preMigrationDb.getIncomingPayment(i2.paymentHash) === None) + assert(preMigrationDb.listOutgoingPayments(0, 5).toSet === Set(ps1, ps2, ps3)) val postMigrationDb = new SqlitePaymentsDb(connection) using(connection.createStatement()) { statement => - assert(getVersion(statement, "payments", 2) == 2) // version still to 2 + assert(getVersion(statement, "payments", 3) == 3) // version still to 3 } - assert(postMigrationDb.listIncomingPayments() == Seq(pr1)) - assert(postMigrationDb.listOutgoingPayments() == Seq(ps1)) - assert(preMigrationDb.listPaymentRequests(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) == Seq(i1)) + val i3 = PaymentRequest.read("lnbc1500n1pdl686hpp5y7mz3lgvrfccqnk9es6trumjgqdpjwcecycpkdggnx7h6cuup90sdpa2fjkzep6ypqkymm4wssycnjzf9rjqurjda4x2cm5ypskuepqv93x7at5ypek7cqzysxqr23s5e864m06fcfp3axsefy276d77tzp0xzzzdfl6p46wvstkeqhu50khm9yxea2d9efp7lvthrta0ktmhsv52hf3tvxm0unsauhmfmp27cqqx4xxe") + postMigrationDb.addPaymentRequest(i3, randomBytes32) + + val ps4 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), Some(UUID.randomUUID()), randomBytes32, 123 msat, alice, 6, PENDING, Some(i3)) + val ps5 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), Some(UUID.randomUUID()), randomBytes32, 456 msat, bob, 7, SUCCEEDED, Some(i2), Some(8), Some(PaymentSuccessSummary(randomBytes32, 42 msat, Nil))) + val ps6 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), Some(UUID.randomUUID()), randomBytes32, 789 msat, bob, 8, FAILED, None, Some(9), None, Some(PaymentFailureSummary(Nil))) + postMigrationDb.addOutgoingPayment(ps4) + postMigrationDb.addOutgoingPayment(ps5) + postMigrationDb.updateOutgoingPayment(PaymentSent(ps5.id, ps5.amount, ps5.successSummary.get.feesPaid, ps5.paymentHash, ps5.successSummary.get.paymentPreimage, Nil, ps5.completedAt.get)) + postMigrationDb.addOutgoingPayment(ps6) + postMigrationDb.updateOutgoingPayment(PaymentFailed(ps6.id, ps6.paymentHash, Nil, ps6.completedAt.get)) + + assert(postMigrationDb.listOutgoingPayments(0, 10).toSet === Set(ps1, ps2, ps3, ps4, ps5, ps6)) + assert(postMigrationDb.listIncomingPayments(0, 10).toSet === Set(pr1)) + assert(postMigrationDb.getPaymentRequest(i1.paymentHash) === Some(i1)) + assert(postMigrationDb.getPaymentRequest(i2.paymentHash) === Some(i2)) + assert(postMigrationDb.getPaymentRequest(i3.paymentHash) === Some(i3)) } test("add/list received payments/find 1 payment that exists/find 1 payment that does not exist") { @@ -100,69 +212,69 @@ class SqlitePaymentsDbSpec extends FunSuite { val db = new SqlitePaymentsDb(sqlite) // can't receive a payment without an invoice associated with it - assertThrows[IllegalArgumentException](db.addIncomingPayment(IncomingPayment(ByteVector32(hex"6e7e8018f05e169cf1d99e77dc22cb372d09f10b6a81f1eae410718c56cad188"), 12345678 msat, 1513871928275L))) + assertThrows[IllegalArgumentException](db.addIncomingPayment(IncomingPayment(randomBytes32, 12345678 msat, 1513871928275L))) val i1 = PaymentRequest.read("lnbc5450n1pw2t4qdpp5vcrf6ylgpettyng4ac3vujsk0zpc25cj0q3zp7l7w44zvxmpzh8qdzz2pshjmt9de6zqen0wgsr2dp4ypcxj7r9d3ejqct5ypekzar0wd5xjuewwpkxzcm99cxqzjccqp2rzjqtspxelp67qc5l56p6999wkatsexzhs826xmupyhk6j8lxl038t27z9tsqqqgpgqqqqqqqlgqqqqqzsqpcz8z8hmy8g3ecunle4n3edn3zg2rly8g4klsk5md736vaqqy3ktxs30ht34rkfkqaffzxmjphvd0637dk2lp6skah2hq09z6lrjna3xqp3d4vyd") val i2 = PaymentRequest.read("lnbc10u1pw2t4phpp5ezwm2gdccydhnphfyepklc0wjkxhz0r4tctg9paunh2lxgeqhcmsdqlxycrqvpqwdshgueqvfjhggr0dcsry7qcqzpgfa4ecv7447p9t5hkujy9qgrxvkkf396p9zar9p87rv2htmeuunkhydl40r64n5s2k0u7uelzc8twxmp37nkcch6m0wg5tvvx69yjz8qpk94qf3") db.addPaymentRequest(i1, ByteVector32.Zeroes) - db.addPaymentRequest(i2, ByteVector32.Zeroes) + db.addPaymentRequest(i2, ByteVector32.One) val p1 = IncomingPayment(i1.paymentHash, 12345678 msat, 1513871928275L) val p2 = IncomingPayment(i2.paymentHash, 12345678 msat, 1513871928275L) - assert(db.listIncomingPayments() === Nil) + assert(db.listIncomingPayments(0, Platform.currentTime.milliseconds.toSeconds) === Nil) db.addIncomingPayment(p1) db.addIncomingPayment(p2) - assert(db.listIncomingPayments().toList === List(p1, p2)) + assert(db.listIncomingPayments(0, Platform.currentTime.milliseconds.toSeconds).toList === List(p1, p2)) assert(db.getIncomingPayment(p1.paymentHash) === Some(p1)) - assert(db.getIncomingPayment(ByteVector32(hex"6e7e8018f05e169cf1d99e77dc22cb372d09f10b6a81f1eae410718c56cad187")) === None) + assert(db.getIncomingPayment(randomBytes32) === None) } test("add/retrieve/update sent payments") { - val db = new SqlitePaymentsDb(TestConstants.sqliteInMemory()) - val s1 = OutgoingPayment(id = UUID.randomUUID(), paymentHash = ByteVector32(hex"0f059ef9b55bb70cc09069ee4df854bf0fab650eee6f2b87ba26d1ad08ab114f"), None, amount = 12345 msat, createdAt = 12345, None, PENDING) - val s2 = OutgoingPayment(id = UUID.randomUUID(), paymentHash = ByteVector32(hex"08d47d5f7164d4b696e8f6b62a03094d4f1c65f16e9d7b11c4a98854707e55cf"), None, amount = 12345 msat, createdAt = 12345, None, PENDING) + val i1 = PaymentRequest(chainHash = Block.TestnetGenesisBlock.hash, amount = Some(123 msat), paymentHash = randomBytes32, privateKey = davePriv, description = "Some invoice", expirySeconds = None, timestamp = 1) + val s1 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), None, i1.paymentHash, 123 msat, alice, 1, PENDING, Some(i1)) + val s2 = OutgoingPayment(UUID.randomUUID(), None, Some(UUID.randomUUID()), randomBytes32, 456 msat, bob, 2, PENDING, None) - assert(db.listOutgoingPayments().isEmpty) + assert(db.listOutgoingPayments(0, Platform.currentTime.milliseconds.toSeconds).isEmpty) db.addOutgoingPayment(s1) db.addOutgoingPayment(s2) - assert(db.listOutgoingPayments().toList == Seq(s1, s2)) + assert(db.listOutgoingPayments(0, Platform.currentTime.milliseconds.toSeconds).toList == Seq(s1, s2)) assert(db.getOutgoingPayment(s1.id) === Some(s1)) - assert(db.getOutgoingPayment(s1.id).get.completedAt.isEmpty) assert(db.getOutgoingPayment(UUID.randomUUID()) === None) assert(db.getOutgoingPayments(s2.paymentHash) === Seq(s2)) - assert(db.getOutgoingPayments(ByteVector32.Zeroes) === Seq.empty) + assert(db.getOutgoingPayments(ByteVector32.Zeroes) === Nil) - val s3 = s2.copy(id = UUID.randomUUID(), amount = 88776655 msat) + val s3 = s2.copy(id = UUID.randomUUID(), amount = 789 msat) + val s4 = s2.copy(id = UUID.randomUUID()) db.addOutgoingPayment(s3) + db.addOutgoingPayment(s4) - db.updateOutgoingPayment(s3.id, FAILED) - assert(db.getOutgoingPayment(s3.id).get.status == FAILED) - assert(db.getOutgoingPayment(s3.id).get.preimage.isEmpty) // failed sent payments don't have a preimage - assert(db.getOutgoingPayment(s3.id).get.completedAt.isDefined) + db.updateOutgoingPayment(PaymentFailed(s3.id, s3.paymentHash, Nil, 10)) + assert(db.getOutgoingPayment(s3.id) === Some(s3.copy(status = FAILED, completedAt = Some(10), failureSummary = Some(PaymentFailureSummary(Nil))))) + db.updateOutgoingPayment(PaymentFailed(s4.id, s4.paymentHash, Seq(LocalFailure(new RuntimeException("woops")), RemoteFailure(Seq(hop_ab, hop_bc), Sphinx.DecryptedFailurePacket(carol, UnknownNextPeer))), 11)) + assert(db.getOutgoingPayment(s4.id) === Some(s4.copy(status = FAILED, completedAt = Some(11), failureSummary = Some(PaymentFailureSummary(Seq(FailureSummary(FailureType.LOCAL, "woops", Nil), FailureSummary(FailureType.REMOTE, "processing node does not know the next peer in the route", List(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, Some(ShortChannelId(43))))))))))) // can't update again once it's in a final state - assertThrows[IllegalArgumentException](db.updateOutgoingPayment(s3.id, SUCCEEDED)) + assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentSent(s3.id, s3.amount, 42 msat, s3.paymentHash, defaultPreimage, Nil))) + + db.updateOutgoingPayment(PaymentSent(s1.id, s1.amount, 15 msat, s1.paymentHash, defaultPreimage, Nil, 10)) + assert(db.getOutgoingPayment(s1.id) === Some(s1.copy(status = SUCCEEDED, completedAt = Some(10), successSummary = Some(PaymentSuccessSummary(defaultPreimage, 15 msat, Nil))))) + db.updateOutgoingPayment(PaymentSent(s2.id, s2.amount, 15 msat, s2.paymentHash, defaultPreimage, Seq(hop_ab, hop_bc), 11)) + assert(db.getOutgoingPayment(s2.id) === Some(s2.copy(status = SUCCEEDED, completedAt = Some(11), successSummary = Some(PaymentSuccessSummary(defaultPreimage, 15 msat, Seq(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, Some(ShortChannelId(43))))))))) - db.updateOutgoingPayment(s1.id, SUCCEEDED, Some(ByteVector32.One)) - assert(db.getOutgoingPayment(s1.id).get.preimage.isDefined) - assert(db.getOutgoingPayment(s1.id).get.completedAt.isDefined) + // can't update again once it's in a final state + assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentFailed(s1.id, s1.paymentHash, Nil))) } test("add/retrieve payment requests") { - - val someTimestamp = 12345 val db = new SqlitePaymentsDb(TestConstants.sqliteInMemory()) - - val bob = Bob.keyManager - + val someTimestamp = 12345 val (paymentHash1, paymentHash2) = (randomBytes32, randomBytes32) - - val i1 = PaymentRequest(chainHash = Block.TestnetGenesisBlock.hash, amount = Some(123 msat), paymentHash = paymentHash1, privateKey = bob.nodeKey.privateKey, description = "Some invoice", expirySeconds = None, timestamp = someTimestamp) - val i2 = PaymentRequest(chainHash = Block.TestnetGenesisBlock.hash, amount = None, paymentHash = paymentHash2, privateKey = bob.nodeKey.privateKey, description = "Some invoice", expirySeconds = Some(123456), timestamp = Platform.currentTime.milliseconds.toSeconds) + val i1 = PaymentRequest(chainHash = Block.TestnetGenesisBlock.hash, amount = Some(123 msat), paymentHash = paymentHash1, privateKey = bobPriv, description = "Some invoice", expirySeconds = None, timestamp = someTimestamp) + val i2 = PaymentRequest(chainHash = Block.TestnetGenesisBlock.hash, amount = None, paymentHash = paymentHash2, privateKey = bobPriv, description = "Some invoice", expirySeconds = Some(123456), timestamp = Platform.currentTime.milliseconds.toSeconds) // i2 doesn't expire assert(i1.expiry.isEmpty && i2.expiry.isDefined) @@ -173,16 +285,28 @@ class SqlitePaymentsDbSpec extends FunSuite { // order matters, i2 has a more recent timestamp than i1 assert(db.listPaymentRequests(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) == Seq(i2, i1)) - assert(db.getPaymentRequest(i1.paymentHash) == Some(i1)) - assert(db.getPaymentRequest(i2.paymentHash) == Some(i2)) + assert(db.getPaymentRequest(i1.paymentHash) === Some(i1)) + assert(db.getPaymentRequest(i2.paymentHash) === Some(i2)) assert(db.listPendingPaymentRequests(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) == Seq(i2, i1)) - assert(db.getPendingPaymentRequestAndPreimage(paymentHash1) == Some((ByteVector32.Zeroes, i1))) - assert(db.getPendingPaymentRequestAndPreimage(paymentHash2) == Some((ByteVector32.One, i2))) + assert(db.getPendingPaymentRequestAndPreimage(paymentHash1) === Some((ByteVector32.Zeroes, i1))) + assert(db.getPendingPaymentRequestAndPreimage(paymentHash2) === Some((ByteVector32.One, i2))) val from = (someTimestamp - 100).seconds.toSeconds val to = (someTimestamp + 100).seconds.toSeconds assert(db.listPaymentRequests(from, to) == Seq(i1)) + + db.addIncomingPayment(IncomingPayment(i2.paymentHash, 42 msat, someTimestamp)) + assert(db.listPendingPaymentRequests(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) == Seq(i1)) } } + +object SqlitePaymentsDbSpec { + val (alicePriv, bobPriv, carolPriv, davePriv) = (randomKey, randomKey, randomKey, randomKey) + val (alice, bob, carol, dave) = (alicePriv.publicKey, bobPriv.publicKey, carolPriv.publicKey, davePriv.publicKey) + val hop_ab = Hop(alice, bob, ChannelUpdate(randomBytes64, randomBytes32, ShortChannelId(42), 1, 0, 0, CltvExpiryDelta(12), 1 msat, 1 msat, 1, None)) + val hop_bc = Hop(bob, carol, ChannelUpdate(randomBytes64, randomBytes32, ShortChannelId(43), 1, 0, 0, CltvExpiryDelta(12), 1 msat, 1 msat, 1, None)) + val defaultPreimage = randomBytes32 + val defaultPaymentHash = Crypto.sha256(defaultPreimage) +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index 4e7c846457..b9ab59418e 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -28,8 +28,9 @@ import fr.acinq.eclair.blockchain.{UtxoStatus, ValidateRequest, ValidateResult, import fr.acinq.eclair.channel.Register.ForwardShortId import fr.acinq.eclair.channel.{AddHtlcFailed, Channel, ChannelUnavailable} import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.db.OutgoingPaymentStatus +import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus} import fr.acinq.eclair.io.Peer.PeerRoutingMessage +import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest import fr.acinq.eclair.payment.PaymentLifecycle._ import fr.acinq.eclair.router.Announcements.{makeChannelUpdate, makeNodeAnnouncement} import fr.acinq.eclair.router._ @@ -46,14 +47,16 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val defaultAmountMsat = 142000000 msat val defaultExpiryDelta = Channel.MIN_CLTV_EXPIRY_DELTA + val defaultPaymentHash = randomBytes32 + val defaultExternalId = UUID.randomUUID() + val defaultPaymentRequest = SendPaymentRequest(defaultAmountMsat, defaultPaymentHash, d, 1, externalId = Some(defaultExternalId)) test("send to route") { fixture => import fixture._ - val defaultPaymentHash = randomBytes32 val nodeParams = TestConstants.Alice.nodeParams.copy(keyManager = testKeyManager) val id = UUID.randomUUID() val paymentDb = nodeParams.db.payments - val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, paymentDb) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest, paymentDb) val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, progressHandler, router, TestProbe().ref)) val monitor = TestProbe() val sender = TestProbe() @@ -70,6 +73,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val Transition(_, WAITING_FOR_ROUTE, WAITING_FOR_PAYMENT_COMPLETE) = monitor.expectMsgClass(classOf[Transition[_]]) awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) + val Some(outgoing) = paymentDb.getOutgoingPayment(id) + assert(outgoing.copy(createdAt = 0) === OutgoingPayment(id, None, Some(defaultExternalId), defaultPaymentHash, defaultAmountMsat, d, 0, OutgoingPaymentStatus.PENDING, None)) sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) sender.expectMsgType[PaymentSent] @@ -78,11 +83,10 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("payment failed (route not found)") { fixture => import fixture._ - val defaultPaymentHash = randomBytes32 val nodeParams = TestConstants.Alice.nodeParams.copy(keyManager = testKeyManager) val paymentDb = nodeParams.db.payments val id = UUID.randomUUID() - val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, paymentDb) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest.copy(targetNodeId = f), paymentDb) val routerForwarder = TestProbe() val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, progressHandler, routerForwarder.ref, TestProbe().ref)) val monitor = TestProbe() @@ -107,7 +111,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val nodeParams = TestConstants.Alice.nodeParams.copy(keyManager = testKeyManager) val paymentDb = nodeParams.db.payments val id = UUID.randomUUID() - val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, paymentDb) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest, paymentDb) val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, progressHandler, router, TestProbe().ref)) val monitor = TestProbe() val sender = TestProbe() @@ -115,7 +119,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 5, routeParams = Some(RouteParams(randomize = false, maxFeeBase = 100 msat, maxFeePct = 0.0, routeMaxLength = 20, routeMaxCltv = CltvExpiryDelta(2016), ratios = None))) + val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 5, routeParams = Some(RouteParams(randomize = false, maxFeeBase = 100 msat, maxFeePct = 0.0, routeMaxLength = 20, routeMaxCltv = CltvExpiryDelta(2016), ratios = None))) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) @@ -125,13 +129,12 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("payment failed (unparsable failure)") { fixture => import fixture._ - val defaultPaymentHash = randomBytes32 val nodeParams = TestConstants.Alice.nodeParams.copy(keyManager = testKeyManager) val paymentDb = nodeParams.db.payments val relayer = TestProbe() val routerForwarder = TestProbe() val id = UUID.randomUUID() - val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, paymentDb) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest, paymentDb) val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, progressHandler, routerForwarder.ref, relayer.ref)) val monitor = TestProbe() val sender = TestProbe() @@ -175,7 +178,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val relayer = TestProbe() val routerForwarder = TestProbe() val id = UUID.randomUUID() - val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, paymentDb) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest, paymentDb) val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, progressHandler, routerForwarder.ref, relayer.ref)) val monitor = TestProbe() val sender = TestProbe() @@ -183,7 +186,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 2) + val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) @@ -203,13 +206,12 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("payment failed (first hop returns an UpdateFailMalformedHtlc)") { fixture => import fixture._ - val defaultPaymentHash = randomBytes32 val nodeParams = TestConstants.Alice.nodeParams.copy(keyManager = testKeyManager) val paymentDb = nodeParams.db.payments val relayer = TestProbe() val routerForwarder = TestProbe() val id = UUID.randomUUID() - val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, paymentDb) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest, paymentDb) val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, progressHandler, routerForwarder.ref, relayer.ref)) val monitor = TestProbe() val sender = TestProbe() @@ -241,7 +243,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val relayer = TestProbe() val routerForwarder = TestProbe() val id = UUID.randomUUID() - val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, nodeParams.db.payments) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest, nodeParams.db.payments) val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, progressHandler, routerForwarder.ref, relayer.ref)) val monitor = TestProbe() val sender = TestProbe() @@ -249,7 +251,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 2) + val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData @@ -282,7 +284,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val relayer = TestProbe() val routerForwarder = TestProbe() val id = UUID.randomUUID() - val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, paymentDb) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest, paymentDb) val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, progressHandler, routerForwarder.ref, relayer.ref)) val monitor = TestProbe() val sender = TestProbe() @@ -290,7 +292,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 5) + val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 5) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) @@ -345,7 +347,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val relayer = TestProbe() val routerForwarder = TestProbe() val id = UUID.randomUUID() - val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, paymentDb) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest, paymentDb) val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, progressHandler, routerForwarder.ref, relayer.ref)) val monitor = TestProbe() val sender = TestProbe() @@ -353,7 +355,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 2) + val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) @@ -393,7 +395,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val nodeParams = TestConstants.Alice.nodeParams.copy(keyManager = testKeyManager) val paymentDb = nodeParams.db.payments val id = UUID.randomUUID() - val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, paymentDb) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest.copy(paymentHash = paymentHash), paymentDb) val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, progressHandler, router, TestProbe().ref)) val monitor = TestProbe() val sender = TestProbe() @@ -408,6 +410,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val Transition(_, WAITING_FOR_ROUTE, WAITING_FOR_PAYMENT_COMPLETE) = monitor.expectMsgClass(classOf[Transition[_]]) awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) + val Some(outgoing) = paymentDb.getOutgoingPayment(id) + assert(outgoing.copy(createdAt = 0) === OutgoingPayment(id, None, Some(defaultExternalId), paymentHash, defaultAmountMsat, d, 0, OutgoingPaymentStatus.PENDING, None)) sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, paymentPreimage)) val ps = eventListener.expectMsgType[PaymentSent] @@ -421,7 +425,6 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("payment succeeded to a channel with fees=0") { fixture => import fixture._ import fr.acinq.eclair.randomKey - val defaultPaymentHash = randomBytes32 val nodeParams = TestConstants.Alice.nodeParams.copy(keyManager = testKeyManager) // the network will be a --(1)--> b ---(2)--> c --(3)--> d and e --(4)--> f (we are a) and b -> g has fees=0 // \ @@ -444,7 +447,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { watcher.expectMsgType[WatchSpentBasic] // actual test begins - val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(UUID.randomUUID(), nodeParams.db.payments) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(UUID.randomUUID(), defaultPaymentRequest.copy(targetNodeId = g), nodeParams.db.payments) val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, progressHandler, router, TestProbe().ref)) val monitor = TestProbe() val sender = TestProbe() @@ -479,4 +482,5 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val filtered = PaymentFailure.transformForUser(failures) assert(filtered == LocalFailure(RouteNotFound) :: RemoteFailure(Hop(a, b, channelUpdate_ab) :: Nil, Sphinx.DecryptedFailurePacket(a, TemporaryNodeFailure)) :: LocalFailure(ChannelUnavailable(ByteVector32.Zeroes)) :: Nil) } + } diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala index e67797448f..6ade383e10 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala @@ -116,7 +116,7 @@ trait Service extends ExtraDirectives with Logging { .map(TextMessage.apply) } - val timeoutResponse: HttpRequest => HttpResponse = { r => + val timeoutResponse: HttpRequest => HttpResponse = { _ => HttpResponse(StatusCodes.RequestTimeout).withEntity(ContentTypes.`application/json`, serialization.writePretty(ErrorResponse("request timed out"))) } @@ -133,7 +133,7 @@ trait Service extends ExtraDirectives with Logging { toStrictEntity(paramParsingTimeout) { formFields("timeoutSeconds".as[Timeout].?) { tm_opt => // this is the akka timeout - implicit val timeout = tm_opt.getOrElse(Timeout(30 seconds)) + implicit val timeout: Timeout = tm_opt.getOrElse(Timeout(30 seconds)) // we ensure that http timeout is greater than akka timeout withRequestTimeout(timeout.duration + 2.seconds) { withRequestTimeoutResponse(timeoutResponse) { @@ -223,22 +223,24 @@ trait Service extends ExtraDirectives with Logging { } } ~ path("payinvoice") { - formFields(invoiceFormParam, amountMsatFormParam.?, "maxAttempts".as[Int].?, "feeThresholdSat".as[Satoshi].?, "maxFeePct".as[Double].?) { - case (invoice@PaymentRequest(_, Some(amount), _, nodeId, _, _), None, maxAttempts, feeThresholdSat_opt, maxFeePct_opt) => - complete(eclairApi.send(nodeId, amount, invoice.paymentHash, Some(invoice), maxAttempts, feeThresholdSat_opt, maxFeePct_opt)) - case (invoice, Some(overrideAmount), maxAttempts, feeThresholdSat_opt, maxFeePct_opt) => - complete(eclairApi.send(invoice.nodeId, overrideAmount, invoice.paymentHash, Some(invoice), maxAttempts, feeThresholdSat_opt, maxFeePct_opt)) + formFields(invoiceFormParam, amountMsatFormParam.?, "maxAttempts".as[Int].?, "feeThresholdSat".as[Satoshi].?, "maxFeePct".as[Double].?, "externalId".as[UUID].?) { + case (invoice@PaymentRequest(_, Some(amount), _, nodeId, _, _), None, maxAttempts, feeThresholdSat_opt, maxFeePct_opt, externalId_opt) => + complete(eclairApi.send(externalId_opt, nodeId, amount, invoice.paymentHash, Some(invoice), maxAttempts, feeThresholdSat_opt, maxFeePct_opt)) + case (invoice, Some(overrideAmount), maxAttempts, feeThresholdSat_opt, maxFeePct_opt, externalId_opt) => + complete(eclairApi.send(externalId_opt, invoice.nodeId, overrideAmount, invoice.paymentHash, Some(invoice), maxAttempts, feeThresholdSat_opt, maxFeePct_opt)) case _ => reject(MalformedFormFieldRejection("invoice", "The invoice must have an amount or you need to specify one using the field 'amountMsat'")) } } ~ path("sendtonode") { - formFields(amountMsatFormParam, paymentHashFormParam, nodeIdFormParam, "maxAttempts".as[Int].?, "feeThresholdSat".as[Satoshi].?, "maxFeePct".as[Double].?) { (amountMsat, paymentHash, nodeId, maxAttempts_opt, feeThresholdSat_opt, maxFeePct_opt) => - complete(eclairApi.send(nodeId, amountMsat, paymentHash, maxAttempts_opt = maxAttempts_opt, feeThresholdSat_opt = feeThresholdSat_opt, maxFeePct_opt = maxFeePct_opt)) + formFields(amountMsatFormParam, paymentHashFormParam, nodeIdFormParam, "maxAttempts".as[Int].?, "feeThresholdSat".as[Satoshi].?, "maxFeePct".as[Double].?, "externalId".as[UUID].?) { + (amountMsat, paymentHash, nodeId, maxAttempts_opt, feeThresholdSat_opt, maxFeePct_opt, externalId_opt) => + complete(eclairApi.send(externalId_opt, nodeId, amountMsat, paymentHash, maxAttempts_opt = maxAttempts_opt, feeThresholdSat_opt = feeThresholdSat_opt, maxFeePct_opt = maxFeePct_opt)) } } ~ path("sendtoroute") { - formFields(amountMsatFormParam, paymentHashFormParam, "finalCltvExpiry".as[Int], "route".as[List[PublicKey]](pubkeyListUnmarshaller)) { (amountMsat, paymentHash, finalCltvExpiry, route) => - complete(eclairApi.sendToRoute(route, amountMsat, paymentHash, CltvExpiryDelta(finalCltvExpiry))) + formFields(amountMsatFormParam, paymentHashFormParam, "finalCltvExpiry".as[Int], "route".as[List[PublicKey]](pubkeyListUnmarshaller), "externalId".as[UUID].?) { + (amountMsat, paymentHash, finalCltvExpiry, route, externalId_opt) => + complete(eclairApi.sendToRoute(externalId_opt, route, amountMsat, paymentHash, CltvExpiryDelta(finalCltvExpiry))) } } ~ path("getsentinfo") { diff --git a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala index 6bcccf64d8..fe69098ddf 100644 --- a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala +++ b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala @@ -32,8 +32,7 @@ import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair._ import fr.acinq.eclair.io.NodeURI import fr.acinq.eclair.io.Peer.PeerInfo -import fr.acinq.eclair.payment.PaymentFailed -import fr.acinq.eclair.payment._ +import fr.acinq.eclair.payment.{PaymentFailed, _} import fr.acinq.eclair.wire.NodeAddress import org.mockito.scalatest.IdiomaticMockito import org.scalatest.{FunSuite, Matchers} @@ -250,7 +249,7 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock test("'send' method should handle payment failures") { val eclair = mock[Eclair] - eclair.send(any, any, any, any, any, any, any)(any[Timeout]) returns Future.failed(new IllegalArgumentException("invoice has expired")) + eclair.send(any, any, any, any, any, any, any, any)(any[Timeout]) returns Future.failed(new IllegalArgumentException("invoice has expired")) val mockService = new MockService(eclair) val invoice = "lnbc12580n1pw2ywztpp554ganw404sh4yjkwnysgn3wjcxfcq7gtx53gxczkjr9nlpc3hzvqdq2wpskwctddyxqr4rqrzjqwryaup9lh50kkranzgcdnn2fgvx390wgj5jd07rwr3vxeje0glc7z9rtvqqwngqqqqqqqlgqqqqqeqqjqrrt8smgjvfj7sg38dwtr9kc9gg3era9k3t2hvq3cup0jvsrtrxuplevqgfhd3rzvhulgcxj97yjuj8gdx8mllwj4wzjd8gdjhpz3lpqqvk2plh" @@ -263,7 +262,7 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock assert(status == BadRequest) val resp = entityAs[ErrorResponse](Json4sSupport.unmarshaller, ClassTag(classOf[ErrorResponse])) assert(resp.error == "invoice has expired") - eclair.send(any, 1258000 msat, any, any, any, any, any)(any[Timeout]).wasCalled(once) + eclair.send(None, any, 1258000 msat, any, any, any, any, any)(any[Timeout]).wasCalled(once) } } @@ -271,7 +270,7 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock val invoice = "lnbc12580n1pw2ywztpp554ganw404sh4yjkwnysgn3wjcxfcq7gtx53gxczkjr9nlpc3hzvqdq2wpskwctddyxqr4rqrzjqwryaup9lh50kkranzgcdnn2fgvx390wgj5jd07rwr3vxeje0glc7z9rtvqqwngqqqqqqqlgqqqqqeqqjqrrt8smgjvfj7sg38dwtr9kc9gg3era9k3t2hvq3cup0jvsrtrxuplevqgfhd3rzvhulgcxj97yjuj8gdx8mllwj4wzjd8gdjhpz3lpqqvk2plh" val eclair = mock[Eclair] - eclair.send(any, any, any, any, any, any, any)(any[Timeout]) returns Future.successful(UUID.randomUUID()) + eclair.send(any, any, any, any, any, any, any, any)(any[Timeout]) returns Future.successful(UUID.randomUUID()) val mockService = new MockService(eclair) Post("/payinvoice", FormData("invoice" -> invoice).toEntity) ~> @@ -280,18 +279,18 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock check { assert(handled) assert(status == OK) - eclair.send(any, 1258000 msat, any, any, any, any, any)(any[Timeout]).wasCalled(once) + eclair.send(None, any, 1258000 msat, any, any, any, any, any)(any[Timeout]).wasCalled(once) } - Post("/payinvoice", FormData("invoice" -> invoice, "amountMsat" -> "123", "feeThresholdSat" -> "112233", "maxFeePct" -> "2.34").toEntity) ~> + val externalId = UUID.randomUUID() + Post("/payinvoice", FormData("invoice" -> invoice, "amountMsat" -> "123", "feeThresholdSat" -> "112233", "maxFeePct" -> "2.34", "externalId" -> externalId.toString).toEntity) ~> addCredentials(BasicHttpCredentials("", mockService.password)) ~> Route.seal(mockService.route) ~> check { assert(handled) assert(status == OK) - eclair.send(any, 123 msat, any, any, any, Some(112233 sat), Some(2.34))(any[Timeout]).wasCalled(once) + eclair.send(Some(externalId), any, 123 msat, any, any, any, Some(112233 sat), Some(2.34))(any[Timeout]).wasCalled(once) } - } test("'getreceivedinfo' method should respond HTTP 404 with a JSON encoded response if the element is not found") { @@ -314,15 +313,16 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock test("'sendtoroute' method should accept a both a json-encoded AND comma separaterd list of pubkeys") { val rawUUID = "487da196-a4dc-4b1e-92b4-3e5e905e9f3f" val paymentUUID = UUID.fromString(rawUUID) + val externalId = UUID.randomUUID() val expectedRoute = List(PublicKey(hex"0217eb8243c95f5a3b7d4c5682d10de354b7007eb59b6807ae407823963c7547a9"), PublicKey(hex"0242a4ae0c5bef18048fbecf995094b74bfb0f7391418d71ed394784373f41e4f3"), PublicKey(hex"026ac9fcd64fb1aa1c491fc490634dc33da41d4a17b554e0adf1b32fee88ee9f28")) val csvNodes = "0217eb8243c95f5a3b7d4c5682d10de354b7007eb59b6807ae407823963c7547a9, 0242a4ae0c5bef18048fbecf995094b74bfb0f7391418d71ed394784373f41e4f3, 026ac9fcd64fb1aa1c491fc490634dc33da41d4a17b554e0adf1b32fee88ee9f28" val jsonNodes = serialization.write(expectedRoute) val eclair = mock[Eclair] - eclair.sendToRoute(any[List[PublicKey]], any[MilliSatoshi], any[ByteVector32], any[CltvExpiryDelta])(any[Timeout]) returns Future.successful(paymentUUID) + eclair.sendToRoute(any[Option[UUID]], any[List[PublicKey]], any[MilliSatoshi], any[ByteVector32], any[CltvExpiryDelta])(any[Timeout]) returns Future.successful(paymentUUID) val mockService = new MockService(eclair) - Post("/sendtoroute", FormData("route" -> jsonNodes, "amountMsat" -> "1234", "paymentHash" -> ByteVector32.Zeroes.toHex, "finalCltvExpiry" -> "190").toEntity) ~> + Post("/sendtoroute", FormData("route" -> jsonNodes, "amountMsat" -> "1234", "paymentHash" -> ByteVector32.Zeroes.toHex, "finalCltvExpiry" -> "190", "externalId" -> externalId.toString).toEntity) ~> addCredentials(BasicHttpCredentials("", mockService.password)) ~> addHeader("Content-Type", "application/json") ~> Route.seal(mockService.route) ~> @@ -330,7 +330,7 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock assert(handled) assert(status == OK) assert(entityAs[String] == "\"" + rawUUID + "\"") - eclair.sendToRoute(expectedRoute, 1234 msat, ByteVector32.Zeroes, CltvExpiryDelta(190))(any[Timeout]).wasCalled(once) + eclair.sendToRoute(Some(externalId), expectedRoute, 1234 msat, ByteVector32.Zeroes, CltvExpiryDelta(190))(any[Timeout]).wasCalled(once) } // this test uses CSV encoded route @@ -342,7 +342,7 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock assert(handled) assert(status == OK) assert(entityAs[String] == "\"" + rawUUID + "\"") - eclair.sendToRoute(expectedRoute, 1234 msat, ByteVector32.One, CltvExpiryDelta(190))(any[Timeout]).wasCalled(once) + eclair.sendToRoute(None, expectedRoute, 1234 msat, ByteVector32.One, CltvExpiryDelta(190))(any[Timeout]).wasCalled(once) } } From 274bbbfd3fb60d23b6db532cdaed50189c1e1a20 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Tue, 17 Sep 2019 17:02:10 +0200 Subject: [PATCH 04/14] PaymentLifecycle: rename onSuccess / onFailure --- .../eclair/payment/PaymentLifecycle.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala index 65e7905a1f..20e605dd34 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala @@ -69,7 +69,7 @@ class PaymentLifecycle(nodeParams: NodeParams, progressHandler: PaymentProgressH goto(WAITING_FOR_PAYMENT_COMPLETE) using WaitingForComplete(s, c, cmd, failures, sharedSecrets, ignoreNodes, ignoreChannels, hops) case Event(Status.Failure(t), WaitingForRoute(s, c, failures)) => - progressHandler.onFail(s, PaymentFailed(id, c.paymentHash, failures :+ LocalFailure(t)))(context) + progressHandler.onFailure(s, PaymentFailed(id, c.paymentHash, failures :+ LocalFailure(t)))(context) stop(FSM.Normal) } @@ -77,7 +77,7 @@ class PaymentLifecycle(nodeParams: NodeParams, progressHandler: PaymentProgressH case Event("ok", _) => stay case Event(fulfill: UpdateFulfillHtlc, WaitingForComplete(s, c, cmd, _, _, _, _, route)) => - progressHandler.onSucceed(s, PaymentSent(id, c.finalPayload.amount, cmd.amount - c.finalPayload.amount, c.paymentHash, fulfill.paymentPreimage, route))(context) + progressHandler.onSuccess(s, PaymentSent(id, c.finalPayload.amount, cmd.amount - c.finalPayload.amount, c.paymentHash, fulfill.paymentPreimage, route))(context) stop(FSM.Normal) case Event(fail: UpdateFailHtlc, WaitingForComplete(s, c, _, failures, sharedSecrets, ignoreNodes, ignoreChannels, hops)) => @@ -85,7 +85,7 @@ class PaymentLifecycle(nodeParams: NodeParams, progressHandler: PaymentProgressH case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) if nodeId == c.targetNodeId => // if destination node returns an error, we fail the payment immediately log.warning(s"received an error message from target nodeId=$nodeId, failing the payment (failure=$failureMessage)") - progressHandler.onFail(s, PaymentFailed(id, c.paymentHash, failures :+ RemoteFailure(hops, e)))(context) + progressHandler.onFailure(s, PaymentFailed(id, c.paymentHash, failures :+ RemoteFailure(hops, e)))(context) stop(FSM.Normal) case res if failures.size + 1 >= c.maxAttempts => // otherwise we never try more than maxAttempts, no matter the kind of error returned @@ -98,7 +98,7 @@ class PaymentLifecycle(nodeParams: NodeParams, progressHandler: PaymentProgressH UnreadableRemoteFailure(hops) } log.warning(s"too many failed attempts, failing the payment") - progressHandler.onFail(s, PaymentFailed(id, c.paymentHash, failures :+ failure))(context) + progressHandler.onFailure(s, PaymentFailed(id, c.paymentHash, failures :+ failure))(context) stop(FSM.Normal) case Failure(t) => log.warning(s"cannot parse returned error: ${t.getMessage}") @@ -163,7 +163,7 @@ class PaymentLifecycle(nodeParams: NodeParams, progressHandler: PaymentProgressH case Event(Status.Failure(t), WaitingForComplete(s, c, _, failures, _, ignoreNodes, ignoreChannels, hops)) => if (failures.size + 1 >= c.maxAttempts) { - progressHandler.onFail(s, PaymentFailed(id, c.paymentHash, failures :+ LocalFailure(t)))(context) + progressHandler.onFailure(s, PaymentFailed(id, c.paymentHash, failures :+ LocalFailure(t)))(context) stop(FSM.Normal) } else { log.info(s"received an error message from local, trying to use a different channel (failure=${t.getMessage})") @@ -191,8 +191,8 @@ object PaymentLifecycle { // @formatter:off def onSend(): Unit - def onSucceed(sender: ActorRef, result: PaymentSent)(ctx: ActorContext): Unit - def onFail(sender: ActorRef, result: PaymentFailed)(ctx: ActorContext): Unit + def onSuccess(sender: ActorRef, result: PaymentSent)(ctx: ActorContext): Unit + def onFailure(sender: ActorRef, result: PaymentFailed)(ctx: ActorContext): Unit // @formatter:on } @@ -203,13 +203,13 @@ object PaymentLifecycle { db.addOutgoingPayment(OutgoingPayment(id, None, r.externalId, r.paymentHash, r.amount, r.targetNodeId, Platform.currentTime, OutgoingPaymentStatus.PENDING, r.paymentRequest)) } - override def onSucceed(sender: ActorRef, result: PaymentSent)(ctx: ActorContext): Unit = { + override def onSuccess(sender: ActorRef, result: PaymentSent)(ctx: ActorContext): Unit = { db.updateOutgoingPayment(result) sender ! result ctx.system.eventStream.publish(result) } - override def onFail(sender: ActorRef, result: PaymentFailed)(ctx: ActorContext): Unit = { + override def onFailure(sender: ActorRef, result: PaymentFailed)(ctx: ActorContext): Unit = { db.updateOutgoingPayment(result) sender ! result ctx.system.eventStream.publish(result) From c620cd109eae5afa1e38796bb8f45fdb6b12d546 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Tue, 17 Sep 2019 17:33:44 +0200 Subject: [PATCH 05/14] Change externalId to be a String --- .../main/scala/fr/acinq/eclair/Eclair.scala | 42 +++++++++++-------- .../scala/fr/acinq/eclair/db/PaymentsDb.scala | 28 ++++++------- .../eclair/db/sqlite/SqlitePaymentsDb.scala | 6 +-- .../eclair/payment/PaymentInitiator.scala | 2 +- .../fr/acinq/eclair/EclairImplSpec.scala | 17 ++++---- .../eclair/db/SqlitePaymentsDbSpec.scala | 8 ++-- .../eclair/payment/PaymentLifecycleSpec.scala | 2 +- .../scala/fr/acinq/eclair/api/Service.scala | 6 +-- .../fr/acinq/eclair/api/ApiServiceSpec.scala | 9 ++-- 9 files changed, 64 insertions(+), 56 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index 4c9c51a5c2..3deb328e54 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -78,13 +78,13 @@ trait Eclair { def receivedInfo(paymentHash: ByteVector32)(implicit timeout: Timeout): Future[Option[IncomingPayment]] - def send(externalId: Option[UUID], recipientNodeId: PublicKey, amount: MilliSatoshi, paymentHash: ByteVector32, invoice_opt: Option[PaymentRequest] = None, maxAttempts_opt: Option[Int] = None, feeThresholdSat_opt: Option[Satoshi] = None, maxFeePct_opt: Option[Double] = None)(implicit timeout: Timeout): Future[UUID] + def send(externalId_opt: Option[String], recipientNodeId: PublicKey, amount: MilliSatoshi, paymentHash: ByteVector32, invoice_opt: Option[PaymentRequest] = None, maxAttempts_opt: Option[Int] = None, feeThresholdSat_opt: Option[Satoshi] = None, maxFeePct_opt: Option[Double] = None)(implicit timeout: Timeout): Future[UUID] def sentInfo(id: Either[UUID, ByteVector32])(implicit timeout: Timeout): Future[Seq[OutgoingPayment]] def findRoute(targetNodeId: PublicKey, amount: MilliSatoshi, assistedRoutes: Seq[Seq[PaymentRequest.ExtraHop]] = Seq.empty)(implicit timeout: Timeout): Future[RouteResponse] - def sendToRoute(externalId: Option[UUID], route: Seq[PublicKey], amount: MilliSatoshi, paymentHash: ByteVector32, finalCltvExpiryDelta: CltvExpiryDelta)(implicit timeout: Timeout): Future[UUID] + def sendToRoute(externalId_opt: Option[String], route: Seq[PublicKey], amount: MilliSatoshi, paymentHash: ByteVector32, finalCltvExpiryDelta: CltvExpiryDelta)(implicit timeout: Timeout): Future[UUID] def audit(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[AuditResponse] @@ -113,6 +113,9 @@ class EclairImpl(appKit: Kit) extends Eclair { implicit val ec: ExecutionContext = appKit.system.dispatcher + // We constrain external identifiers. This allows uuid, long and pubkey to be used. + private val externalIdMaxLength = 66 + override def connect(target: Either[NodeURI, PublicKey])(implicit timeout: Timeout): Future[String] = target match { case Left(uri) => (appKit.switchboard ? Peer.Connect(uri)).mapTo[String] case Right(pubKey) => (appKit.switchboard ? Peer.Connect(pubKey, None)).mapTo[String] @@ -186,30 +189,35 @@ class EclairImpl(appKit: Kit) extends Eclair { (appKit.router ? RouteRequest(appKit.nodeParams.nodeId, targetNodeId, amount, assistedRoutes)).mapTo[RouteResponse] } - override def sendToRoute(externalId: Option[UUID], route: Seq[PublicKey], amount: MilliSatoshi, paymentHash: ByteVector32, finalCltvExpiryDelta: CltvExpiryDelta)(implicit timeout: Timeout): Future[UUID] = { - (appKit.paymentInitiator ? SendPaymentRequest(amount, paymentHash, route.last, 1, finalCltvExpiryDelta, None, externalId, route)).mapTo[UUID] + override def sendToRoute(externalId_opt: Option[String], route: Seq[PublicKey], amount: MilliSatoshi, paymentHash: ByteVector32, finalCltvExpiryDelta: CltvExpiryDelta)(implicit timeout: Timeout): Future[UUID] = { + externalId_opt match { + case Some(externalId) if externalId.length > externalIdMaxLength => Future.failed(new IllegalArgumentException("externalId is too long: cannot exceed 66 characters")) + case _ => (appKit.paymentInitiator ? SendPaymentRequest(amount, paymentHash, route.last, 1, finalCltvExpiryDelta, None, externalId_opt, route)).mapTo[UUID] + } } - override def send(externalId: Option[UUID], recipientNodeId: PublicKey, amount: MilliSatoshi, paymentHash: ByteVector32, invoice_opt: Option[PaymentRequest], maxAttempts_opt: Option[Int], feeThreshold_opt: Option[Satoshi], maxFeePct_opt: Option[Double])(implicit timeout: Timeout): Future[UUID] = { + override def send(externalId_opt: Option[String], recipientNodeId: PublicKey, amount: MilliSatoshi, paymentHash: ByteVector32, invoice_opt: Option[PaymentRequest], maxAttempts_opt: Option[Int], feeThreshold_opt: Option[Satoshi], maxFeePct_opt: Option[Double])(implicit timeout: Timeout): Future[UUID] = { val maxAttempts = maxAttempts_opt.getOrElse(appKit.nodeParams.maxPaymentAttempts) - val defaultRouteParams = Router.getDefaultRouteParams(appKit.nodeParams.routerConf) val routeParams = defaultRouteParams.copy( maxFeePct = maxFeePct_opt.getOrElse(defaultRouteParams.maxFeePct), maxFeeBase = feeThreshold_opt.map(_.toMilliSatoshi).getOrElse(defaultRouteParams.maxFeeBase) ) - invoice_opt match { - case Some(invoice) if invoice.isExpired => Future.failed(new IllegalArgumentException("invoice has expired")) - case Some(invoice) => - val sendPayment = invoice.minFinalCltvExpiryDelta match { - case Some(minFinalCltvExpiryDelta) => SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts, minFinalCltvExpiryDelta, invoice_opt, externalId, assistedRoutes = invoice.routingInfo, routeParams = Some(routeParams)) - case None => SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts, paymentRequest = invoice_opt, externalId = externalId, assistedRoutes = invoice.routingInfo, routeParams = Some(routeParams)) - } - (appKit.paymentInitiator ? sendPayment).mapTo[UUID] - case None => - val sendPayment = SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts = maxAttempts, externalId = externalId, routeParams = Some(routeParams)) - (appKit.paymentInitiator ? sendPayment).mapTo[UUID] + externalId_opt match { + case Some(externalId) if externalId.length > externalIdMaxLength => Future.failed(new IllegalArgumentException("externalId is too long: cannot exceed 66 characters")) + case _ => invoice_opt match { + case Some(invoice) if invoice.isExpired => Future.failed(new IllegalArgumentException("invoice has expired")) + case Some(invoice) => + val sendPayment = invoice.minFinalCltvExpiryDelta match { + case Some(minFinalCltvExpiryDelta) => SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts, minFinalCltvExpiryDelta, invoice_opt, externalId_opt, assistedRoutes = invoice.routingInfo, routeParams = Some(routeParams)) + case None => SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts, paymentRequest = invoice_opt, externalId = externalId_opt, assistedRoutes = invoice.routingInfo, routeParams = Some(routeParams)) + } + (appKit.paymentInitiator ? sendPayment).mapTo[UUID] + case None => + val sendPayment = SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts = maxAttempts, externalId = externalId_opt, routeParams = Some(routeParams)) + (appKit.paymentInitiator ? sendPayment).mapTo[UUID] + } } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala index 64de9c7ae5..af538c5d1b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala @@ -83,28 +83,28 @@ case class IncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, rece * An outgoing payment sent by this node. * At first it is in a pending state, then will become either a success or a failure. * - * @param id internal payment identifier. - * @param parentId internal identifier of a parent payment, if any. - * @param externalId external payment identifier: lets lightning applications reconcile payments with their own db. - * @param paymentHash payment_hash. - * @param amount amount of the payment, in milli-satoshis. - * @param targetNodeId node ID of the payment recipient. - * @param createdAt absolute time in seconds since UNIX epoch when the payment was created. - * @param status current status of the payment. - * @param paymentRequest_opt Bolt 11 payment request (if paying from an invoice). - * @param completedAt absolute time in seconds since UNIX epoch when the payment completed (success of failure). - * @param successSummary summary of the payment success (if status == "SUCCEEDED"). - * @param failureSummary summary of the payment failure (if status == "FAILED"). + * @param id internal payment identifier. + * @param parentId internal identifier of a parent payment, if any. + * @param externalId external payment identifier: lets lightning applications reconcile payments with their own db. + * @param paymentHash payment_hash. + * @param amount amount of the payment, in milli-satoshis. + * @param targetNodeId node ID of the payment recipient. + * @param createdAt absolute time in seconds since UNIX epoch when the payment was created. + * @param status current status of the payment. + * @param paymentRequest Bolt 11 payment request (if paying from an invoice). + * @param completedAt absolute time in seconds since UNIX epoch when the payment completed (success of failure). + * @param successSummary summary of the payment success (if status == "SUCCEEDED"). + * @param failureSummary summary of the payment failure (if status == "FAILED"). */ case class OutgoingPayment(id: UUID, parentId: Option[UUID], - externalId: Option[UUID], + externalId: Option[String], paymentHash: ByteVector32, amount: MilliSatoshi, targetNodeId: PublicKey, createdAt: Long, status: OutgoingPaymentStatus.Value, - paymentRequest_opt: Option[PaymentRequest], + paymentRequest: Option[PaymentRequest], completedAt: Option[Long] = None, successSummary: Option[PaymentSuccessSummary] = None, failureSummary: Option[PaymentFailureSummary] = None) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala index 22aa6639ca..17f308b3df 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala @@ -102,13 +102,13 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { using(sqlite.prepareStatement("INSERT INTO sent_payments (id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, status, payment_request) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement => statement.setString(1, sent.id.toString) statement.setString(2, sent.parentId.map(_.toString).orNull) - statement.setString(3, sent.externalId.map(_.toString).orNull) + statement.setString(3, sent.externalId.orNull) statement.setBytes(4, sent.paymentHash.toArray) statement.setLong(5, sent.amount.toLong) statement.setBytes(6, sent.targetNodeId.value.toArray) statement.setLong(7, sent.createdAt) statement.setString(8, sent.status.toString) - statement.setString(9, sent.paymentRequest_opt.map(PaymentRequest.write).orNull) + statement.setString(9, sent.paymentRequest.map(PaymentRequest.write).orNull) statement.executeUpdate() } @@ -136,7 +136,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { val result = OutgoingPayment( UUID.fromString(rs.getString("id")), rs.getStringNullable("parent_id").map(UUID.fromString), - rs.getStringNullable("external_id").map(UUID.fromString), + rs.getStringNullable("external_id"), rs.getByteVector32("payment_hash"), MilliSatoshi(rs.getLong("amount_msat")), PublicKey(rs.getByteVector("target_node_id")), diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala index b86ee963e9..863ba579ce 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala @@ -59,7 +59,7 @@ object PaymentInitiator { maxAttempts: Int, finalExpiryDelta: CltvExpiryDelta = Channel.MIN_CLTV_EXPIRY_DELTA, paymentRequest: Option[PaymentRequest] = None, - externalId: Option[UUID] = None, + externalId: Option[String] = None, predefinedRoute: Seq[PublicKey] = Nil, assistedRoutes: Seq[Seq[ExtraHop]] = Nil, routeParams: Option[RouteParams] = None) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala index 44d9e7f748..7a78f791ac 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala @@ -16,8 +16,6 @@ package fr.acinq.eclair -import java.util.UUID - import akka.actor.ActorSystem import akka.testkit.{TestKit, TestProbe} import akka.util.Timeout @@ -106,7 +104,7 @@ class EclairImplSpec extends TestKit(ActorSystem("test")) with fixture.FunSuiteL assert(send.assistedRoutes === Seq.empty) // with assisted routes - val externalId1 = UUID.randomUUID() + val externalId1 = "030bb6a5e0c6b203c7e2180fb78c7ba4bdce46126761d8201b91ddac089cdecc87" val hints = List(List(ExtraHop(Bob.nodeParams.nodeId, ShortChannelId("569178x2331x1"), feeBase = 10 msat, feeProportionalMillionths = 1, cltvExpiryDelta = CltvExpiryDelta(12)))) val invoice1 = PaymentRequest(Block.RegtestGenesisBlock.hash, Some(123 msat), ByteVector32.Zeroes, randomKey, "description", None, None, hints) eclair.send(Some(externalId1), nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(invoice1)) @@ -119,10 +117,11 @@ class EclairImplSpec extends TestKit(ActorSystem("test")) with fixture.FunSuiteL assert(send1.assistedRoutes === hints) // with finalCltvExpiry + val externalId2 = "487da196-a4dc-4b1e-92b4-3e5e905e9f3f" val invoice2 = PaymentRequest("lntb", Some(123 msat), System.currentTimeMillis() / 1000L, nodeId, List(PaymentRequest.MinFinalCltvExpiry(96), PaymentRequest.PaymentHash(ByteVector32.Zeroes), PaymentRequest.Description("description")), ByteVector.empty) - eclair.send(Some(externalId1), nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(invoice2)) + eclair.send(Some(externalId2), nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(invoice2)) val send2 = paymentInitiator.expectMsgType[SendPaymentRequest] - assert(send2.externalId === Some(externalId1)) + assert(send2.externalId === Some(externalId2)) assert(send2.targetNodeId === nodeId) assert(send2.amount === 123.msat) assert(send2.paymentHash === ByteVector32.Zeroes) @@ -139,6 +138,9 @@ class EclairImplSpec extends TestKit(ActorSystem("test")) with fixture.FunSuiteL assert(send3.routeParams.get.maxFeeBase === 123000.msat) // conversion sat -> msat assert(send3.routeParams.get.maxFeePct === 4.20) + val invalidExternalId = "Robert'); DROP TABLE received_payments; DROP TABLE sent_payments; DROP TABLE payments;" + assertThrows[IllegalArgumentException](Await.result(eclair.send(Some(invalidExternalId), nodeId, 123 msat, ByteVector32.Zeroes), 50 millis)) + val expiredInvoice = invoice2.copy(timestamp = 0L) assertThrows[IllegalArgumentException](Await.result(eclair.send(None, nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(expiredInvoice)), 50 millis)) } @@ -259,13 +261,12 @@ class EclairImplSpec extends TestKit(ActorSystem("test")) with fixture.FunSuiteL test("sendtoroute should pass the parameters correctly") { f => import f._ - val externalId = UUID.randomUUID() val route = Seq(PublicKey(hex"030bb6a5e0c6b203c7e2180fb78c7ba4bdce46126761d8201b91ddac089cdecc87")) val eclair = new EclairImpl(kit) - eclair.sendToRoute(Some(externalId), route, 1234 msat, ByteVector32.One, CltvExpiryDelta(123)) + eclair.sendToRoute(Some("42"), route, 1234 msat, ByteVector32.One, CltvExpiryDelta(123)) val send = paymentInitiator.expectMsgType[SendPaymentRequest] - assert(send.externalId === Some(externalId)) + assert(send.externalId === Some("42")) assert(send.predefinedRoute === route) assert(send.amount === 1234.msat) assert(send.finalExpiryDelta === CltvExpiryDelta(123)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala index 8a1d2ee211..c391b86e68 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala @@ -191,9 +191,9 @@ class SqlitePaymentsDbSpec extends FunSuite { val i3 = PaymentRequest.read("lnbc1500n1pdl686hpp5y7mz3lgvrfccqnk9es6trumjgqdpjwcecycpkdggnx7h6cuup90sdpa2fjkzep6ypqkymm4wssycnjzf9rjqurjda4x2cm5ypskuepqv93x7at5ypek7cqzysxqr23s5e864m06fcfp3axsefy276d77tzp0xzzzdfl6p46wvstkeqhu50khm9yxea2d9efp7lvthrta0ktmhsv52hf3tvxm0unsauhmfmp27cqqx4xxe") postMigrationDb.addPaymentRequest(i3, randomBytes32) - val ps4 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), Some(UUID.randomUUID()), randomBytes32, 123 msat, alice, 6, PENDING, Some(i3)) - val ps5 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), Some(UUID.randomUUID()), randomBytes32, 456 msat, bob, 7, SUCCEEDED, Some(i2), Some(8), Some(PaymentSuccessSummary(randomBytes32, 42 msat, Nil))) - val ps6 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), Some(UUID.randomUUID()), randomBytes32, 789 msat, bob, 8, FAILED, None, Some(9), None, Some(PaymentFailureSummary(Nil))) + val ps4 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), Some("1"), randomBytes32, 123 msat, alice, 6, PENDING, Some(i3)) + val ps5 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), Some("2"), randomBytes32, 456 msat, bob, 7, SUCCEEDED, Some(i2), Some(8), Some(PaymentSuccessSummary(randomBytes32, 42 msat, Nil))) + val ps6 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), Some("3"), randomBytes32, 789 msat, bob, 8, FAILED, None, Some(9), None, Some(PaymentFailureSummary(Nil))) postMigrationDb.addOutgoingPayment(ps4) postMigrationDb.addOutgoingPayment(ps5) postMigrationDb.updateOutgoingPayment(PaymentSent(ps5.id, ps5.amount, ps5.successSummary.get.feesPaid, ps5.paymentHash, ps5.successSummary.get.paymentPreimage, Nil, ps5.completedAt.get)) @@ -235,7 +235,7 @@ class SqlitePaymentsDbSpec extends FunSuite { val i1 = PaymentRequest(chainHash = Block.TestnetGenesisBlock.hash, amount = Some(123 msat), paymentHash = randomBytes32, privateKey = davePriv, description = "Some invoice", expirySeconds = None, timestamp = 1) val s1 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), None, i1.paymentHash, 123 msat, alice, 1, PENDING, Some(i1)) - val s2 = OutgoingPayment(UUID.randomUUID(), None, Some(UUID.randomUUID()), randomBytes32, 456 msat, bob, 2, PENDING, None) + val s2 = OutgoingPayment(UUID.randomUUID(), None, Some("1"), randomBytes32, 456 msat, bob, 2, PENDING, None) assert(db.listOutgoingPayments(0, Platform.currentTime.milliseconds.toSeconds).isEmpty) db.addOutgoingPayment(s1) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index b9ab59418e..48694e7553 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -48,7 +48,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val defaultAmountMsat = 142000000 msat val defaultExpiryDelta = Channel.MIN_CLTV_EXPIRY_DELTA val defaultPaymentHash = randomBytes32 - val defaultExternalId = UUID.randomUUID() + val defaultExternalId = UUID.randomUUID().toString val defaultPaymentRequest = SendPaymentRequest(defaultAmountMsat, defaultPaymentHash, d, 1, externalId = Some(defaultExternalId)) test("send to route") { fixture => diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala index 6ade383e10..23feee8fe6 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala @@ -223,7 +223,7 @@ trait Service extends ExtraDirectives with Logging { } } ~ path("payinvoice") { - formFields(invoiceFormParam, amountMsatFormParam.?, "maxAttempts".as[Int].?, "feeThresholdSat".as[Satoshi].?, "maxFeePct".as[Double].?, "externalId".as[UUID].?) { + formFields(invoiceFormParam, amountMsatFormParam.?, "maxAttempts".as[Int].?, "feeThresholdSat".as[Satoshi].?, "maxFeePct".as[Double].?, "externalId".?) { case (invoice@PaymentRequest(_, Some(amount), _, nodeId, _, _), None, maxAttempts, feeThresholdSat_opt, maxFeePct_opt, externalId_opt) => complete(eclairApi.send(externalId_opt, nodeId, amount, invoice.paymentHash, Some(invoice), maxAttempts, feeThresholdSat_opt, maxFeePct_opt)) case (invoice, Some(overrideAmount), maxAttempts, feeThresholdSat_opt, maxFeePct_opt, externalId_opt) => @@ -232,13 +232,13 @@ trait Service extends ExtraDirectives with Logging { } } ~ path("sendtonode") { - formFields(amountMsatFormParam, paymentHashFormParam, nodeIdFormParam, "maxAttempts".as[Int].?, "feeThresholdSat".as[Satoshi].?, "maxFeePct".as[Double].?, "externalId".as[UUID].?) { + formFields(amountMsatFormParam, paymentHashFormParam, nodeIdFormParam, "maxAttempts".as[Int].?, "feeThresholdSat".as[Satoshi].?, "maxFeePct".as[Double].?, "externalId".?) { (amountMsat, paymentHash, nodeId, maxAttempts_opt, feeThresholdSat_opt, maxFeePct_opt, externalId_opt) => complete(eclairApi.send(externalId_opt, nodeId, amountMsat, paymentHash, maxAttempts_opt = maxAttempts_opt, feeThresholdSat_opt = feeThresholdSat_opt, maxFeePct_opt = maxFeePct_opt)) } } ~ path("sendtoroute") { - formFields(amountMsatFormParam, paymentHashFormParam, "finalCltvExpiry".as[Int], "route".as[List[PublicKey]](pubkeyListUnmarshaller), "externalId".as[UUID].?) { + formFields(amountMsatFormParam, paymentHashFormParam, "finalCltvExpiry".as[Int], "route".as[List[PublicKey]](pubkeyListUnmarshaller), "externalId".?) { (amountMsat, paymentHash, finalCltvExpiry, route, externalId_opt) => complete(eclairApi.sendToRoute(externalId_opt, route, amountMsat, paymentHash, CltvExpiryDelta(finalCltvExpiry))) } diff --git a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala index fe69098ddf..01077eeaea 100644 --- a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala +++ b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala @@ -282,14 +282,13 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock eclair.send(None, any, 1258000 msat, any, any, any, any, any)(any[Timeout]).wasCalled(once) } - val externalId = UUID.randomUUID() - Post("/payinvoice", FormData("invoice" -> invoice, "amountMsat" -> "123", "feeThresholdSat" -> "112233", "maxFeePct" -> "2.34", "externalId" -> externalId.toString).toEntity) ~> + Post("/payinvoice", FormData("invoice" -> invoice, "amountMsat" -> "123", "feeThresholdSat" -> "112233", "maxFeePct" -> "2.34", "externalId" -> "42").toEntity) ~> addCredentials(BasicHttpCredentials("", mockService.password)) ~> Route.seal(mockService.route) ~> check { assert(handled) assert(status == OK) - eclair.send(Some(externalId), any, 123 msat, any, any, any, Some(112233 sat), Some(2.34))(any[Timeout]).wasCalled(once) + eclair.send(Some("42"), any, 123 msat, any, any, any, Some(112233 sat), Some(2.34))(any[Timeout]).wasCalled(once) } } @@ -313,13 +312,13 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock test("'sendtoroute' method should accept a both a json-encoded AND comma separaterd list of pubkeys") { val rawUUID = "487da196-a4dc-4b1e-92b4-3e5e905e9f3f" val paymentUUID = UUID.fromString(rawUUID) - val externalId = UUID.randomUUID() + val externalId = UUID.randomUUID().toString val expectedRoute = List(PublicKey(hex"0217eb8243c95f5a3b7d4c5682d10de354b7007eb59b6807ae407823963c7547a9"), PublicKey(hex"0242a4ae0c5bef18048fbecf995094b74bfb0f7391418d71ed394784373f41e4f3"), PublicKey(hex"026ac9fcd64fb1aa1c491fc490634dc33da41d4a17b554e0adf1b32fee88ee9f28")) val csvNodes = "0217eb8243c95f5a3b7d4c5682d10de354b7007eb59b6807ae407823963c7547a9, 0242a4ae0c5bef18048fbecf995094b74bfb0f7391418d71ed394784373f41e4f3, 026ac9fcd64fb1aa1c491fc490634dc33da41d4a17b554e0adf1b32fee88ee9f28" val jsonNodes = serialization.write(expectedRoute) val eclair = mock[Eclair] - eclair.sendToRoute(any[Option[UUID]], any[List[PublicKey]], any[MilliSatoshi], any[ByteVector32], any[CltvExpiryDelta])(any[Timeout]) returns Future.successful(paymentUUID) + eclair.sendToRoute(any[Option[String]], any[List[PublicKey]], any[MilliSatoshi], any[ByteVector32], any[CltvExpiryDelta])(any[Timeout]) returns Future.successful(paymentUUID) val mockService = new MockService(eclair) Post("/sendtoroute", FormData("route" -> jsonNodes, "amountMsat" -> "1234", "paymentHash" -> ByteVector32.Zeroes.toHex, "finalCltvExpiry" -> "190", "externalId" -> externalId.toString).toEntity) ~> From e7a545539e6a8260340a74ba225236cce7ae6ba0 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Wed, 18 Sep 2019 17:52:16 +0200 Subject: [PATCH 06/14] Re-work the PaymentsDb interface and the Incoming/Outgoing payment structures. Clarify use of seconds / milliseconds -> we use milliseconds everywhere except at the Eclair API level (probably because it's easier from bash to get a unix timestamp in seconds than in milliseconds). --- .../main/scala/fr/acinq/eclair/Eclair.scala | 16 +- .../scala/fr/acinq/eclair/db/PaymentsDb.scala | 132 ++++++--- .../eclair/db/sqlite/SqliteAuditDb.scala | 25 +- .../eclair/db/sqlite/SqlitePaymentsDb.scala | 211 ++++++------- .../main/scala/fr/acinq/eclair/package.scala | 7 +- .../eclair/payment/LocalPaymentHandler.scala | 14 +- .../eclair/payment/PaymentLifecycle.scala | 2 +- .../fr/acinq/eclair/EclairImplSpec.scala | 12 +- .../acinq/eclair/db/SqliteAuditDbSpec.scala | 28 +- .../eclair/db/SqlitePaymentsDbSpec.scala | 276 +++++++++--------- .../eclair/payment/PaymentHandlerSpec.scala | 28 +- .../eclair/payment/PaymentLifecycleSpec.scala | 40 +-- 12 files changed, 425 insertions(+), 366 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index 3deb328e54..c235c9fbec 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -46,9 +46,13 @@ case class AuditResponse(sent: Seq[PaymentSent], received: Seq[PaymentReceived], case class TimestampQueryFilters(from: Long, to: Long) object TimestampQueryFilters { + /** We use this in the context of timestamp filtering, when we don't need an upper bound. */ + val MaxEpochMilliseconds = Duration.fromNanos(Long.MaxValue).toMillis + def getDefaultTimestampFilters(from_opt: Option[Long], to_opt: Option[Long]) = { - val from = from_opt.getOrElse(0L) - val to = to_opt.getOrElse(MaxEpochSeconds) + // NB: we expect callers to use seconds, but internally we use milli-seconds everywhere. + val from = from_opt.getOrElse(0L).seconds.toMillis + val to = to_opt.map(_.seconds.toMillis).getOrElse(MaxEpochMilliseconds) TimestampQueryFilters(from, to) } @@ -224,7 +228,7 @@ class EclairImpl(appKit: Kit) extends Eclair { override def sentInfo(id: Either[UUID, ByteVector32])(implicit timeout: Timeout): Future[Seq[OutgoingPayment]] = Future { id match { case Left(uuid) => appKit.nodeParams.db.payments.getOutgoingPayment(uuid).toSeq - case Right(paymentHash) => appKit.nodeParams.db.payments.getOutgoingPayments(paymentHash) + case Right(paymentHash) => appKit.nodeParams.db.payments.listOutgoingPayments(paymentHash) } } @@ -253,17 +257,17 @@ class EclairImpl(appKit: Kit) extends Eclair { override def allInvoices(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[Seq[PaymentRequest]] = Future { val filter = getDefaultTimestampFilters(from_opt, to_opt) - appKit.nodeParams.db.payments.listPaymentRequests(filter.from, filter.to) + appKit.nodeParams.db.payments.listIncomingPayments(filter.from, filter.to).map(_.paymentRequest) } override def pendingInvoices(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[Seq[PaymentRequest]] = Future { val filter = getDefaultTimestampFilters(from_opt, to_opt) - appKit.nodeParams.db.payments.listPendingPaymentRequests(filter.from, filter.to) + appKit.nodeParams.db.payments.listPendingIncomingPayments(filter.from, filter.to).map(_.paymentRequest) } override def getInvoice(paymentHash: ByteVector32)(implicit timeout: Timeout): Future[Option[PaymentRequest]] = Future { - appKit.nodeParams.db.payments.getPaymentRequest(paymentHash) + appKit.nodeParams.db.payments.getIncomingPayment(paymentHash).map(_.paymentRequest) } /** diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala index af538c5d1b..e64124a8af 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala @@ -24,60 +24,88 @@ import fr.acinq.eclair.payment._ import fr.acinq.eclair.router.Hop import fr.acinq.eclair.{MilliSatoshi, ShortChannelId} +import scala.compat.Platform + trait PaymentsDb { /** Create a record for a non yet finalized outgoing payment. */ - def addOutgoingPayment(outgoingPayment: OutgoingPayment) + def addOutgoingPayment(outgoingPayment: OutgoingPayment): Unit /** Update the status of the payment in case of success. */ - def updateOutgoingPayment(paymentResult: PaymentSent) + def updateOutgoingPayment(paymentResult: PaymentSent): Unit /** Update the status of the payment in case of failure. */ - def updateOutgoingPayment(paymentResult: PaymentFailed) + def updateOutgoingPayment(paymentResult: PaymentFailed): Unit /** Get an outgoing payment attempt. */ def getOutgoingPayment(id: UUID): Option[OutgoingPayment] - /** Get all the outgoing payment attempts to pay the given paymentHash. */ - def getOutgoingPayments(paymentHash: ByteVector32): Seq[OutgoingPayment] + /** List all the outgoing payment attempts that tried to pay the given payment hash. */ + def listOutgoingPayments(paymentHash: ByteVector32): Seq[OutgoingPayment] - /** Get all the outgoing payment attempts in the given time range. */ + /** List all the outgoing payment attempts in the given time range (milli-seconds). */ def listOutgoingPayments(from: Long, to: Long): Seq[OutgoingPayment] - /** Add a new payment request (to receive a payment). */ - def addPaymentRequest(pr: PaymentRequest, preimage: ByteVector32) - - /** Get the payment request for the given payment hash, if any. */ - def getPaymentRequest(paymentHash: ByteVector32): Option[PaymentRequest] + /** Add a new expected incoming payment (not yet received). */ + def addIncomingPayment(pr: PaymentRequest, preimage: ByteVector32): Unit - /** Get the currently pending payment request for the given payment hash, if any. */ - def getPendingPaymentRequestAndPreimage(paymentHash: ByteVector32): Option[(ByteVector32, PaymentRequest)] + /** + * Mark an incoming payment as received (paid). The received amount may exceed the payment request amount. + * Note that this function assumes that there is a matching payment request in the DB. + */ + def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: Long = Platform.currentTime): Unit - /** Get all payment requests (pending, expired and fulfilled) in the given time range. */ - def listPaymentRequests(from: Long, to: Long): Seq[PaymentRequest] + /** Get information about the incoming payment (paid or not) for the given payment hash, if any. */ + def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] - /** Get pending, non expired payment requests in the given time range. */ - def listPendingPaymentRequests(from: Long, to: Long): Seq[PaymentRequest] + /** List all incoming payments (pending, expired and succeeded) in the given time range (milli-seconds). */ + def listIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] - /** Add a received payment (assumes there is already a payment request for the given payment hash). */ - def addIncomingPayment(payment: IncomingPayment) + /** List all pending (not paid, not expired) incoming payments in the given time range (milli-seconds). */ + def listPendingIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] - /** Get the received payment associated with a given payment hash, if any. */ - def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] + /** List all expired (not paid) incoming payments in the given time range (milli-seconds). */ + def listExpiredIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] - /** Get all payments received in the given time range. */ - def listIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] + /** List all received (paid) incoming payments in the given time range (milli-seconds). */ + def listReceivedIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] } /** - * Incoming payment object stored in DB. + * An incoming payment received by this node. + * At first it is in a pending state once the payment request has been generated, then will become either a success (if + * we receive a valid HTLC) or a failure (if the payment request expires). * - * @param paymentHash identifier of the payment. - * @param amount amount of the payment, in milli-satoshis. - * @param receivedAt absolute time in seconds since UNIX epoch when the payment was received. + * @param paymentRequest Bolt 11 payment request. + * @param paymentPreimage pre-image associated with the payment request's payment_hash. + * @param createdAt absolute time in milli-seconds since UNIX epoch when the payment request was generated. + * @param status current status of the payment. */ -case class IncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: Long) +case class IncomingPayment(paymentRequest: PaymentRequest, + paymentPreimage: ByteVector32, + createdAt: Long, + status: IncomingPaymentStatus) + +sealed trait IncomingPaymentStatus + +object IncomingPaymentStatus { + + /** Payment is pending (waiting to receive). */ + case object Pending extends IncomingPaymentStatus + + /** Payment has expired. */ + case object Expired extends IncomingPaymentStatus + + /** + * Payment has been successfully received. + * + * @param amount amount of the payment received, in milli-satoshis (may exceed the payment request amount). + * @param receivedAt absolute time in milli-seconds since UNIX epoch when the payment was received. + */ + case class Received(amount: MilliSatoshi, receivedAt: Long) extends IncomingPaymentStatus + +} /** * An outgoing payment sent by this node. @@ -89,12 +117,9 @@ case class IncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, rece * @param paymentHash payment_hash. * @param amount amount of the payment, in milli-satoshis. * @param targetNodeId node ID of the payment recipient. - * @param createdAt absolute time in seconds since UNIX epoch when the payment was created. - * @param status current status of the payment. + * @param createdAt absolute time in milli-seconds since UNIX epoch when the payment was created. * @param paymentRequest Bolt 11 payment request (if paying from an invoice). - * @param completedAt absolute time in seconds since UNIX epoch when the payment completed (success of failure). - * @param successSummary summary of the payment success (if status == "SUCCEEDED"). - * @param failureSummary summary of the payment failure (if status == "FAILED"). + * @param status current status of the payment. */ case class OutgoingPayment(id: UUID, parentId: Option[UUID], @@ -103,21 +128,36 @@ case class OutgoingPayment(id: UUID, amount: MilliSatoshi, targetNodeId: PublicKey, createdAt: Long, - status: OutgoingPaymentStatus.Value, paymentRequest: Option[PaymentRequest], - completedAt: Option[Long] = None, - successSummary: Option[PaymentSuccessSummary] = None, - failureSummary: Option[PaymentFailureSummary] = None) - -object OutgoingPaymentStatus extends Enumeration { - val PENDING = Value(1, "PENDING") - val SUCCEEDED = Value(2, "SUCCEEDED") - val FAILED = Value(3, "FAILED") -} + status: OutgoingPaymentStatus) + +sealed trait OutgoingPaymentStatus + +object OutgoingPaymentStatus { + + /** Payment is pending (waiting for the recipient to release the pre-image). */ + case object Pending extends OutgoingPaymentStatus + + /** + * Payment has been successfully sent and the recipient released the pre-image. + * We now have a valid proof-of-payment. + * + * @param paymentPreimage the preimage of the payment_hash. + * @param feesPaid total amount of fees paid to intermediate routing nodes. + * @param route payment route. + * @param completedAt absolute time in milli-seconds since UNIX epoch when the payment was completed. + */ + case class Succeeded(paymentPreimage: ByteVector32, feesPaid: MilliSatoshi, route: Seq[HopSummary], completedAt: Long) extends OutgoingPaymentStatus + + /** + * Payment has failed and may be retried. + * + * @param failures failed payment attempts. + * @param completedAt absolute time in milli-seconds since UNIX epoch when the payment was completed. + */ + case class Failed(failures: Seq[FailureSummary], completedAt: Long) extends OutgoingPaymentStatus -case class PaymentSuccessSummary(paymentPreimage: ByteVector32, feesPaid: MilliSatoshi, route: Seq[HopSummary]) - -case class PaymentFailureSummary(failures: Seq[FailureSummary]) +} /** A minimal representation of a hop in a payment route (suitable to store in a database). */ case class HopSummary(nodeId: PublicKey, nextNodeId: PublicKey, shortChannelId: Option[ShortChannelId] = None) { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala index f0342e953a..31b6e3d296 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala @@ -30,7 +30,6 @@ import grizzled.slf4j.Logging import scala.collection.immutable.Queue import scala.compat.Platform -import scala.concurrent.duration._ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { @@ -185,9 +184,9 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { } override def listSent(from: Long, to: Long): Seq[PaymentSent] = - using(sqlite.prepareStatement("SELECT * FROM sent WHERE timestamp >= ? AND timestamp < ?")) { statement => - statement.setLong(1, from.seconds.toMillis) - statement.setLong(2, to.seconds.toMillis) + using(sqlite.prepareStatement("SELECT * FROM sent WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement => + statement.setLong(1, from) + statement.setLong(2, to) val rs = statement.executeQuery() var q: Queue[PaymentSent] = Queue() while (rs.next()) { @@ -204,9 +203,9 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { } override def listReceived(from: Long, to: Long): Seq[PaymentReceived] = - using(sqlite.prepareStatement("SELECT * FROM received WHERE timestamp >= ? AND timestamp < ?")) { statement => - statement.setLong(1, from.seconds.toMillis) - statement.setLong(2, to.seconds.toMillis) + using(sqlite.prepareStatement("SELECT * FROM received WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement => + statement.setLong(1, from) + statement.setLong(2, to) val rs = statement.executeQuery() var q: Queue[PaymentReceived] = Queue() while (rs.next()) { @@ -219,9 +218,9 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { } override def listRelayed(from: Long, to: Long): Seq[PaymentRelayed] = - using(sqlite.prepareStatement("SELECT * FROM relayed WHERE timestamp >= ? AND timestamp < ?")) { statement => - statement.setLong(1, from.seconds.toMillis) - statement.setLong(2, to.seconds.toMillis) + using(sqlite.prepareStatement("SELECT * FROM relayed WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement => + statement.setLong(1, from) + statement.setLong(2, to) val rs = statement.executeQuery() var q: Queue[PaymentRelayed] = Queue() while (rs.next()) { @@ -237,9 +236,9 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { } override def listNetworkFees(from: Long, to: Long): Seq[NetworkFee] = - using(sqlite.prepareStatement("SELECT * FROM network_fees WHERE timestamp >= ? AND timestamp < ?")) { statement => - statement.setLong(1, from.seconds.toMillis) - statement.setLong(2, to.seconds.toMillis) + using(sqlite.prepareStatement("SELECT * FROM network_fees WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement => + statement.setLong(1, from) + statement.setLong(2, to) val rs = statement.executeQuery() var q: Queue[NetworkFee] = Queue() while (rs.next()) { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala index 17f308b3df..20334ad4e4 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala @@ -58,20 +58,27 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { } def migration23(statement: Statement) = { - // Nothing changes in the received_payments table, but the sent_payments table changes a lot. + // We add many more columns to the sent_payments table. statement.executeUpdate("DROP index payment_hash_idx") statement.executeUpdate("ALTER TABLE sent_payments RENAME TO _sent_payments_old") - statement.executeUpdate("CREATE TABLE sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT, external_id TEXT, payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, target_node_id BLOB NOT NULL, created_at INTEGER NOT NULL, status VARCHAR NOT NULL, payment_request TEXT, completed_at INTEGER, payment_preimage BLOB, fees_msat INTEGER, payment_route BLOB, failures BLOB)") + statement.executeUpdate("CREATE TABLE sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT, external_id TEXT, payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, target_node_id BLOB NOT NULL, created_at INTEGER NOT NULL, payment_request TEXT, completed_at INTEGER, payment_preimage BLOB, fees_msat INTEGER, payment_route BLOB, failures BLOB)") // Old rows will be missing a target node id, so we use an easy-to-spot default value. val defaultTargetNodeId = PrivateKey(ByteVector32.One).publicKey - statement.executeUpdate(s"INSERT INTO sent_payments (id, payment_hash, amount_msat, target_node_id, created_at, status, completed_at, payment_preimage) SELECT id, payment_hash, amount_msat, X'${defaultTargetNodeId.toString}', created_at, status, completed_at, preimage FROM _sent_payments_old") + statement.executeUpdate(s"INSERT INTO sent_payments (id, payment_hash, amount_msat, target_node_id, created_at, completed_at, payment_preimage) SELECT id, payment_hash, amount_msat, X'${defaultTargetNodeId.toString}', created_at, completed_at, preimage FROM _sent_payments_old") statement.executeUpdate("DROP table _sent_payments_old") + statement.executeUpdate("ALTER TABLE received_payments RENAME TO _received_payments_old") + // We make payment request expiration not null in the received_payments table. + // When it was previously set to NULL the default expiry should apply. + statement.executeUpdate(s"UPDATE _received_payments_old SET expire_at = created_at + ${PaymentRequest.DEFAULT_EXPIRY_SECONDS} WHERE expire_at IS NULL") + statement.executeUpdate("CREATE TABLE received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, payment_preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER NOT NULL, received_at INTEGER)") + statement.executeUpdate("INSERT INTO received_payments (payment_hash, payment_preimage, payment_request, received_msat, created_at, expire_at, received_at) SELECT payment_hash, preimage, payment_request, received_msat, created_at, expire_at, received_at FROM _received_payments_old") + statement.executeUpdate("DROP table _received_payments_old") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_parent_id_idx ON sent_payments(parent_id)") statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_payment_hash_idx ON sent_payments(payment_hash)") statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_created_idx ON sent_payments(created_at)") statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_created_idx ON received_payments(created_at)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received_payments(received_at)") } getVersion(statement, DB_NAME, CURRENT_VERSION) match { @@ -85,21 +92,21 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { migration23(statement) setVersion(statement, DB_NAME, CURRENT_VERSION) case CURRENT_VERSION => - statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER, received_at INTEGER)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT, external_id TEXT, payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, target_node_id BLOB NOT NULL, created_at INTEGER NOT NULL, status VARCHAR NOT NULL, payment_request TEXT, completed_at INTEGER, payment_preimage BLOB, fees_msat INTEGER, payment_route BLOB, failures BLOB)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, payment_preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER NOT NULL, received_at INTEGER)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT, external_id TEXT, payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, target_node_id BLOB NOT NULL, created_at INTEGER NOT NULL, payment_request TEXT, completed_at INTEGER, payment_preimage BLOB, fees_msat INTEGER, payment_route BLOB, failures BLOB)") statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_parent_id_idx ON sent_payments(parent_id)") statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_payment_hash_idx ON sent_payments(payment_hash)") statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_created_idx ON sent_payments(created_at)") statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_created_idx ON received_payments(created_at)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received_payments(received_at)") case unknownVersion => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } } - override def addOutgoingPayment(sent: OutgoingPayment): Unit = - using(sqlite.prepareStatement("INSERT INTO sent_payments (id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, status, payment_request) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement => + override def addOutgoingPayment(sent: OutgoingPayment): Unit = { + require(sent.status == OutgoingPaymentStatus.Pending, s"outgoing payment isn't pending (${sent.status.getClass.getSimpleName})") + using(sqlite.prepareStatement("INSERT INTO sent_payments (id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, payment_request) VALUES (?, ?, ?, ?, ?, ?, ?, ?)")) { statement => statement.setString(1, sent.id.toString) statement.setString(2, sent.parentId.map(_.toString).orNull) statement.setString(3, sent.externalId.orNull) @@ -107,28 +114,26 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { statement.setLong(5, sent.amount.toLong) statement.setBytes(6, sent.targetNodeId.value.toArray) statement.setLong(7, sent.createdAt) - statement.setString(8, sent.status.toString) - statement.setString(9, sent.paymentRequest.map(PaymentRequest.write).orNull) + statement.setString(8, sent.paymentRequest.map(PaymentRequest.write).orNull) statement.executeUpdate() } + } override def updateOutgoingPayment(paymentResult: PaymentSent): Unit = - using(sqlite.prepareStatement("UPDATE sent_payments SET (completed_at, status, payment_preimage, fees_msat, payment_route) = (?, ?, ?, ?, ?) WHERE id = ? AND completed_at IS NULL")) { statement => + using(sqlite.prepareStatement("UPDATE sent_payments SET (completed_at, payment_preimage, fees_msat, payment_route) = (?, ?, ?, ?) WHERE id = ? AND completed_at IS NULL")) { statement => statement.setLong(1, paymentResult.timestamp) - statement.setString(2, OutgoingPaymentStatus.SUCCEEDED.toString) - statement.setBytes(3, paymentResult.paymentPreimage.toArray) - statement.setLong(4, paymentResult.feesPaid.toLong) - statement.setBytes(5, paymentRouteCodec.encode(paymentResult.route.map(h => HopSummary(h)).toList).require.toByteArray) - statement.setString(6, paymentResult.id.toString) + statement.setBytes(2, paymentResult.paymentPreimage.toArray) + statement.setLong(3, paymentResult.feesPaid.toLong) + statement.setBytes(4, paymentRouteCodec.encode(paymentResult.route.map(h => HopSummary(h)).toList).require.toByteArray) + statement.setString(5, paymentResult.id.toString) if (statement.executeUpdate() == 0) throw new IllegalArgumentException(s"Tried to mark an outgoing payment as succeeded but already in final status (id=${paymentResult.id})") } override def updateOutgoingPayment(paymentResult: PaymentFailed): Unit = - using(sqlite.prepareStatement("UPDATE sent_payments SET (completed_at, status, failures) = (?, ?, ?) WHERE id = ? AND completed_at IS NULL")) { statement => + using(sqlite.prepareStatement("UPDATE sent_payments SET (completed_at, failures) = (?, ?) WHERE id = ? AND completed_at IS NULL")) { statement => statement.setLong(1, paymentResult.timestamp) - statement.setString(2, OutgoingPaymentStatus.FAILED.toString) - statement.setBytes(3, paymentFailuresCodec.encode(paymentResult.failures.map(f => FailureSummary(f)).toList).require.toByteArray) - statement.setString(4, paymentResult.id.toString) + statement.setBytes(2, paymentFailuresCodec.encode(paymentResult.failures.map(f => FailureSummary(f)).toList).require.toByteArray) + statement.setString(3, paymentResult.id.toString) if (statement.executeUpdate() == 0) throw new IllegalArgumentException(s"Tried to mark an outgoing payment as failed but already in final status (id=${paymentResult.id})") } @@ -141,31 +146,37 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { MilliSatoshi(rs.getLong("amount_msat")), PublicKey(rs.getByteVector("target_node_id")), rs.getLong("created_at"), - OutgoingPaymentStatus.withName(rs.getString("status")), rs.getStringNullable("payment_request").map(PaymentRequest.read), - getNullableLong(rs, "completed_at") + OutgoingPaymentStatus.Pending ) - result.status match { - case OutgoingPaymentStatus.SUCCEEDED => result.copy(successSummary = Some(PaymentSuccessSummary( - rs.getByteVector32("payment_preimage"), + // If we have a pre-image, the payment succeeded. + rs.getByteVector32Nullable("payment_preimage") match { + case Some(paymentPreimage) => result.copy(status = OutgoingPaymentStatus.Succeeded( + paymentPreimage, MilliSatoshi(rs.getLong("fees_msat")), rs.getBitVectorOpt("payment_route").map(b => paymentRouteCodec.decode(b) match { case Attempt.Successful(route) => route.value case Attempt.Failure(_) => Nil - }).getOrElse(Nil) - ))) - case OutgoingPaymentStatus.FAILED => result.copy(failureSummary = Some(PaymentFailureSummary( - rs.getBitVectorOpt("failures").map(b => paymentFailuresCodec.decode(b) match { - case Attempt.Successful(failures) => failures.value - case Attempt.Failure(_) => Nil - }).getOrElse(Nil) - ))) - case _ => result + }).getOrElse(Nil), + rs.getLong("completed_at") + )) + case None => getNullableLong(rs, "completed_at") match { + // Otherwise if the payment was marked completed, it's a failure. + case Some(completedAt) => result.copy(status = OutgoingPaymentStatus.Failed( + rs.getBitVectorOpt("failures").map(b => paymentFailuresCodec.decode(b) match { + case Attempt.Successful(failures) => failures.value + case Attempt.Failure(_) => Nil + }).getOrElse(Nil), + completedAt + )) + // Else it's still pending. + case _ => result + } } } override def getOutgoingPayment(id: UUID): Option[OutgoingPayment] = - using(sqlite.prepareStatement("SELECT id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, status, payment_request, completed_at, payment_preimage, fees_msat, payment_route, failures FROM sent_payments WHERE id = ?")) { statement => + using(sqlite.prepareStatement("SELECT id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, payment_request, completed_at, payment_preimage, fees_msat, payment_route, failures FROM sent_payments WHERE id = ?")) { statement => statement.setString(1, id.toString) val rs = statement.executeQuery() if (rs.next()) { @@ -175,8 +186,8 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { } } - override def getOutgoingPayments(paymentHash: ByteVector32): Seq[OutgoingPayment] = - using(sqlite.prepareStatement("SELECT id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, status, payment_request, completed_at, payment_preimage, fees_msat, payment_route, failures FROM sent_payments WHERE payment_hash = ?")) { statement => + override def listOutgoingPayments(paymentHash: ByteVector32): Seq[OutgoingPayment] = + using(sqlite.prepareStatement("SELECT id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, payment_request, completed_at, payment_preimage, fees_msat, payment_route, failures FROM sent_payments WHERE payment_hash = ? ORDER BY created_at")) { statement => statement.setBytes(1, paymentHash.toArray) val rs = statement.executeQuery() var q: Queue[OutgoingPayment] = Queue() @@ -187,9 +198,9 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { } override def listOutgoingPayments(from: Long, to: Long): Seq[OutgoingPayment] = - using(sqlite.prepareStatement("SELECT id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, status, payment_request, completed_at, payment_preimage, fees_msat, payment_route, failures FROM sent_payments WHERE created_at >= ? AND created_at < ?")) { statement => - statement.setLong(1, from.seconds.toMillis) - statement.setLong(2, to.seconds.toMillis) + using(sqlite.prepareStatement("SELECT id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, payment_request, completed_at, payment_preimage, fees_msat, payment_route, failures FROM sent_payments WHERE created_at >= ? AND created_at < ? ORDER BY created_at")) { statement => + statement.setLong(1, from) + statement.setLong(2, to) val rs = statement.executeQuery() var q: Queue[OutgoingPayment] = Queue() while (rs.next()) { @@ -198,102 +209,94 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { q } - override def addPaymentRequest(pr: PaymentRequest, preimage: ByteVector32): Unit = { - val insertStmt = pr.expiry match { - case Some(_) => "INSERT INTO received_payments (payment_hash, preimage, payment_request, created_at, expire_at) VALUES (?, ?, ?, ?, ?)" - case None => "INSERT INTO received_payments (payment_hash, preimage, payment_request, created_at) VALUES (?, ?, ?, ?)" - } - - using(sqlite.prepareStatement(insertStmt)) { statement => + override def addIncomingPayment(pr: PaymentRequest, preimage: ByteVector32): Unit = + using(sqlite.prepareStatement("INSERT INTO received_payments (payment_hash, payment_preimage, payment_request, created_at, expire_at) VALUES (?, ?, ?, ?, ?)")) { statement => statement.setBytes(1, pr.paymentHash.toArray) statement.setBytes(2, preimage.toArray) statement.setString(3, PaymentRequest.write(pr)) statement.setLong(4, pr.timestamp.seconds.toMillis) // BOLT11 timestamp is in seconds - pr.expiry.foreach { ex => statement.setLong(5, pr.timestamp.seconds.toMillis + ex.seconds.toMillis) } // we store "when" the invoice will expire, in milliseconds + statement.setLong(5, (pr.timestamp + pr.expiry.getOrElse(PaymentRequest.DEFAULT_EXPIRY_SECONDS.toLong)).seconds.toMillis) statement.executeUpdate() } - } - override def getPaymentRequest(paymentHash: ByteVector32): Option[PaymentRequest] = { - using(sqlite.prepareStatement("SELECT payment_request FROM received_payments WHERE payment_hash = ?")) { statement => - statement.setBytes(1, paymentHash.toArray) - val rs = statement.executeQuery() - if (rs.next()) { - Some(PaymentRequest.read(rs.getString("payment_request"))) - } else { - None - } + override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: Long): Unit = + using(sqlite.prepareStatement("UPDATE received_payments SET (received_msat, received_at) = (?, ?) WHERE payment_hash = ?")) { statement => + statement.setLong(1, amount.toLong) + statement.setLong(2, receivedAt) + statement.setBytes(3, paymentHash.toArray) + val res = statement.executeUpdate() + if (res == 0) throw new IllegalArgumentException("Inserted a received payment without having an invoice") + } + + private def parseIncomingPayment(rs: ResultSet): IncomingPayment = { + val paymentRequest = PaymentRequest.read(rs.getString("payment_request")) + val paymentPreimage = rs.getByteVector32("payment_preimage") + val createdAt = rs.getLong("created_at") + val received = getNullableLong(rs, "received_msat").map(MilliSatoshi(_)) + received match { + case Some(amount) => IncomingPayment(paymentRequest, paymentPreimage, createdAt, IncomingPaymentStatus.Received(amount, rs.getLong("received_at"))) + case None if paymentRequest.isExpired => IncomingPayment(paymentRequest, paymentPreimage, createdAt, IncomingPaymentStatus.Expired) + case None => IncomingPayment(paymentRequest, paymentPreimage, createdAt, IncomingPaymentStatus.Pending) } } - override def getPendingPaymentRequestAndPreimage(paymentHash: ByteVector32): Option[(ByteVector32, PaymentRequest)] = { - using(sqlite.prepareStatement("SELECT payment_request, preimage FROM received_payments WHERE payment_hash = ? AND received_at IS NULL")) { statement => + override def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] = + using(sqlite.prepareStatement("SELECT payment_hash, payment_preimage, payment_request, received_msat, created_at, received_at FROM received_payments WHERE payment_hash = ?")) { statement => statement.setBytes(1, paymentHash.toArray) val rs = statement.executeQuery() if (rs.next()) { - val preimage = rs.getByteVector32("preimage") - val pr = PaymentRequest.read(rs.getString("payment_request")) - Some(preimage, pr) + Some(parseIncomingPayment(rs)) } else { None } } - } - - override def listPaymentRequests(from: Long, to: Long): Seq[PaymentRequest] = listPaymentRequests(from, to, pendingOnly = false) - - override def listPendingPaymentRequests(from: Long, to: Long): Seq[PaymentRequest] = listPaymentRequests(from, to, pendingOnly = true) - - def listPaymentRequests(from: Long, to: Long, pendingOnly: Boolean): Seq[PaymentRequest] = { - val queryStmt = pendingOnly match { - case true => "SELECT payment_request FROM received_payments WHERE created_at > ? AND created_at < ? AND (expire_at > ? OR expire_at IS NULL) AND received_msat IS NULL ORDER BY created_at DESC" - case false => "SELECT payment_request FROM received_payments WHERE created_at > ? AND created_at < ? ORDER BY created_at DESC" - } - - using(sqlite.prepareStatement(queryStmt)) { statement => - statement.setLong(1, from.seconds.toMillis) - statement.setLong(2, to.seconds.toMillis) - if (pendingOnly) statement.setLong(3, Platform.currentTime) + override def listIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = + using(sqlite.prepareStatement("SELECT payment_hash, payment_preimage, payment_request, received_msat, created_at, received_at FROM received_payments WHERE created_at > ? AND created_at < ? ORDER BY created_at")) { statement => + statement.setLong(1, from) + statement.setLong(2, to) val rs = statement.executeQuery() - var q: Queue[PaymentRequest] = Queue() + var q: Queue[IncomingPayment] = Queue() while (rs.next()) { - q = q :+ PaymentRequest.read(rs.getString("payment_request")) + q = q :+ parseIncomingPayment(rs) } q } - } - override def addIncomingPayment(payment: IncomingPayment): Unit = { - using(sqlite.prepareStatement("UPDATE received_payments SET (received_msat, received_at) = (?, ?) WHERE payment_hash = ?")) { statement => - statement.setLong(1, payment.amount.toLong) - statement.setLong(2, payment.receivedAt) - statement.setBytes(3, payment.paymentHash.toArray) - val res = statement.executeUpdate() - if (res == 0) throw new IllegalArgumentException("Inserted a received payment without having an invoice") + override def listReceivedIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = + using(sqlite.prepareStatement("SELECT payment_hash, payment_preimage, payment_request, received_msat, created_at, received_at FROM received_payments WHERE received_msat > 0 AND created_at > ? AND created_at < ? ORDER BY created_at")) { statement => + statement.setLong(1, from) + statement.setLong(2, to) + val rs = statement.executeQuery() + var q: Queue[IncomingPayment] = Queue() + while (rs.next()) { + q = q :+ parseIncomingPayment(rs) + } + q } - } - override def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] = { - using(sqlite.prepareStatement("SELECT payment_hash, received_msat, received_at FROM received_payments WHERE payment_hash = ? AND received_msat > 0")) { statement => - statement.setBytes(1, paymentHash.toArray) + override def listPendingIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = + using(sqlite.prepareStatement("SELECT payment_hash, payment_preimage, payment_request, received_msat, created_at, received_at FROM received_payments WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at > ? ORDER BY created_at")) { statement => + statement.setLong(1, from) + statement.setLong(2, to) + statement.setLong(3, Platform.currentTime) val rs = statement.executeQuery() - if (rs.next()) { - Some(IncomingPayment(rs.getByteVector32("payment_hash"), MilliSatoshi(rs.getLong("received_msat")), rs.getLong("received_at"))) - } else { - None + var q: Queue[IncomingPayment] = Queue() + while (rs.next()) { + q = q :+ parseIncomingPayment(rs) } + q } - } - override def listIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = - using(sqlite.prepareStatement("SELECT payment_hash, received_msat, received_at FROM received_payments WHERE received_msat > 0 AND received_at >= ? AND received_at < ?")) { statement => - statement.setLong(1, from.seconds.toMillis) - statement.setLong(2, to.seconds.toMillis) + override def listExpiredIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = + using(sqlite.prepareStatement("SELECT payment_hash, payment_preimage, payment_request, received_msat, created_at, received_at FROM received_payments WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at < ? ORDER BY created_at")) { statement => + statement.setLong(1, from) + statement.setLong(2, to) + statement.setLong(3, Platform.currentTime) val rs = statement.executeQuery() var q: Queue[IncomingPayment] = Queue() while (rs.next()) { - q = q :+ IncomingPayment(rs.getByteVector32("payment_hash"), MilliSatoshi(rs.getLong("received_msat")), rs.getLong("received_at")) + q = q :+ parseIncomingPayment(rs) } q } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/package.scala b/eclair-core/src/main/scala/fr/acinq/eclair/package.scala index 003f7764fb..f14482fccf 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/package.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/package.scala @@ -22,7 +22,7 @@ import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin._ import scodec.Attempt import scodec.bits.{BitVector, ByteVector} -import scala.concurrent.duration.Duration + import scala.util.{Failure, Success, Try} package object eclair { @@ -152,11 +152,6 @@ package object eclair { } } - /** - * We use this in the context of timestamp filtering, when we don't need an upper bound. - */ - val MaxEpochSeconds = Duration.fromNanos(Long.MaxValue).toSeconds - implicit class LongToBtcAmount(l: Long) { // @formatter:off def msat: MilliSatoshi = MilliSatoshi(l) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/LocalPaymentHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/LocalPaymentHandler.scala index fab34c5a21..f21705ca7f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/LocalPaymentHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/LocalPaymentHandler.scala @@ -19,12 +19,11 @@ package fr.acinq.eclair.payment import akka.actor.{Actor, ActorLogging, Props, Status} import fr.acinq.bitcoin.Crypto import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Channel} -import fr.acinq.eclair.db.IncomingPayment +import fr.acinq.eclair.db.{IncomingPayment, IncomingPaymentStatus} import fr.acinq.eclair.payment.PaymentLifecycle.ReceivePayment import fr.acinq.eclair.wire._ import fr.acinq.eclair.{NodeParams, randomBytes32} -import scala.compat.Platform import scala.concurrent.ExecutionContext import scala.util.{Failure, Success, Try} @@ -49,7 +48,7 @@ class LocalPaymentHandler(nodeParams: NodeParams) extends Actor with ActorLoggin val expirySeconds = expirySeconds_opt.getOrElse(nodeParams.paymentRequestExpiry.toSeconds) val paymentRequest = PaymentRequest(nodeParams.chainHash, amount_opt, paymentHash, nodeParams.privateKey, desc, fallbackAddress_opt, expirySeconds = Some(expirySeconds), extraHops = extraHops) log.debug(s"generated payment request={} from amount={}", PaymentRequest.write(paymentRequest), amount_opt) - paymentDb.addPaymentRequest(paymentRequest, paymentPreimage) + paymentDb.addIncomingPayment(paymentRequest, paymentPreimage) paymentRequest } match { case Success(paymentRequest) => sender ! paymentRequest @@ -57,8 +56,11 @@ class LocalPaymentHandler(nodeParams: NodeParams) extends Actor with ActorLoggin } case htlc: UpdateAddHtlc => - paymentDb.getPendingPaymentRequestAndPreimage(htlc.paymentHash) match { - case Some((paymentPreimage, paymentRequest)) => + paymentDb.getIncomingPayment(htlc.paymentHash) match { + case Some(IncomingPayment(_, _, _, status)) if status.isInstanceOf[IncomingPaymentStatus.Received] => + log.warning(s"ignoring incoming payment for paymentHash=${htlc.paymentHash} which has already been paid") + sender ! CMD_FAIL_HTLC(htlc.id, Right(IncorrectOrUnknownPaymentDetails(htlc.amountMsat, nodeParams.currentBlockHeight)), commit = true) + case Some(IncomingPayment(paymentRequest, paymentPreimage, _, _)) => val minFinalExpiry = paymentRequest.minFinalCltvExpiryDelta.getOrElse(Channel.MIN_CLTV_EXPIRY_DELTA).toCltvExpiry(nodeParams.currentBlockHeight) // The htlc amount must be equal or greater than the requested amount. A slight overpaying is permitted, however // it must not be greater than two times the requested amount. @@ -77,7 +79,7 @@ class LocalPaymentHandler(nodeParams: NodeParams) extends Actor with ActorLoggin case _ => log.info(s"received payment for paymentHash=${htlc.paymentHash} amount=${htlc.amountMsat}") // amount is correct or was not specified in the payment request - nodeParams.db.payments.addIncomingPayment(IncomingPayment(htlc.paymentHash, htlc.amountMsat, Platform.currentTime)) + nodeParams.db.payments.receiveIncomingPayment(htlc.paymentHash, htlc.amountMsat) sender ! CMD_FULFILL_HTLC(htlc.id, paymentPreimage, commit = true) context.system.eventStream.publish(PaymentReceived(htlc.amountMsat, htlc.paymentHash)) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala index 20e605dd34..c8c52529bd 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala @@ -200,7 +200,7 @@ object PaymentLifecycle { case class DefaultPaymentProgressHandler(id: UUID, r: SendPaymentRequest, db: PaymentsDb) extends PaymentProgressHandler { override def onSend(): Unit = { - db.addOutgoingPayment(OutgoingPayment(id, None, r.externalId, r.paymentHash, r.amount, r.targetNodeId, Platform.currentTime, OutgoingPaymentStatus.PENDING, r.paymentRequest)) + db.addOutgoingPayment(OutgoingPayment(id, None, r.externalId, r.paymentHash, r.amount, r.targetNodeId, Platform.currentTime, r.paymentRequest, OutgoingPaymentStatus.Pending)) } override def onSuccess(sender: ActorRef, result: PaymentSent)(ctx: ActorContext): Unit = { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala index 7a78f791ac..f6e604ad23 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala @@ -237,7 +237,7 @@ class EclairImplSpec extends TestKit(ActorSystem("test")) with fixture.FunSuiteL auditDb.listSent(anyLong, anyLong) returns Seq.empty auditDb.listReceived(anyLong, anyLong) returns Seq.empty auditDb.listRelayed(anyLong, anyLong) returns Seq.empty - paymentDb.listPaymentRequests(anyLong, anyLong) returns Seq.empty + paymentDb.listIncomingPayments(anyLong, anyLong) returns Seq.empty val databases = mock[Databases] databases.audit returns auditDb @@ -247,15 +247,15 @@ class EclairImplSpec extends TestKit(ActorSystem("test")) with fixture.FunSuiteL val eclair = new EclairImpl(kitWithMockAudit) Await.result(eclair.networkFees(None, None), 10 seconds) - auditDb.listNetworkFees(0, MaxEpochSeconds).wasCalled(once) // assert the call was made only once and with the specified params + auditDb.listNetworkFees(0, TimestampQueryFilters.MaxEpochMilliseconds).wasCalled(once) // assert the call was made only once and with the specified params Await.result(eclair.audit(None, None), 10 seconds) - auditDb.listRelayed(0, MaxEpochSeconds).wasCalled(once) - auditDb.listReceived(0, MaxEpochSeconds).wasCalled(once) - auditDb.listSent(0, MaxEpochSeconds).wasCalled(once) + auditDb.listRelayed(0, TimestampQueryFilters.MaxEpochMilliseconds).wasCalled(once) + auditDb.listReceived(0, TimestampQueryFilters.MaxEpochMilliseconds).wasCalled(once) + auditDb.listSent(0, TimestampQueryFilters.MaxEpochMilliseconds).wasCalled(once) Await.result(eclair.allInvoices(None, None), 10 seconds) - paymentDb.listPaymentRequests(0, MaxEpochSeconds).wasCalled(once) // assert the call was made only once and with the specified params + paymentDb.listIncomingPayments(0, TimestampQueryFilters.MaxEpochMilliseconds).wasCalled(once) // assert the call was made only once and with the specified params } test("sendtoroute should pass the parameters correctly") { f => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala index 898783e0b1..ee77c29bff 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala @@ -69,12 +69,12 @@ class SqliteAuditDbSpec extends FunSuite { db.add(e9) db.add(e10) - assert(db.listSent(from = 0L, to = (Platform.currentTime.milliseconds + 15.minute).toSeconds).toSet === Set(e1.copy(route = Nil), e5.copy(route = Nil), e6)) - assert(db.listSent(from = 100000L, to = (Platform.currentTime.milliseconds + 1.minute).toSeconds).toList === List(e1.copy(route = Nil))) - assert(db.listReceived(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toSeconds).toList === List(e2)) - assert(db.listRelayed(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toSeconds).toList === List(e3)) - assert(db.listNetworkFees(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toSeconds).size === 1) - assert(db.listNetworkFees(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toSeconds).head.txType === "mutual") + assert(db.listSent(from = 0L, to = (Platform.currentTime.milliseconds + 15.minute).toMillis).toSet === Set(e1.copy(route = Nil), e5.copy(route = Nil), e6)) + assert(db.listSent(from = 100000L, to = (Platform.currentTime.milliseconds + 1.minute).toMillis).toList === List(e1.copy(route = Nil))) + assert(db.listReceived(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toMillis).toList === List(e2)) + assert(db.listRelayed(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toMillis).toList === List(e3)) + assert(db.listNetworkFees(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toMillis).size === 1) + assert(db.listNetworkFees(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toMillis).head.txType === "mutual") } test("stats") { @@ -170,9 +170,9 @@ class SqliteAuditDbSpec extends FunSuite { } // existing rows in the 'sent' table will use id=00000000-0000-0000-0000-000000000000 as default - assert(migratedDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) === Seq(ps.copy(id = ChannelCodecs.UNKNOWN_UUID, route = Nil))) + assert(migratedDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ChannelCodecs.UNKNOWN_UUID, route = Nil))) // existing rows in the 'received' table will not contain a fromChannelId anymore - assert(migratedDb.listReceived(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) === Seq(pr)) + assert(migratedDb.listReceived(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq(pr)) val postMigrationDb = new SqliteAuditDb(connection) @@ -187,8 +187,8 @@ class SqliteAuditDbSpec extends FunSuite { postMigrationDb.add(pr1) // the old 'sent' record will have the UNKNOWN_UUID and an empty route but the new ones will have their actual id - assert(postMigrationDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) === Seq(ps.copy(id = ChannelCodecs.UNKNOWN_UUID, route = Nil), ps1.copy(route = Nil), ps2.copy(route = Nil))) - assert(postMigrationDb.listReceived(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) === Seq(pr, pr1)) + assert(postMigrationDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ChannelCodecs.UNKNOWN_UUID, route = Nil), ps1.copy(route = Nil), ps2.copy(route = Nil))) + assert(postMigrationDb.listReceived(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq(pr, pr1)) } test("handle migration version 2 -> 4") { @@ -300,9 +300,9 @@ class SqliteAuditDbSpec extends FunSuite { } // existing rows in the 'sent' table will use route=NULL as default - assert(migratedDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) === Seq(ps.copy(route = Nil))) + assert(migratedDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq(ps.copy(route = Nil))) // existing rows in the 'received' table will not contain a fromChannelId anymore - assert(migratedDb.listReceived(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) === Seq(pr)) + assert(migratedDb.listReceived(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq(pr)) val postMigrationDb = new SqliteAuditDb(connection) @@ -315,8 +315,8 @@ class SqliteAuditDbSpec extends FunSuite { postMigrationDb.add(pr1) // the old 'sent' record will have the UNKNOWN_UUID and an empty route but the new ones will have their actual id - assert(postMigrationDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) === Seq(ps.copy(route = Nil), ps1.copy(route = Nil), ps2.copy(route = Nil))) - assert(postMigrationDb.listReceived(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) === Seq(pr, pr1)) + assert(postMigrationDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq(ps.copy(route = Nil), ps1.copy(route = Nil), ps2.copy(route = Nil))) + assert(postMigrationDb.listReceived(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq(pr, pr1)) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala index c391b86e68..37d92e3724 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala @@ -21,7 +21,6 @@ import java.util.UUID import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin.{Block, ByteVector32, Crypto} import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.db.OutgoingPaymentStatus._ import fr.acinq.eclair.db.sqlite.SqlitePaymentsDb import fr.acinq.eclair.db.sqlite.SqliteUtils._ import fr.acinq.eclair.payment._ @@ -55,15 +54,13 @@ class SqlitePaymentsDbSpec extends FunSuite { assert(getVersion(statement, "payments", 1) == 1) // version 1 is deployed now } - val oldReceivedPayment = IncomingPayment(randomBytes32, 123 msat, 1233322) - // Changes between version 1 and 2: // - the monolithic payments table has been replaced by two tables, received_payments and sent_payments // - old records from the payments table are ignored (not migrated to the new tables) using(connection.prepareStatement("INSERT INTO payments VALUES (?, ?, ?)")) { statement => - statement.setBytes(1, oldReceivedPayment.paymentHash.toArray) - statement.setLong(2, oldReceivedPayment.amount.toLong) - statement.setLong(3, oldReceivedPayment.receivedAt) + statement.setBytes(1, paymentHash1.toArray) + statement.setLong(2, (123 msat).toLong) + statement.setLong(3, 1000) // received_at statement.executeUpdate() } @@ -74,21 +71,19 @@ class SqlitePaymentsDbSpec extends FunSuite { } // the existing received payment can NOT be queried anymore - assert(preMigrationDb.getIncomingPayment(oldReceivedPayment.paymentHash).isEmpty) + assert(preMigrationDb.getIncomingPayment(paymentHash1).isEmpty) // add a few rows - val ps1 = OutgoingPayment(UUID.randomUUID(), None, None, oldReceivedPayment.paymentHash, 12345 msat, alice, 12345, PENDING, None) - val i1 = PaymentRequest.read("lnbc10u1pw2t4phpp5ezwm2gdccydhnphfyepklc0wjkxhz0r4tctg9paunh2lxgeqhcmsdqlxycrqvpqwdshgueqvfjhggr0dcsry7qcqzpgfa4ecv7447p9t5hkujy9qgrxvkkf396p9zar9p87rv2htmeuunkhydl40r64n5s2k0u7uelzc8twxmp37nkcch6m0wg5tvvx69yjz8qpk94qf3") - val pr1 = IncomingPayment(i1.paymentHash, 12345678 msat, 1513871928275L) + val ps1 = OutgoingPayment(UUID.randomUUID(), None, None, paymentHash1, 12345 msat, alice, 1000, None, OutgoingPaymentStatus.Pending) + val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(500 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 1) + val pr1 = IncomingPayment(i1, preimage1, i1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(550 msat, 1100)) - preMigrationDb.addPaymentRequest(i1, ByteVector32.Zeroes) - preMigrationDb.addIncomingPayment(pr1) preMigrationDb.addOutgoingPayment(ps1) + preMigrationDb.addIncomingPayment(i1, preimage1) + preMigrationDb.receiveIncomingPayment(i1.paymentHash, 550 msat, 1100) - val now = Platform.currentTime.milliseconds.toSeconds - assert(preMigrationDb.listIncomingPayments(0, now) == Seq(pr1)) - assert(preMigrationDb.listOutgoingPayments(0, now) == Seq(ps1)) - assert(preMigrationDb.listPaymentRequests(0, now) == Seq(i1)) + assert(preMigrationDb.listIncomingPayments(1, 1500) === Seq(pr1)) + assert(preMigrationDb.listOutgoingPayments(1, 1500) === Seq(ps1)) val postMigrationDb = new SqlitePaymentsDb(connection) @@ -96,9 +91,8 @@ class SqlitePaymentsDbSpec extends FunSuite { assert(getVersion(statement, "payments", 3) == 3) // version still to 3 } - assert(postMigrationDb.listIncomingPayments(0, now) == Seq(pr1)) - assert(postMigrationDb.listOutgoingPayments(0, now) == Seq(ps1)) - assert(preMigrationDb.listPaymentRequests(0, now) == Seq(i1)) + assert(postMigrationDb.listIncomingPayments(1, 1500) === Seq(pr1)) + assert(postMigrationDb.listOutgoingPayments(1, 1500) === Seq(ps1)) } test("handle version migration 2->3") { @@ -116,57 +110,72 @@ class SqlitePaymentsDbSpec extends FunSuite { } // Insert a bunch of old version 2 rows. - val ps1 = OutgoingPayment(UUID.randomUUID(), None, None, randomBytes32, 561 msat, PrivateKey(ByteVector32.One).publicKey, 0, PENDING, None) - val ps2 = OutgoingPayment(UUID.randomUUID(), None, None, randomBytes32, 1105 msat, PrivateKey(ByteVector32.One).publicKey, 1, FAILED, None, Some(2), None, Some(PaymentFailureSummary(Nil))) - val ps3 = OutgoingPayment(UUID.randomUUID(), None, None, defaultPaymentHash, 1729 msat, PrivateKey(ByteVector32.One).publicKey, 4, SUCCEEDED, None, Some(5), Some(PaymentSuccessSummary(defaultPreimage, 0 msat, Nil))) - val i1 = PaymentRequest.read("lnbc10u1pw2t4phpp5ezwm2gdccydhnphfyepklc0wjkxhz0r4tctg9paunh2lxgeqhcmsdqlxycrqvpqwdshgueqvfjhggr0dcsry7qcqzpgfa4ecv7447p9t5hkujy9qgrxvkkf396p9zar9p87rv2htmeuunkhydl40r64n5s2k0u7uelzc8twxmp37nkcch6m0wg5tvvx69yjz8qpk94qf3") - val pr1 = IncomingPayment(i1.paymentHash, 12345678 msat, 9) - val i2 = PaymentRequest.read("lnbc1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdpl2pkx2ctnv5sxxmmwwd5kgetjypeh2ursdae8g6twvus8g6rfwvs8qun0dfjkxaq8rkx3yf5tcsyz3d73gafnh3cax9rn449d9p5uxz9ezhhypd0elx87sjle52x86fux2ypatgddc6k63n7erqz25le42c4u4ecky03ylcqca784w") + val ps1 = OutgoingPayment(UUID.randomUUID(), None, None, randomBytes32, 561 msat, PrivateKey(ByteVector32.One).publicKey, 1000, None, OutgoingPaymentStatus.Pending) + val ps2 = OutgoingPayment(UUID.randomUUID(), None, None, randomBytes32, 1105 msat, PrivateKey(ByteVector32.One).publicKey, 1010, None, OutgoingPaymentStatus.Failed(Nil, 1050)) + val ps3 = OutgoingPayment(UUID.randomUUID(), None, None, paymentHash1, 1729 msat, PrivateKey(ByteVector32.One).publicKey, 1040, None, OutgoingPaymentStatus.Succeeded(preimage1, 0 msat, Nil, 1060)) + val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 1) + val pr1 = IncomingPayment(i1, preimage1, i1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(12345678 msat, 1090)) + val i2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash2, carolPriv, "Another invoice", expirySeconds = Some(30), timestamp = 1) + val pr2 = IncomingPayment(i2, preimage2, i2.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired) // Changes between version 2 and 3 to sent_payments: + // - removed the status column // - added optional payment failures // - added optional payment success details (fees paid and route) // - added optional payment request // - added target node ID // - added externalID and parentID - using(connection.prepareStatement("INSERT INTO sent_payments VALUES (?, ?, NULL, ?, ?, NULL, ?)")) { statement => + using(connection.prepareStatement("INSERT INTO sent_payments (id, payment_hash, amount_msat, created_at, status) VALUES (?, ?, ?, ?, ?)")) { statement => statement.setString(1, ps1.id.toString) statement.setBytes(2, ps1.paymentHash.toArray) statement.setLong(3, ps1.amount.toLong) statement.setLong(4, ps1.createdAt) - statement.setString(5, ps1.status.toString) + statement.setString(5, "PENDING") + statement.executeUpdate() + } + + using(connection.prepareStatement("INSERT INTO sent_payments (id, payment_hash, amount_msat, created_at, completed_at, status) VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setString(1, ps2.id.toString) + statement.setBytes(2, ps2.paymentHash.toArray) + statement.setLong(3, ps2.amount.toLong) + statement.setLong(4, ps2.createdAt) + statement.setLong(5, ps2.status.asInstanceOf[OutgoingPaymentStatus.Failed].completedAt) + statement.setString(6, "FAILED") statement.executeUpdate() } - for (ps <- Seq(ps2, ps3)) { - using(connection.prepareStatement("INSERT INTO sent_payments VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement => - statement.setString(1, ps.id.toString) - statement.setBytes(2, ps.paymentHash.toArray) - statement.setBytes(3, ps.successSummary.map(_.paymentPreimage.toArray).orNull) - statement.setLong(4, ps.amount.toLong) - statement.setLong(5, ps.createdAt) - statement.setLong(6, ps.completedAt.get) - statement.setString(7, ps.status.toString) - statement.executeUpdate() - } + using(connection.prepareStatement("INSERT INTO sent_payments (id, payment_hash, preimage, amount_msat, created_at, completed_at, status) VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement => + statement.setString(1, ps3.id.toString) + statement.setBytes(2, ps3.paymentHash.toArray) + statement.setBytes(3, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].paymentPreimage.toArray) + statement.setLong(4, ps3.amount.toLong) + statement.setLong(5, ps3.createdAt) + statement.setLong(6, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].completedAt) + statement.setString(7, "SUCCEEDED") + statement.executeUpdate() } - using(connection.prepareStatement("INSERT INTO received_payments VALUES (?, ?, ?, ?, ?, NULL, ?)")) { statement => + // Changes between version 2 and 3 to received_payments: + // - renamed the preimage column + // - made expire_at not null + + using(connection.prepareStatement("INSERT INTO received_payments (payment_hash, preimage, payment_request, received_msat, created_at, received_at) VALUES (?, ?, ?, ?, ?, ?)")) { statement => statement.setBytes(1, i1.paymentHash.toArray) - statement.setBytes(2, defaultPreimage.toArray) + statement.setBytes(2, pr1.paymentPreimage.toArray) statement.setString(3, PaymentRequest.write(i1)) - statement.setLong(4, pr1.amount.toLong) - statement.setLong(5, 0) // created_at - statement.setLong(6, pr1.receivedAt) + statement.setLong(4, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].amount.toLong) + statement.setLong(5, pr1.createdAt) + statement.setLong(6, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].receivedAt) statement.executeUpdate() } - using(connection.prepareStatement("INSERT INTO received_payments VALUES (?, ?, ?, NULL, ?, NULL, NULL)")) { statement => + using(connection.prepareStatement("INSERT INTO received_payments (payment_hash, preimage, payment_request, created_at, expire_at) VALUES (?, ?, ?, ?, ?)")) { statement => statement.setBytes(1, i2.paymentHash.toArray) - statement.setBytes(2, defaultPreimage.toArray) + statement.setBytes(2, pr2.paymentPreimage.toArray) statement.setString(3, PaymentRequest.write(i2)) - statement.setLong(4, 0) // created_at + statement.setLong(4, pr2.createdAt) + statement.setLong(5, (i2.timestamp + i2.expiry.get).seconds.toMillis) statement.executeUpdate() } @@ -176,11 +185,9 @@ class SqlitePaymentsDbSpec extends FunSuite { assert(getVersion(statement, "payments", 2) == 3) // version changed from 2 -> 3 } - assert(preMigrationDb.getPaymentRequest(i1.paymentHash) === Some(i1)) - assert(preMigrationDb.getPaymentRequest(i2.paymentHash) === Some(i2)) assert(preMigrationDb.getIncomingPayment(i1.paymentHash) === Some(pr1)) - assert(preMigrationDb.getIncomingPayment(i2.paymentHash) === None) - assert(preMigrationDb.listOutgoingPayments(0, 5).toSet === Set(ps1, ps2, ps3)) + assert(preMigrationDb.getIncomingPayment(i2.paymentHash) === Some(pr2)) + assert(preMigrationDb.listOutgoingPayments(1, 2000) === Seq(ps1, ps2, ps3)) val postMigrationDb = new SqlitePaymentsDb(connection) @@ -188,118 +195,119 @@ class SqlitePaymentsDbSpec extends FunSuite { assert(getVersion(statement, "payments", 3) == 3) // version still to 3 } - val i3 = PaymentRequest.read("lnbc1500n1pdl686hpp5y7mz3lgvrfccqnk9es6trumjgqdpjwcecycpkdggnx7h6cuup90sdpa2fjkzep6ypqkymm4wssycnjzf9rjqurjda4x2cm5ypskuepqv93x7at5ypek7cqzysxqr23s5e864m06fcfp3axsefy276d77tzp0xzzzdfl6p46wvstkeqhu50khm9yxea2d9efp7lvthrta0ktmhsv52hf3tvxm0unsauhmfmp27cqqx4xxe") - postMigrationDb.addPaymentRequest(i3, randomBytes32) + val i3 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), paymentHash3, alicePriv, "invoice #3", expirySeconds = Some(30)) + val pr3 = IncomingPayment(i3, preimage3, i3.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending) + postMigrationDb.addIncomingPayment(i3, pr3.paymentPreimage) - val ps4 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), Some("1"), randomBytes32, 123 msat, alice, 6, PENDING, Some(i3)) - val ps5 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), Some("2"), randomBytes32, 456 msat, bob, 7, SUCCEEDED, Some(i2), Some(8), Some(PaymentSuccessSummary(randomBytes32, 42 msat, Nil))) - val ps6 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), Some("3"), randomBytes32, 789 msat, bob, 8, FAILED, None, Some(9), None, Some(PaymentFailureSummary(Nil))) + val ps4 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), Some("1"), randomBytes32, 123 msat, alice, 1100, Some(i3), OutgoingPaymentStatus.Pending) + val ps5 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), Some("2"), randomBytes32, 456 msat, bob, 1150, Some(i2), OutgoingPaymentStatus.Succeeded(preimage1, 42 msat, Nil, 1180)) + val ps6 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), Some("3"), randomBytes32, 789 msat, bob, 1250, None, OutgoingPaymentStatus.Failed(Nil, 1300)) postMigrationDb.addOutgoingPayment(ps4) - postMigrationDb.addOutgoingPayment(ps5) - postMigrationDb.updateOutgoingPayment(PaymentSent(ps5.id, ps5.amount, ps5.successSummary.get.feesPaid, ps5.paymentHash, ps5.successSummary.get.paymentPreimage, Nil, ps5.completedAt.get)) - postMigrationDb.addOutgoingPayment(ps6) - postMigrationDb.updateOutgoingPayment(PaymentFailed(ps6.id, ps6.paymentHash, Nil, ps6.completedAt.get)) - - assert(postMigrationDb.listOutgoingPayments(0, 10).toSet === Set(ps1, ps2, ps3, ps4, ps5, ps6)) - assert(postMigrationDb.listIncomingPayments(0, 10).toSet === Set(pr1)) - assert(postMigrationDb.getPaymentRequest(i1.paymentHash) === Some(i1)) - assert(postMigrationDb.getPaymentRequest(i2.paymentHash) === Some(i2)) - assert(postMigrationDb.getPaymentRequest(i3.paymentHash) === Some(i3)) + postMigrationDb.addOutgoingPayment(ps5.copy(status = OutgoingPaymentStatus.Pending)) + postMigrationDb.updateOutgoingPayment(PaymentSent(ps5.id, ps5.amount, 42 msat, ps5.paymentHash, preimage1, Nil, 1180)) + postMigrationDb.addOutgoingPayment(ps6.copy(status = OutgoingPaymentStatus.Pending)) + postMigrationDb.updateOutgoingPayment(PaymentFailed(ps6.id, ps6.paymentHash, Nil, 1300)) + + assert(postMigrationDb.listOutgoingPayments(1, 2000) === Seq(ps1, ps2, ps3, ps4, ps5, ps6)) + assert(postMigrationDb.listIncomingPayments(1, Platform.currentTime) === Seq(pr1, pr2, pr3)) + assert(postMigrationDb.listExpiredIncomingPayments(1, 2000) === Seq(pr2)) } - test("add/list received payments/find 1 payment that exists/find 1 payment that does not exist") { + test("add/retrieve/update incoming payments") { val sqlite = TestConstants.sqliteInMemory() val db = new SqlitePaymentsDb(sqlite) // can't receive a payment without an invoice associated with it - assertThrows[IllegalArgumentException](db.addIncomingPayment(IncomingPayment(randomBytes32, 12345678 msat, 1513871928275L))) - - val i1 = PaymentRequest.read("lnbc5450n1pw2t4qdpp5vcrf6ylgpettyng4ac3vujsk0zpc25cj0q3zp7l7w44zvxmpzh8qdzz2pshjmt9de6zqen0wgsr2dp4ypcxj7r9d3ejqct5ypekzar0wd5xjuewwpkxzcm99cxqzjccqp2rzjqtspxelp67qc5l56p6999wkatsexzhs826xmupyhk6j8lxl038t27z9tsqqqgpgqqqqqqqlgqqqqqzsqpcz8z8hmy8g3ecunle4n3edn3zg2rly8g4klsk5md736vaqqy3ktxs30ht34rkfkqaffzxmjphvd0637dk2lp6skah2hq09z6lrjna3xqp3d4vyd") - val i2 = PaymentRequest.read("lnbc10u1pw2t4phpp5ezwm2gdccydhnphfyepklc0wjkxhz0r4tctg9paunh2lxgeqhcmsdqlxycrqvpqwdshgueqvfjhggr0dcsry7qcqzpgfa4ecv7447p9t5hkujy9qgrxvkkf396p9zar9p87rv2htmeuunkhydl40r64n5s2k0u7uelzc8twxmp37nkcch6m0wg5tvvx69yjz8qpk94qf3") - - db.addPaymentRequest(i1, ByteVector32.Zeroes) - db.addPaymentRequest(i2, ByteVector32.One) - - val p1 = IncomingPayment(i1.paymentHash, 12345678 msat, 1513871928275L) - val p2 = IncomingPayment(i2.paymentHash, 12345678 msat, 1513871928275L) - assert(db.listIncomingPayments(0, Platform.currentTime.milliseconds.toSeconds) === Nil) - db.addIncomingPayment(p1) - db.addIncomingPayment(p2) - assert(db.listIncomingPayments(0, Platform.currentTime.milliseconds.toSeconds).toList === List(p1, p2)) - assert(db.getIncomingPayment(p1.paymentHash) === Some(p1)) - assert(db.getIncomingPayment(randomBytes32) === None) + assertThrows[IllegalArgumentException](db.receiveIncomingPayment(randomBytes32, 12345678 msat)) + + val expiredInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32, alicePriv, "invoice #1", timestamp = 1) + val expiredInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32, bobPriv, "invoice #2", timestamp = 2, expirySeconds = Some(30)) + val expiredPayment1 = IncomingPayment(expiredInvoice1, randomBytes32, expiredInvoice1.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired) + val expiredPayment2 = IncomingPayment(expiredInvoice2, randomBytes32, expiredInvoice2.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired) + + val pendingInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32, alicePriv, "invoice #3") + val pendingInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32, bobPriv, "invoice #4", expirySeconds = Some(30)) + val pendingPayment1 = IncomingPayment(pendingInvoice1, randomBytes32, pendingInvoice1.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending) + val pendingPayment2 = IncomingPayment(pendingInvoice2, randomBytes32, pendingInvoice2.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending) + + val paidInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32, alicePriv, "invoice #5") + val paidInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32, bobPriv, "invoice #6", expirySeconds = Some(60)) + val receivedAt1 = Platform.currentTime + 1 + val payment1 = IncomingPayment(paidInvoice1, randomBytes32, paidInvoice1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(561 msat, receivedAt1)) + val receivedAt2 = Platform.currentTime + 2 + val payment2 = IncomingPayment(paidInvoice2, randomBytes32, paidInvoice2.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(1111 msat, receivedAt2)) + + db.addIncomingPayment(pendingInvoice1, pendingPayment1.paymentPreimage) + db.addIncomingPayment(pendingInvoice2, pendingPayment2.paymentPreimage) + db.addIncomingPayment(expiredInvoice1, expiredPayment1.paymentPreimage) + db.addIncomingPayment(expiredInvoice2, expiredPayment2.paymentPreimage) + db.addIncomingPayment(paidInvoice1, payment1.paymentPreimage) + db.addIncomingPayment(paidInvoice2, payment2.paymentPreimage) + + assert(db.getIncomingPayment(pendingInvoice1.paymentHash) === Some(pendingPayment1)) + assert(db.getIncomingPayment(expiredInvoice2.paymentHash) === Some(expiredPayment2)) + assert(db.getIncomingPayment(paidInvoice1.paymentHash) === Some(payment1.copy(status = IncomingPaymentStatus.Pending))) + + val now = Platform.currentTime + assert(db.listIncomingPayments(0, now) === Seq(expiredPayment1, expiredPayment2, pendingPayment1, pendingPayment2, payment1.copy(status = IncomingPaymentStatus.Pending), payment2.copy(status = IncomingPaymentStatus.Pending))) + assert(db.listExpiredIncomingPayments(0, now) === Seq(expiredPayment1, expiredPayment2)) + assert(db.listReceivedIncomingPayments(0, now) === Nil) + assert(db.listPendingIncomingPayments(0, now) === Seq(pendingPayment1, pendingPayment2, payment1.copy(status = IncomingPaymentStatus.Pending), payment2.copy(status = IncomingPaymentStatus.Pending))) + + db.receiveIncomingPayment(paidInvoice1.paymentHash, 561 msat, receivedAt1) + db.receiveIncomingPayment(paidInvoice2.paymentHash, 1111 msat, receivedAt2) + + assert(db.getIncomingPayment(paidInvoice1.paymentHash) === Some(payment1)) + assert(db.listIncomingPayments(0, now) === Seq(expiredPayment1, expiredPayment2, pendingPayment1, pendingPayment2, payment1, payment2)) + assert(db.listIncomingPayments(now - 60.seconds.toMillis, now) === Seq(pendingPayment1, pendingPayment2, payment1, payment2)) + assert(db.listPendingIncomingPayments(0, now) === Seq(pendingPayment1, pendingPayment2)) + assert(db.listReceivedIncomingPayments(0, now) === Seq(payment1, payment2)) } - test("add/retrieve/update sent payments") { + test("add/retrieve/update outgoing payments") { val db = new SqlitePaymentsDb(TestConstants.sqliteInMemory()) - val i1 = PaymentRequest(chainHash = Block.TestnetGenesisBlock.hash, amount = Some(123 msat), paymentHash = randomBytes32, privateKey = davePriv, description = "Some invoice", expirySeconds = None, timestamp = 1) - val s1 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), None, i1.paymentHash, 123 msat, alice, 1, PENDING, Some(i1)) - val s2 = OutgoingPayment(UUID.randomUUID(), None, Some("1"), randomBytes32, 456 msat, bob, 2, PENDING, None) + val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(123 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 0) + val s1 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), None, i1.paymentHash, 123 msat, alice, 100, Some(i1), OutgoingPaymentStatus.Pending) + val s2 = OutgoingPayment(UUID.randomUUID(), None, Some("1"), paymentHash2, 456 msat, bob, 200, None, OutgoingPaymentStatus.Pending) - assert(db.listOutgoingPayments(0, Platform.currentTime.milliseconds.toSeconds).isEmpty) + assert(db.listOutgoingPayments(0, Platform.currentTime).isEmpty) db.addOutgoingPayment(s1) db.addOutgoingPayment(s2) - assert(db.listOutgoingPayments(0, Platform.currentTime.milliseconds.toSeconds).toList == Seq(s1, s2)) + // can't add an outgoing payment in non-pending state + assertThrows[IllegalArgumentException](db.addOutgoingPayment(s1.copy(status = OutgoingPaymentStatus.Succeeded(randomBytes32, 0 msat, Nil, 110)))) + + assert(db.listOutgoingPayments(1, 300).toList == Seq(s1, s2)) + assert(db.listOutgoingPayments(1, 150).toList == Seq(s1)) + assert(db.listOutgoingPayments(150, 250).toList == Seq(s2)) assert(db.getOutgoingPayment(s1.id) === Some(s1)) assert(db.getOutgoingPayment(UUID.randomUUID()) === None) - assert(db.getOutgoingPayments(s2.paymentHash) === Seq(s2)) - assert(db.getOutgoingPayments(ByteVector32.Zeroes) === Nil) + assert(db.listOutgoingPayments(s2.paymentHash) === Seq(s2)) + assert(db.listOutgoingPayments(ByteVector32.Zeroes) === Nil) - val s3 = s2.copy(id = UUID.randomUUID(), amount = 789 msat) - val s4 = s2.copy(id = UUID.randomUUID()) + val s3 = s2.copy(id = UUID.randomUUID(), amount = 789 msat, createdAt = 300) + val s4 = s2.copy(id = UUID.randomUUID(), createdAt = 300) db.addOutgoingPayment(s3) db.addOutgoingPayment(s4) - db.updateOutgoingPayment(PaymentFailed(s3.id, s3.paymentHash, Nil, 10)) - assert(db.getOutgoingPayment(s3.id) === Some(s3.copy(status = FAILED, completedAt = Some(10), failureSummary = Some(PaymentFailureSummary(Nil))))) - db.updateOutgoingPayment(PaymentFailed(s4.id, s4.paymentHash, Seq(LocalFailure(new RuntimeException("woops")), RemoteFailure(Seq(hop_ab, hop_bc), Sphinx.DecryptedFailurePacket(carol, UnknownNextPeer))), 11)) - assert(db.getOutgoingPayment(s4.id) === Some(s4.copy(status = FAILED, completedAt = Some(11), failureSummary = Some(PaymentFailureSummary(Seq(FailureSummary(FailureType.LOCAL, "woops", Nil), FailureSummary(FailureType.REMOTE, "processing node does not know the next peer in the route", List(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, Some(ShortChannelId(43))))))))))) + db.updateOutgoingPayment(PaymentFailed(s3.id, s3.paymentHash, Nil, 310)) + assert(db.getOutgoingPayment(s3.id) === Some(s3.copy(status = OutgoingPaymentStatus.Failed(Nil, 310)))) + db.updateOutgoingPayment(PaymentFailed(s4.id, s4.paymentHash, Seq(LocalFailure(new RuntimeException("woops")), RemoteFailure(Seq(hop_ab, hop_bc), Sphinx.DecryptedFailurePacket(carol, UnknownNextPeer))), 320)) + assert(db.getOutgoingPayment(s4.id) === Some(s4.copy(status = OutgoingPaymentStatus.Failed(Seq(FailureSummary(FailureType.LOCAL, "woops", Nil), FailureSummary(FailureType.REMOTE, "processing node does not know the next peer in the route", List(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, Some(ShortChannelId(43)))))), 320)))) // can't update again once it's in a final state - assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentSent(s3.id, s3.amount, 42 msat, s3.paymentHash, defaultPreimage, Nil))) + assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentSent(s3.id, s3.amount, 42 msat, s3.paymentHash, preimage1, Nil))) - db.updateOutgoingPayment(PaymentSent(s1.id, s1.amount, 15 msat, s1.paymentHash, defaultPreimage, Nil, 10)) - assert(db.getOutgoingPayment(s1.id) === Some(s1.copy(status = SUCCEEDED, completedAt = Some(10), successSummary = Some(PaymentSuccessSummary(defaultPreimage, 15 msat, Nil))))) - db.updateOutgoingPayment(PaymentSent(s2.id, s2.amount, 15 msat, s2.paymentHash, defaultPreimage, Seq(hop_ab, hop_bc), 11)) - assert(db.getOutgoingPayment(s2.id) === Some(s2.copy(status = SUCCEEDED, completedAt = Some(11), successSummary = Some(PaymentSuccessSummary(defaultPreimage, 15 msat, Seq(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, Some(ShortChannelId(43))))))))) + db.updateOutgoingPayment(PaymentSent(s1.id, s1.amount, 15 msat, s1.paymentHash, preimage1, Nil, 400)) + assert(db.getOutgoingPayment(s1.id) === Some(s1.copy(status = OutgoingPaymentStatus.Succeeded(preimage1, 15 msat, Nil, 400)))) + db.updateOutgoingPayment(PaymentSent(s2.id, s2.amount, 15 msat, s2.paymentHash, preimage2, Seq(hop_ab, hop_bc), 410)) + assert(db.getOutgoingPayment(s2.id) === Some(s2.copy(status = OutgoingPaymentStatus.Succeeded(preimage2, 15 msat, Seq(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, Some(ShortChannelId(43)))), 410)))) // can't update again once it's in a final state assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentFailed(s1.id, s1.paymentHash, Nil))) } - test("add/retrieve payment requests") { - val db = new SqlitePaymentsDb(TestConstants.sqliteInMemory()) - val someTimestamp = 12345 - val (paymentHash1, paymentHash2) = (randomBytes32, randomBytes32) - val i1 = PaymentRequest(chainHash = Block.TestnetGenesisBlock.hash, amount = Some(123 msat), paymentHash = paymentHash1, privateKey = bobPriv, description = "Some invoice", expirySeconds = None, timestamp = someTimestamp) - val i2 = PaymentRequest(chainHash = Block.TestnetGenesisBlock.hash, amount = None, paymentHash = paymentHash2, privateKey = bobPriv, description = "Some invoice", expirySeconds = Some(123456), timestamp = Platform.currentTime.milliseconds.toSeconds) - - // i2 doesn't expire - assert(i1.expiry.isEmpty && i2.expiry.isDefined) - assert(i1.amount.isDefined && i2.amount.isEmpty) - - db.addPaymentRequest(i1, ByteVector32.Zeroes) - db.addPaymentRequest(i2, ByteVector32.One) - - // order matters, i2 has a more recent timestamp than i1 - assert(db.listPaymentRequests(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) == Seq(i2, i1)) - assert(db.getPaymentRequest(i1.paymentHash) === Some(i1)) - assert(db.getPaymentRequest(i2.paymentHash) === Some(i2)) - - assert(db.listPendingPaymentRequests(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) == Seq(i2, i1)) - assert(db.getPendingPaymentRequestAndPreimage(paymentHash1) === Some((ByteVector32.Zeroes, i1))) - assert(db.getPendingPaymentRequestAndPreimage(paymentHash2) === Some((ByteVector32.One, i2))) - - val from = (someTimestamp - 100).seconds.toSeconds - val to = (someTimestamp + 100).seconds.toSeconds - assert(db.listPaymentRequests(from, to) == Seq(i1)) - - db.addIncomingPayment(IncomingPayment(i2.paymentHash, 42 msat, someTimestamp)) - assert(db.listPendingPaymentRequests(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) == Seq(i1)) - } - } object SqlitePaymentsDbSpec { @@ -307,6 +315,6 @@ object SqlitePaymentsDbSpec { val (alice, bob, carol, dave) = (alicePriv.publicKey, bobPriv.publicKey, carolPriv.publicKey, davePriv.publicKey) val hop_ab = Hop(alice, bob, ChannelUpdate(randomBytes64, randomBytes32, ShortChannelId(42), 1, 0, 0, CltvExpiryDelta(12), 1 msat, 1 msat, 1, None)) val hop_bc = Hop(bob, carol, ChannelUpdate(randomBytes64, randomBytes32, ShortChannelId(43), 1, 0, 0, CltvExpiryDelta(12), 1 msat, 1 msat, 1, None)) - val defaultPreimage = randomBytes32 - val defaultPaymentHash = Crypto.sha256(defaultPreimage) + val (preimage1, preimage2, preimage3, preimage4) = (randomBytes32, randomBytes32, randomBytes32, randomBytes32) + val (paymentHash1, paymentHash2, paymentHash3, paymentHash4) = (Crypto.sha256(preimage1), Crypto.sha256(preimage2), Crypto.sha256(preimage3), Crypto.sha256(preimage4)) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentHandlerSpec.scala index 6e4ff42d31..34c5f8c557 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentHandlerSpec.scala @@ -19,9 +19,10 @@ package fr.acinq.eclair.payment import akka.actor.ActorSystem import akka.actor.Status.Failure import akka.testkit.{TestActorRef, TestKit, TestProbe} -import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.bitcoin.{ByteVector32, Crypto} import fr.acinq.eclair.TestConstants.Alice import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC} +import fr.acinq.eclair.db.IncomingPaymentStatus import fr.acinq.eclair.payment.PaymentLifecycle.ReceivePayment import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.wire.{IncorrectOrUnknownPaymentDetails, UpdateAddHtlc} @@ -50,9 +51,11 @@ class PaymentHandlerSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { sender.send(handler, ReceivePayment(Some(amountMsat), "1 coffee")) val pr = sender.expectMsgType[PaymentRequest] - assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).isEmpty) - assert(nodeParams.db.payments.getPendingPaymentRequestAndPreimage(pr.paymentHash).isDefined) - assert(!nodeParams.db.payments.getPendingPaymentRequestAndPreimage(pr.paymentHash).get._2.isExpired) + val incoming = nodeParams.db.payments.getIncomingPayment(pr.paymentHash) + assert(incoming.isDefined) + assert(incoming.get.status === IncomingPaymentStatus.Pending) + assert(!incoming.get.paymentRequest.isExpired) + assert(Crypto.sha256(incoming.get.paymentPreimage) === pr.paymentHash) val add = UpdateAddHtlc(ByteVector32(ByteVector.fill(32)(1)), 0, amountMsat, pr.paymentHash, expiry, TestConstants.emptyOnionPacket) sender.send(handler, add) @@ -60,32 +63,36 @@ class PaymentHandlerSpec extends TestKit(ActorSystem("test")) with FunSuiteLike val paymentRelayed = eventListener.expectMsgType[PaymentReceived] assert(paymentRelayed.copy(timestamp = 0) === PaymentReceived(amountMsat, add.paymentHash, timestamp = 0)) - assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).exists(_.paymentHash == pr.paymentHash)) + val received = nodeParams.db.payments.getIncomingPayment(pr.paymentHash) + assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0) === IncomingPaymentStatus.Received(amountMsat, 0)) } { sender.send(handler, ReceivePayment(Some(amountMsat), "another coffee")) val pr = sender.expectMsgType[PaymentRequest] - assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).isEmpty) + assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).get.status === IncomingPaymentStatus.Pending) val add = UpdateAddHtlc(ByteVector32(ByteVector.fill(32)(1)), 0, amountMsat, pr.paymentHash, expiry, TestConstants.emptyOnionPacket) sender.send(handler, add) sender.expectMsgType[CMD_FULFILL_HTLC] val paymentRelayed = eventListener.expectMsgType[PaymentReceived] assert(paymentRelayed.copy(timestamp = 0) === PaymentReceived(amountMsat, add.paymentHash, timestamp = 0)) - assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).exists(_.paymentHash == pr.paymentHash)) + val received = nodeParams.db.payments.getIncomingPayment(pr.paymentHash) + assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0) === IncomingPaymentStatus.Received(amountMsat, 0)) } { sender.send(handler, ReceivePayment(Some(amountMsat), "bad expiry")) val pr = sender.expectMsgType[PaymentRequest] - assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).isEmpty) + assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).get.status === IncomingPaymentStatus.Pending) val add = UpdateAddHtlc(ByteVector32(ByteVector.fill(32)(1)), 0, amountMsat, pr.paymentHash, cltvExpiry = CltvExpiryDelta(3).toCltvExpiry(nodeParams.currentBlockHeight), TestConstants.emptyOnionPacket) sender.send(handler, add) assert(sender.expectMsgType[CMD_FAIL_HTLC].reason == Right(IncorrectOrUnknownPaymentDetails(amountMsat, nodeParams.currentBlockHeight))) eventListener.expectNoMsg(300 milliseconds) - assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).isEmpty) + assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).get.status === IncomingPaymentStatus.Pending) } } @@ -168,6 +175,7 @@ class PaymentHandlerSpec extends TestKit(ActorSystem("test")) with FunSuiteLike sender.send(handler, add) sender.expectMsgType[CMD_FAIL_HTLC] - assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).isEmpty) + val Some(incoming) = nodeParams.db.payments.getIncomingPayment(pr.paymentHash) + assert(incoming.paymentRequest.isExpired && incoming.status === IncomingPaymentStatus.Expired) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index 48694e7553..283745977c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -72,13 +72,13 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val Transition(_, WAITING_FOR_ROUTE, WAITING_FOR_PAYMENT_COMPLETE) = monitor.expectMsgClass(classOf[Transition[_]]) - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) val Some(outgoing) = paymentDb.getOutgoingPayment(id) - assert(outgoing.copy(createdAt = 0) === OutgoingPayment(id, None, Some(defaultExternalId), defaultPaymentHash, defaultAmountMsat, d, 0, OutgoingPaymentStatus.PENDING, None)) + assert(outgoing.copy(createdAt = 0) === OutgoingPayment(id, None, Some(defaultExternalId), defaultPaymentHash, defaultAmountMsat, d, 0, None, OutgoingPaymentStatus.Pending)) sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) sender.expectMsgType[PaymentSent] - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.SUCCEEDED)) + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Succeeded])) } test("payment failed (route not found)") { fixture => @@ -99,11 +99,11 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val routeRequest = routerForwarder.expectMsgType[RouteRequest] - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) routerForwarder.forward(router, routeRequest) assert(sender.expectMsgType[PaymentFailed].failures === LocalFailure(RouteNotFound) :: Nil) - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.FAILED)) + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Failed])) } test("payment failed (route too expensive)") { fixture => @@ -124,7 +124,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val Seq(LocalFailure(RouteNotFound)) = sender.expectMsgType[PaymentFailed].failures - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.FAILED)) + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Failed])) } test("payment failed (unparsable failure)") { fixture => @@ -144,7 +144,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 2) sender.send(paymentFSM, request) - awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) + awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) @@ -168,7 +168,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // we allow 2 tries, so we send a 2nd request to the router assert(sender.expectMsgType[PaymentFailed].failures === UnreadableRemoteFailure(hops) :: UnreadableRemoteFailure(hops) :: Nil) - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.FAILED)) // after last attempt the payment is failed + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Failed])) // after last attempt the payment is failed } test("payment failed (local error)") { fixture => @@ -188,7 +188,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 2) sender.send(paymentFSM, request) - awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) + awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData routerForwarder.expectMsg(RouteRequest(nodeParams.nodeId, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) @@ -201,7 +201,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // then the payment lifecycle will ask for a new route excluding the channel routerForwarder.expectMsg(RouteRequest(nodeParams.nodeId, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set(ChannelDesc(channelId_ab, a, b)))) - awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) // payment is still pending because the error is recoverable + awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) // payment is still pending because the error is recoverable } test("payment failed (first hop returns an UpdateFailMalformedHtlc)") { fixture => @@ -221,7 +221,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 2) sender.send(paymentFSM, request) - awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) + awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData routerForwarder.expectMsg(RouteRequest(nodeParams.nodeId, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) @@ -234,7 +234,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // then the payment lifecycle will ask for a new route excluding the channel routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set(ChannelDesc(channelId_ab, a, b)))) - awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) + awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) } test("payment failed (TemporaryChannelFailure)") { fixture => @@ -294,7 +294,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 5) sender.send(paymentFSM, request) - awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) + awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData routerForwarder.expectMsg(RouteRequest(nodeParams.nodeId, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) @@ -311,7 +311,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // payment lifecycle forwards the embedded channelUpdate to the router routerForwarder.expectMsg(channelUpdate_bc_modified) - awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) // 1 failure but not final, the payment is still PENDING + awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) // 1 failure but not final, the payment is still PENDING routerForwarder.expectMsg(RouteRequest(nodeParams.nodeId, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) routerForwarder.forward(router) @@ -337,7 +337,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // this time the router can't find a route: game over assert(sender.expectMsgType[PaymentFailed].failures === RemoteFailure(hops, Sphinx.DecryptedFailurePacket(b, failure)) :: RemoteFailure(hops2, Sphinx.DecryptedFailurePacket(b, failure2)) :: LocalFailure(RouteNotFound) :: Nil) - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.FAILED)) + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Failed])) } def testPermanentFailure(fixture: FixtureParam, failure: FailureMessage): Unit = { @@ -357,7 +357,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 2) sender.send(paymentFSM, request) - awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) + awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData routerForwarder.expectMsg(RouteRequest(nodeParams.nodeId, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) @@ -375,7 +375,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // we allow 2 tries, so we send a 2nd request to the router, which won't find another route assert(sender.expectMsgType[PaymentFailed].failures === RemoteFailure(hops, Sphinx.DecryptedFailurePacket(b, failure)) :: LocalFailure(RouteNotFound) :: Nil) - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.FAILED)) + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Failed])) } test("payment failed (PermanentChannelFailure)") { fixture => @@ -409,9 +409,9 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val Transition(_, WAITING_FOR_ROUTE, WAITING_FOR_PAYMENT_COMPLETE) = monitor.expectMsgClass(classOf[Transition[_]]) - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) val Some(outgoing) = paymentDb.getOutgoingPayment(id) - assert(outgoing.copy(createdAt = 0) === OutgoingPayment(id, None, Some(defaultExternalId), paymentHash, defaultAmountMsat, d, 0, OutgoingPaymentStatus.PENDING, None)) + assert(outgoing.copy(createdAt = 0) === OutgoingPayment(id, None, Some(defaultExternalId), paymentHash, defaultAmountMsat, d, 0, None, OutgoingPaymentStatus.Pending)) sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, paymentPreimage)) val ps = eventListener.expectMsgType[PaymentSent] @@ -419,7 +419,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { assert(ps.amount === defaultAmountMsat) assert(ps.paymentHash === paymentHash) assert(ps.paymentPreimage === paymentPreimage) - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.SUCCEEDED)) + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Succeeded])) } test("payment succeeded to a channel with fees=0") { fixture => From 010ffa15725290a2f7fc530940b5f286998c54ca Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Wed, 18 Sep 2019 18:08:11 +0200 Subject: [PATCH 07/14] fixup! Re-work the PaymentsDb interface and the Incoming/Outgoing payment structures. Clarify use of seconds / milliseconds -> we use milliseconds everywhere except at the Eclair API level (probably because it's easier from bash to get a unix timestamp in seconds than in milliseconds). --- .../main/scala/fr/acinq/eclair/api/JsonSerializers.scala | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/JsonSerializers.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/JsonSerializers.scala index 73c85f0c82..f0cb3d4442 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/JsonSerializers.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/JsonSerializers.scala @@ -25,7 +25,6 @@ import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.{ByteVector32, ByteVector64, OutPoint, Satoshi, Transaction} import fr.acinq.eclair.channel.{ChannelVersion, State} import fr.acinq.eclair.crypto.ShaChain -import fr.acinq.eclair.db.OutgoingPaymentStatus import fr.acinq.eclair.payment.PaymentRequest import fr.acinq.eclair.router.RouteResponse import fr.acinq.eclair.transactions.Direction @@ -185,10 +184,6 @@ class JavaUUIDSerializer extends CustomSerializer[UUID](format => ({ null }, { case id: UUID => JString(id.toString) })) -class OutgoingPaymentStatusSerializer extends CustomSerializer[OutgoingPaymentStatus.Value](format => ({ null }, { - case el: OutgoingPaymentStatus.Value => JString(el.toString) -})) - object JsonSupport extends Json4sSupport { implicit val serialization = jackson.Serialization @@ -221,8 +216,7 @@ object JsonSupport extends Json4sSupport { new NodeAddressSerializer + new DirectionSerializer + new PaymentRequestSerializer + - new JavaUUIDSerializer + - new OutgoingPaymentStatusSerializer + new JavaUUIDSerializer case class CustomTypeHints(custom: Map[Class[_], String]) extends TypeHints { val reverse: Map[String, Class[_]] = custom.map(_.swap) From e4a6677a45e0a66b8d45b334d6a3777a8f57decc Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Thu, 19 Sep 2019 09:14:27 +0200 Subject: [PATCH 08/14] DB: run migrations / init inside a SQL transaction --- .../eclair/db/sqlite/SqliteAuditDb.scala | 3 +- .../eclair/db/sqlite/SqliteChannelsDb.scala | 17 ++-- .../eclair/db/sqlite/SqliteNetworkDb.scala | 2 +- .../eclair/db/sqlite/SqlitePaymentsDb.scala | 2 +- .../eclair/db/sqlite/SqlitePeersDb.scala | 2 +- .../db/sqlite/SqlitePendingRelayDb.scala | 2 +- .../acinq/eclair/db/sqlite/SqliteUtils.scala | 12 ++- .../eclair/db/SqliteChannelsDbSpec.scala | 1 - .../fr/acinq/eclair/db/SqliteUtilsSpec.scala | 77 +++++++++++++++++++ 9 files changed, 104 insertions(+), 14 deletions(-) create mode 100644 eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala index 31b6e3d296..f17b0c2346 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala @@ -39,7 +39,7 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { val DB_NAME = "audit" val CURRENT_VERSION = 4 - using(sqlite.createStatement()) { statement => + using(sqlite.createStatement(), disableAutoCommit = true) { statement => def migration12(statement: Statement) = { statement.executeUpdate(s"ALTER TABLE sent ADD id BLOB DEFAULT '${ChannelCodecs.UNKNOWN_UUID.toString}' NOT NULL") @@ -98,7 +98,6 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)") statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)") statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)") - case unknownVersion => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala index 9cadd6ab1e..4144916a21 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala @@ -35,24 +35,31 @@ class SqliteChannelsDb(sqlite: Connection) extends ChannelsDb with Logging { val DB_NAME = "channels" val CURRENT_VERSION = 2 - private def migration12(statement: Statement) = { - statement.executeUpdate("ALTER TABLE local_channels ADD COLUMN is_closed BOOLEAN NOT NULL DEFAULT 0") + // The SQLite documentation states that "It is not possible to enable or disable foreign key constraints in the middle + // of a multi-statement transaction (when SQLite is not in autocommit mode).". + // So we need to set foreign keys before we initialize tables / migrations (which is done inside a transaction). + using(sqlite.createStatement()) { statement => + statement.execute("PRAGMA foreign_keys = ON") } - using(sqlite.createStatement()) { statement => + using(sqlite.createStatement(), disableAutoCommit = true) { statement => + + def migration12(statement: Statement) = { + statement.executeUpdate("ALTER TABLE local_channels ADD COLUMN is_closed BOOLEAN NOT NULL DEFAULT 0") + } + getVersion(statement, DB_NAME, CURRENT_VERSION) match { case 1 => logger.warn(s"migrating db $DB_NAME, found version=1 current=$CURRENT_VERSION") migration12(statement) setVersion(statement, DB_NAME, CURRENT_VERSION) case CURRENT_VERSION => - statement.execute("PRAGMA foreign_keys = ON") statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT 0)") statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") - case unknownVersion => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } + } override def addOrUpdateChannel(state: HasCommitments): Unit = { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala index cc2e8c1fac..55a0ffb4e8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala @@ -36,7 +36,7 @@ class SqliteNetworkDb(sqlite: Connection) extends NetworkDb with Logging { val DB_NAME = "network" val CURRENT_VERSION = 2 - using(sqlite.createStatement()) { statement => + using(sqlite.createStatement(), disableAutoCommit = true) { statement => getVersion(statement, DB_NAME, CURRENT_VERSION) match { case 1 => // channel_update are cheap to retrieve, so let's just wipe them out and they'll get resynced diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala index 20334ad4e4..312459177f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala @@ -48,7 +48,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { private val paymentFailuresCodec = discriminated[List[FailureSummary]].by(byte) .typecase(0x01, listOfN(uint8, failureSummaryCodec)) - using(sqlite.createStatement()) { statement => + using(sqlite.createStatement(), disableAutoCommit = true) { statement => def migration12(statement: Statement) = { // Version 2 is "backwards compatible" in the sense that it uses separate tables from version 1 (which used a single "payments" table). diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala index 8d9e828ba3..98df768201 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala @@ -32,7 +32,7 @@ import SqliteUtils.ExtendedResultSet._ val DB_NAME = "peers" val CURRENT_VERSION = 1 - using(sqlite.createStatement()) { statement => + using(sqlite.createStatement(), disableAutoCommit = true) { statement => require(getVersion(statement, DB_NAME, CURRENT_VERSION) == CURRENT_VERSION, s"incompatible version of $DB_NAME DB found") // there is only one version currently deployed statement.executeUpdate("CREATE TABLE IF NOT EXISTS peers (node_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)") } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingRelayDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingRelayDb.scala index b0621ac5e2..a1aec7a5cc 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingRelayDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingRelayDb.scala @@ -33,7 +33,7 @@ class SqlitePendingRelayDb(sqlite: Connection) extends PendingRelayDb { val DB_NAME = "pending_relay" val CURRENT_VERSION = 1 - using(sqlite.createStatement()) { statement => + using(sqlite.createStatement(), disableAutoCommit = true) { statement => require(getVersion(statement, DB_NAME, CURRENT_VERSION) == CURRENT_VERSION, s"incompatible version of $DB_NAME DB found") // there is only one version currently deployed // note: should we use a foreign key to local_channels table here? statement.executeUpdate("CREATE TABLE IF NOT EXISTS pending_relay (channel_id BLOB NOT NULL, htlc_id INTEGER NOT NULL, data BLOB NOT NULL, PRIMARY KEY(channel_id, htlc_id))") diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala index d8d6875162..2f63cf969a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala @@ -27,12 +27,20 @@ import scala.collection.immutable.Queue object SqliteUtils { /** - * Manages closing of statement + * This helper makes sure statements are closed. + * + * @param disableAutoCommit if set to true, all updates in the block will be run in a transaction. */ def using[T <: Statement, U](statement: T, disableAutoCommit: Boolean = false)(block: T => U): U = { try { if (disableAutoCommit) statement.getConnection.setAutoCommit(false) - block(statement) + val res = block(statement) + if (disableAutoCommit) statement.getConnection.commit() + res + } catch { + case t: Exception => + if (disableAutoCommit) statement.getConnection.rollback() + throw t } finally { if (disableAutoCommit) statement.getConnection.setAutoCommit(true) if (statement != null) statement.close() diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteChannelsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteChannelsDbSpec.scala index ac7d61ed37..993ebb4065 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteChannelsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteChannelsDbSpec.scala @@ -26,7 +26,6 @@ import org.scalatest.FunSuite import org.sqlite.SQLiteException import scodec.bits.ByteVector - class SqliteChannelsDbSpec extends FunSuite { test("init sqlite 2 times in a row") { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala new file mode 100644 index 0000000000..da83bbca3e --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala @@ -0,0 +1,77 @@ +/* + * Copyright 2019 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.db + +import fr.acinq.eclair.TestConstants +import fr.acinq.eclair.db.sqlite.SqliteUtils.using +import org.scalatest.FunSuite +import org.sqlite.SQLiteException + +class SqliteUtilsSpec extends FunSuite { + + test("using with auto-commit disabled") { + val conn = TestConstants.sqliteInMemory() + + using(conn.createStatement()) { statement => + statement.executeUpdate("CREATE TABLE utils_test (id INTEGER NOT NULL PRIMARY KEY, updated_at INTEGER)") + statement.executeUpdate("INSERT INTO utils_test VALUES (1, 1)") + statement.executeUpdate("INSERT INTO utils_test VALUES (2, 2)") + } + + using(conn.createStatement()) { statement => + val results = statement.executeQuery("SELECT * FROM utils_test ORDER BY id") + assert(results.next()) + assert(results.getLong("id") === 1) + assert(results.next()) + assert(results.getLong("id") === 2) + assert(!results.next()) + } + + assertThrows[SQLiteException](using(conn.createStatement(), disableAutoCommit = true) { statement => + statement.executeUpdate("INSERT INTO utils_test VALUES (3, 3)") + statement.executeUpdate("INSERT INTO utils_test VALUES (1, 3)") // should throw (primary key violation) + }) + + using(conn.createStatement()) { statement => + val results = statement.executeQuery("SELECT * FROM utils_test ORDER BY id") + assert(results.next()) + assert(results.getLong("id") === 1) + assert(results.next()) + assert(results.getLong("id") === 2) + assert(!results.next()) + } + + using(conn.createStatement(), disableAutoCommit = true) { statement => + statement.executeUpdate("INSERT INTO utils_test VALUES (3, 3)") + statement.executeUpdate("INSERT INTO utils_test VALUES (4, 4)") + } + + using(conn.createStatement()) { statement => + val results = statement.executeQuery("SELECT * FROM utils_test ORDER BY id") + assert(results.next()) + assert(results.getLong("id") === 1) + assert(results.next()) + assert(results.getLong("id") === 2) + assert(results.next()) + assert(results.getLong("id") === 3) + assert(results.next()) + assert(results.getLong("id") === 4) + assert(!results.next()) + } + } + +} From 6c02243dd672d3802a3c494f431a07b32de3d76a Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Thu, 19 Sep 2019 17:06:35 +0200 Subject: [PATCH 09/14] AuditDb stores channelId for sent and received payments. This effectively reverts a previous commit and removes the need for a DB migration. Payments are now always represented as a list of partial payments. --- .../scala/fr/acinq/eclair/db/PaymentsDb.scala | 7 +- .../eclair/db/sqlite/SqliteAuditDb.scala | 89 +++++---- .../eclair/db/sqlite/SqlitePaymentsDb.scala | 52 ++++-- .../eclair/payment/LocalPaymentHandler.scala | 2 +- .../acinq/eclair/payment/PaymentEvents.scala | 25 ++- .../eclair/payment/PaymentLifecycle.scala | 6 +- .../fr/acinq/eclair/payment/Relayer.scala | 2 +- .../acinq/eclair/db/SqliteAuditDbSpec.scala | 169 ++++-------------- .../eclair/db/SqlitePaymentsDbSpec.scala | 52 ++++-- .../eclair/integration/IntegrationSpec.scala | 2 +- .../eclair/payment/PaymentHandlerSpec.scala | 5 +- .../eclair/payment/PaymentLifecycleSpec.scala | 7 +- 12 files changed, 181 insertions(+), 237 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala index e64124a8af..e7113eb37d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala @@ -40,6 +40,9 @@ trait PaymentsDb { /** Get an outgoing payment attempt. */ def getOutgoingPayment(id: UUID): Option[OutgoingPayment] + /** List all the outgoing payment attempts that are children of the given id. */ + def listOutgoingPayments(parentId: UUID): Seq[OutgoingPayment] + /** List all the outgoing payment attempts that tried to pay the given payment hash. */ def listOutgoingPayments(paymentHash: ByteVector32): Seq[OutgoingPayment] @@ -112,7 +115,7 @@ object IncomingPaymentStatus { * At first it is in a pending state, then will become either a success or a failure. * * @param id internal payment identifier. - * @param parentId internal identifier of a parent payment, if any. + * @param parentId internal identifier of a parent payment, or [[id]] if single-part payment. * @param externalId external payment identifier: lets lightning applications reconcile payments with their own db. * @param paymentHash payment_hash. * @param amount amount of the payment, in milli-satoshis. @@ -122,7 +125,7 @@ object IncomingPaymentStatus { * @param status current status of the payment. */ case class OutgoingPayment(id: UUID, - parentId: Option[UUID], + parentId: UUID, externalId: Option[String], paymentHash: ByteVector32, amount: MilliSatoshi, diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala index f17b0c2346..c99789ab84 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala @@ -37,7 +37,7 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { import ExtendedResultSet._ val DB_NAME = "audit" - val CURRENT_VERSION = 4 + val CURRENT_VERSION = 3 using(sqlite.createStatement(), disableAutoCommit = true) { statement => @@ -50,42 +50,20 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)") } - def migration34(statement: Statement) = { - statement.executeUpdate("DROP index sent_timestamp_idx") - statement.executeUpdate("ALTER TABLE sent RENAME TO _sent_old") - statement.executeUpdate("CREATE TABLE sent (id BLOB NOT NULL, amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("INSERT INTO sent (id, amount_msat, fees_msat, payment_hash, payment_preimage, timestamp) SELECT id, amount_msat, fees_msat, payment_hash, payment_preimage, timestamp FROM _sent_old") - statement.executeUpdate("DROP table _sent_old") - statement.executeUpdate("CREATE INDEX sent_timestamp_idx ON sent(timestamp)") - - statement.executeUpdate("DROP index received_timestamp_idx") - statement.executeUpdate("ALTER TABLE received RENAME TO _received_old") - statement.executeUpdate("CREATE TABLE received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("INSERT INTO received (amount_msat, payment_hash, timestamp) SELECT amount_msat, payment_hash, timestamp FROM _received_old") - statement.executeUpdate("DROP table _received_old") - statement.executeUpdate("CREATE INDEX received_timestamp_idx ON received(timestamp)") - } - getVersion(statement, DB_NAME, CURRENT_VERSION) match { case 1 => // previous version let's migrate logger.warn(s"migrating db $DB_NAME, found version=1 current=$CURRENT_VERSION") migration12(statement) migration23(statement) - migration34(statement) setVersion(statement, DB_NAME, CURRENT_VERSION) case 2 => logger.warn(s"migrating db $DB_NAME, found version=2 current=$CURRENT_VERSION") migration23(statement) - migration34(statement) - setVersion(statement, DB_NAME, CURRENT_VERSION) - case 3 => - logger.warn(s"migrating db $DB_NAME, found version=3 current=$CURRENT_VERSION") - migration34(statement) setVersion(statement, DB_NAME, CURRENT_VERSION) case CURRENT_VERSION => statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (id BLOB NOT NULL, amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL, id BLOB NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (amount_in_msat INTEGER NOT NULL, amount_out_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)") statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event STRING NOT NULL, timestamp INTEGER NOT NULL)") @@ -126,23 +104,30 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { } override def add(e: PaymentSent): Unit = - using(sqlite.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?)")) { statement => - statement.setBytes(1, e.id.toString.getBytes) - statement.setLong(2, e.amount.toLong) - statement.setLong(3, e.feesPaid.toLong) - statement.setBytes(4, e.paymentHash.toArray) - statement.setBytes(5, e.paymentPreimage.toArray) - statement.setLong(6, e.timestamp) - - statement.executeUpdate() + using(sqlite.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement => + e.parts.foreach(p => { + statement.setLong(1, p.amount.toLong) + statement.setLong(2, p.feesPaid.toLong) + statement.setBytes(3, e.paymentHash.toArray) + statement.setBytes(4, e.paymentPreimage.toArray) + statement.setBytes(5, p.toChannelId.toArray) + statement.setLong(6, p.timestamp) + statement.setBytes(7, p.id.toString.getBytes) + statement.addBatch() + }) + statement.executeBatch() } override def add(e: PaymentReceived): Unit = - using(sqlite.prepareStatement("INSERT INTO received VALUES (?, ?, ?)")) { statement => - statement.setLong(1, e.amount.toLong) - statement.setBytes(2, e.paymentHash.toArray) - statement.setLong(3, e.timestamp) - statement.executeUpdate() + using(sqlite.prepareStatement("INSERT INTO received VALUES (?, ?, ?, ?)")) { statement => + e.parts.foreach(p => { + statement.setLong(1, p.amount.toLong) + statement.setBytes(2, e.paymentHash.toArray) + statement.setBytes(3, p.fromChannelId.toArray) + statement.setLong(4, p.timestamp) + statement.addBatch() + }) + statement.executeBatch() } override def add(e: PaymentRelayed): Unit = @@ -190,13 +175,16 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { var q: Queue[PaymentSent] = Queue() while (rs.next()) { q = q :+ PaymentSent( - id = UUID.fromString(rs.getString("id")), - amount = MilliSatoshi(rs.getLong("amount_msat")), - feesPaid = MilliSatoshi(rs.getLong("fees_msat")), - paymentHash = rs.getByteVector32("payment_hash"), - paymentPreimage = rs.getByteVector32("payment_preimage"), - route = Nil, - timestamp = rs.getLong("timestamp")) + UUID.fromString(rs.getString("id")), + rs.getByteVector32("payment_hash"), + rs.getByteVector32("payment_preimage"), + Seq(PaymentSent.PartialPayment( + UUID.fromString(rs.getString("id")), + MilliSatoshi(rs.getLong("amount_msat")), + MilliSatoshi(rs.getLong("fees_msat")), + rs.getByteVector32("to_channel_id"), + Nil, // we don't store the route + rs.getLong("timestamp")))) } q } @@ -209,9 +197,12 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { var q: Queue[PaymentReceived] = Queue() while (rs.next()) { q = q :+ PaymentReceived( - amount = MilliSatoshi(rs.getLong("amount_msat")), - paymentHash = rs.getByteVector32("payment_hash"), - timestamp = rs.getLong("timestamp")) + rs.getByteVector32("payment_hash"), + Seq(PaymentReceived.PartialPayment( + MilliSatoshi(rs.getLong("amount_msat")), + rs.getByteVector32("from_channel_id"), + rs.getLong("timestamp") + ))) } q } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala index 312459177f..32304ee9b1 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala @@ -61,10 +61,10 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { // We add many more columns to the sent_payments table. statement.executeUpdate("DROP index payment_hash_idx") statement.executeUpdate("ALTER TABLE sent_payments RENAME TO _sent_payments_old") - statement.executeUpdate("CREATE TABLE sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT, external_id TEXT, payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, target_node_id BLOB NOT NULL, created_at INTEGER NOT NULL, payment_request TEXT, completed_at INTEGER, payment_preimage BLOB, fees_msat INTEGER, payment_route BLOB, failures BLOB)") + statement.executeUpdate("CREATE TABLE sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, target_node_id BLOB NOT NULL, created_at INTEGER NOT NULL, payment_request TEXT, completed_at INTEGER, payment_preimage BLOB, fees_msat INTEGER, payment_route BLOB, failures BLOB)") // Old rows will be missing a target node id, so we use an easy-to-spot default value. val defaultTargetNodeId = PrivateKey(ByteVector32.One).publicKey - statement.executeUpdate(s"INSERT INTO sent_payments (id, payment_hash, amount_msat, target_node_id, created_at, completed_at, payment_preimage) SELECT id, payment_hash, amount_msat, X'${defaultTargetNodeId.toString}', created_at, completed_at, preimage FROM _sent_payments_old") + statement.executeUpdate(s"INSERT INTO sent_payments (id, parent_id, payment_hash, amount_msat, target_node_id, created_at, completed_at, payment_preimage) SELECT id, id, payment_hash, amount_msat, X'${defaultTargetNodeId.toString}', created_at, completed_at, preimage FROM _sent_payments_old") statement.executeUpdate("DROP table _sent_payments_old") statement.executeUpdate("ALTER TABLE received_payments RENAME TO _received_payments_old") @@ -93,7 +93,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { setVersion(statement, DB_NAME, CURRENT_VERSION) case CURRENT_VERSION => statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, payment_preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER NOT NULL, received_at INTEGER)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT, external_id TEXT, payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, target_node_id BLOB NOT NULL, created_at INTEGER NOT NULL, payment_request TEXT, completed_at INTEGER, payment_preimage BLOB, fees_msat INTEGER, payment_route BLOB, failures BLOB)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, target_node_id BLOB NOT NULL, created_at INTEGER NOT NULL, payment_request TEXT, completed_at INTEGER, payment_preimage BLOB, fees_msat INTEGER, payment_route BLOB, failures BLOB)") statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_parent_id_idx ON sent_payments(parent_id)") statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_payment_hash_idx ON sent_payments(payment_hash)") @@ -108,7 +108,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { require(sent.status == OutgoingPaymentStatus.Pending, s"outgoing payment isn't pending (${sent.status.getClass.getSimpleName})") using(sqlite.prepareStatement("INSERT INTO sent_payments (id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, payment_request) VALUES (?, ?, ?, ?, ?, ?, ?, ?)")) { statement => statement.setString(1, sent.id.toString) - statement.setString(2, sent.parentId.map(_.toString).orNull) + statement.setString(2, sent.parentId.toString) statement.setString(3, sent.externalId.orNull) statement.setBytes(4, sent.paymentHash.toArray) statement.setLong(5, sent.amount.toLong) @@ -121,12 +121,15 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { override def updateOutgoingPayment(paymentResult: PaymentSent): Unit = using(sqlite.prepareStatement("UPDATE sent_payments SET (completed_at, payment_preimage, fees_msat, payment_route) = (?, ?, ?, ?) WHERE id = ? AND completed_at IS NULL")) { statement => - statement.setLong(1, paymentResult.timestamp) - statement.setBytes(2, paymentResult.paymentPreimage.toArray) - statement.setLong(3, paymentResult.feesPaid.toLong) - statement.setBytes(4, paymentRouteCodec.encode(paymentResult.route.map(h => HopSummary(h)).toList).require.toByteArray) - statement.setString(5, paymentResult.id.toString) - if (statement.executeUpdate() == 0) throw new IllegalArgumentException(s"Tried to mark an outgoing payment as succeeded but already in final status (id=${paymentResult.id})") + paymentResult.parts.foreach(p => { + statement.setLong(1, p.timestamp) + statement.setBytes(2, paymentResult.paymentPreimage.toArray) + statement.setLong(3, p.feesPaid.toLong) + statement.setBytes(4, paymentRouteCodec.encode(p.route.map(h => HopSummary(h)).toList).require.toByteArray) + statement.setString(5, p.id.toString) + statement.addBatch() + }) + if (statement.executeBatch().contains(0)) throw new IllegalArgumentException(s"Tried to mark an outgoing payment as succeeded but already in final status (id=${paymentResult.id})") } override def updateOutgoingPayment(paymentResult: PaymentFailed): Unit = @@ -140,7 +143,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { private def parseOutgoingPayment(rs: ResultSet): OutgoingPayment = { val result = OutgoingPayment( UUID.fromString(rs.getString("id")), - rs.getStringNullable("parent_id").map(UUID.fromString), + UUID.fromString(rs.getString("parent_id")), rs.getStringNullable("external_id"), rs.getByteVector32("payment_hash"), MilliSatoshi(rs.getLong("amount_msat")), @@ -176,7 +179,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { } override def getOutgoingPayment(id: UUID): Option[OutgoingPayment] = - using(sqlite.prepareStatement("SELECT id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, payment_request, completed_at, payment_preimage, fees_msat, payment_route, failures FROM sent_payments WHERE id = ?")) { statement => + using(sqlite.prepareStatement("SELECT * FROM sent_payments WHERE id = ?")) { statement => statement.setString(1, id.toString) val rs = statement.executeQuery() if (rs.next()) { @@ -186,8 +189,19 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { } } + override def listOutgoingPayments(parentId: UUID): Seq[OutgoingPayment] = + using(sqlite.prepareStatement("SELECT * FROM sent_payments WHERE parent_id = ? ORDER BY created_at")) { statement => + statement.setString(1, parentId.toString) + val rs = statement.executeQuery() + var q: Queue[OutgoingPayment] = Queue() + while (rs.next()) { + q = q :+ parseOutgoingPayment(rs) + } + q + } + override def listOutgoingPayments(paymentHash: ByteVector32): Seq[OutgoingPayment] = - using(sqlite.prepareStatement("SELECT id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, payment_request, completed_at, payment_preimage, fees_msat, payment_route, failures FROM sent_payments WHERE payment_hash = ? ORDER BY created_at")) { statement => + using(sqlite.prepareStatement("SELECT * FROM sent_payments WHERE payment_hash = ? ORDER BY created_at")) { statement => statement.setBytes(1, paymentHash.toArray) val rs = statement.executeQuery() var q: Queue[OutgoingPayment] = Queue() @@ -198,7 +212,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { } override def listOutgoingPayments(from: Long, to: Long): Seq[OutgoingPayment] = - using(sqlite.prepareStatement("SELECT id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, payment_request, completed_at, payment_preimage, fees_msat, payment_route, failures FROM sent_payments WHERE created_at >= ? AND created_at < ? ORDER BY created_at")) { statement => + using(sqlite.prepareStatement("SELECT * FROM sent_payments WHERE created_at >= ? AND created_at < ? ORDER BY created_at")) { statement => statement.setLong(1, from) statement.setLong(2, to) val rs = statement.executeQuery() @@ -241,7 +255,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { } override def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] = - using(sqlite.prepareStatement("SELECT payment_hash, payment_preimage, payment_request, received_msat, created_at, received_at FROM received_payments WHERE payment_hash = ?")) { statement => + using(sqlite.prepareStatement("SELECT * FROM received_payments WHERE payment_hash = ?")) { statement => statement.setBytes(1, paymentHash.toArray) val rs = statement.executeQuery() if (rs.next()) { @@ -252,7 +266,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { } override def listIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = - using(sqlite.prepareStatement("SELECT payment_hash, payment_preimage, payment_request, received_msat, created_at, received_at FROM received_payments WHERE created_at > ? AND created_at < ? ORDER BY created_at")) { statement => + using(sqlite.prepareStatement("SELECT * FROM received_payments WHERE created_at > ? AND created_at < ? ORDER BY created_at")) { statement => statement.setLong(1, from) statement.setLong(2, to) val rs = statement.executeQuery() @@ -264,7 +278,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { } override def listReceivedIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = - using(sqlite.prepareStatement("SELECT payment_hash, payment_preimage, payment_request, received_msat, created_at, received_at FROM received_payments WHERE received_msat > 0 AND created_at > ? AND created_at < ? ORDER BY created_at")) { statement => + using(sqlite.prepareStatement("SELECT * FROM received_payments WHERE received_msat > 0 AND created_at > ? AND created_at < ? ORDER BY created_at")) { statement => statement.setLong(1, from) statement.setLong(2, to) val rs = statement.executeQuery() @@ -276,7 +290,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { } override def listPendingIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = - using(sqlite.prepareStatement("SELECT payment_hash, payment_preimage, payment_request, received_msat, created_at, received_at FROM received_payments WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at > ? ORDER BY created_at")) { statement => + using(sqlite.prepareStatement("SELECT * FROM received_payments WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at > ? ORDER BY created_at")) { statement => statement.setLong(1, from) statement.setLong(2, to) statement.setLong(3, Platform.currentTime) @@ -289,7 +303,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { } override def listExpiredIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = - using(sqlite.prepareStatement("SELECT payment_hash, payment_preimage, payment_request, received_msat, created_at, received_at FROM received_payments WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at < ? ORDER BY created_at")) { statement => + using(sqlite.prepareStatement("SELECT * FROM received_payments WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at < ? ORDER BY created_at")) { statement => statement.setLong(1, from) statement.setLong(2, to) statement.setLong(3, Platform.currentTime) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/LocalPaymentHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/LocalPaymentHandler.scala index f21705ca7f..d5d399f932 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/LocalPaymentHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/LocalPaymentHandler.scala @@ -81,7 +81,7 @@ class LocalPaymentHandler(nodeParams: NodeParams) extends Actor with ActorLoggin // amount is correct or was not specified in the payment request nodeParams.db.payments.receiveIncomingPayment(htlc.paymentHash, htlc.amountMsat) sender ! CMD_FULFILL_HTLC(htlc.id, paymentPreimage, commit = true) - context.system.eventStream.publish(PaymentReceived(htlc.amountMsat, htlc.paymentHash)) + context.system.eventStream.publish(PaymentReceived(htlc.paymentHash, PaymentReceived.PartialPayment(htlc.amountMsat, htlc.channelId) :: Nil)) } case None => sender ! CMD_FAIL_HTLC(htlc.id, Right(IncorrectOrUnknownPaymentDetails(htlc.amountMsat, nodeParams.currentBlockHeight)), commit = true) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala index 2bad5b064f..340d38475b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala @@ -34,13 +34,34 @@ sealed trait PaymentEvent { val timestamp: Long } -case class PaymentSent(id: UUID, amount: MilliSatoshi, feesPaid: MilliSatoshi, paymentHash: ByteVector32, paymentPreimage: ByteVector32, route: Seq[Hop], timestamp: Long = Platform.currentTime) extends PaymentEvent +case class PaymentSent(id: UUID, paymentHash: ByteVector32, paymentPreimage: ByteVector32, parts: Seq[PaymentSent.PartialPayment]) extends PaymentEvent { + require(parts.nonEmpty, "sent payment is empty") + val amount: MilliSatoshi = parts.map(_.amount).sum + val feesPaid: MilliSatoshi = parts.map(_.feesPaid).sum + val timestamp: Long = parts.map(_.timestamp).min +} + +object PaymentSent { + + case class PartialPayment(id: UUID, amount: MilliSatoshi, feesPaid: MilliSatoshi, toChannelId: ByteVector32, route: Seq[Hop], timestamp: Long = Platform.currentTime) + +} case class PaymentFailed(id: UUID, paymentHash: ByteVector32, failures: Seq[PaymentFailure], timestamp: Long = Platform.currentTime) extends PaymentEvent case class PaymentRelayed(amountIn: MilliSatoshi, amountOut: MilliSatoshi, paymentHash: ByteVector32, fromChannelId: ByteVector32, toChannelId: ByteVector32, timestamp: Long = Platform.currentTime) extends PaymentEvent -case class PaymentReceived(amount: MilliSatoshi, paymentHash: ByteVector32, timestamp: Long = Platform.currentTime) extends PaymentEvent +case class PaymentReceived(paymentHash: ByteVector32, parts: Seq[PaymentReceived.PartialPayment]) extends PaymentEvent { + require(parts.nonEmpty, "received payment is empty") + val amount: MilliSatoshi = parts.map(_.amount).sum + val timestamp: Long = parts.map(_.timestamp).max +} + +object PaymentReceived { + + case class PartialPayment(amount: MilliSatoshi, fromChannelId: ByteVector32, timestamp: Long = Platform.currentTime) + +} case class PaymentSettlingOnChain(id: UUID, amount: MilliSatoshi, paymentHash: ByteVector32, timestamp: Long = Platform.currentTime) extends PaymentEvent diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala index c8c52529bd..51f5cbeb21 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala @@ -28,6 +28,7 @@ import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus, PaymentsDb} import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest import fr.acinq.eclair.payment.PaymentLifecycle._ import fr.acinq.eclair.payment.PaymentRequest.ExtraHop +import fr.acinq.eclair.payment.PaymentSent.PartialPayment import fr.acinq.eclair.router._ import fr.acinq.eclair.wire.Onion._ import fr.acinq.eclair.wire._ @@ -77,7 +78,8 @@ class PaymentLifecycle(nodeParams: NodeParams, progressHandler: PaymentProgressH case Event("ok", _) => stay case Event(fulfill: UpdateFulfillHtlc, WaitingForComplete(s, c, cmd, _, _, _, _, route)) => - progressHandler.onSuccess(s, PaymentSent(id, c.finalPayload.amount, cmd.amount - c.finalPayload.amount, c.paymentHash, fulfill.paymentPreimage, route))(context) + val p = PartialPayment(id, c.finalPayload.amount, cmd.amount - c.finalPayload.amount, fulfill.channelId, route) + progressHandler.onSuccess(s, PaymentSent(id, c.paymentHash, fulfill.paymentPreimage, p :: Nil))(context) stop(FSM.Normal) case Event(fail: UpdateFailHtlc, WaitingForComplete(s, c, _, failures, sharedSecrets, ignoreNodes, ignoreChannels, hops)) => @@ -200,7 +202,7 @@ object PaymentLifecycle { case class DefaultPaymentProgressHandler(id: UUID, r: SendPaymentRequest, db: PaymentsDb) extends PaymentProgressHandler { override def onSend(): Unit = { - db.addOutgoingPayment(OutgoingPayment(id, None, r.externalId, r.paymentHash, r.amount, r.targetNodeId, Platform.currentTime, r.paymentRequest, OutgoingPaymentStatus.Pending)) + db.addOutgoingPayment(OutgoingPayment(id, id, r.externalId, r.paymentHash, r.amount, r.targetNodeId, Platform.currentTime, r.paymentRequest, OutgoingPaymentStatus.Pending)) } override def onSuccess(sender: ActorRef, result: PaymentSent)(ctx: ActorContext): Unit = { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala index 192d26bcbd..0b63d77285 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala @@ -161,7 +161,7 @@ class Relayer(nodeParams: NodeParams, register: ActorRef, paymentHandler: ActorR // we sent the payment, but we probably restarted and the reference to the original sender was lost, // we publish the success on the event stream and update the status in paymentDb val feesPaid = 0.msat // fees are unknown since we lost the reference to the payment - val result = PaymentSent(id, add.amountMsat, feesPaid, add.paymentHash, fulfill.paymentPreimage, Nil) + val result = PaymentSent(id, add.paymentHash, fulfill.paymentPreimage, Seq(PaymentSent.PartialPayment(id, add.amountMsat, feesPaid, add.channelId, Nil))) nodeParams.db.payments.updateOutgoingPayment(result) context.system.eventStream.publish(result) case Local(_, Some(sender)) => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala index ee77c29bff..2dc8c16625 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala @@ -25,8 +25,7 @@ import fr.acinq.eclair.channel.{AvailableBalanceChanged, ChannelErrorOccurred, N import fr.acinq.eclair.db.sqlite.SqliteAuditDb import fr.acinq.eclair.db.sqlite.SqliteUtils.{getVersion, using} import fr.acinq.eclair.payment._ -import fr.acinq.eclair.router.Hop -import fr.acinq.eclair.wire.{ChannelCodecs, ChannelCodecsSpec, ChannelUpdate} +import fr.acinq.eclair.wire.{ChannelCodecs, ChannelCodecsSpec} import org.scalatest.FunSuite import scala.compat.Platform @@ -35,8 +34,6 @@ import scala.concurrent.duration._ class SqliteAuditDbSpec extends FunSuite { - import SqliteAuditDbSpec._ - test("init sqlite 2 times in a row") { val sqlite = TestConstants.sqliteInMemory() val db1 = new SqliteAuditDb(sqlite) @@ -47,12 +44,17 @@ class SqliteAuditDbSpec extends FunSuite { val sqlite = TestConstants.sqliteInMemory() val db = new SqliteAuditDb(sqlite) - val e1 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, 42000 msat, 1000 msat, randomBytes32, randomBytes32, Seq(Hop(carol, dave, channelUpdate2))) - val e2 = PaymentReceived(42000 msat, randomBytes32) + val e1 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, randomBytes32, randomBytes32, PaymentSent.PartialPayment(ChannelCodecs.UNKNOWN_UUID, 42000 msat, 1000 msat, randomBytes32, Nil) :: Nil) + val pp2a = PaymentReceived.PartialPayment(42000 msat, randomBytes32) + val pp2b = PaymentReceived.PartialPayment(42100 msat, randomBytes32) + val e2 = PaymentReceived(randomBytes32, pp2a :: pp2b :: Nil) val e3 = PaymentRelayed(42000 msat, 1000 msat, randomBytes32, randomBytes32, randomBytes32) val e4 = NetworkFeePaid(null, randomKey.publicKey, randomBytes32, Transaction(0, Seq.empty, Seq.empty, 0), 42 sat, "mutual") - val e5 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, 42000 msat, 1000 msat, randomBytes32, randomBytes32, Seq(Hop(alice, bob, channelUpdate1), Hop(bob, carol, channelUpdate2)), timestamp = 0) - val e6 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, 42000 msat, 1000 msat, randomBytes32, randomBytes32, Nil, timestamp = (Platform.currentTime.milliseconds + 10.minutes).toMillis) + val pp5a = PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, Nil, timestamp = 0) + val pp5b = PaymentSent.PartialPayment(UUID.randomUUID(), 42100 msat, 900 msat, randomBytes32, Nil, timestamp = 1) + val e5 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, randomBytes32, randomBytes32, pp5a :: pp5b :: Nil) + val pp6 = PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, Nil, timestamp = (Platform.currentTime.milliseconds + 10.minutes).toMillis) + val e6 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, randomBytes32, randomBytes32, pp6 :: Nil) val e7 = AvailableBalanceChanged(null, randomBytes32, ShortChannelId(500000, 42, 1), 456123000 msat, ChannelCodecsSpec.commitments) val e8 = ChannelLifecycleEvent(randomBytes32, randomKey.publicKey, 456123000 sat, isFunder = true, isPrivate = false, "mutual") val e9 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true) @@ -69,9 +71,9 @@ class SqliteAuditDbSpec extends FunSuite { db.add(e9) db.add(e10) - assert(db.listSent(from = 0L, to = (Platform.currentTime.milliseconds + 15.minute).toMillis).toSet === Set(e1.copy(route = Nil), e5.copy(route = Nil), e6)) - assert(db.listSent(from = 100000L, to = (Platform.currentTime.milliseconds + 1.minute).toMillis).toList === List(e1.copy(route = Nil))) - assert(db.listReceived(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toMillis).toList === List(e2)) + assert(db.listSent(from = 0L, to = (Platform.currentTime.milliseconds + 15.minute).toMillis).toSet === Set(e1, e5.copy(id = pp5a.id, parts = pp5a :: Nil), e5.copy(id = pp5b.id, parts = pp5b :: Nil), e6.copy(id = pp6.id))) + assert(db.listSent(from = 100000L, to = (Platform.currentTime.milliseconds + 1.minute).toMillis).toList === List(e1)) + assert(db.listReceived(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toMillis).toList === List(e2.copy(parts = pp2a :: Nil), e2.copy(parts = pp2b :: Nil))) assert(db.listRelayed(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toMillis).toList === List(e3)) assert(db.listNetworkFees(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toMillis).size === 1) assert(db.listNetworkFees(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toMillis).head.txType === "mutual") @@ -106,7 +108,7 @@ class SqliteAuditDbSpec extends FunSuite { )) } - test("handle migration version 1 -> 4") { + test("handle migration version 1 -> 3") { val connection = TestConstants.sqliteInMemory() @@ -129,69 +131,54 @@ class SqliteAuditDbSpec extends FunSuite { } using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit", 4) == 1) // we expect version 1 + assert(getVersion(statement, "audit", 3) == 1) // we expect version 1 } - val ps = PaymentSent(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, randomBytes32, Seq(Hop(alice, bob, channelUpdate1))) - val ps1 = PaymentSent(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, randomBytes32, Seq(Hop(alice, bob, channelUpdate1), Hop(bob, carol, channelUpdate2))) - val ps2 = PaymentSent(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, randomBytes32, Nil) - val pr = PaymentReceived(561 msat, randomBytes32) - val pr1 = PaymentReceived(1105 msat, randomBytes32) + val ps = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, Nil) :: Nil) + val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, Nil) + val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, Nil) + val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, pp1 :: pp2 :: Nil) val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true) val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true) - // Changes to the 'sent' table between versions 1 and 4: - // - the 'id' column was added - // - the 'toChannelId' column was removed + // add a row (no ID on sent) using(connection.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?)")) { statement => statement.setLong(1, ps.amount.toLong) statement.setLong(2, ps.feesPaid.toLong) statement.setBytes(3, ps.paymentHash.toArray) statement.setBytes(4, ps.paymentPreimage.toArray) - statement.setBytes(5, randomBytes32.toArray) // toChannelId + statement.setBytes(5, ps.parts.head.toChannelId.toArray) statement.setLong(6, ps.timestamp) statement.executeUpdate() } - // Changes to the 'received' table between versions 1 and 4: - // - the 'fromChannelId' column was removed - using(connection.prepareStatement("INSERT INTO received VALUES (?, ?, ?, ?)")) { statement => - statement.setLong(1, pr.amount.toLong) - statement.setBytes(2, pr.paymentHash.toArray) - statement.setBytes(3, randomBytes32.toArray) // fromChannelId - statement.setLong(4, pr.timestamp) - statement.executeUpdate() - } - val migratedDb = new SqliteAuditDb(connection) using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit", 4) == 4) // version changed from 1 -> 4 + assert(getVersion(statement, "audit", 3) == 3) // version changed from 1 -> 3 } // existing rows in the 'sent' table will use id=00000000-0000-0000-0000-000000000000 as default - assert(migratedDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ChannelCodecs.UNKNOWN_UUID, route = Nil))) - // existing rows in the 'received' table will not contain a fromChannelId anymore - assert(migratedDb.listReceived(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq(pr)) + assert(migratedDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ChannelCodecs.UNKNOWN_UUID, parts = Seq(ps.parts.head.copy(id = ChannelCodecs.UNKNOWN_UUID))))) val postMigrationDb = new SqliteAuditDb(connection) using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit", 4) == 4) // version 4 + assert(getVersion(statement, "audit", 3) == 3) // version 3 } postMigrationDb.add(ps1) - postMigrationDb.add(ps2) postMigrationDb.add(e1) postMigrationDb.add(e2) - postMigrationDb.add(pr1) - // the old 'sent' record will have the UNKNOWN_UUID and an empty route but the new ones will have their actual id - assert(postMigrationDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ChannelCodecs.UNKNOWN_UUID, route = Nil), ps1.copy(route = Nil), ps2.copy(route = Nil))) - assert(postMigrationDb.listReceived(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq(pr, pr1)) + // the old record will have the UNKNOWN_UUID but the new ones will have their actual id + assert(postMigrationDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq( + ps.copy(id = ChannelCodecs.UNKNOWN_UUID, parts = Seq(ps.parts.head.copy(id = ChannelCodecs.UNKNOWN_UUID))), + ps1.copy(id = pp1.id, parts = pp1 :: Nil), + ps1.copy(id = pp2.id, parts = pp2 :: Nil))) } - test("handle migration version 2 -> 4") { + test("handle migration version 2 -> 3") { val connection = TestConstants.sqliteInMemory() @@ -214,7 +201,7 @@ class SqliteAuditDbSpec extends FunSuite { } using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit", 4) == 2) // version 2 is deployed now + assert(getVersion(statement, "audit", 3) == 2) // version 2 is deployed now } val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true) @@ -223,7 +210,7 @@ class SqliteAuditDbSpec extends FunSuite { val migratedDb = new SqliteAuditDb(connection) using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit", 4) == 4) // version changed from 2 -> 4 + assert(getVersion(statement, "audit", 3) == 3) // version changed from 2 -> 3 } migratedDb.add(e1) @@ -231,100 +218,10 @@ class SqliteAuditDbSpec extends FunSuite { val postMigrationDb = new SqliteAuditDb(connection) using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit", 4) == 4) // version 4 + assert(getVersion(statement, "audit", 3) == 3) // version 3 } postMigrationDb.add(e2) } - test("handle migration version 3 -> 4") { - val connection = TestConstants.sqliteInMemory() - - // simulate existing previous version db - using(connection.createStatement()) { statement => - getVersion(statement, "audit", 3) - statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL, id BLOB NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (amount_in_msat INTEGER NOT NULL, amount_out_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event STRING NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, error_name STRING NOT NULL, error_message STRING NOT NULL, is_fatal INTEGER NOT NULL, timestamp INTEGER NOT NULL)") - - statement.executeUpdate("CREATE INDEX IF NOT EXISTS balance_updated_idx ON balance_updated(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)") - } - - using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit", 4) == 3) // version 3 is deployed now - } - - val ps = PaymentSent(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, randomBytes32, Seq(Hop(alice, bob, channelUpdate1))) - val ps1 = PaymentSent(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, randomBytes32, Seq(Hop(alice, bob, channelUpdate1), Hop(bob, carol, channelUpdate2))) - val ps2 = PaymentSent(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, randomBytes32, Nil) - val pr = PaymentReceived(561 msat, randomBytes32) - val pr1 = PaymentReceived(1105 msat, randomBytes32) - - // Changes to the 'sent' table between versions 3 and 4: - // - the 'toChannelId' column was removed - using(connection.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement => - statement.setLong(1, ps.amount.toLong) - statement.setLong(2, ps.feesPaid.toLong) - statement.setBytes(3, ps.paymentHash.toArray) - statement.setBytes(4, ps.paymentPreimage.toArray) - statement.setBytes(5, randomBytes32.toArray) // toChannelId - statement.setLong(6, ps.timestamp) - statement.setBytes(7, ps.id.toString.getBytes) - statement.executeUpdate() - } - - // Changes to the 'received' table between versions 3 and 4: - // - the 'fromChannelId' column was removed - using(connection.prepareStatement("INSERT INTO received VALUES (?, ?, ?, ?)")) { statement => - statement.setLong(1, pr.amount.toLong) - statement.setBytes(2, pr.paymentHash.toArray) - statement.setBytes(3, randomBytes32.toArray) // fromChannelId - statement.setLong(4, pr.timestamp) - statement.executeUpdate() - } - - val migratedDb = new SqliteAuditDb(connection) - - using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit", 4) == 4) // version changed from 3 -> 4 - } - - // existing rows in the 'sent' table will use route=NULL as default - assert(migratedDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq(ps.copy(route = Nil))) - // existing rows in the 'received' table will not contain a fromChannelId anymore - assert(migratedDb.listReceived(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq(pr)) - - val postMigrationDb = new SqliteAuditDb(connection) - - using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit", 4) == 4) // version 4 - } - - postMigrationDb.add(ps1) - postMigrationDb.add(ps2) - postMigrationDb.add(pr1) - - // the old 'sent' record will have the UNKNOWN_UUID and an empty route but the new ones will have their actual id - assert(postMigrationDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq(ps.copy(route = Nil), ps1.copy(route = Nil), ps2.copy(route = Nil))) - assert(postMigrationDb.listReceived(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq(pr, pr1)) - } - } - -object SqliteAuditDbSpec { - - val (alice, bob, carol, dave) = (randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey) - val channelUpdate1 = ChannelUpdate(randomBytes64, randomBytes32, ShortChannelId(561), 0, 0, 0, CltvExpiryDelta(144), 100 msat, 10 msat, 1000, None) - val channelUpdate2 = ChannelUpdate(randomBytes64, randomBytes32, ShortChannelId(1105), 0, 0, 0, CltvExpiryDelta(9), 1000 msat, 15 msat, 100, None) - -} \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala index 37d92e3724..2159fb6036 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala @@ -74,7 +74,7 @@ class SqlitePaymentsDbSpec extends FunSuite { assert(preMigrationDb.getIncomingPayment(paymentHash1).isEmpty) // add a few rows - val ps1 = OutgoingPayment(UUID.randomUUID(), None, None, paymentHash1, 12345 msat, alice, 1000, None, OutgoingPaymentStatus.Pending) + val ps1 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), None, paymentHash1, 12345 msat, alice, 1000, None, OutgoingPaymentStatus.Pending) val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(500 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 1) val pr1 = IncomingPayment(i1, preimage1, i1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(550 msat, 1100)) @@ -110,9 +110,12 @@ class SqlitePaymentsDbSpec extends FunSuite { } // Insert a bunch of old version 2 rows. - val ps1 = OutgoingPayment(UUID.randomUUID(), None, None, randomBytes32, 561 msat, PrivateKey(ByteVector32.One).publicKey, 1000, None, OutgoingPaymentStatus.Pending) - val ps2 = OutgoingPayment(UUID.randomUUID(), None, None, randomBytes32, 1105 msat, PrivateKey(ByteVector32.One).publicKey, 1010, None, OutgoingPaymentStatus.Failed(Nil, 1050)) - val ps3 = OutgoingPayment(UUID.randomUUID(), None, None, paymentHash1, 1729 msat, PrivateKey(ByteVector32.One).publicKey, 1040, None, OutgoingPaymentStatus.Succeeded(preimage1, 0 msat, Nil, 1060)) + val id1 = UUID.randomUUID() + val id2 = UUID.randomUUID() + val id3 = UUID.randomUUID() + val ps1 = OutgoingPayment(id1, id1, None, randomBytes32, 561 msat, PrivateKey(ByteVector32.One).publicKey, 1000, None, OutgoingPaymentStatus.Pending) + val ps2 = OutgoingPayment(id2, id2, None, randomBytes32, 1105 msat, PrivateKey(ByteVector32.One).publicKey, 1010, None, OutgoingPaymentStatus.Failed(Nil, 1050)) + val ps3 = OutgoingPayment(id3, id3, None, paymentHash1, 1729 msat, PrivateKey(ByteVector32.One).publicKey, 1040, None, OutgoingPaymentStatus.Succeeded(preimage1, 0 msat, Nil, 1060)) val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 1) val pr1 = IncomingPayment(i1, preimage1, i1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(12345678 msat, 1090)) val i2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash2, carolPriv, "Another invoice", expirySeconds = Some(30), timestamp = 1) @@ -199,12 +202,12 @@ class SqlitePaymentsDbSpec extends FunSuite { val pr3 = IncomingPayment(i3, preimage3, i3.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending) postMigrationDb.addIncomingPayment(i3, pr3.paymentPreimage) - val ps4 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), Some("1"), randomBytes32, 123 msat, alice, 1100, Some(i3), OutgoingPaymentStatus.Pending) - val ps5 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), Some("2"), randomBytes32, 456 msat, bob, 1150, Some(i2), OutgoingPaymentStatus.Succeeded(preimage1, 42 msat, Nil, 1180)) - val ps6 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), Some("3"), randomBytes32, 789 msat, bob, 1250, None, OutgoingPaymentStatus.Failed(Nil, 1300)) + val ps4 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("1"), randomBytes32, 123 msat, alice, 1100, Some(i3), OutgoingPaymentStatus.Pending) + val ps5 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("2"), randomBytes32, 456 msat, bob, 1150, Some(i2), OutgoingPaymentStatus.Succeeded(preimage1, 42 msat, Nil, 1180)) + val ps6 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("3"), randomBytes32, 789 msat, bob, 1250, None, OutgoingPaymentStatus.Failed(Nil, 1300)) postMigrationDb.addOutgoingPayment(ps4) postMigrationDb.addOutgoingPayment(ps5.copy(status = OutgoingPaymentStatus.Pending)) - postMigrationDb.updateOutgoingPayment(PaymentSent(ps5.id, ps5.amount, 42 msat, ps5.paymentHash, preimage1, Nil, 1180)) + postMigrationDb.updateOutgoingPayment(PaymentSent(ps5.parentId, ps5.paymentHash, preimage1, Seq(PaymentSent.PartialPayment(ps5.id, ps5.amount, 42 msat, randomBytes32, Nil, 1180)))) postMigrationDb.addOutgoingPayment(ps6.copy(status = OutgoingPaymentStatus.Pending)) postMigrationDb.updateOutgoingPayment(PaymentFailed(ps6.id, ps6.paymentHash, Nil, 1300)) @@ -267,9 +270,10 @@ class SqlitePaymentsDbSpec extends FunSuite { test("add/retrieve/update outgoing payments") { val db = new SqlitePaymentsDb(TestConstants.sqliteInMemory()) + val parentId = UUID.randomUUID() val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(123 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 0) - val s1 = OutgoingPayment(UUID.randomUUID(), Some(UUID.randomUUID()), None, i1.paymentHash, 123 msat, alice, 100, Some(i1), OutgoingPaymentStatus.Pending) - val s2 = OutgoingPayment(UUID.randomUUID(), None, Some("1"), paymentHash2, 456 msat, bob, 200, None, OutgoingPaymentStatus.Pending) + val s1 = OutgoingPayment(UUID.randomUUID(), parentId, None, paymentHash1, 123 msat, alice, 100, Some(i1), OutgoingPaymentStatus.Pending) + val s2 = OutgoingPayment(UUID.randomUUID(), parentId, Some("1"), paymentHash1, 456 msat, bob, 200, None, OutgoingPaymentStatus.Pending) assert(db.listOutgoingPayments(0, Platform.currentTime).isEmpty) db.addOutgoingPayment(s1) @@ -283,7 +287,9 @@ class SqlitePaymentsDbSpec extends FunSuite { assert(db.listOutgoingPayments(150, 250).toList == Seq(s2)) assert(db.getOutgoingPayment(s1.id) === Some(s1)) assert(db.getOutgoingPayment(UUID.randomUUID()) === None) - assert(db.listOutgoingPayments(s2.paymentHash) === Seq(s2)) + assert(db.listOutgoingPayments(s2.paymentHash) === Seq(s1, s2)) + assert(db.listOutgoingPayments(s1.id) === Nil) + assert(db.listOutgoingPayments(parentId) === Seq(s1, s2)) assert(db.listOutgoingPayments(ByteVector32.Zeroes) === Nil) val s3 = s2.copy(id = UUID.randomUUID(), amount = 789 msat, createdAt = 300) @@ -292,17 +298,25 @@ class SqlitePaymentsDbSpec extends FunSuite { db.addOutgoingPayment(s4) db.updateOutgoingPayment(PaymentFailed(s3.id, s3.paymentHash, Nil, 310)) - assert(db.getOutgoingPayment(s3.id) === Some(s3.copy(status = OutgoingPaymentStatus.Failed(Nil, 310)))) + val ss3 = s3.copy(status = OutgoingPaymentStatus.Failed(Nil, 310)) + assert(db.getOutgoingPayment(s3.id) === Some(ss3)) db.updateOutgoingPayment(PaymentFailed(s4.id, s4.paymentHash, Seq(LocalFailure(new RuntimeException("woops")), RemoteFailure(Seq(hop_ab, hop_bc), Sphinx.DecryptedFailurePacket(carol, UnknownNextPeer))), 320)) - assert(db.getOutgoingPayment(s4.id) === Some(s4.copy(status = OutgoingPaymentStatus.Failed(Seq(FailureSummary(FailureType.LOCAL, "woops", Nil), FailureSummary(FailureType.REMOTE, "processing node does not know the next peer in the route", List(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, Some(ShortChannelId(43)))))), 320)))) + val ss4 = s4.copy(status = OutgoingPaymentStatus.Failed(Seq(FailureSummary(FailureType.LOCAL, "woops", Nil), FailureSummary(FailureType.REMOTE, "processing node does not know the next peer in the route", List(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, Some(ShortChannelId(43)))))), 320)) + assert(db.getOutgoingPayment(s4.id) === Some(ss4)) // can't update again once it's in a final state - assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentSent(s3.id, s3.amount, 42 msat, s3.paymentHash, preimage1, Nil))) - - db.updateOutgoingPayment(PaymentSent(s1.id, s1.amount, 15 msat, s1.paymentHash, preimage1, Nil, 400)) - assert(db.getOutgoingPayment(s1.id) === Some(s1.copy(status = OutgoingPaymentStatus.Succeeded(preimage1, 15 msat, Nil, 400)))) - db.updateOutgoingPayment(PaymentSent(s2.id, s2.amount, 15 msat, s2.paymentHash, preimage2, Seq(hop_ab, hop_bc), 410)) - assert(db.getOutgoingPayment(s2.id) === Some(s2.copy(status = OutgoingPaymentStatus.Succeeded(preimage2, 15 msat, Seq(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, Some(ShortChannelId(43)))), 410)))) + assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentSent(parentId, s3.paymentHash, preimage1, Seq(PaymentSent.PartialPayment(s3.id, s3.amount, 42 msat, randomBytes32, Nil))))) + + val paymentSent = PaymentSent(parentId, paymentHash1, preimage1, Seq( + PaymentSent.PartialPayment(s1.id, s1.amount, 15 msat, randomBytes32, Nil, 400), + PaymentSent.PartialPayment(s2.id, s2.amount, 20 msat, randomBytes32, Seq(hop_ab, hop_bc), 410) + )) + val ss1 = s1.copy(status = OutgoingPaymentStatus.Succeeded(preimage1, 15 msat, Nil, 400)) + val ss2 = s2.copy(status = OutgoingPaymentStatus.Succeeded(preimage1, 20 msat, Seq(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, Some(ShortChannelId(43)))), 410)) + db.updateOutgoingPayment(paymentSent) + assert(db.getOutgoingPayment(s1.id) === Some(ss1)) + assert(db.getOutgoingPayment(s2.id) === Some(ss2)) + assert(db.listOutgoingPayments(parentId) === Seq(ss1, ss2, ss3, ss4)) // can't update again once it's in a final state assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentFailed(s1.id, s1.paymentHash, Nil))) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala index a703f2bf43..8b7ac29b98 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala @@ -432,7 +432,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService awaitCond({ sender.expectMsgType[PaymentEvent](10 seconds) match { case PaymentFailed(_, _, failures, _) => failures == Seq.empty // if something went wrong fail with a hint - case PaymentSent(_, _, _, _, _, route, _) => route.exists(_.nodeId == nodes("G").nodeParams.nodeId) + case PaymentSent(_, _, _, part :: Nil) => part.route.exists(_.nodeId == nodes("G").nodeParams.nodeId) case _ => false } }, max = 30 seconds, interval = 10 seconds) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentHandlerSpec.scala index 34c5f8c557..a66733a12b 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentHandlerSpec.scala @@ -24,6 +24,7 @@ import fr.acinq.eclair.TestConstants.Alice import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC} import fr.acinq.eclair.db.IncomingPaymentStatus import fr.acinq.eclair.payment.PaymentLifecycle.ReceivePayment +import fr.acinq.eclair.payment.PaymentReceived.PartialPayment import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.wire.{IncorrectOrUnknownPaymentDetails, UpdateAddHtlc} import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, ShortChannelId, TestConstants, randomKey} @@ -62,7 +63,7 @@ class PaymentHandlerSpec extends TestKit(ActorSystem("test")) with FunSuiteLike sender.expectMsgType[CMD_FULFILL_HTLC] val paymentRelayed = eventListener.expectMsgType[PaymentReceived] - assert(paymentRelayed.copy(timestamp = 0) === PaymentReceived(amountMsat, add.paymentHash, timestamp = 0)) + assert(paymentRelayed.copy(parts = paymentRelayed.parts.map(_.copy(timestamp = 0))) === PaymentReceived(add.paymentHash, PartialPayment(amountMsat, add.channelId, timestamp = 0) :: Nil)) val received = nodeParams.db.payments.getIncomingPayment(pr.paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0) === IncomingPaymentStatus.Received(amountMsat, 0)) @@ -77,7 +78,7 @@ class PaymentHandlerSpec extends TestKit(ActorSystem("test")) with FunSuiteLike sender.send(handler, add) sender.expectMsgType[CMD_FULFILL_HTLC] val paymentRelayed = eventListener.expectMsgType[PaymentReceived] - assert(paymentRelayed.copy(timestamp = 0) === PaymentReceived(amountMsat, add.paymentHash, timestamp = 0)) + assert(paymentRelayed.copy(parts = paymentRelayed.parts.map(_.copy(timestamp = 0))) === PaymentReceived(add.paymentHash, PartialPayment(amountMsat, add.channelId, timestamp = 0) :: Nil)) val received = nodeParams.db.payments.getIncomingPayment(pr.paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0) === IncomingPaymentStatus.Received(amountMsat, 0)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index 283745977c..66ee7fecc4 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -32,6 +32,7 @@ import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus} import fr.acinq.eclair.io.Peer.PeerRoutingMessage import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest import fr.acinq.eclair.payment.PaymentLifecycle._ +import fr.acinq.eclair.payment.PaymentSent.PartialPayment import fr.acinq.eclair.router.Announcements.{makeChannelUpdate, makeNodeAnnouncement} import fr.acinq.eclair.router._ import fr.acinq.eclair.transactions.Scripts @@ -74,7 +75,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val Transition(_, WAITING_FOR_ROUTE, WAITING_FOR_PAYMENT_COMPLETE) = monitor.expectMsgClass(classOf[Transition[_]]) awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) val Some(outgoing) = paymentDb.getOutgoingPayment(id) - assert(outgoing.copy(createdAt = 0) === OutgoingPayment(id, None, Some(defaultExternalId), defaultPaymentHash, defaultAmountMsat, d, 0, None, OutgoingPaymentStatus.Pending)) + assert(outgoing.copy(createdAt = 0) === OutgoingPayment(id, id, Some(defaultExternalId), defaultPaymentHash, defaultAmountMsat, d, 0, None, OutgoingPaymentStatus.Pending)) sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) sender.expectMsgType[PaymentSent] @@ -411,7 +412,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val Transition(_, WAITING_FOR_ROUTE, WAITING_FOR_PAYMENT_COMPLETE) = monitor.expectMsgClass(classOf[Transition[_]]) awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) val Some(outgoing) = paymentDb.getOutgoingPayment(id) - assert(outgoing.copy(createdAt = 0) === OutgoingPayment(id, None, Some(defaultExternalId), paymentHash, defaultAmountMsat, d, 0, None, OutgoingPaymentStatus.Pending)) + assert(outgoing.copy(createdAt = 0) === OutgoingPayment(id, id, Some(defaultExternalId), paymentHash, defaultAmountMsat, d, 0, None, OutgoingPaymentStatus.Pending)) sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, paymentPreimage)) val ps = eventListener.expectMsgType[PaymentSent] @@ -468,7 +469,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) val paymentOK = sender.expectMsgType[PaymentSent] - val PaymentSent(_, request.finalPayload.amount, fee, request.paymentHash, paymentOK.paymentPreimage, _, _) = eventListener.expectMsgType[PaymentSent] + val PaymentSent(_, request.paymentHash, paymentOK.paymentPreimage, PartialPayment(_, request.finalPayload.amount, fee, ByteVector32.Zeroes, _, _) :: Nil) = eventListener.expectMsgType[PaymentSent] // during the route computation the fees were treated as if they were 1msat but when sending the onion we actually put zero // NB: A -> B doesn't pay fees because it's our direct neighbor From a9f34722558ad46290463b2f97f1bfa092d38f5c Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Thu, 19 Sep 2019 19:00:30 +0200 Subject: [PATCH 10/14] fixup! AuditDb stores channelId for sent and received payments. This effectively reverts a previous commit and removes the need for a DB migration. Payments are now always represented as a list of partial payments. --- .../test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala index 01077eeaea..d9be8cba66 100644 --- a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala +++ b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala @@ -362,8 +362,8 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock system.eventStream.publish(pf) wsClient.expectMessage(expectedSerializedPf) - val ps = PaymentSent(fixedUUID, amount = 21 msat, feesPaid = 1 msat, paymentHash = ByteVector32.Zeroes, paymentPreimage = ByteVector32.One, route = Nil, timestamp = 1553784337711L) - val expectedSerializedPs = """{"type":"payment-sent","id":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","amount":21,"feesPaid":1,"paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","paymentPreimage":"0100000000000000000000000000000000000000000000000000000000000000","route":[],"timestamp":1553784337711}""" + val ps = PaymentSent(fixedUUID, ByteVector32.Zeroes, ByteVector32.One, Seq(PaymentSent.PartialPayment(fixedUUID, 21 msat, 1 msat, ByteVector32.Zeroes, Nil, 1553784337711L))) + val expectedSerializedPs = """{"type":"payment-sent","id":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","paymentPreimage":"0100000000000000000000000000000000000000000000000000000000000000","parts":[{"id":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","amount":21,"feesPaid":1,"toChannelId":"0000000000000000000000000000000000000000000000000000000000000000","route":[],"timestamp":1553784337711}]}""" serialization.write(ps)(mockService.formatsWithTypeHint) === expectedSerializedPs system.eventStream.publish(ps) wsClient.expectMessage(expectedSerializedPs) @@ -374,8 +374,8 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock system.eventStream.publish(prel) wsClient.expectMessage(expectedSerializedPrel) - val precv = PaymentReceived(amount = 21 msat, paymentHash = ByteVector32.Zeroes, timestamp = 1553784963659L) - val expectedSerializedPrecv = """{"type":"payment-received","amount":21,"paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","timestamp":1553784963659}""" + val precv = PaymentReceived(ByteVector32.Zeroes, Seq(PaymentReceived.PartialPayment(21 msat, ByteVector32.Zeroes, 1553784963659L))) + val expectedSerializedPrecv = """{"type":"payment-received","paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","parts":[{"amount":21,"fromChannelId":"0000000000000000000000000000000000000000000000000000000000000000","timestamp":1553784963659}]}""" serialization.write(precv)(mockService.formatsWithTypeHint) === expectedSerializedPrecv system.eventStream.publish(precv) wsClient.expectMessage(expectedSerializedPrecv) From cb83a55e94cb87a322775c6ec0758c2af68d7dce Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Fri, 20 Sep 2019 10:32:03 +0200 Subject: [PATCH 11/14] Rename disableAutoCommit to inTransaction --- .../electrum/db/sqlite/SqliteWalletDb.scala | 2 +- .../fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala | 2 +- .../acinq/eclair/db/sqlite/SqliteChannelsDb.scala | 2 +- .../acinq/eclair/db/sqlite/SqliteNetworkDb.scala | 4 ++-- .../acinq/eclair/db/sqlite/SqlitePaymentsDb.scala | 2 +- .../fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala | 2 +- .../eclair/db/sqlite/SqlitePendingRelayDb.scala | 2 +- .../fr/acinq/eclair/db/sqlite/SqliteUtils.scala | 14 +++++++------- .../scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala | 4 ++-- 9 files changed, 17 insertions(+), 17 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/electrum/db/sqlite/SqliteWalletDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/electrum/db/sqlite/SqliteWalletDb.scala index f817dc3e21..e0611d517d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/electrum/db/sqlite/SqliteWalletDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/electrum/db/sqlite/SqliteWalletDb.scala @@ -46,7 +46,7 @@ class SqliteWalletDb(sqlite: Connection) extends WalletDb { } override def addHeaders(startHeight: Int, headers: Seq[BlockHeader]): Unit = { - using(sqlite.prepareStatement("INSERT OR IGNORE INTO headers VALUES (?, ?, ?)"), disableAutoCommit = true) { statement => + using(sqlite.prepareStatement("INSERT OR IGNORE INTO headers VALUES (?, ?, ?)"), inTransaction = true) { statement => var height = startHeight headers.foreach(header => { statement.setInt(1, height) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala index c99789ab84..662d1227c9 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala @@ -39,7 +39,7 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { val DB_NAME = "audit" val CURRENT_VERSION = 3 - using(sqlite.createStatement(), disableAutoCommit = true) { statement => + using(sqlite.createStatement(), inTransaction = true) { statement => def migration12(statement: Statement) = { statement.executeUpdate(s"ALTER TABLE sent ADD id BLOB DEFAULT '${ChannelCodecs.UNKNOWN_UUID.toString}' NOT NULL") diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala index 4144916a21..1641a73790 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala @@ -42,7 +42,7 @@ class SqliteChannelsDb(sqlite: Connection) extends ChannelsDb with Logging { statement.execute("PRAGMA foreign_keys = ON") } - using(sqlite.createStatement(), disableAutoCommit = true) { statement => + using(sqlite.createStatement(), inTransaction = true) { statement => def migration12(statement: Statement) = { statement.executeUpdate("ALTER TABLE local_channels ADD COLUMN is_closed BOOLEAN NOT NULL DEFAULT 0") diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala index 55a0ffb4e8..0295d55b39 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala @@ -36,7 +36,7 @@ class SqliteNetworkDb(sqlite: Connection) extends NetworkDb with Logging { val DB_NAME = "network" val CURRENT_VERSION = 2 - using(sqlite.createStatement(), disableAutoCommit = true) { statement => + using(sqlite.createStatement(), inTransaction = true) { statement => getVersion(statement, DB_NAME, CURRENT_VERSION) match { case 1 => // channel_update are cheap to retrieve, so let's just wipe them out and they'll get resynced @@ -142,7 +142,7 @@ class SqliteNetworkDb(sqlite: Connection) extends NetworkDb with Logging { } override def addToPruned(shortChannelIds: Iterable[ShortChannelId]): Unit = { - using(sqlite.prepareStatement("INSERT OR IGNORE INTO pruned VALUES (?)"), disableAutoCommit = true) { statement => + using(sqlite.prepareStatement("INSERT OR IGNORE INTO pruned VALUES (?)"), inTransaction = true) { statement => shortChannelIds.foreach(shortChannelId => { statement.setLong(1, shortChannelId.toLong) statement.addBatch() diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala index 32304ee9b1..a196285941 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala @@ -48,7 +48,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { private val paymentFailuresCodec = discriminated[List[FailureSummary]].by(byte) .typecase(0x01, listOfN(uint8, failureSummaryCodec)) - using(sqlite.createStatement(), disableAutoCommit = true) { statement => + using(sqlite.createStatement(), inTransaction = true) { statement => def migration12(statement: Statement) = { // Version 2 is "backwards compatible" in the sense that it uses separate tables from version 1 (which used a single "payments" table). diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala index 98df768201..11725d94fb 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala @@ -32,7 +32,7 @@ import SqliteUtils.ExtendedResultSet._ val DB_NAME = "peers" val CURRENT_VERSION = 1 - using(sqlite.createStatement(), disableAutoCommit = true) { statement => + using(sqlite.createStatement(), inTransaction = true) { statement => require(getVersion(statement, DB_NAME, CURRENT_VERSION) == CURRENT_VERSION, s"incompatible version of $DB_NAME DB found") // there is only one version currently deployed statement.executeUpdate("CREATE TABLE IF NOT EXISTS peers (node_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)") } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingRelayDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingRelayDb.scala index a1aec7a5cc..888e7388fd 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingRelayDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingRelayDb.scala @@ -33,7 +33,7 @@ class SqlitePendingRelayDb(sqlite: Connection) extends PendingRelayDb { val DB_NAME = "pending_relay" val CURRENT_VERSION = 1 - using(sqlite.createStatement(), disableAutoCommit = true) { statement => + using(sqlite.createStatement(), inTransaction = true) { statement => require(getVersion(statement, DB_NAME, CURRENT_VERSION) == CURRENT_VERSION, s"incompatible version of $DB_NAME DB found") // there is only one version currently deployed // note: should we use a foreign key to local_channels table here? statement.executeUpdate("CREATE TABLE IF NOT EXISTS pending_relay (channel_id BLOB NOT NULL, htlc_id INTEGER NOT NULL, data BLOB NOT NULL, PRIMARY KEY(channel_id, htlc_id))") diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala index 2f63cf969a..4b5b957e1a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala @@ -27,22 +27,22 @@ import scala.collection.immutable.Queue object SqliteUtils { /** - * This helper makes sure statements are closed. + * This helper makes sure statements are correctly closed. * - * @param disableAutoCommit if set to true, all updates in the block will be run in a transaction. + * @param inTransaction if set to true, all updates in the block will be run in a transaction. */ - def using[T <: Statement, U](statement: T, disableAutoCommit: Boolean = false)(block: T => U): U = { + def using[T <: Statement, U](statement: T, inTransaction: Boolean = false)(block: T => U): U = { try { - if (disableAutoCommit) statement.getConnection.setAutoCommit(false) + if (inTransaction) statement.getConnection.setAutoCommit(false) val res = block(statement) - if (disableAutoCommit) statement.getConnection.commit() + if (inTransaction) statement.getConnection.commit() res } catch { case t: Exception => - if (disableAutoCommit) statement.getConnection.rollback() + if (inTransaction) statement.getConnection.rollback() throw t } finally { - if (disableAutoCommit) statement.getConnection.setAutoCommit(true) + if (inTransaction) statement.getConnection.setAutoCommit(true) if (statement != null) statement.close() } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala index da83bbca3e..8bafb8fcb1 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala @@ -41,7 +41,7 @@ class SqliteUtilsSpec extends FunSuite { assert(!results.next()) } - assertThrows[SQLiteException](using(conn.createStatement(), disableAutoCommit = true) { statement => + assertThrows[SQLiteException](using(conn.createStatement(), inTransaction = true) { statement => statement.executeUpdate("INSERT INTO utils_test VALUES (3, 3)") statement.executeUpdate("INSERT INTO utils_test VALUES (1, 3)") // should throw (primary key violation) }) @@ -55,7 +55,7 @@ class SqliteUtilsSpec extends FunSuite { assert(!results.next()) } - using(conn.createStatement(), disableAutoCommit = true) { statement => + using(conn.createStatement(), inTransaction = true) { statement => statement.executeUpdate("INSERT INTO utils_test VALUES (3, 3)") statement.executeUpdate("INSERT INTO utils_test VALUES (4, 4)") } From 27d9a04e30a62af2f57e08fea5d36b547d73699a Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Fri, 20 Sep 2019 11:56:46 +0200 Subject: [PATCH 12/14] Add comments to PaymentEvents. Make route optional. --- .../fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala | 2 +- .../acinq/eclair/db/sqlite/SqlitePaymentsDb.scala | 2 +- .../fr/acinq/eclair/payment/PaymentEvents.scala | 12 +++++++----- .../fr/acinq/eclair/payment/PaymentLifecycle.scala | 2 +- .../scala/fr/acinq/eclair/payment/Relayer.scala | 2 +- .../fr/acinq/eclair/db/SqliteAuditDbSpec.scala | 14 +++++++------- .../fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala | 8 ++++---- .../acinq/eclair/integration/IntegrationSpec.scala | 4 +--- 8 files changed, 23 insertions(+), 23 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala index 662d1227c9..ad49eca9e4 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala @@ -183,7 +183,7 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { MilliSatoshi(rs.getLong("amount_msat")), MilliSatoshi(rs.getLong("fees_msat")), rs.getByteVector32("to_channel_id"), - Nil, // we don't store the route + None, // we don't store the route rs.getLong("timestamp")))) } q diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala index a196285941..0cdb65c15a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala @@ -125,7 +125,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { statement.setLong(1, p.timestamp) statement.setBytes(2, paymentResult.paymentPreimage.toArray) statement.setLong(3, p.feesPaid.toLong) - statement.setBytes(4, paymentRouteCodec.encode(p.route.map(h => HopSummary(h)).toList).require.toByteArray) + statement.setBytes(4, paymentRouteCodec.encode(p.route.getOrElse(Nil).map(h => HopSummary(h)).toList).require.toByteArray) statement.setString(5, p.id.toString) statement.addBatch() }) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala index 340d38475b..8dafe453d7 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala @@ -35,15 +35,17 @@ sealed trait PaymentEvent { } case class PaymentSent(id: UUID, paymentHash: ByteVector32, paymentPreimage: ByteVector32, parts: Seq[PaymentSent.PartialPayment]) extends PaymentEvent { - require(parts.nonEmpty, "sent payment is empty") + require(parts.nonEmpty, "must have at least one subpayment") val amount: MilliSatoshi = parts.map(_.amount).sum val feesPaid: MilliSatoshi = parts.map(_.feesPaid).sum - val timestamp: Long = parts.map(_.timestamp).min + val timestamp: Long = parts.map(_.timestamp).min // we use min here because we receive the proof of payment as soon as the first partial payment is fulfilled } object PaymentSent { - case class PartialPayment(id: UUID, amount: MilliSatoshi, feesPaid: MilliSatoshi, toChannelId: ByteVector32, route: Seq[Hop], timestamp: Long = Platform.currentTime) + case class PartialPayment(id: UUID, amount: MilliSatoshi, feesPaid: MilliSatoshi, toChannelId: ByteVector32, route: Option[Seq[Hop]], timestamp: Long = Platform.currentTime) { + require(route.isEmpty || route.get.nonEmpty, "route must be None or contain at least one hop") + } } @@ -52,9 +54,9 @@ case class PaymentFailed(id: UUID, paymentHash: ByteVector32, failures: Seq[Paym case class PaymentRelayed(amountIn: MilliSatoshi, amountOut: MilliSatoshi, paymentHash: ByteVector32, fromChannelId: ByteVector32, toChannelId: ByteVector32, timestamp: Long = Platform.currentTime) extends PaymentEvent case class PaymentReceived(paymentHash: ByteVector32, parts: Seq[PaymentReceived.PartialPayment]) extends PaymentEvent { - require(parts.nonEmpty, "received payment is empty") + require(parts.nonEmpty, "must have at least one subpayment") val amount: MilliSatoshi = parts.map(_.amount).sum - val timestamp: Long = parts.map(_.timestamp).max + val timestamp: Long = parts.map(_.timestamp).max // we use max here because we fulfill the payment only once we received all the parts } object PaymentReceived { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala index 51f5cbeb21..68b5ecd4b5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala @@ -78,7 +78,7 @@ class PaymentLifecycle(nodeParams: NodeParams, progressHandler: PaymentProgressH case Event("ok", _) => stay case Event(fulfill: UpdateFulfillHtlc, WaitingForComplete(s, c, cmd, _, _, _, _, route)) => - val p = PartialPayment(id, c.finalPayload.amount, cmd.amount - c.finalPayload.amount, fulfill.channelId, route) + val p = PartialPayment(id, c.finalPayload.amount, cmd.amount - c.finalPayload.amount, fulfill.channelId, Some(route)) progressHandler.onSuccess(s, PaymentSent(id, c.paymentHash, fulfill.paymentPreimage, p :: Nil))(context) stop(FSM.Normal) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala index 0b63d77285..bb32da3e73 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala @@ -161,7 +161,7 @@ class Relayer(nodeParams: NodeParams, register: ActorRef, paymentHandler: ActorR // we sent the payment, but we probably restarted and the reference to the original sender was lost, // we publish the success on the event stream and update the status in paymentDb val feesPaid = 0.msat // fees are unknown since we lost the reference to the payment - val result = PaymentSent(id, add.paymentHash, fulfill.paymentPreimage, Seq(PaymentSent.PartialPayment(id, add.amountMsat, feesPaid, add.channelId, Nil))) + val result = PaymentSent(id, add.paymentHash, fulfill.paymentPreimage, Seq(PaymentSent.PartialPayment(id, add.amountMsat, feesPaid, add.channelId, None))) nodeParams.db.payments.updateOutgoingPayment(result) context.system.eventStream.publish(result) case Local(_, Some(sender)) => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala index 2dc8c16625..7eb9b8d1b5 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala @@ -44,16 +44,16 @@ class SqliteAuditDbSpec extends FunSuite { val sqlite = TestConstants.sqliteInMemory() val db = new SqliteAuditDb(sqlite) - val e1 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, randomBytes32, randomBytes32, PaymentSent.PartialPayment(ChannelCodecs.UNKNOWN_UUID, 42000 msat, 1000 msat, randomBytes32, Nil) :: Nil) + val e1 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, randomBytes32, randomBytes32, PaymentSent.PartialPayment(ChannelCodecs.UNKNOWN_UUID, 42000 msat, 1000 msat, randomBytes32, None) :: Nil) val pp2a = PaymentReceived.PartialPayment(42000 msat, randomBytes32) val pp2b = PaymentReceived.PartialPayment(42100 msat, randomBytes32) val e2 = PaymentReceived(randomBytes32, pp2a :: pp2b :: Nil) val e3 = PaymentRelayed(42000 msat, 1000 msat, randomBytes32, randomBytes32, randomBytes32) val e4 = NetworkFeePaid(null, randomKey.publicKey, randomBytes32, Transaction(0, Seq.empty, Seq.empty, 0), 42 sat, "mutual") - val pp5a = PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, Nil, timestamp = 0) - val pp5b = PaymentSent.PartialPayment(UUID.randomUUID(), 42100 msat, 900 msat, randomBytes32, Nil, timestamp = 1) + val pp5a = PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None, timestamp = 0) + val pp5b = PaymentSent.PartialPayment(UUID.randomUUID(), 42100 msat, 900 msat, randomBytes32, None, timestamp = 1) val e5 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, randomBytes32, randomBytes32, pp5a :: pp5b :: Nil) - val pp6 = PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, Nil, timestamp = (Platform.currentTime.milliseconds + 10.minutes).toMillis) + val pp6 = PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None, timestamp = (Platform.currentTime.milliseconds + 10.minutes).toMillis) val e6 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, randomBytes32, randomBytes32, pp6 :: Nil) val e7 = AvailableBalanceChanged(null, randomBytes32, ShortChannelId(500000, 42, 1), 456123000 msat, ChannelCodecsSpec.commitments) val e8 = ChannelLifecycleEvent(randomBytes32, randomKey.publicKey, 456123000 sat, isFunder = true, isPrivate = false, "mutual") @@ -134,9 +134,9 @@ class SqliteAuditDbSpec extends FunSuite { assert(getVersion(statement, "audit", 3) == 1) // we expect version 1 } - val ps = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, Nil) :: Nil) - val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, Nil) - val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, Nil) + val ps = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None) :: Nil) + val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, None) + val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, None) val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, pp1 :: pp2 :: Nil) val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true) val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala index 2159fb6036..f70d18aaf0 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala @@ -207,7 +207,7 @@ class SqlitePaymentsDbSpec extends FunSuite { val ps6 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("3"), randomBytes32, 789 msat, bob, 1250, None, OutgoingPaymentStatus.Failed(Nil, 1300)) postMigrationDb.addOutgoingPayment(ps4) postMigrationDb.addOutgoingPayment(ps5.copy(status = OutgoingPaymentStatus.Pending)) - postMigrationDb.updateOutgoingPayment(PaymentSent(ps5.parentId, ps5.paymentHash, preimage1, Seq(PaymentSent.PartialPayment(ps5.id, ps5.amount, 42 msat, randomBytes32, Nil, 1180)))) + postMigrationDb.updateOutgoingPayment(PaymentSent(ps5.parentId, ps5.paymentHash, preimage1, Seq(PaymentSent.PartialPayment(ps5.id, ps5.amount, 42 msat, randomBytes32, None, 1180)))) postMigrationDb.addOutgoingPayment(ps6.copy(status = OutgoingPaymentStatus.Pending)) postMigrationDb.updateOutgoingPayment(PaymentFailed(ps6.id, ps6.paymentHash, Nil, 1300)) @@ -305,11 +305,11 @@ class SqlitePaymentsDbSpec extends FunSuite { assert(db.getOutgoingPayment(s4.id) === Some(ss4)) // can't update again once it's in a final state - assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentSent(parentId, s3.paymentHash, preimage1, Seq(PaymentSent.PartialPayment(s3.id, s3.amount, 42 msat, randomBytes32, Nil))))) + assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentSent(parentId, s3.paymentHash, preimage1, Seq(PaymentSent.PartialPayment(s3.id, s3.amount, 42 msat, randomBytes32, None))))) val paymentSent = PaymentSent(parentId, paymentHash1, preimage1, Seq( - PaymentSent.PartialPayment(s1.id, s1.amount, 15 msat, randomBytes32, Nil, 400), - PaymentSent.PartialPayment(s2.id, s2.amount, 20 msat, randomBytes32, Seq(hop_ab, hop_bc), 410) + PaymentSent.PartialPayment(s1.id, s1.amount, 15 msat, randomBytes32, None, 400), + PaymentSent.PartialPayment(s2.id, s2.amount, 20 msat, randomBytes32, Some(Seq(hop_ab, hop_bc)), 410) )) val ss1 = s1.copy(status = OutgoingPaymentStatus.Succeeded(preimage1, 15 msat, Nil, 400)) val ss2 = s2.copy(status = OutgoingPaymentStatus.Succeeded(preimage1, 20 msat, Seq(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, Some(ShortChannelId(43)))), 410)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala index 8b7ac29b98..4e9fe9ffed 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala @@ -51,11 +51,9 @@ import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} import scodec.bits.ByteVector import scala.collection.JavaConversions._ -import scala.compat.Platform import scala.concurrent.Await import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ -import scala.util.Try /** * Created by PM on 15/03/2017. @@ -432,7 +430,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService awaitCond({ sender.expectMsgType[PaymentEvent](10 seconds) match { case PaymentFailed(_, _, failures, _) => failures == Seq.empty // if something went wrong fail with a hint - case PaymentSent(_, _, _, part :: Nil) => part.route.exists(_.nodeId == nodes("G").nodeParams.nodeId) + case PaymentSent(_, _, _, part :: Nil) => part.route.get.exists(_.nodeId == nodes("G").nodeParams.nodeId) case _ => false } }, max = 30 seconds, interval = 10 seconds) From 46878260947cbeb699dd16d49cd7b988b85ca8b3 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Fri, 20 Sep 2019 14:09:28 +0200 Subject: [PATCH 13/14] fixup! Add comments to PaymentEvents. Make route optional. --- .../src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala index d9be8cba66..a80a8e0ab4 100644 --- a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala +++ b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala @@ -362,7 +362,7 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock system.eventStream.publish(pf) wsClient.expectMessage(expectedSerializedPf) - val ps = PaymentSent(fixedUUID, ByteVector32.Zeroes, ByteVector32.One, Seq(PaymentSent.PartialPayment(fixedUUID, 21 msat, 1 msat, ByteVector32.Zeroes, Nil, 1553784337711L))) + val ps = PaymentSent(fixedUUID, ByteVector32.Zeroes, ByteVector32.One, Seq(PaymentSent.PartialPayment(fixedUUID, 21 msat, 1 msat, ByteVector32.Zeroes, None, 1553784337711L))) val expectedSerializedPs = """{"type":"payment-sent","id":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","paymentPreimage":"0100000000000000000000000000000000000000000000000000000000000000","parts":[{"id":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","amount":21,"feesPaid":1,"toChannelId":"0000000000000000000000000000000000000000000000000000000000000000","route":[],"timestamp":1553784337711}]}""" serialization.write(ps)(mockService.formatsWithTypeHint) === expectedSerializedPs system.eventStream.publish(ps) From 7d1fe8f4cbec3ad8c9d8b9a786d6c0e858b9c5e5 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Fri, 20 Sep 2019 14:16:39 +0200 Subject: [PATCH 14/14] fixup! Add comments to PaymentEvents. Make route optional. --- .../src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala index a80a8e0ab4..83fcc74f0b 100644 --- a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala +++ b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala @@ -363,7 +363,7 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock wsClient.expectMessage(expectedSerializedPf) val ps = PaymentSent(fixedUUID, ByteVector32.Zeroes, ByteVector32.One, Seq(PaymentSent.PartialPayment(fixedUUID, 21 msat, 1 msat, ByteVector32.Zeroes, None, 1553784337711L))) - val expectedSerializedPs = """{"type":"payment-sent","id":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","paymentPreimage":"0100000000000000000000000000000000000000000000000000000000000000","parts":[{"id":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","amount":21,"feesPaid":1,"toChannelId":"0000000000000000000000000000000000000000000000000000000000000000","route":[],"timestamp":1553784337711}]}""" + val expectedSerializedPs = """{"type":"payment-sent","id":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","paymentPreimage":"0100000000000000000000000000000000000000000000000000000000000000","parts":[{"id":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","amount":21,"feesPaid":1,"toChannelId":"0000000000000000000000000000000000000000000000000000000000000000","timestamp":1553784337711}]}""" serialization.write(ps)(mockService.formatsWithTypeHint) === expectedSerializedPs system.eventStream.publish(ps) wsClient.expectMessage(expectedSerializedPs)