diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index fc8e7356c5..c235c9fbec 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -46,9 +46,13 @@ case class AuditResponse(sent: Seq[PaymentSent], received: Seq[PaymentReceived], case class TimestampQueryFilters(from: Long, to: Long) object TimestampQueryFilters { + /** We use this in the context of timestamp filtering, when we don't need an upper bound. */ + val MaxEpochMilliseconds = Duration.fromNanos(Long.MaxValue).toMillis + def getDefaultTimestampFilters(from_opt: Option[Long], to_opt: Option[Long]) = { - val from = from_opt.getOrElse(0L) - val to = to_opt.getOrElse(MaxEpochSeconds) + // NB: we expect callers to use seconds, but internally we use milli-seconds everywhere. + val from = from_opt.getOrElse(0L).seconds.toMillis + val to = to_opt.map(_.seconds.toMillis).getOrElse(MaxEpochMilliseconds) TimestampQueryFilters(from, to) } @@ -78,13 +82,13 @@ trait Eclair { def receivedInfo(paymentHash: ByteVector32)(implicit timeout: Timeout): Future[Option[IncomingPayment]] - def send(recipientNodeId: PublicKey, amount: MilliSatoshi, paymentHash: ByteVector32, invoice_opt: Option[PaymentRequest] = None, maxAttempts_opt: Option[Int] = None, feeThresholdSat_opt: Option[Satoshi] = None, maxFeePct_opt: Option[Double] = None)(implicit timeout: Timeout): Future[UUID] + def send(externalId_opt: Option[String], recipientNodeId: PublicKey, amount: MilliSatoshi, paymentHash: ByteVector32, invoice_opt: Option[PaymentRequest] = None, maxAttempts_opt: Option[Int] = None, feeThresholdSat_opt: Option[Satoshi] = None, maxFeePct_opt: Option[Double] = None)(implicit timeout: Timeout): Future[UUID] def sentInfo(id: Either[UUID, ByteVector32])(implicit timeout: Timeout): Future[Seq[OutgoingPayment]] def findRoute(targetNodeId: PublicKey, amount: MilliSatoshi, assistedRoutes: Seq[Seq[PaymentRequest.ExtraHop]] = Seq.empty)(implicit timeout: Timeout): Future[RouteResponse] - def sendToRoute(route: Seq[PublicKey], amount: MilliSatoshi, paymentHash: ByteVector32, finalCltvExpiryDelta: CltvExpiryDelta)(implicit timeout: Timeout): Future[UUID] + def sendToRoute(externalId_opt: Option[String], route: Seq[PublicKey], amount: MilliSatoshi, paymentHash: ByteVector32, finalCltvExpiryDelta: CltvExpiryDelta)(implicit timeout: Timeout): Future[UUID] def audit(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[AuditResponse] @@ -113,6 +117,9 @@ class EclairImpl(appKit: Kit) extends Eclair { implicit val ec: ExecutionContext = appKit.system.dispatcher + // We constrain external identifiers. This allows uuid, long and pubkey to be used. + private val externalIdMaxLength = 66 + override def connect(target: Either[NodeURI, PublicKey])(implicit timeout: Timeout): Future[String] = target match { case Left(uri) => (appKit.switchboard ? Peer.Connect(uri)).mapTo[String] case Right(pubKey) => (appKit.switchboard ? Peer.Connect(pubKey, None)).mapTo[String] @@ -186,37 +193,42 @@ class EclairImpl(appKit: Kit) extends Eclair { (appKit.router ? RouteRequest(appKit.nodeParams.nodeId, targetNodeId, amount, assistedRoutes)).mapTo[RouteResponse] } - override def sendToRoute(route: Seq[PublicKey], amount: MilliSatoshi, paymentHash: ByteVector32, finalCltvExpiryDelta: CltvExpiryDelta)(implicit timeout: Timeout): Future[UUID] = { - (appKit.paymentInitiator ? SendPaymentRequest(amount, paymentHash, route.last, 1, finalCltvExpiryDelta, route)).mapTo[UUID] + override def sendToRoute(externalId_opt: Option[String], route: Seq[PublicKey], amount: MilliSatoshi, paymentHash: ByteVector32, finalCltvExpiryDelta: CltvExpiryDelta)(implicit timeout: Timeout): Future[UUID] = { + externalId_opt match { + case Some(externalId) if externalId.length > externalIdMaxLength => Future.failed(new IllegalArgumentException("externalId is too long: cannot exceed 66 characters")) + case _ => (appKit.paymentInitiator ? SendPaymentRequest(amount, paymentHash, route.last, 1, finalCltvExpiryDelta, None, externalId_opt, route)).mapTo[UUID] + } } - override def send(recipientNodeId: PublicKey, amount: MilliSatoshi, paymentHash: ByteVector32, invoice_opt: Option[PaymentRequest], maxAttempts_opt: Option[Int], feeThreshold_opt: Option[Satoshi], maxFeePct_opt: Option[Double])(implicit timeout: Timeout): Future[UUID] = { + override def send(externalId_opt: Option[String], recipientNodeId: PublicKey, amount: MilliSatoshi, paymentHash: ByteVector32, invoice_opt: Option[PaymentRequest], maxAttempts_opt: Option[Int], feeThreshold_opt: Option[Satoshi], maxFeePct_opt: Option[Double])(implicit timeout: Timeout): Future[UUID] = { val maxAttempts = maxAttempts_opt.getOrElse(appKit.nodeParams.maxPaymentAttempts) - val defaultRouteParams = Router.getDefaultRouteParams(appKit.nodeParams.routerConf) val routeParams = defaultRouteParams.copy( maxFeePct = maxFeePct_opt.getOrElse(defaultRouteParams.maxFeePct), maxFeeBase = feeThreshold_opt.map(_.toMilliSatoshi).getOrElse(defaultRouteParams.maxFeeBase) ) - invoice_opt match { - case Some(invoice) if invoice.isExpired => Future.failed(new IllegalArgumentException("invoice has expired")) - case Some(invoice) => - val sendPayment = invoice.minFinalCltvExpiryDelta match { - case Some(minFinalCltvExpiryDelta) => SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts, minFinalCltvExpiryDelta, assistedRoutes = invoice.routingInfo, routeParams = Some(routeParams)) - case None => SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts, assistedRoutes = invoice.routingInfo, routeParams = Some(routeParams)) - } - (appKit.paymentInitiator ? sendPayment).mapTo[UUID] - case None => - val sendPayment = SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts = maxAttempts, routeParams = Some(routeParams)) - (appKit.paymentInitiator ? sendPayment).mapTo[UUID] + externalId_opt match { + case Some(externalId) if externalId.length > externalIdMaxLength => Future.failed(new IllegalArgumentException("externalId is too long: cannot exceed 66 characters")) + case _ => invoice_opt match { + case Some(invoice) if invoice.isExpired => Future.failed(new IllegalArgumentException("invoice has expired")) + case Some(invoice) => + val sendPayment = invoice.minFinalCltvExpiryDelta match { + case Some(minFinalCltvExpiryDelta) => SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts, minFinalCltvExpiryDelta, invoice_opt, externalId_opt, assistedRoutes = invoice.routingInfo, routeParams = Some(routeParams)) + case None => SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts, paymentRequest = invoice_opt, externalId = externalId_opt, assistedRoutes = invoice.routingInfo, routeParams = Some(routeParams)) + } + (appKit.paymentInitiator ? sendPayment).mapTo[UUID] + case None => + val sendPayment = SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts = maxAttempts, externalId = externalId_opt, routeParams = Some(routeParams)) + (appKit.paymentInitiator ? sendPayment).mapTo[UUID] + } } } override def sentInfo(id: Either[UUID, ByteVector32])(implicit timeout: Timeout): Future[Seq[OutgoingPayment]] = Future { id match { case Left(uuid) => appKit.nodeParams.db.payments.getOutgoingPayment(uuid).toSeq - case Right(paymentHash) => appKit.nodeParams.db.payments.getOutgoingPayments(paymentHash) + case Right(paymentHash) => appKit.nodeParams.db.payments.listOutgoingPayments(paymentHash) } } @@ -245,17 +257,17 @@ class EclairImpl(appKit: Kit) extends Eclair { override def allInvoices(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[Seq[PaymentRequest]] = Future { val filter = getDefaultTimestampFilters(from_opt, to_opt) - appKit.nodeParams.db.payments.listPaymentRequests(filter.from, filter.to) + appKit.nodeParams.db.payments.listIncomingPayments(filter.from, filter.to).map(_.paymentRequest) } override def pendingInvoices(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[Seq[PaymentRequest]] = Future { val filter = getDefaultTimestampFilters(from_opt, to_opt) - appKit.nodeParams.db.payments.listPendingPaymentRequests(filter.from, filter.to) + appKit.nodeParams.db.payments.listPendingIncomingPayments(filter.from, filter.to).map(_.paymentRequest) } override def getInvoice(paymentHash: ByteVector32)(implicit timeout: Timeout): Future[Option[PaymentRequest]] = Future { - appKit.nodeParams.db.payments.getPaymentRequest(paymentHash) + appKit.nodeParams.db.payments.getIncomingPayment(paymentHash).map(_.paymentRequest) } /** diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/electrum/db/sqlite/SqliteWalletDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/electrum/db/sqlite/SqliteWalletDb.scala index f817dc3e21..e0611d517d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/electrum/db/sqlite/SqliteWalletDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/electrum/db/sqlite/SqliteWalletDb.scala @@ -46,7 +46,7 @@ class SqliteWalletDb(sqlite: Connection) extends WalletDb { } override def addHeaders(startHeight: Int, headers: Seq[BlockHeader]): Unit = { - using(sqlite.prepareStatement("INSERT OR IGNORE INTO headers VALUES (?, ?, ?)"), disableAutoCommit = true) { statement => + using(sqlite.prepareStatement("INSERT OR IGNORE INTO headers VALUES (?, ?, ?)"), inTransaction = true) { statement => var height = startHeight headers.foreach(header => { statement.setInt(1, height) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala index 616c3b2b39..7b726270f9 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala @@ -1732,7 +1732,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId case _: ChannelException => () case _ => log.error(cause, s"msg=$cmd stateData=$stateData ") } - context.system.eventStream.publish(ChannelErrorOccured(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, LocalError(cause), isFatal = false)) + context.system.eventStream.publish(ChannelErrorOccurred(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, LocalError(cause), isFatal = false)) stay replying Status.Failure(cause) } @@ -1785,7 +1785,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId val error = Error(d.channelId, exc.getMessage) // NB: we don't use the handleLocalError handler because it would result in the commit tx being published, which we don't want: // implementation *guarantees* that in case of BITCOIN_FUNDING_PUBLISH_FAILED, the funding tx hasn't and will never be published, so we can close the channel right away - context.system.eventStream.publish(ChannelErrorOccured(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, LocalError(exc), isFatal = true)) + context.system.eventStream.publish(ChannelErrorOccurred(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, LocalError(exc), isFatal = true)) goto(CLOSED) sending error } @@ -1793,7 +1793,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId log.warning(s"funding tx hasn't been confirmed in time, cancelling channel delay=$FUNDING_TIMEOUT_FUNDEE") val exc = FundingTxTimedout(d.channelId) val error = Error(d.channelId, exc.getMessage) - context.system.eventStream.publish(ChannelErrorOccured(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, LocalError(exc), isFatal = true)) + context.system.eventStream.publish(ChannelErrorOccurred(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, LocalError(exc), isFatal = true)) goto(CLOSED) sending error } @@ -1863,7 +1863,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId case _ => log.error(cause, s"msg=${msg.getOrElse("n/a")} stateData=$stateData ") } val error = Error(Helpers.getChannelId(d), cause.getMessage) - context.system.eventStream.publish(ChannelErrorOccured(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, LocalError(cause), isFatal = true)) + context.system.eventStream.publish(ChannelErrorOccurred(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, LocalError(cause), isFatal = true)) d match { case dd: HasCommitments if Closing.nothingAtStake(dd) => goto(CLOSED) @@ -1879,7 +1879,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId def handleRemoteError(e: Error, d: Data) = { // see BOLT 1: only print out data verbatim if is composed of printable ASCII characters log.error(s"peer sent error: ascii='${e.toAscii}' bin=${e.data.toHex}") - context.system.eventStream.publish(ChannelErrorOccured(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, RemoteError(e), isFatal = true)) + context.system.eventStream.publish(ChannelErrorOccurred(self, Helpers.getChannelId(stateData), remoteNodeId, stateData, RemoteError(e), isFatal = true)) d match { case _: DATA_CLOSING => stay // nothing to do, there is already a spending tx published diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelEvents.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelEvents.scala index c3b3d79979..bf6441b0d8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelEvents.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelEvents.scala @@ -48,7 +48,7 @@ case class ChannelSignatureSent(channel: ActorRef, commitments: Commitments) ext case class ChannelSignatureReceived(channel: ActorRef, commitments: Commitments) extends ChannelEvent -case class ChannelErrorOccured(channel: ActorRef, channelId: ByteVector32, remoteNodeId: PublicKey, data: Data, error: ChannelError, isFatal: Boolean) extends ChannelEvent +case class ChannelErrorOccurred(channel: ActorRef, channelId: ByteVector32, remoteNodeId: PublicKey, data: Data, error: ChannelError, isFatal: Boolean) extends ChannelEvent case class NetworkFeePaid(channel: ActorRef, remoteNodeId: PublicKey, channelId: ByteVector32, tx: Transaction, fee: Satoshi, txType: String) extends ChannelEvent diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/AuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/AuditDb.scala index 5fccd441b1..c271604353 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/AuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/AuditDb.scala @@ -16,8 +16,8 @@ package fr.acinq.eclair.db -import fr.acinq.bitcoin.{ByteVector32, Satoshi} import fr.acinq.bitcoin.Crypto.PublicKey +import fr.acinq.bitcoin.{ByteVector32, Satoshi} import fr.acinq.eclair.channel._ import fr.acinq.eclair.payment.{PaymentReceived, PaymentRelayed, PaymentSent} @@ -35,7 +35,7 @@ trait AuditDb { def add(networkFeePaid: NetworkFeePaid) - def add(channelErrorOccured: ChannelErrorOccured) + def add(channelErrorOccurred: ChannelErrorOccurred) def listSent(from: Long, to: Long): Seq[PaymentSent] @@ -47,7 +47,7 @@ trait AuditDb { def stats: Seq[Stats] - def close: Unit + def close(): Unit } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala index b9f4aa838f..e7113eb37d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala @@ -19,70 +19,174 @@ package fr.acinq.eclair.db import java.util.UUID import fr.acinq.bitcoin.ByteVector32 -import fr.acinq.eclair.MilliSatoshi -import fr.acinq.eclair.payment.PaymentRequest +import fr.acinq.bitcoin.Crypto.PublicKey +import fr.acinq.eclair.payment._ +import fr.acinq.eclair.router.Hop +import fr.acinq.eclair.{MilliSatoshi, ShortChannelId} + +import scala.compat.Platform trait PaymentsDb { - // creates a record for a non yet finalized outgoing payment - def addOutgoingPayment(outgoingPayment: OutgoingPayment) + /** Create a record for a non yet finalized outgoing payment. */ + def addOutgoingPayment(outgoingPayment: OutgoingPayment): Unit + + /** Update the status of the payment in case of success. */ + def updateOutgoingPayment(paymentResult: PaymentSent): Unit - // updates the status of the payment, if the newStatus is SUCCEEDED you must supply a preimage - def updateOutgoingPayment(id: UUID, newStatus: OutgoingPaymentStatus.Value, preimage: Option[ByteVector32] = None) + /** Update the status of the payment in case of failure. */ + def updateOutgoingPayment(paymentResult: PaymentFailed): Unit + /** Get an outgoing payment attempt. */ def getOutgoingPayment(id: UUID): Option[OutgoingPayment] - // all the outgoing payment (attempts) to pay the given paymentHash - def getOutgoingPayments(paymentHash: ByteVector32): Seq[OutgoingPayment] + /** List all the outgoing payment attempts that are children of the given id. */ + def listOutgoingPayments(parentId: UUID): Seq[OutgoingPayment] - def listOutgoingPayments(): Seq[OutgoingPayment] + /** List all the outgoing payment attempts that tried to pay the given payment hash. */ + def listOutgoingPayments(paymentHash: ByteVector32): Seq[OutgoingPayment] - def addPaymentRequest(pr: PaymentRequest, preimage: ByteVector32) + /** List all the outgoing payment attempts in the given time range (milli-seconds). */ + def listOutgoingPayments(from: Long, to: Long): Seq[OutgoingPayment] - def getPaymentRequest(paymentHash: ByteVector32): Option[PaymentRequest] + /** Add a new expected incoming payment (not yet received). */ + def addIncomingPayment(pr: PaymentRequest, preimage: ByteVector32): Unit - // returns non paid payment request - def getPendingPaymentRequestAndPreimage(paymentHash: ByteVector32): Option[(ByteVector32, PaymentRequest)] + /** + * Mark an incoming payment as received (paid). The received amount may exceed the payment request amount. + * Note that this function assumes that there is a matching payment request in the DB. + */ + def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: Long = Platform.currentTime): Unit - def listPaymentRequests(from: Long, to: Long): Seq[PaymentRequest] + /** Get information about the incoming payment (paid or not) for the given payment hash, if any. */ + def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] - // returns non paid, non expired payment requests - def listPendingPaymentRequests(from: Long, to: Long): Seq[PaymentRequest] + /** List all incoming payments (pending, expired and succeeded) in the given time range (milli-seconds). */ + def listIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] - // assumes there is already a payment request for it (the record for the given payment hash) - def addIncomingPayment(payment: IncomingPayment) + /** List all pending (not paid, not expired) incoming payments in the given time range (milli-seconds). */ + def listPendingIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] - def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] + /** List all expired (not paid) incoming payments in the given time range (milli-seconds). */ + def listExpiredIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] - def listIncomingPayments(): Seq[IncomingPayment] + /** List all received (paid) incoming payments in the given time range (milli-seconds). */ + def listReceivedIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] } /** - * Incoming payment object stored in DB. - * - * @param paymentHash identifier of the payment - * @param amount amount of the payment, in milli-satoshis - * @param receivedAt absolute time in seconds since UNIX epoch when the payment was received. - */ -case class IncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: Long) + * An incoming payment received by this node. + * At first it is in a pending state once the payment request has been generated, then will become either a success (if + * we receive a valid HTLC) or a failure (if the payment request expires). + * + * @param paymentRequest Bolt 11 payment request. + * @param paymentPreimage pre-image associated with the payment request's payment_hash. + * @param createdAt absolute time in milli-seconds since UNIX epoch when the payment request was generated. + * @param status current status of the payment. + */ +case class IncomingPayment(paymentRequest: PaymentRequest, + paymentPreimage: ByteVector32, + createdAt: Long, + status: IncomingPaymentStatus) + +sealed trait IncomingPaymentStatus + +object IncomingPaymentStatus { + + /** Payment is pending (waiting to receive). */ + case object Pending extends IncomingPaymentStatus + + /** Payment has expired. */ + case object Expired extends IncomingPaymentStatus + + /** + * Payment has been successfully received. + * + * @param amount amount of the payment received, in milli-satoshis (may exceed the payment request amount). + * @param receivedAt absolute time in milli-seconds since UNIX epoch when the payment was received. + */ + case class Received(amount: MilliSatoshi, receivedAt: Long) extends IncomingPaymentStatus + +} /** - * Sent payment is every payment that is sent by this node, they may not be finalized and - * when is final it can be failed or successful. - * - * @param id internal payment identifier - * @param paymentHash payment_hash - * @param preimage the preimage of the payment_hash, known if the outgoing payment was successful - * @param amount amount of the payment, in milli-satoshis - * @param createdAt absolute time in seconds since UNIX epoch when the payment was created. - * @param completedAt absolute time in seconds since UNIX epoch when the payment succeeded. - * @param status current status of the payment. - */ -case class OutgoingPayment(id: UUID, paymentHash: ByteVector32, preimage:Option[ByteVector32], amount: MilliSatoshi, createdAt: Long, completedAt: Option[Long], status: OutgoingPaymentStatus.Value) - -object OutgoingPaymentStatus extends Enumeration { - val PENDING = Value(1, "PENDING") - val SUCCEEDED = Value(2, "SUCCEEDED") - val FAILED = Value(3, "FAILED") + * An outgoing payment sent by this node. + * At first it is in a pending state, then will become either a success or a failure. + * + * @param id internal payment identifier. + * @param parentId internal identifier of a parent payment, or [[id]] if single-part payment. + * @param externalId external payment identifier: lets lightning applications reconcile payments with their own db. + * @param paymentHash payment_hash. + * @param amount amount of the payment, in milli-satoshis. + * @param targetNodeId node ID of the payment recipient. + * @param createdAt absolute time in milli-seconds since UNIX epoch when the payment was created. + * @param paymentRequest Bolt 11 payment request (if paying from an invoice). + * @param status current status of the payment. + */ +case class OutgoingPayment(id: UUID, + parentId: UUID, + externalId: Option[String], + paymentHash: ByteVector32, + amount: MilliSatoshi, + targetNodeId: PublicKey, + createdAt: Long, + paymentRequest: Option[PaymentRequest], + status: OutgoingPaymentStatus) + +sealed trait OutgoingPaymentStatus + +object OutgoingPaymentStatus { + + /** Payment is pending (waiting for the recipient to release the pre-image). */ + case object Pending extends OutgoingPaymentStatus + + /** + * Payment has been successfully sent and the recipient released the pre-image. + * We now have a valid proof-of-payment. + * + * @param paymentPreimage the preimage of the payment_hash. + * @param feesPaid total amount of fees paid to intermediate routing nodes. + * @param route payment route. + * @param completedAt absolute time in milli-seconds since UNIX epoch when the payment was completed. + */ + case class Succeeded(paymentPreimage: ByteVector32, feesPaid: MilliSatoshi, route: Seq[HopSummary], completedAt: Long) extends OutgoingPaymentStatus + + /** + * Payment has failed and may be retried. + * + * @param failures failed payment attempts. + * @param completedAt absolute time in milli-seconds since UNIX epoch when the payment was completed. + */ + case class Failed(failures: Seq[FailureSummary], completedAt: Long) extends OutgoingPaymentStatus + +} + +/** A minimal representation of a hop in a payment route (suitable to store in a database). */ +case class HopSummary(nodeId: PublicKey, nextNodeId: PublicKey, shortChannelId: Option[ShortChannelId] = None) { + override def toString = shortChannelId match { + case Some(shortChannelId) => s"$nodeId->$nextNodeId ($shortChannelId)" + case None => s"$nodeId->$nextNodeId" + } +} + +object HopSummary { + def apply(h: Hop): HopSummary = HopSummary(h.nodeId, h.nextNodeId, Some(h.lastUpdate.shortChannelId)) +} + +/** A minimal representation of a payment failure (suitable to store in a database). */ +case class FailureSummary(failureType: FailureType.Value, failureMessage: String, failedRoute: List[HopSummary]) + +object FailureType extends Enumeration { + val LOCAL = Value(1, "Local") + val REMOTE = Value(2, "Remote") + val UNREADABLE_REMOTE = Value(3, "UnreadableRemote") +} + +object FailureSummary { + def apply(f: PaymentFailure): FailureSummary = f match { + case LocalFailure(t) => FailureSummary(FailureType.LOCAL, t.getMessage, Nil) + case RemoteFailure(route, e) => FailureSummary(FailureType.REMOTE, e.failureMessage.message, route.map(h => HopSummary(h)).toList) + case UnreadableRemoteFailure(route) => FailureSummary(FailureType.UNREADABLE_REMOTE, "could not decrypt failure onion", route.map(h => HopSummary(h)).toList) + } } \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala index c03bec2a84..ad49eca9e4 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala @@ -18,17 +18,18 @@ package fr.acinq.eclair.db.sqlite import java.sql.{Connection, Statement} import java.util.UUID + import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.Satoshi import fr.acinq.eclair.MilliSatoshi -import fr.acinq.eclair.channel.{AvailableBalanceChanged, Channel, ChannelErrorOccured, NetworkFeePaid} -import fr.acinq.eclair.db.{AuditDb, ChannelLifecycleEvent, NetworkFee, Stats} -import fr.acinq.eclair.payment.{PaymentReceived, PaymentRelayed, PaymentSent} +import fr.acinq.eclair.channel.{AvailableBalanceChanged, Channel, ChannelErrorOccurred, NetworkFeePaid} +import fr.acinq.eclair.db._ +import fr.acinq.eclair.payment._ import fr.acinq.eclair.wire.ChannelCodecs import grizzled.slf4j.Logging + import scala.collection.immutable.Queue import scala.compat.Platform -import concurrent.duration._ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { @@ -38,7 +39,7 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { val DB_NAME = "audit" val CURRENT_VERSION = 3 - using(sqlite.createStatement()) { statement => + using(sqlite.createStatement(), inTransaction = true) { statement => def migration12(statement: Statement) = { statement.executeUpdate(s"ALTER TABLE sent ADD id BLOB DEFAULT '${ChannelCodecs.UNKNOWN_UUID.toString}' NOT NULL") @@ -75,7 +76,6 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)") statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)") statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)") - case unknownVersion => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } } @@ -105,24 +105,29 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { override def add(e: PaymentSent): Unit = using(sqlite.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement => - statement.setLong(1, e.amount.toLong) - statement.setLong(2, e.feesPaid.toLong) - statement.setBytes(3, e.paymentHash.toArray) - statement.setBytes(4, e.paymentPreimage.toArray) - statement.setBytes(5, e.toChannelId.toArray) - statement.setLong(6, e.timestamp) - statement.setBytes(7, e.id.toString.getBytes) - - statement.executeUpdate() + e.parts.foreach(p => { + statement.setLong(1, p.amount.toLong) + statement.setLong(2, p.feesPaid.toLong) + statement.setBytes(3, e.paymentHash.toArray) + statement.setBytes(4, e.paymentPreimage.toArray) + statement.setBytes(5, p.toChannelId.toArray) + statement.setLong(6, p.timestamp) + statement.setBytes(7, p.id.toString.getBytes) + statement.addBatch() + }) + statement.executeBatch() } override def add(e: PaymentReceived): Unit = using(sqlite.prepareStatement("INSERT INTO received VALUES (?, ?, ?, ?)")) { statement => - statement.setLong(1, e.amount.toLong) - statement.setBytes(2, e.paymentHash.toArray) - statement.setBytes(3, e.fromChannelId.toArray) - statement.setLong(4, e.timestamp) - statement.executeUpdate() + e.parts.foreach(p => { + statement.setLong(1, p.amount.toLong) + statement.setBytes(2, e.paymentHash.toArray) + statement.setBytes(3, p.fromChannelId.toArray) + statement.setLong(4, p.timestamp) + statement.addBatch() + }) + statement.executeBatch() } override def add(e: PaymentRelayed): Unit = @@ -147,7 +152,7 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { statement.executeUpdate() } - override def add(e: ChannelErrorOccured): Unit = + override def add(e: ChannelErrorOccurred): Unit = using(sqlite.prepareStatement("INSERT INTO channel_errors VALUES (?, ?, ?, ?, ?, ?)")) { statement => val (errorName, errorMessage) = e.error match { case Channel.LocalError(t) => (t.getClass.getSimpleName, t.getMessage) @@ -163,44 +168,49 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { } override def listSent(from: Long, to: Long): Seq[PaymentSent] = - using(sqlite.prepareStatement("SELECT * FROM sent WHERE timestamp >= ? AND timestamp < ?")) { statement => - statement.setLong(1, from.seconds.toMillis) - statement.setLong(2, to.seconds.toMillis) + using(sqlite.prepareStatement("SELECT * FROM sent WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement => + statement.setLong(1, from) + statement.setLong(2, to) val rs = statement.executeQuery() var q: Queue[PaymentSent] = Queue() while (rs.next()) { q = q :+ PaymentSent( - id = UUID.fromString(rs.getString("id")), - amount = MilliSatoshi(rs.getLong("amount_msat")), - feesPaid = MilliSatoshi(rs.getLong("fees_msat")), - paymentHash = rs.getByteVector32("payment_hash"), - paymentPreimage = rs.getByteVector32("payment_preimage"), - toChannelId = rs.getByteVector32("to_channel_id"), - timestamp = rs.getLong("timestamp")) + UUID.fromString(rs.getString("id")), + rs.getByteVector32("payment_hash"), + rs.getByteVector32("payment_preimage"), + Seq(PaymentSent.PartialPayment( + UUID.fromString(rs.getString("id")), + MilliSatoshi(rs.getLong("amount_msat")), + MilliSatoshi(rs.getLong("fees_msat")), + rs.getByteVector32("to_channel_id"), + None, // we don't store the route + rs.getLong("timestamp")))) } q } override def listReceived(from: Long, to: Long): Seq[PaymentReceived] = - using(sqlite.prepareStatement("SELECT * FROM received WHERE timestamp >= ? AND timestamp < ?")) { statement => - statement.setLong(1, from.seconds.toMillis) - statement.setLong(2, to.seconds.toMillis) + using(sqlite.prepareStatement("SELECT * FROM received WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement => + statement.setLong(1, from) + statement.setLong(2, to) val rs = statement.executeQuery() var q: Queue[PaymentReceived] = Queue() while (rs.next()) { q = q :+ PaymentReceived( - amount = MilliSatoshi(rs.getLong("amount_msat")), - paymentHash = rs.getByteVector32("payment_hash"), - fromChannelId = rs.getByteVector32("from_channel_id"), - timestamp = rs.getLong("timestamp")) + rs.getByteVector32("payment_hash"), + Seq(PaymentReceived.PartialPayment( + MilliSatoshi(rs.getLong("amount_msat")), + rs.getByteVector32("from_channel_id"), + rs.getLong("timestamp") + ))) } q } override def listRelayed(from: Long, to: Long): Seq[PaymentRelayed] = - using(sqlite.prepareStatement("SELECT * FROM relayed WHERE timestamp >= ? AND timestamp < ?")) { statement => - statement.setLong(1, from.seconds.toMillis) - statement.setLong(2, to.seconds.toMillis) + using(sqlite.prepareStatement("SELECT * FROM relayed WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement => + statement.setLong(1, from) + statement.setLong(2, to) val rs = statement.executeQuery() var q: Queue[PaymentRelayed] = Queue() while (rs.next()) { @@ -216,9 +226,9 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { } override def listNetworkFees(from: Long, to: Long): Seq[NetworkFee] = - using(sqlite.prepareStatement("SELECT * FROM network_fees WHERE timestamp >= ? AND timestamp < ?")) { statement => - statement.setLong(1, from.seconds.toMillis) - statement.setLong(2, to.seconds.toMillis) + using(sqlite.prepareStatement("SELECT * FROM network_fees WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement => + statement.setLong(1, from) + statement.setLong(2, to) val rs = statement.executeQuery() var q: Queue[NetworkFee] = Queue() while (rs.next()) { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala index 9cadd6ab1e..1641a73790 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala @@ -35,24 +35,31 @@ class SqliteChannelsDb(sqlite: Connection) extends ChannelsDb with Logging { val DB_NAME = "channels" val CURRENT_VERSION = 2 - private def migration12(statement: Statement) = { - statement.executeUpdate("ALTER TABLE local_channels ADD COLUMN is_closed BOOLEAN NOT NULL DEFAULT 0") + // The SQLite documentation states that "It is not possible to enable or disable foreign key constraints in the middle + // of a multi-statement transaction (when SQLite is not in autocommit mode).". + // So we need to set foreign keys before we initialize tables / migrations (which is done inside a transaction). + using(sqlite.createStatement()) { statement => + statement.execute("PRAGMA foreign_keys = ON") } - using(sqlite.createStatement()) { statement => + using(sqlite.createStatement(), inTransaction = true) { statement => + + def migration12(statement: Statement) = { + statement.executeUpdate("ALTER TABLE local_channels ADD COLUMN is_closed BOOLEAN NOT NULL DEFAULT 0") + } + getVersion(statement, DB_NAME, CURRENT_VERSION) match { case 1 => logger.warn(s"migrating db $DB_NAME, found version=1 current=$CURRENT_VERSION") migration12(statement) setVersion(statement, DB_NAME, CURRENT_VERSION) case CURRENT_VERSION => - statement.execute("PRAGMA foreign_keys = ON") statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT 0)") statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") - case unknownVersion => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } + } override def addOrUpdateChannel(state: HasCommitments): Unit = { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala index cc2e8c1fac..0295d55b39 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala @@ -36,7 +36,7 @@ class SqliteNetworkDb(sqlite: Connection) extends NetworkDb with Logging { val DB_NAME = "network" val CURRENT_VERSION = 2 - using(sqlite.createStatement()) { statement => + using(sqlite.createStatement(), inTransaction = true) { statement => getVersion(statement, DB_NAME, CURRENT_VERSION) match { case 1 => // channel_update are cheap to retrieve, so let's just wipe them out and they'll get resynced @@ -142,7 +142,7 @@ class SqliteNetworkDb(sqlite: Connection) extends NetworkDb with Logging { } override def addToPruned(shortChannelIds: Iterable[ShortChannelId]): Unit = { - using(sqlite.prepareStatement("INSERT OR IGNORE INTO pruned VALUES (?)"), disableAutoCommit = true) { statement => + using(sqlite.prepareStatement("INSERT OR IGNORE INTO pruned VALUES (?)"), inTransaction = true) { statement => shortChannelIds.foreach(shortChannelId => { statement.setLong(1, shortChannelId.toLong) statement.addBatch() diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala index 127ec6be2e..0cdb65c15a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala @@ -16,214 +16,303 @@ package fr.acinq.eclair.db.sqlite -import java.sql.Connection +import java.sql.{Connection, ResultSet, Statement} import java.util.UUID + import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} +import fr.acinq.eclair.MilliSatoshi +import fr.acinq.eclair.db._ import fr.acinq.eclair.db.sqlite.SqliteUtils._ -import fr.acinq.eclair.db.{IncomingPayment, OutgoingPayment, OutgoingPaymentStatus, PaymentsDb} -import fr.acinq.eclair.payment.PaymentRequest +import fr.acinq.eclair.payment.{PaymentFailed, PaymentRequest, PaymentSent} +import fr.acinq.eclair.wire.CommonCodecs import grizzled.slf4j.Logging +import scodec.Attempt +import scodec.codecs._ + import scala.collection.immutable.Queue -import OutgoingPaymentStatus._ -import fr.acinq.eclair.MilliSatoshi -import concurrent.duration._ import scala.compat.Platform +import scala.concurrent.duration._ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { import SqliteUtils.ExtendedResultSet._ val DB_NAME = "payments" - val CURRENT_VERSION = 2 - - using(sqlite.createStatement()) { statement => - require(getVersion(statement, DB_NAME, CURRENT_VERSION) <= CURRENT_VERSION, s"incompatible version of $DB_NAME DB found") // version 2 is "backward compatible" in the sense that it uses separate tables from version 1. There is no migration though - statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER, received_at INTEGER)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, payment_hash BLOB NOT NULL, preimage BLOB, amount_msat INTEGER NOT NULL, created_at INTEGER NOT NULL, completed_at INTEGER, status VARCHAR NOT NULL)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS payment_hash_idx ON sent_payments(payment_hash)") - setVersion(statement, DB_NAME, CURRENT_VERSION) + val CURRENT_VERSION = 3 + + private val hopSummaryCodec = (("node_id" | CommonCodecs.publicKey) :: ("next_node_id" | CommonCodecs.publicKey) :: ("short_channel_id" | optional(bool, CommonCodecs.shortchannelid))).as[HopSummary] + private val paymentRouteCodec = discriminated[List[HopSummary]].by(byte) + .typecase(0x01, listOfN(uint8, hopSummaryCodec)) + private val failureSummaryCodec = (("type" | enumerated(uint8, FailureType)) :: ("message" | ascii32) :: paymentRouteCodec).as[FailureSummary] + private val paymentFailuresCodec = discriminated[List[FailureSummary]].by(byte) + .typecase(0x01, listOfN(uint8, failureSummaryCodec)) + + using(sqlite.createStatement(), inTransaction = true) { statement => + + def migration12(statement: Statement) = { + // Version 2 is "backwards compatible" in the sense that it uses separate tables from version 1 (which used a single "payments" table). + statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER, received_at INTEGER)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, payment_hash BLOB NOT NULL, preimage BLOB, amount_msat INTEGER NOT NULL, created_at INTEGER NOT NULL, completed_at INTEGER, status VARCHAR NOT NULL)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS payment_hash_idx ON sent_payments(payment_hash)") + } + + def migration23(statement: Statement) = { + // We add many more columns to the sent_payments table. + statement.executeUpdate("DROP index payment_hash_idx") + statement.executeUpdate("ALTER TABLE sent_payments RENAME TO _sent_payments_old") + statement.executeUpdate("CREATE TABLE sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, target_node_id BLOB NOT NULL, created_at INTEGER NOT NULL, payment_request TEXT, completed_at INTEGER, payment_preimage BLOB, fees_msat INTEGER, payment_route BLOB, failures BLOB)") + // Old rows will be missing a target node id, so we use an easy-to-spot default value. + val defaultTargetNodeId = PrivateKey(ByteVector32.One).publicKey + statement.executeUpdate(s"INSERT INTO sent_payments (id, parent_id, payment_hash, amount_msat, target_node_id, created_at, completed_at, payment_preimage) SELECT id, id, payment_hash, amount_msat, X'${defaultTargetNodeId.toString}', created_at, completed_at, preimage FROM _sent_payments_old") + statement.executeUpdate("DROP table _sent_payments_old") + + statement.executeUpdate("ALTER TABLE received_payments RENAME TO _received_payments_old") + // We make payment request expiration not null in the received_payments table. + // When it was previously set to NULL the default expiry should apply. + statement.executeUpdate(s"UPDATE _received_payments_old SET expire_at = created_at + ${PaymentRequest.DEFAULT_EXPIRY_SECONDS} WHERE expire_at IS NULL") + statement.executeUpdate("CREATE TABLE received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, payment_preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER NOT NULL, received_at INTEGER)") + statement.executeUpdate("INSERT INTO received_payments (payment_hash, payment_preimage, payment_request, received_msat, created_at, expire_at, received_at) SELECT payment_hash, preimage, payment_request, received_msat, created_at, expire_at, received_at FROM _received_payments_old") + statement.executeUpdate("DROP table _received_payments_old") + + statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_parent_id_idx ON sent_payments(parent_id)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_payment_hash_idx ON sent_payments(payment_hash)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_created_idx ON sent_payments(created_at)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_created_idx ON received_payments(created_at)") + } + + getVersion(statement, DB_NAME, CURRENT_VERSION) match { + case 1 => + logger.warn(s"migrating db $DB_NAME, found version=1 current=$CURRENT_VERSION") + migration12(statement) + migration23(statement) + setVersion(statement, DB_NAME, CURRENT_VERSION) + case 2 => + logger.warn(s"migrating db $DB_NAME, found version=2 current=$CURRENT_VERSION") + migration23(statement) + setVersion(statement, DB_NAME, CURRENT_VERSION) + case CURRENT_VERSION => + statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, payment_preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER NOT NULL, received_at INTEGER)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, target_node_id BLOB NOT NULL, created_at INTEGER NOT NULL, payment_request TEXT, completed_at INTEGER, payment_preimage BLOB, fees_msat INTEGER, payment_route BLOB, failures BLOB)") + + statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_parent_id_idx ON sent_payments(parent_id)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_payment_hash_idx ON sent_payments(payment_hash)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_created_idx ON sent_payments(created_at)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_created_idx ON received_payments(created_at)") + case unknownVersion => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") + } + } override def addOutgoingPayment(sent: OutgoingPayment): Unit = { - using(sqlite.prepareStatement("INSERT INTO sent_payments (id, payment_hash, amount_msat, created_at, status) VALUES (?, ?, ?, ?, ?)")) { statement => + require(sent.status == OutgoingPaymentStatus.Pending, s"outgoing payment isn't pending (${sent.status.getClass.getSimpleName})") + using(sqlite.prepareStatement("INSERT INTO sent_payments (id, parent_id, external_id, payment_hash, amount_msat, target_node_id, created_at, payment_request) VALUES (?, ?, ?, ?, ?, ?, ?, ?)")) { statement => statement.setString(1, sent.id.toString) - statement.setBytes(2, sent.paymentHash.toArray) - statement.setLong(3, sent.amount.toLong) - statement.setLong(4, sent.createdAt) - statement.setString(5, sent.status.toString) - val res = statement.executeUpdate() - logger.debug(s"inserted $res payment=${sent.paymentHash} into payment DB") + statement.setString(2, sent.parentId.toString) + statement.setString(3, sent.externalId.orNull) + statement.setBytes(4, sent.paymentHash.toArray) + statement.setLong(5, sent.amount.toLong) + statement.setBytes(6, sent.targetNodeId.value.toArray) + statement.setLong(7, sent.createdAt) + statement.setString(8, sent.paymentRequest.map(PaymentRequest.write).orNull) + statement.executeUpdate() } } - override def updateOutgoingPayment(id: UUID, newStatus: OutgoingPaymentStatus.Value, preimage: Option[ByteVector32] = None) = { - require((newStatus == SUCCEEDED && preimage.isDefined) || (newStatus == FAILED && preimage.isEmpty), "Wrong combination of state/preimage") + override def updateOutgoingPayment(paymentResult: PaymentSent): Unit = + using(sqlite.prepareStatement("UPDATE sent_payments SET (completed_at, payment_preimage, fees_msat, payment_route) = (?, ?, ?, ?) WHERE id = ? AND completed_at IS NULL")) { statement => + paymentResult.parts.foreach(p => { + statement.setLong(1, p.timestamp) + statement.setBytes(2, paymentResult.paymentPreimage.toArray) + statement.setLong(3, p.feesPaid.toLong) + statement.setBytes(4, paymentRouteCodec.encode(p.route.getOrElse(Nil).map(h => HopSummary(h)).toList).require.toByteArray) + statement.setString(5, p.id.toString) + statement.addBatch() + }) + if (statement.executeBatch().contains(0)) throw new IllegalArgumentException(s"Tried to mark an outgoing payment as succeeded but already in final status (id=${paymentResult.id})") + } - using(sqlite.prepareStatement("UPDATE sent_payments SET (completed_at, preimage, status) = (?, ?, ?) WHERE id = ? AND completed_at IS NULL")) { statement => - statement.setLong(1, Platform.currentTime) - statement.setBytes(2, if (preimage.isEmpty) null else preimage.get.toArray) - statement.setString(3, newStatus.toString) - statement.setString(4, id.toString) - if (statement.executeUpdate() == 0) throw new IllegalArgumentException(s"Tried to update an outgoing payment (id=$id) already in final status with=$newStatus") + override def updateOutgoingPayment(paymentResult: PaymentFailed): Unit = + using(sqlite.prepareStatement("UPDATE sent_payments SET (completed_at, failures) = (?, ?) WHERE id = ? AND completed_at IS NULL")) { statement => + statement.setLong(1, paymentResult.timestamp) + statement.setBytes(2, paymentFailuresCodec.encode(paymentResult.failures.map(f => FailureSummary(f)).toList).require.toByteArray) + statement.setString(3, paymentResult.id.toString) + if (statement.executeUpdate() == 0) throw new IllegalArgumentException(s"Tried to mark an outgoing payment as failed but already in final status (id=${paymentResult.id})") + } + + private def parseOutgoingPayment(rs: ResultSet): OutgoingPayment = { + val result = OutgoingPayment( + UUID.fromString(rs.getString("id")), + UUID.fromString(rs.getString("parent_id")), + rs.getStringNullable("external_id"), + rs.getByteVector32("payment_hash"), + MilliSatoshi(rs.getLong("amount_msat")), + PublicKey(rs.getByteVector("target_node_id")), + rs.getLong("created_at"), + rs.getStringNullable("payment_request").map(PaymentRequest.read), + OutgoingPaymentStatus.Pending + ) + // If we have a pre-image, the payment succeeded. + rs.getByteVector32Nullable("payment_preimage") match { + case Some(paymentPreimage) => result.copy(status = OutgoingPaymentStatus.Succeeded( + paymentPreimage, + MilliSatoshi(rs.getLong("fees_msat")), + rs.getBitVectorOpt("payment_route").map(b => paymentRouteCodec.decode(b) match { + case Attempt.Successful(route) => route.value + case Attempt.Failure(_) => Nil + }).getOrElse(Nil), + rs.getLong("completed_at") + )) + case None => getNullableLong(rs, "completed_at") match { + // Otherwise if the payment was marked completed, it's a failure. + case Some(completedAt) => result.copy(status = OutgoingPaymentStatus.Failed( + rs.getBitVectorOpt("failures").map(b => paymentFailuresCodec.decode(b) match { + case Attempt.Successful(failures) => failures.value + case Attempt.Failure(_) => Nil + }).getOrElse(Nil), + completedAt + )) + // Else it's still pending. + case _ => result + } } } - override def getOutgoingPayment(id: UUID): Option[OutgoingPayment] = { - using(sqlite.prepareStatement("SELECT id, payment_hash, preimage, amount_msat, created_at, completed_at, status FROM sent_payments WHERE id = ?")) { statement => + override def getOutgoingPayment(id: UUID): Option[OutgoingPayment] = + using(sqlite.prepareStatement("SELECT * FROM sent_payments WHERE id = ?")) { statement => statement.setString(1, id.toString) val rs = statement.executeQuery() if (rs.next()) { - Some(OutgoingPayment( - UUID.fromString(rs.getString("id")), - rs.getByteVector32("payment_hash"), - rs.getByteVector32Nullable("preimage"), - MilliSatoshi(rs.getLong("amount_msat")), - rs.getLong("created_at"), - getNullableLong(rs, "completed_at"), - OutgoingPaymentStatus.withName(rs.getString("status")) - )) + Some(parseOutgoingPayment(rs)) } else { None } } - } - override def getOutgoingPayments(paymentHash: ByteVector32): Seq[OutgoingPayment] = { - using(sqlite.prepareStatement("SELECT id, payment_hash, preimage, amount_msat, created_at, completed_at, status FROM sent_payments WHERE payment_hash = ?")) { statement => - statement.setBytes(1, paymentHash.toArray) + override def listOutgoingPayments(parentId: UUID): Seq[OutgoingPayment] = + using(sqlite.prepareStatement("SELECT * FROM sent_payments WHERE parent_id = ? ORDER BY created_at")) { statement => + statement.setString(1, parentId.toString) val rs = statement.executeQuery() var q: Queue[OutgoingPayment] = Queue() while (rs.next()) { - q = q :+ OutgoingPayment( - UUID.fromString(rs.getString("id")), - rs.getByteVector32("payment_hash"), - rs.getByteVector32Nullable("preimage"), - MilliSatoshi(rs.getLong("amount_msat")), - rs.getLong("created_at"), - getNullableLong(rs, "completed_at"), - OutgoingPaymentStatus.withName(rs.getString("status")) - ) + q = q :+ parseOutgoingPayment(rs) } q } - } - override def listOutgoingPayments(): Seq[OutgoingPayment] = { - using(sqlite.createStatement()) { statement => - val rs = statement.executeQuery("SELECT id, payment_hash, preimage, amount_msat, created_at, completed_at, status FROM sent_payments") + override def listOutgoingPayments(paymentHash: ByteVector32): Seq[OutgoingPayment] = + using(sqlite.prepareStatement("SELECT * FROM sent_payments WHERE payment_hash = ? ORDER BY created_at")) { statement => + statement.setBytes(1, paymentHash.toArray) + val rs = statement.executeQuery() var q: Queue[OutgoingPayment] = Queue() while (rs.next()) { - q = q :+ OutgoingPayment( - UUID.fromString(rs.getString("id")), - rs.getByteVector32("payment_hash"), - rs.getByteVector32Nullable("preimage"), - MilliSatoshi(rs.getLong("amount_msat")), - rs.getLong("created_at"), - getNullableLong(rs, "completed_at"), - OutgoingPaymentStatus.withName(rs.getString("status")) - ) + q = q :+ parseOutgoingPayment(rs) } q } - } - override def addPaymentRequest(pr: PaymentRequest, preimage: ByteVector32): Unit = { - val insertStmt = pr.expiry match { - case Some(_) => "INSERT INTO received_payments (payment_hash, preimage, payment_request, created_at, expire_at) VALUES (?, ?, ?, ?, ?)" - case None => "INSERT INTO received_payments (payment_hash, preimage, payment_request, created_at) VALUES (?, ?, ?, ?)" + override def listOutgoingPayments(from: Long, to: Long): Seq[OutgoingPayment] = + using(sqlite.prepareStatement("SELECT * FROM sent_payments WHERE created_at >= ? AND created_at < ? ORDER BY created_at")) { statement => + statement.setLong(1, from) + statement.setLong(2, to) + val rs = statement.executeQuery() + var q: Queue[OutgoingPayment] = Queue() + while (rs.next()) { + q = q :+ parseOutgoingPayment(rs) + } + q } - using(sqlite.prepareStatement(insertStmt)) { statement => + override def addIncomingPayment(pr: PaymentRequest, preimage: ByteVector32): Unit = + using(sqlite.prepareStatement("INSERT INTO received_payments (payment_hash, payment_preimage, payment_request, created_at, expire_at) VALUES (?, ?, ?, ?, ?)")) { statement => statement.setBytes(1, pr.paymentHash.toArray) statement.setBytes(2, preimage.toArray) statement.setString(3, PaymentRequest.write(pr)) statement.setLong(4, pr.timestamp.seconds.toMillis) // BOLT11 timestamp is in seconds - pr.expiry.foreach { ex => statement.setLong(5, pr.timestamp.seconds.toMillis + ex.seconds.toMillis) } // we store "when" the invoice will expire, in milliseconds + statement.setLong(5, (pr.timestamp + pr.expiry.getOrElse(PaymentRequest.DEFAULT_EXPIRY_SECONDS.toLong)).seconds.toMillis) statement.executeUpdate() } - } - override def getPaymentRequest(paymentHash: ByteVector32): Option[PaymentRequest] = { - using(sqlite.prepareStatement("SELECT payment_request FROM received_payments WHERE payment_hash = ?")) { statement => - statement.setBytes(1, paymentHash.toArray) - val rs = statement.executeQuery() - if (rs.next()) { - Some(PaymentRequest.read(rs.getString("payment_request"))) - } else { - None - } + override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: Long): Unit = + using(sqlite.prepareStatement("UPDATE received_payments SET (received_msat, received_at) = (?, ?) WHERE payment_hash = ?")) { statement => + statement.setLong(1, amount.toLong) + statement.setLong(2, receivedAt) + statement.setBytes(3, paymentHash.toArray) + val res = statement.executeUpdate() + if (res == 0) throw new IllegalArgumentException("Inserted a received payment without having an invoice") + } + + private def parseIncomingPayment(rs: ResultSet): IncomingPayment = { + val paymentRequest = PaymentRequest.read(rs.getString("payment_request")) + val paymentPreimage = rs.getByteVector32("payment_preimage") + val createdAt = rs.getLong("created_at") + val received = getNullableLong(rs, "received_msat").map(MilliSatoshi(_)) + received match { + case Some(amount) => IncomingPayment(paymentRequest, paymentPreimage, createdAt, IncomingPaymentStatus.Received(amount, rs.getLong("received_at"))) + case None if paymentRequest.isExpired => IncomingPayment(paymentRequest, paymentPreimage, createdAt, IncomingPaymentStatus.Expired) + case None => IncomingPayment(paymentRequest, paymentPreimage, createdAt, IncomingPaymentStatus.Pending) } } - override def getPendingPaymentRequestAndPreimage(paymentHash: ByteVector32): Option[(ByteVector32, PaymentRequest)] = { - using(sqlite.prepareStatement("SELECT payment_request, preimage FROM received_payments WHERE payment_hash = ? AND received_at IS NULL")) { statement => + override def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] = + using(sqlite.prepareStatement("SELECT * FROM received_payments WHERE payment_hash = ?")) { statement => statement.setBytes(1, paymentHash.toArray) val rs = statement.executeQuery() if (rs.next()) { - val preimage = rs.getByteVector32("preimage") - val pr = PaymentRequest.read(rs.getString("payment_request")) - Some(preimage, pr) + Some(parseIncomingPayment(rs)) } else { None } } - } - - override def listPaymentRequests(from: Long, to: Long): Seq[PaymentRequest] = listPaymentRequests(from, to, pendingOnly = false) - - override def listPendingPaymentRequests(from: Long, to: Long): Seq[PaymentRequest] = listPaymentRequests(from, to, pendingOnly = true) - - def listPaymentRequests(from: Long, to: Long, pendingOnly: Boolean): Seq[PaymentRequest] = { - val queryStmt = pendingOnly match { - case true => "SELECT payment_request FROM received_payments WHERE created_at > ? AND created_at < ? AND (expire_at > ? OR expire_at IS NULL) AND received_msat IS NULL ORDER BY created_at DESC" - case false => "SELECT payment_request FROM received_payments WHERE created_at > ? AND created_at < ? ORDER BY created_at DESC" - } - - using(sqlite.prepareStatement(queryStmt)) { statement => - statement.setLong(1, from.seconds.toMillis) - statement.setLong(2, to.seconds.toMillis) - if (pendingOnly) statement.setLong(3, Platform.currentTime) + override def listIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = + using(sqlite.prepareStatement("SELECT * FROM received_payments WHERE created_at > ? AND created_at < ? ORDER BY created_at")) { statement => + statement.setLong(1, from) + statement.setLong(2, to) val rs = statement.executeQuery() - var q: Queue[PaymentRequest] = Queue() + var q: Queue[IncomingPayment] = Queue() while (rs.next()) { - q = q :+ PaymentRequest.read(rs.getString("payment_request")) + q = q :+ parseIncomingPayment(rs) } q } - } - override def addIncomingPayment(payment: IncomingPayment): Unit = { - using(sqlite.prepareStatement("UPDATE received_payments SET (received_msat, received_at) = (?, ?) WHERE payment_hash = ?")) { statement => - statement.setLong(1, payment.amount.toLong) - statement.setLong(2, payment.receivedAt) - statement.setBytes(3, payment.paymentHash.toArray) - val res = statement.executeUpdate() - if (res == 0) throw new IllegalArgumentException("Inserted a received payment without having an invoice") + override def listReceivedIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = + using(sqlite.prepareStatement("SELECT * FROM received_payments WHERE received_msat > 0 AND created_at > ? AND created_at < ? ORDER BY created_at")) { statement => + statement.setLong(1, from) + statement.setLong(2, to) + val rs = statement.executeQuery() + var q: Queue[IncomingPayment] = Queue() + while (rs.next()) { + q = q :+ parseIncomingPayment(rs) + } + q } - } - override def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] = { - using(sqlite.prepareStatement("SELECT payment_hash, received_msat, received_at FROM received_payments WHERE payment_hash = ? AND received_msat > 0")) { statement => - statement.setBytes(1, paymentHash.toArray) + override def listPendingIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = + using(sqlite.prepareStatement("SELECT * FROM received_payments WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at > ? ORDER BY created_at")) { statement => + statement.setLong(1, from) + statement.setLong(2, to) + statement.setLong(3, Platform.currentTime) val rs = statement.executeQuery() - if (rs.next()) { - Some(IncomingPayment(rs.getByteVector32("payment_hash"), MilliSatoshi(rs.getLong("received_msat")), rs.getLong("received_at"))) - } else { - None + var q: Queue[IncomingPayment] = Queue() + while (rs.next()) { + q = q :+ parseIncomingPayment(rs) } + q } - } - override def listIncomingPayments(): Seq[IncomingPayment] = { - using(sqlite.createStatement()) { statement => - val rs = statement.executeQuery("SELECT payment_hash, received_msat, received_at FROM received_payments WHERE received_msat > 0") + override def listExpiredIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] = + using(sqlite.prepareStatement("SELECT * FROM received_payments WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at < ? ORDER BY created_at")) { statement => + statement.setLong(1, from) + statement.setLong(2, to) + statement.setLong(3, Platform.currentTime) + val rs = statement.executeQuery() var q: Queue[IncomingPayment] = Queue() while (rs.next()) { - q = q :+ IncomingPayment(rs.getByteVector32("payment_hash"), MilliSatoshi(rs.getLong("received_msat")), rs.getLong("received_at")) + q = q :+ parseIncomingPayment(rs) } q } - } } \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala index 8d9e828ba3..11725d94fb 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala @@ -32,7 +32,7 @@ import SqliteUtils.ExtendedResultSet._ val DB_NAME = "peers" val CURRENT_VERSION = 1 - using(sqlite.createStatement()) { statement => + using(sqlite.createStatement(), inTransaction = true) { statement => require(getVersion(statement, DB_NAME, CURRENT_VERSION) == CURRENT_VERSION, s"incompatible version of $DB_NAME DB found") // there is only one version currently deployed statement.executeUpdate("CREATE TABLE IF NOT EXISTS peers (node_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)") } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingRelayDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingRelayDb.scala index b0621ac5e2..888e7388fd 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingRelayDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingRelayDb.scala @@ -33,7 +33,7 @@ class SqlitePendingRelayDb(sqlite: Connection) extends PendingRelayDb { val DB_NAME = "pending_relay" val CURRENT_VERSION = 1 - using(sqlite.createStatement()) { statement => + using(sqlite.createStatement(), inTransaction = true) { statement => require(getVersion(statement, DB_NAME, CURRENT_VERSION) == CURRENT_VERSION, s"incompatible version of $DB_NAME DB found") // there is only one version currently deployed // note: should we use a foreign key to local_channels table here? statement.executeUpdate("CREATE TABLE IF NOT EXISTS pending_relay (channel_id BLOB NOT NULL, htlc_id INTEGER NOT NULL, data BLOB NOT NULL, PRIMARY KEY(channel_id, htlc_id))") diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala index 35af8fc8dd..4b5b957e1a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala @@ -27,31 +27,31 @@ import scala.collection.immutable.Queue object SqliteUtils { /** - * Manages closing of statement - * - * @param statement - * @param block - */ - def using[T <: Statement, U](statement: T, disableAutoCommit: Boolean = false)(block: T => U): U = { + * This helper makes sure statements are correctly closed. + * + * @param inTransaction if set to true, all updates in the block will be run in a transaction. + */ + def using[T <: Statement, U](statement: T, inTransaction: Boolean = false)(block: T => U): U = { try { - if (disableAutoCommit) statement.getConnection.setAutoCommit(false) - block(statement) + if (inTransaction) statement.getConnection.setAutoCommit(false) + val res = block(statement) + if (inTransaction) statement.getConnection.commit() + res + } catch { + case t: Exception => + if (inTransaction) statement.getConnection.rollback() + throw t } finally { - if (disableAutoCommit) statement.getConnection.setAutoCommit(true) + if (inTransaction) statement.getConnection.setAutoCommit(true) if (statement != null) statement.close() } } /** - * Several logical databases (channels, network, peers) may be stored in the same physical sqlite database. - * We keep track of their respective version using a dedicated table. The version entry will be created if - * there is none but will never be updated here (use setVersion to do that). - * - * @param statement - * @param db_name - * @param currentVersion - * @return - */ + * Several logical databases (channels, network, peers) may be stored in the same physical sqlite database. + * We keep track of their respective version using a dedicated table. The version entry will be created if + * there is none but will never be updated here (use setVersion to do that). + */ def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int = { statement.executeUpdate("CREATE TABLE IF NOT EXISTS versions (db_name TEXT NOT NULL PRIMARY KEY, version INTEGER NOT NULL)") // if there was no version for the current db, then insert the current version @@ -62,12 +62,8 @@ object SqliteUtils { } /** - * Updates the version for a particular logical database, it will overwrite the previous version. - * @param statement - * @param db_name - * @param newVersion - * @return - */ + * Updates the version for a particular logical database, it will overwrite the previous version. + */ def setVersion(statement: Statement, db_name: String, newVersion: Int) = { statement.executeUpdate("CREATE TABLE IF NOT EXISTS versions (db_name TEXT NOT NULL PRIMARY KEY, version INTEGER NOT NULL)") // overwrite the existing version @@ -75,15 +71,10 @@ object SqliteUtils { } /** - * This helper assumes that there is a "data" column available, decodable with the provided codec - * - * TODO: we should use an scala.Iterator instead - * - * @param rs - * @param codec - * @tparam T - * @return - */ + * This helper assumes that there is a "data" column available, decodable with the provided codec + * + * TODO: we should use an scala.Iterator instead + */ def codecSequence[T](rs: ResultSet, codec: Codec[T]): Seq[T] = { var q: Queue[T] = Queue() while (rs.next()) { @@ -93,27 +84,22 @@ object SqliteUtils { } /** - * This helper retrieves the value from a nullable integer column and interprets it as an option. This is needed - * because `rs.getLong` would return `0` for a null value. - * It is used on Android only - * - * @param label - * @return - */ - def getNullableLong(rs: ResultSet, label: String) : Option[Long] = { + * This helper retrieves the value from a nullable integer column and interprets it as an option. This is needed + * because `rs.getLong` would return `0` for a null value. + * It is used on Android only + */ + def getNullableLong(rs: ResultSet, label: String): Option[Long] = { val result = rs.getLong(label) if (rs.wasNull()) None else Some(result) } /** - * Obtain an exclusive lock on a sqlite database. This is useful when we want to make sure that only one process - * accesses the database file (see https://www.sqlite.org/pragma.html). - * - * The lock will be kept until the database is closed, or if the locking mode is explicitly reset. - * - * @param sqlite - */ - def obtainExclusiveLock(sqlite: Connection){ + * Obtain an exclusive lock on a sqlite database. This is useful when we want to make sure that only one process + * accesses the database file (see https://www.sqlite.org/pragma.html). + * + * The lock will be kept until the database is closed, or if the locking mode is explicitly reset. + */ + def obtainExclusiveLock(sqlite: Connection) { val statement = sqlite.createStatement() statement.execute("PRAGMA locking_mode = EXCLUSIVE") // we have to make a write to actually obtain the lock @@ -127,15 +113,27 @@ object SqliteUtils { def getByteVector(columnLabel: String): ByteVector = ByteVector(rs.getBytes(columnLabel)) + def getByteVectorNullable(columnLabel: String): ByteVector = { + val result = rs.getBytes(columnLabel) + if (rs.wasNull()) ByteVector.empty else ByteVector(result) + } + def getByteVector32(columnLabel: String): ByteVector32 = ByteVector32(ByteVector(rs.getBytes(columnLabel))) def getByteVector32Nullable(columnLabel: String): Option[ByteVector32] = { val bytes = rs.getBytes(columnLabel) - if(rs.wasNull()) None else Some(ByteVector32(ByteVector(bytes))) + if (rs.wasNull()) None else Some(ByteVector32(ByteVector(bytes))) + } + + def getStringNullable(columnLabel: String): Option[String] = { + val result = rs.getString(columnLabel) + if (rs.wasNull()) None else Some(result) } + } object ExtendedResultSet { implicit def conv(rs: ResultSet): ExtendedResultSet = ExtendedResultSet(rs) } + } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/package.scala b/eclair-core/src/main/scala/fr/acinq/eclair/package.scala index 003f7764fb..f14482fccf 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/package.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/package.scala @@ -22,7 +22,7 @@ import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin._ import scodec.Attempt import scodec.bits.{BitVector, ByteVector} -import scala.concurrent.duration.Duration + import scala.util.{Failure, Success, Try} package object eclair { @@ -152,11 +152,6 @@ package object eclair { } } - /** - * We use this in the context of timestamp filtering, when we don't need an upper bound. - */ - val MaxEpochSeconds = Duration.fromNanos(Long.MaxValue).toSeconds - implicit class LongToBtcAmount(l: Long) { // @formatter:off def msat: MilliSatoshi = MilliSatoshi(l) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Auditor.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Auditor.scala index 6a8eaa82bd..1036925221 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Auditor.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Auditor.scala @@ -35,7 +35,7 @@ class Auditor(nodeParams: NodeParams) extends Actor with ActorLogging { context.system.eventStream.subscribe(self, classOf[PaymentEvent]) context.system.eventStream.subscribe(self, classOf[NetworkFeePaid]) context.system.eventStream.subscribe(self, classOf[AvailableBalanceChanged]) - context.system.eventStream.subscribe(self, classOf[ChannelErrorOccured]) + context.system.eventStream.subscribe(self, classOf[ChannelErrorOccurred]) context.system.eventStream.subscribe(self, classOf[ChannelStateChanged]) context.system.eventStream.subscribe(self, classOf[ChannelClosed]) @@ -74,7 +74,7 @@ class Auditor(nodeParams: NodeParams) extends Actor with ActorLogging { case e: AvailableBalanceChanged => balanceEventThrottler ! e - case e: ChannelErrorOccured => + case e: ChannelErrorOccurred => val metric = Kamon.counter("channels.errors") e.error match { case LocalError(_) if e.isFatal => metric.withTag("origin", "local").withTag("fatal", "yes").increment() @@ -113,8 +113,8 @@ class Auditor(nodeParams: NodeParams) extends Actor with ActorLogging { } /** - * We don't want to log every tiny payment, and we don't want to log probing events. - */ + * We don't want to log every tiny payment, and we don't want to log probing events. + */ class BalanceEventThrottler(db: AuditDb) extends Actor with ActorLogging { import ExecutionContext.Implicits.global @@ -135,11 +135,11 @@ class BalanceEventThrottler(db: AuditDb) extends Actor with ActorLogging { // we delay the processing of the event in order to smooth variations log.info(s"will log balance event in $delay for channelId=${e.channelId}") context.system.scheduler.scheduleOnce(delay, self, ProcessEvent(e.channelId)) - context.become(run(pending + (e.channelId -> (BalanceUpdate(e, e))))) + context.become(run(pending + (e.channelId -> BalanceUpdate(e, e)))) case Some(BalanceUpdate(first, _)) => // we already are about to log a balance event, let's update the data we have log.info(s"updating balance data for channelId=${e.channelId}") - context.become(run(pending + (e.channelId -> (BalanceUpdate(first, e))))) + context.become(run(pending + (e.channelId -> BalanceUpdate(first, e)))) } case ProcessEvent(channelId) => diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala index 40ca336c44..fc738ac934 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala @@ -20,7 +20,6 @@ import akka.actor.{Actor, ActorLogging, ActorRef, Props} import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.crypto.Sphinx.DecryptedFailurePacket import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest -import fr.acinq.eclair.payment.PaymentLifecycle.{PaymentFailed, PaymentResult, RemoteFailure} import fr.acinq.eclair.router.{Announcements, Data, PublicChannel} import fr.acinq.eclair.wire.IncorrectOrUnknownPaymentDetails import fr.acinq.eclair.{LongToBtcAmount, NodeParams, randomBytes32, secureRandom} @@ -61,9 +60,9 @@ class Autoprobe(nodeParams: NodeParams, router: ActorRef, paymentInitiator: Acto scheduleProbe() } - case paymentResult: PaymentResult => + case paymentResult: PaymentEvent => paymentResult match { - case PaymentFailed(_, _, _ :+ RemoteFailure(_, DecryptedFailurePacket(targetNodeId, IncorrectOrUnknownPaymentDetails(_, _)))) => + case PaymentFailed(_, _, _ :+ RemoteFailure(_, DecryptedFailurePacket(targetNodeId, IncorrectOrUnknownPaymentDetails(_, _))), _) => log.info(s"payment probe successful to node=$targetNodeId") case _ => log.info(s"payment probe failed with paymentResult=$paymentResult") diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/LocalPaymentHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/LocalPaymentHandler.scala index 9d3cea6cbc..d5d399f932 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/LocalPaymentHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/LocalPaymentHandler.scala @@ -19,12 +19,11 @@ package fr.acinq.eclair.payment import akka.actor.{Actor, ActorLogging, Props, Status} import fr.acinq.bitcoin.Crypto import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Channel} -import fr.acinq.eclair.db.IncomingPayment +import fr.acinq.eclair.db.{IncomingPayment, IncomingPaymentStatus} import fr.acinq.eclair.payment.PaymentLifecycle.ReceivePayment import fr.acinq.eclair.wire._ import fr.acinq.eclair.{NodeParams, randomBytes32} -import scala.compat.Platform import scala.concurrent.ExecutionContext import scala.util.{Failure, Success, Try} @@ -49,7 +48,7 @@ class LocalPaymentHandler(nodeParams: NodeParams) extends Actor with ActorLoggin val expirySeconds = expirySeconds_opt.getOrElse(nodeParams.paymentRequestExpiry.toSeconds) val paymentRequest = PaymentRequest(nodeParams.chainHash, amount_opt, paymentHash, nodeParams.privateKey, desc, fallbackAddress_opt, expirySeconds = Some(expirySeconds), extraHops = extraHops) log.debug(s"generated payment request={} from amount={}", PaymentRequest.write(paymentRequest), amount_opt) - paymentDb.addPaymentRequest(paymentRequest, paymentPreimage) + paymentDb.addIncomingPayment(paymentRequest, paymentPreimage) paymentRequest } match { case Success(paymentRequest) => sender ! paymentRequest @@ -57,8 +56,11 @@ class LocalPaymentHandler(nodeParams: NodeParams) extends Actor with ActorLoggin } case htlc: UpdateAddHtlc => - paymentDb.getPendingPaymentRequestAndPreimage(htlc.paymentHash) match { - case Some((paymentPreimage, paymentRequest)) => + paymentDb.getIncomingPayment(htlc.paymentHash) match { + case Some(IncomingPayment(_, _, _, status)) if status.isInstanceOf[IncomingPaymentStatus.Received] => + log.warning(s"ignoring incoming payment for paymentHash=${htlc.paymentHash} which has already been paid") + sender ! CMD_FAIL_HTLC(htlc.id, Right(IncorrectOrUnknownPaymentDetails(htlc.amountMsat, nodeParams.currentBlockHeight)), commit = true) + case Some(IncomingPayment(paymentRequest, paymentPreimage, _, _)) => val minFinalExpiry = paymentRequest.minFinalCltvExpiryDelta.getOrElse(Channel.MIN_CLTV_EXPIRY_DELTA).toCltvExpiry(nodeParams.currentBlockHeight) // The htlc amount must be equal or greater than the requested amount. A slight overpaying is permitted, however // it must not be greater than two times the requested amount. @@ -77,10 +79,9 @@ class LocalPaymentHandler(nodeParams: NodeParams) extends Actor with ActorLoggin case _ => log.info(s"received payment for paymentHash=${htlc.paymentHash} amount=${htlc.amountMsat}") // amount is correct or was not specified in the payment request - nodeParams.db.payments.addIncomingPayment(IncomingPayment(htlc.paymentHash, htlc.amountMsat, Platform.currentTime)) + nodeParams.db.payments.receiveIncomingPayment(htlc.paymentHash, htlc.amountMsat) sender ! CMD_FULFILL_HTLC(htlc.id, paymentPreimage, commit = true) - context.system.eventStream.publish(PaymentReceived(htlc.amountMsat, htlc.paymentHash, htlc.channelId)) - + context.system.eventStream.publish(PaymentReceived(htlc.paymentHash, PaymentReceived.PartialPayment(htlc.amountMsat, htlc.channelId) :: Nil)) } case None => sender ! CMD_FAIL_HTLC(htlc.id, Right(IncorrectOrUnknownPaymentDetails(htlc.amountMsat, nodeParams.currentBlockHeight)), commit = true) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala index 7b0f1b55c2..8dafe453d7 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala @@ -20,20 +20,99 @@ import java.util.UUID import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.MilliSatoshi +import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.router.Hop import scala.compat.Platform /** - * Created by PM on 01/02/2017. - */ + * Created by PM on 01/02/2017. + */ + sealed trait PaymentEvent { val paymentHash: ByteVector32 + val timestamp: Long +} + +case class PaymentSent(id: UUID, paymentHash: ByteVector32, paymentPreimage: ByteVector32, parts: Seq[PaymentSent.PartialPayment]) extends PaymentEvent { + require(parts.nonEmpty, "must have at least one subpayment") + val amount: MilliSatoshi = parts.map(_.amount).sum + val feesPaid: MilliSatoshi = parts.map(_.feesPaid).sum + val timestamp: Long = parts.map(_.timestamp).min // we use min here because we receive the proof of payment as soon as the first partial payment is fulfilled } -case class PaymentSent(id: UUID, amount: MilliSatoshi, feesPaid: MilliSatoshi, paymentHash: ByteVector32, paymentPreimage: ByteVector32, toChannelId: ByteVector32, timestamp: Long = Platform.currentTime) extends PaymentEvent +object PaymentSent { + + case class PartialPayment(id: UUID, amount: MilliSatoshi, feesPaid: MilliSatoshi, toChannelId: ByteVector32, route: Option[Seq[Hop]], timestamp: Long = Platform.currentTime) { + require(route.isEmpty || route.get.nonEmpty, "route must be None or contain at least one hop") + } + +} + +case class PaymentFailed(id: UUID, paymentHash: ByteVector32, failures: Seq[PaymentFailure], timestamp: Long = Platform.currentTime) extends PaymentEvent case class PaymentRelayed(amountIn: MilliSatoshi, amountOut: MilliSatoshi, paymentHash: ByteVector32, fromChannelId: ByteVector32, toChannelId: ByteVector32, timestamp: Long = Platform.currentTime) extends PaymentEvent -case class PaymentReceived(amount: MilliSatoshi, paymentHash: ByteVector32, fromChannelId: ByteVector32, timestamp: Long = Platform.currentTime) extends PaymentEvent +case class PaymentReceived(paymentHash: ByteVector32, parts: Seq[PaymentReceived.PartialPayment]) extends PaymentEvent { + require(parts.nonEmpty, "must have at least one subpayment") + val amount: MilliSatoshi = parts.map(_.amount).sum + val timestamp: Long = parts.map(_.timestamp).max // we use max here because we fulfill the payment only once we received all the parts +} + +object PaymentReceived { + + case class PartialPayment(amount: MilliSatoshi, fromChannelId: ByteVector32, timestamp: Long = Platform.currentTime) + +} case class PaymentSettlingOnChain(id: UUID, amount: MilliSatoshi, paymentHash: ByteVector32, timestamp: Long = Platform.currentTime) extends PaymentEvent + +sealed trait PaymentFailure + +/** A failure happened locally, preventing the payment from being sent (e.g. no route found). */ +case class LocalFailure(t: Throwable) extends PaymentFailure + +/** A remote node failed the payment and we were able to decrypt the onion failure packet. */ +case class RemoteFailure(route: Seq[Hop], e: Sphinx.DecryptedFailurePacket) extends PaymentFailure + +/** A remote node failed the payment but we couldn't decrypt the failure (e.g. a malicious node tampered with the message). */ +case class UnreadableRemoteFailure(route: Seq[Hop]) extends PaymentFailure + +object PaymentFailure { + + import fr.acinq.bitcoin.Crypto.PublicKey + import fr.acinq.eclair.channel.AddHtlcFailed + import fr.acinq.eclair.router.RouteNotFound + import fr.acinq.eclair.wire.Update + + /** + * Rewrites a list of failures to retrieve the meaningful part. + * + * If a list of failures with many elements ends up with a LocalFailure RouteNotFound, this RouteNotFound failure + * should be removed. This last failure is irrelevant information. In such a case only the n-1 attempts were rejected + * with a **significant reason**; the final RouteNotFound error provides no meaningful insight. + * + * This method should be used by the user interface to provide a non-exhaustive but more useful feedback. + * + * @param failures a list of payment failures for a payment + */ + def transformForUser(failures: Seq[PaymentFailure]): Seq[PaymentFailure] = { + failures.map { + case LocalFailure(AddHtlcFailed(_, _, t, _, _, _)) => LocalFailure(t) // we're interested in the error which caused the add-htlc to fail + case other => other + } match { + case previousFailures :+ LocalFailure(RouteNotFound) if previousFailures.nonEmpty => previousFailures + case other => other + } + } + + /** + * This allows us to detect if a bad node always answers with a new update (e.g. with a slightly different expiry or fee) + * in order to mess with us. + */ + def hasAlreadyFailedOnce(nodeId: PublicKey, failures: Seq[PaymentFailure]): Boolean = + failures + .collectFirst { case RemoteFailure(_, Sphinx.DecryptedFailurePacket(origin, u: Update)) if origin == nodeId => u.update } + .isDefined + +} \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala index f1b6acf1df..863ba579ce 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala @@ -22,7 +22,7 @@ import akka.actor.{Actor, ActorLogging, ActorRef, Props} import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.channel.Channel -import fr.acinq.eclair.payment.PaymentLifecycle.{SendPayment, SendPaymentToRoute} +import fr.acinq.eclair.payment.PaymentLifecycle.{DefaultPaymentProgressHandler, SendPayment, SendPaymentToRoute} import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.router.RouteParams import fr.acinq.eclair.wire.Onion.FinalLegacyPayload @@ -38,7 +38,7 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: Actor val paymentId = UUID.randomUUID() // We add one block in order to not have our htlc fail when a new block has just been found. val finalExpiry = p.finalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight + 1) - val payFsm = context.actorOf(PaymentLifecycle.props(nodeParams, paymentId, router, register)) + val payFsm = context.actorOf(PaymentLifecycle.props(nodeParams, DefaultPaymentProgressHandler(paymentId, p, nodeParams.db.payments), router, register)) // NB: we only generate legacy payment onions for now for maximum compatibility. p.predefinedRoute match { case Nil => payFsm forward SendPayment(p.paymentHash, p.targetNodeId, FinalLegacyPayload(p.amount, finalExpiry), p.maxAttempts, p.assistedRoutes, p.routeParams) @@ -58,6 +58,8 @@ object PaymentInitiator { targetNodeId: PublicKey, maxAttempts: Int, finalExpiryDelta: CltvExpiryDelta = Channel.MIN_CLTV_EXPIRY_DELTA, + paymentRequest: Option[PaymentRequest] = None, + externalId: Option[String] = None, predefinedRoute: Seq[PublicKey] = Nil, assistedRoutes: Seq[Seq[ExtraHop]] = Nil, routeParams: Option[RouteParams] = None) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala index d06566ca7e..68b5ecd4b5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala @@ -18,15 +18,17 @@ package fr.acinq.eclair.payment import java.util.UUID -import akka.actor.{ActorRef, FSM, Props, Status} +import akka.actor.{ActorContext, ActorRef, FSM, Props, Status} import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair._ -import fr.acinq.eclair.channel.{AddHtlcFailed, CMD_ADD_HTLC, Register} +import fr.acinq.eclair.channel.{CMD_ADD_HTLC, Register} import fr.acinq.eclair.crypto.{Sphinx, TransportHandler} -import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus} +import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus, PaymentsDb} +import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest import fr.acinq.eclair.payment.PaymentLifecycle._ import fr.acinq.eclair.payment.PaymentRequest.ExtraHop +import fr.acinq.eclair.payment.PaymentSent.PartialPayment import fr.acinq.eclair.router._ import fr.acinq.eclair.wire.Onion._ import fr.acinq.eclair.wire._ @@ -39,22 +41,23 @@ import scala.util.{Failure, Success} /** * Created by PM on 26/08/2016. */ -class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, register: ActorRef) extends FSM[PaymentLifecycle.State, PaymentLifecycle.Data] { - val paymentsDb = nodeParams.db.payments +class PaymentLifecycle(nodeParams: NodeParams, progressHandler: PaymentProgressHandler, router: ActorRef, register: ActorRef) extends FSM[PaymentLifecycle.State, PaymentLifecycle.Data] { + + val id = progressHandler.id startWith(WAITING_FOR_REQUEST, WaitingForRequest) when(WAITING_FOR_REQUEST) { case Event(c: SendPaymentToRoute, WaitingForRequest) => val send = SendPayment(c.paymentHash, c.hops.last, c.finalPayload, maxAttempts = 1) - paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.finalPayload.amount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING)) router ! FinalizeRoute(c.hops) + progressHandler.onSend() goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, send, failures = Nil) case Event(c: SendPayment, WaitingForRequest) => router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, routeParams = c.routeParams) - paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.finalPayload.amount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING)) + progressHandler.onSend() goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, c, failures = Nil) } @@ -67,18 +70,16 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis goto(WAITING_FOR_PAYMENT_COMPLETE) using WaitingForComplete(s, c, cmd, failures, sharedSecrets, ignoreNodes, ignoreChannels, hops) case Event(Status.Failure(t), WaitingForRoute(s, c, failures)) => - reply(s, PaymentFailed(id, c.paymentHash, failures = failures :+ LocalFailure(t))) - paymentsDb.updateOutgoingPayment(id, OutgoingPaymentStatus.FAILED) + progressHandler.onFailure(s, PaymentFailed(id, c.paymentHash, failures :+ LocalFailure(t)))(context) stop(FSM.Normal) } when(WAITING_FOR_PAYMENT_COMPLETE) { - case Event("ok", _) => stay() + case Event("ok", _) => stay - case Event(fulfill: UpdateFulfillHtlc, WaitingForComplete(s, c, cmd, _, _, _, _, hops)) => - paymentsDb.updateOutgoingPayment(id, OutgoingPaymentStatus.SUCCEEDED, preimage = Some(fulfill.paymentPreimage)) - reply(s, PaymentSucceeded(id, cmd.amount, c.paymentHash, fulfill.paymentPreimage, hops)) - context.system.eventStream.publish(PaymentSent(id, c.finalPayload.amount, cmd.amount - c.finalPayload.amount, cmd.paymentHash, fulfill.paymentPreimage, fulfill.channelId)) + case Event(fulfill: UpdateFulfillHtlc, WaitingForComplete(s, c, cmd, _, _, _, _, route)) => + val p = PartialPayment(id, c.finalPayload.amount, cmd.amount - c.finalPayload.amount, fulfill.channelId, Some(route)) + progressHandler.onSuccess(s, PaymentSent(id, c.paymentHash, fulfill.paymentPreimage, p :: Nil))(context) stop(FSM.Normal) case Event(fail: UpdateFailHtlc, WaitingForComplete(s, c, _, failures, sharedSecrets, ignoreNodes, ignoreChannels, hops)) => @@ -86,8 +87,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) if nodeId == c.targetNodeId => // if destination node returns an error, we fail the payment immediately log.warning(s"received an error message from target nodeId=$nodeId, failing the payment (failure=$failureMessage)") - reply(s, PaymentFailed(id, c.paymentHash, failures = failures :+ RemoteFailure(hops, e))) - paymentsDb.updateOutgoingPayment(id, OutgoingPaymentStatus.FAILED) + progressHandler.onFailure(s, PaymentFailed(id, c.paymentHash, failures :+ RemoteFailure(hops, e)))(context) stop(FSM.Normal) case res if failures.size + 1 >= c.maxAttempts => // otherwise we never try more than maxAttempts, no matter the kind of error returned @@ -100,8 +100,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis UnreadableRemoteFailure(hops) } log.warning(s"too many failed attempts, failing the payment") - reply(s, PaymentFailed(id, c.paymentHash, failures = failures :+ failure)) - paymentsDb.updateOutgoingPayment(id, OutgoingPaymentStatus.FAILED) + progressHandler.onFailure(s, PaymentFailed(id, c.paymentHash, failures :+ failure))(context) stop(FSM.Normal) case Failure(t) => log.warning(s"cannot parse returned error: ${t.getMessage}") @@ -128,7 +127,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis log.info(s"received exact same update from nodeId=$nodeId, excluding the channel from futures routes") val nextNodeId = hops.find(_.nodeId == nodeId).get.nextNodeId router ! ExcludeChannel(ChannelDesc(u.shortChannelId, nodeId, nextNodeId)) - case Some(u) if hasAlreadyFailedOnce(nodeId, failures) => + case Some(u) if PaymentFailure.hasAlreadyFailedOnce(nodeId, failures) => // this node had already given us a new channel update and is still unhappy, it is probably messing with us, let's exclude it log.warning(s"it is the second time nodeId=$nodeId answers with a new update, excluding it: old=$u new=${failureMessage.update}") val nextNodeId = hops.find(_.nodeId == nodeId).get.nextNodeId @@ -166,8 +165,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis case Event(Status.Failure(t), WaitingForComplete(s, c, _, failures, _, ignoreNodes, ignoreChannels, hops)) => if (failures.size + 1 >= c.maxAttempts) { - paymentsDb.updateOutgoingPayment(id, OutgoingPaymentStatus.FAILED) - reply(s, PaymentFailed(id, c.paymentHash, failures :+ LocalFailure(t))) + progressHandler.onFailure(s, PaymentFailed(id, c.paymentHash, failures :+ LocalFailure(t)))(context) stop(FSM.Normal) } else { log.info(s"received an error message from local, trying to use a different channel (failure=${t.getMessage})") @@ -182,17 +180,44 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis case Event(_: TransportHandler.ReadAck, _) => stay // ignored, router replies with this when we forward a channel_update } - def reply(to: ActorRef, e: PaymentResult): Unit = { - to ! e - context.system.eventStream.publish(e) - } - initialize() } object PaymentLifecycle { - def props(nodeParams: NodeParams, id: UUID, router: ActorRef, register: ActorRef) = Props(classOf[PaymentLifecycle], nodeParams, id, router, register) + def props(nodeParams: NodeParams, progressHandler: PaymentProgressHandler, router: ActorRef, register: ActorRef) = Props(classOf[PaymentLifecycle], nodeParams, progressHandler, router, register) + + /** This handler notifies other components of payment progress. */ + trait PaymentProgressHandler { + val id: UUID + + // @formatter:off + def onSend(): Unit + def onSuccess(sender: ActorRef, result: PaymentSent)(ctx: ActorContext): Unit + def onFailure(sender: ActorRef, result: PaymentFailed)(ctx: ActorContext): Unit + // @formatter:on + } + + /** Normal payments are stored in the payments DB and emit payment events. */ + case class DefaultPaymentProgressHandler(id: UUID, r: SendPaymentRequest, db: PaymentsDb) extends PaymentProgressHandler { + + override def onSend(): Unit = { + db.addOutgoingPayment(OutgoingPayment(id, id, r.externalId, r.paymentHash, r.amount, r.targetNodeId, Platform.currentTime, r.paymentRequest, OutgoingPaymentStatus.Pending)) + } + + override def onSuccess(sender: ActorRef, result: PaymentSent)(ctx: ActorContext): Unit = { + db.updateOutgoingPayment(result) + sender ! result + ctx.system.eventStream.publish(result) + } + + override def onFailure(sender: ActorRef, result: PaymentFailed)(ctx: ActorContext): Unit = { + db.updateOutgoingPayment(result) + sender ! result + ctx.system.eventStream.publish(result) + } + + } // @formatter:off case class ReceivePayment(amount_opt: Option[MilliSatoshi], description: String, expirySeconds_opt: Option[Long] = None, extraHops: List[List[ExtraHop]] = Nil, fallbackAddress: Option[String] = None, paymentPreimage: Option[ByteVector32] = None) @@ -206,14 +231,6 @@ object PaymentLifecycle { require(finalPayload.amount > 0.msat, s"amount must be > 0") } - sealed trait PaymentResult - case class PaymentSucceeded(id: UUID, amount: MilliSatoshi, paymentHash: ByteVector32, paymentPreimage: ByteVector32, route: Seq[Hop]) extends PaymentResult // note: the amount includes fees - sealed trait PaymentFailure - case class LocalFailure(t: Throwable) extends PaymentFailure - case class RemoteFailure(route: Seq[Hop], e: Sphinx.DecryptedFailurePacket) extends PaymentFailure - case class UnreadableRemoteFailure(route: Seq[Hop]) extends PaymentFailure - case class PaymentFailed(id: UUID, paymentHash: ByteVector32, failures: Seq[PaymentFailure]) extends PaymentResult - sealed trait Data case object WaitingForRequest extends Data case class WaitingForRoute(sender: ActorRef, c: SendPayment, failures: Seq[PaymentFailure]) extends Data @@ -268,27 +285,6 @@ object PaymentLifecycle { CMD_ADD_HTLC(firstAmount, paymentHash, firstExpiry, onion.packet, upstream = Left(id), commit = true) -> onion.sharedSecrets } - /** - * Rewrites a list of failures to retrieve the meaningful part. - *

- * If a list of failures with many elements ends up with a LocalFailure RouteNotFound, this RouteNotFound failure - * should be removed. This last failure is irrelevant information. In such a case only the n-1 attempts were rejected - * with a **significant reason** ; the final RouteNotFound error provides no meaningful insight. - *

- * This method should be used by the user interface to provide a non-exhaustive but more useful feedback. - * - * @param failures a list of payment failures for a payment - */ - def transformForUser(failures: Seq[PaymentFailure]): Seq[PaymentFailure] = { - failures.map { - case LocalFailure(AddHtlcFailed(_, _, t, _, _, _)) => LocalFailure(t) // we're interested in the error which caused the add-htlc to fail - case other => other - } match { - case previousFailures :+ LocalFailure(RouteNotFound) if previousFailures.nonEmpty => previousFailures - case other => other - } - } - /** * This method retrieves the channel update that we used when we built a route. * It just iterates over the hops, but there are at most 20 of them. @@ -297,12 +293,4 @@ object PaymentLifecycle { */ def getChannelUpdateForNode(nodeId: PublicKey, hops: Seq[Hop]): Option[ChannelUpdate] = hops.find(_.nodeId == nodeId).map(_.lastUpdate) - /** - * This allows us to detect if a bad node always answers with a new update (e.g. with a slightly different expiry or fee) - * in order to mess with us. - */ - def hasAlreadyFailedOnce(nodeId: PublicKey, failures: Seq[PaymentFailure]): Boolean = - failures - .collectFirst { case RemoteFailure(_, Sphinx.DecryptedFailurePacket(origin, u: Update)) if origin == nodeId => u.update } - .isDefined } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala index 1ee391d05e..bb32da3e73 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala @@ -24,8 +24,6 @@ import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.db.OutgoingPaymentStatus -import fr.acinq.eclair.payment.PaymentLifecycle.{PaymentFailed, PaymentSucceeded} import fr.acinq.eclair.router.Announcements import fr.acinq.eclair.wire._ import fr.acinq.eclair.{CltvExpiryDelta, Features, LongToBtcAmount, MilliSatoshi, NodeParams, ShortChannelId, UInt64, nodeFee} @@ -138,8 +136,9 @@ class Relayer(nodeParams: NodeParams, register: ActorRef, paymentHandler: ActorR case Local(id, None) => // we sent the payment, but we probably restarted and the reference to the original sender was lost, // we publish the failure on the event stream and update the status in paymentDb - nodeParams.db.payments.updateOutgoingPayment(id, OutgoingPaymentStatus.FAILED) - context.system.eventStream.publish(PaymentFailed(id, paymentHash, Nil)) + val result = PaymentFailed(id, paymentHash, Nil) + nodeParams.db.payments.updateOutgoingPayment(result) + context.system.eventStream.publish(result) case Local(_, Some(sender)) => sender ! Status.Failure(addFailed) case Relayed(originChannelId, originHtlcId, _, _) => @@ -159,12 +158,12 @@ class Relayer(nodeParams: NodeParams, register: ActorRef, paymentHandler: ActorR case ForwardFulfill(fulfill, to, add) => to match { case Local(id, None) => - val feesPaid = 0.msat - context.system.eventStream.publish(PaymentSent(id, add.amountMsat, feesPaid, add.paymentHash, fulfill.paymentPreimage, fulfill.channelId)) // we sent the payment, but we probably restarted and the reference to the original sender was lost, - // we publish the failure on the event stream and update the status in paymentDb - nodeParams.db.payments.updateOutgoingPayment(id, OutgoingPaymentStatus.SUCCEEDED, Some(fulfill.paymentPreimage)) - context.system.eventStream.publish(PaymentSucceeded(id, add.amountMsat, add.paymentHash, fulfill.paymentPreimage, Nil)) // + // we publish the success on the event stream and update the status in paymentDb + val feesPaid = 0.msat // fees are unknown since we lost the reference to the payment + val result = PaymentSent(id, add.paymentHash, fulfill.paymentPreimage, Seq(PaymentSent.PartialPayment(id, add.amountMsat, feesPaid, add.channelId, None))) + nodeParams.db.payments.updateOutgoingPayment(result) + context.system.eventStream.publish(result) case Local(_, Some(sender)) => sender ! fulfill case Relayed(originChannelId, originHtlcId, amountIn, amountOut) => @@ -178,8 +177,9 @@ class Relayer(nodeParams: NodeParams, register: ActorRef, paymentHandler: ActorR case Local(id, None) => // we sent the payment, but we probably restarted and the reference to the original sender was lost // we publish the failure on the event stream and update the status in paymentDb - nodeParams.db.payments.updateOutgoingPayment(id, OutgoingPaymentStatus.FAILED) - context.system.eventStream.publish(PaymentFailed(id, add.paymentHash, Nil)) + val result = PaymentFailed(id, add.paymentHash, Nil) + nodeParams.db.payments.updateOutgoingPayment(result) + context.system.eventStream.publish(result) case Local(_, Some(sender)) => sender ! fail case Relayed(originChannelId, originHtlcId, _, _) => @@ -192,8 +192,9 @@ class Relayer(nodeParams: NodeParams, register: ActorRef, paymentHandler: ActorR case Local(id, None) => // we sent the payment, but we probably restarted and the reference to the original sender was lost // we publish the failure on the event stream and update the status in paymentDb - nodeParams.db.payments.updateOutgoingPayment(id, OutgoingPaymentStatus.FAILED) - context.system.eventStream.publish(PaymentFailed(id, add.paymentHash, Nil)) + val result = PaymentFailed(id, add.paymentHash, Nil) + nodeParams.db.payments.updateOutgoingPayment(result) + context.system.eventStream.publish(result) case Local(_, Some(sender)) => sender ! fail case Relayed(originChannelId, originHtlcId, _, _) => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala index 32f4450896..f6e604ad23 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala @@ -94,43 +94,55 @@ class EclairImplSpec extends TestKit(ActorSystem("test")) with fixture.FunSuiteL val eclair = new EclairImpl(kit) val nodeId = PublicKey(hex"030bb6a5e0c6b203c7e2180fb78c7ba4bdce46126761d8201b91ddac089cdecc87") - eclair.send(nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = None) + eclair.send(None, nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = None) val send = paymentInitiator.expectMsgType[SendPaymentRequest] - assert(send.targetNodeId == nodeId) - assert(send.amount == 123.msat) - assert(send.paymentHash == ByteVector32.Zeroes) - assert(send.assistedRoutes == Seq.empty) + assert(send.externalId === None) + assert(send.targetNodeId === nodeId) + assert(send.amount === 123.msat) + assert(send.paymentHash === ByteVector32.Zeroes) + assert(send.paymentRequest === None) + assert(send.assistedRoutes === Seq.empty) // with assisted routes + val externalId1 = "030bb6a5e0c6b203c7e2180fb78c7ba4bdce46126761d8201b91ddac089cdecc87" val hints = List(List(ExtraHop(Bob.nodeParams.nodeId, ShortChannelId("569178x2331x1"), feeBase = 10 msat, feeProportionalMillionths = 1, cltvExpiryDelta = CltvExpiryDelta(12)))) val invoice1 = PaymentRequest(Block.RegtestGenesisBlock.hash, Some(123 msat), ByteVector32.Zeroes, randomKey, "description", None, None, hints) - eclair.send(nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(invoice1)) + eclair.send(Some(externalId1), nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(invoice1)) val send1 = paymentInitiator.expectMsgType[SendPaymentRequest] - assert(send1.targetNodeId == nodeId) - assert(send1.amount == 123.msat) - assert(send1.paymentHash == ByteVector32.Zeroes) - assert(send1.assistedRoutes == hints) + assert(send1.externalId === Some(externalId1)) + assert(send1.targetNodeId === nodeId) + assert(send1.amount === 123.msat) + assert(send1.paymentHash === ByteVector32.Zeroes) + assert(send1.paymentRequest === Some(invoice1)) + assert(send1.assistedRoutes === hints) // with finalCltvExpiry + val externalId2 = "487da196-a4dc-4b1e-92b4-3e5e905e9f3f" val invoice2 = PaymentRequest("lntb", Some(123 msat), System.currentTimeMillis() / 1000L, nodeId, List(PaymentRequest.MinFinalCltvExpiry(96), PaymentRequest.PaymentHash(ByteVector32.Zeroes), PaymentRequest.Description("description")), ByteVector.empty) - eclair.send(nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(invoice2)) + eclair.send(Some(externalId2), nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(invoice2)) val send2 = paymentInitiator.expectMsgType[SendPaymentRequest] - assert(send2.targetNodeId == nodeId) - assert(send2.amount == 123.msat) - assert(send2.paymentHash == ByteVector32.Zeroes) - assert(send2.finalExpiryDelta == CltvExpiryDelta(96)) + assert(send2.externalId === Some(externalId2)) + assert(send2.targetNodeId === nodeId) + assert(send2.amount === 123.msat) + assert(send2.paymentHash === ByteVector32.Zeroes) + assert(send2.paymentRequest === Some(invoice2)) + assert(send2.finalExpiryDelta === CltvExpiryDelta(96)) // with custom route fees parameters - eclair.send(nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = None, feeThreshold_opt = Some(123 sat), maxFeePct_opt = Some(4.20)) + eclair.send(None, nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = None, feeThreshold_opt = Some(123 sat), maxFeePct_opt = Some(4.20)) val send3 = paymentInitiator.expectMsgType[SendPaymentRequest] - assert(send3.targetNodeId == nodeId) - assert(send3.amount == 123.msat) - assert(send3.paymentHash == ByteVector32.Zeroes) - assert(send3.routeParams.get.maxFeeBase == 123000.msat) // conversion sat -> msat - assert(send3.routeParams.get.maxFeePct == 4.20) + assert(send3.externalId === None) + assert(send3.targetNodeId === nodeId) + assert(send3.amount === 123.msat) + assert(send3.paymentHash === ByteVector32.Zeroes) + assert(send3.routeParams.get.maxFeeBase === 123000.msat) // conversion sat -> msat + assert(send3.routeParams.get.maxFeePct === 4.20) + + val invalidExternalId = "Robert'); DROP TABLE received_payments; DROP TABLE sent_payments; DROP TABLE payments;" + assertThrows[IllegalArgumentException](Await.result(eclair.send(Some(invalidExternalId), nodeId, 123 msat, ByteVector32.Zeroes), 50 millis)) val expiredInvoice = invoice2.copy(timestamp = 0L) - assertThrows[IllegalArgumentException](Await.result(eclair.send(nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(expiredInvoice)), 50 millis)) + assertThrows[IllegalArgumentException](Await.result(eclair.send(None, nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(expiredInvoice)), 50 millis)) } test("allupdates can filter by nodeId") { f => @@ -225,7 +237,7 @@ class EclairImplSpec extends TestKit(ActorSystem("test")) with fixture.FunSuiteL auditDb.listSent(anyLong, anyLong) returns Seq.empty auditDb.listReceived(anyLong, anyLong) returns Seq.empty auditDb.listRelayed(anyLong, anyLong) returns Seq.empty - paymentDb.listPaymentRequests(anyLong, anyLong) returns Seq.empty + paymentDb.listIncomingPayments(anyLong, anyLong) returns Seq.empty val databases = mock[Databases] databases.audit returns auditDb @@ -235,15 +247,15 @@ class EclairImplSpec extends TestKit(ActorSystem("test")) with fixture.FunSuiteL val eclair = new EclairImpl(kitWithMockAudit) Await.result(eclair.networkFees(None, None), 10 seconds) - auditDb.listNetworkFees(0, MaxEpochSeconds).wasCalled(once) // assert the call was made only once and with the specified params + auditDb.listNetworkFees(0, TimestampQueryFilters.MaxEpochMilliseconds).wasCalled(once) // assert the call was made only once and with the specified params Await.result(eclair.audit(None, None), 10 seconds) - auditDb.listRelayed(0, MaxEpochSeconds).wasCalled(once) - auditDb.listReceived(0, MaxEpochSeconds).wasCalled(once) - auditDb.listSent(0, MaxEpochSeconds).wasCalled(once) + auditDb.listRelayed(0, TimestampQueryFilters.MaxEpochMilliseconds).wasCalled(once) + auditDb.listReceived(0, TimestampQueryFilters.MaxEpochMilliseconds).wasCalled(once) + auditDb.listSent(0, TimestampQueryFilters.MaxEpochMilliseconds).wasCalled(once) Await.result(eclair.allInvoices(None, None), 10 seconds) - paymentDb.listPaymentRequests(0, MaxEpochSeconds).wasCalled(once) // assert the call was made only once and with the specified params + paymentDb.listIncomingPayments(0, TimestampQueryFilters.MaxEpochMilliseconds).wasCalled(once) // assert the call was made only once and with the specified params } test("sendtoroute should pass the parameters correctly") { f => @@ -251,13 +263,14 @@ class EclairImplSpec extends TestKit(ActorSystem("test")) with fixture.FunSuiteL val route = Seq(PublicKey(hex"030bb6a5e0c6b203c7e2180fb78c7ba4bdce46126761d8201b91ddac089cdecc87")) val eclair = new EclairImpl(kit) - eclair.sendToRoute(route, 1234 msat, ByteVector32.One, CltvExpiryDelta(123)) + eclair.sendToRoute(Some("42"), route, 1234 msat, ByteVector32.One, CltvExpiryDelta(123)) val send = paymentInitiator.expectMsgType[SendPaymentRequest] - assert(send.predefinedRoute == route) + assert(send.externalId === Some("42")) + assert(send.predefinedRoute === route) assert(send.amount === 1234.msat) assert(send.finalExpiryDelta === CltvExpiryDelta(123)) - assert(send.paymentHash == ByteVector32.One) + assert(send.paymentHash === ByteVector32.One) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala index 10182a87db..581ffe4b14 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala @@ -29,7 +29,7 @@ import fr.acinq.eclair.blockchain._ import fr.acinq.eclair.blockchain.fee.FeeratesPerKw import fr.acinq.eclair.channel.Channel._ import fr.acinq.eclair.channel.states.StateTestsHelperMethods -import fr.acinq.eclair.channel.{ChannelErrorOccured, _} +import fr.acinq.eclair.channel.{ChannelErrorOccurred, _} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.io.Peer import fr.acinq.eclair.payment._ @@ -1710,7 +1710,7 @@ class NormalStateSpec extends TestkitBaseClass with StateTestsHelperMethods { crossSign(alice, bob, alice2bob, bob2alice) val listener = TestProbe() - system.eventStream.subscribe(listener.ref, classOf[ChannelErrorOccured]) + system.eventStream.subscribe(listener.ref, classOf[ChannelErrorOccurred]) // actual test begins: // * Bob receives the HTLC pre-image and wants to fulfill @@ -1726,7 +1726,7 @@ class NormalStateSpec extends TestkitBaseClass with StateTestsHelperMethods { bob2alice.expectMsgType[UpdateFulfillHtlc] sender.send(bob, CurrentBlockCount((htlc.cltvExpiry - Bob.nodeParams.fulfillSafetyBeforeTimeoutBlocks).toLong)) - val ChannelErrorOccured(_, _, _, _, LocalError(err), isFatal) = listener.expectMsgType[ChannelErrorOccured] + val ChannelErrorOccurred(_, _, _, _, LocalError(err), isFatal) = listener.expectMsgType[ChannelErrorOccurred] assert(isFatal) assert(err.isInstanceOf[HtlcWillTimeoutUpstream]) @@ -1745,7 +1745,7 @@ class NormalStateSpec extends TestkitBaseClass with StateTestsHelperMethods { crossSign(alice, bob, alice2bob, bob2alice) val listener = TestProbe() - system.eventStream.subscribe(listener.ref, classOf[ChannelErrorOccured]) + system.eventStream.subscribe(listener.ref, classOf[ChannelErrorOccurred]) // actual test begins: // * Bob receives the HTLC pre-image and wants to fulfill but doesn't sign @@ -1761,7 +1761,7 @@ class NormalStateSpec extends TestkitBaseClass with StateTestsHelperMethods { bob2alice.expectMsgType[UpdateFulfillHtlc] sender.send(bob, CurrentBlockCount((htlc.cltvExpiry - Bob.nodeParams.fulfillSafetyBeforeTimeoutBlocks).toLong)) - val ChannelErrorOccured(_, _, _, _, LocalError(err), isFatal) = listener.expectMsgType[ChannelErrorOccured] + val ChannelErrorOccurred(_, _, _, _, LocalError(err), isFatal) = listener.expectMsgType[ChannelErrorOccurred] assert(isFatal) assert(err.isInstanceOf[HtlcWillTimeoutUpstream]) @@ -1780,7 +1780,7 @@ class NormalStateSpec extends TestkitBaseClass with StateTestsHelperMethods { crossSign(alice, bob, alice2bob, bob2alice) val listener = TestProbe() - system.eventStream.subscribe(listener.ref, classOf[ChannelErrorOccured]) + system.eventStream.subscribe(listener.ref, classOf[ChannelErrorOccurred]) // actual test begins: // * Bob receives the HTLC pre-image and wants to fulfill @@ -1801,7 +1801,7 @@ class NormalStateSpec extends TestkitBaseClass with StateTestsHelperMethods { alice2bob.forward(bob) sender.send(bob, CurrentBlockCount((htlc.cltvExpiry - Bob.nodeParams.fulfillSafetyBeforeTimeoutBlocks).toLong)) - val ChannelErrorOccured(_, _, _, _, LocalError(err), isFatal) = listener.expectMsgType[ChannelErrorOccured] + val ChannelErrorOccurred(_, _, _, _, LocalError(err), isFatal) = listener.expectMsgType[ChannelErrorOccurred] assert(isFatal) assert(err.isInstanceOf[HtlcWillTimeoutUpstream]) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala index e73e50140e..ca9fe9e34c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala @@ -405,7 +405,7 @@ class OfflineStateSpec extends TestkitBaseClass with StateTestsHelperMethods { crossSign(alice, bob, alice2bob, bob2alice) val listener = TestProbe() - system.eventStream.subscribe(listener.ref, classOf[ChannelErrorOccured]) + system.eventStream.subscribe(listener.ref, classOf[ChannelErrorOccurred]) val initialState = bob.stateData.asInstanceOf[DATA_NORMAL] val initialCommitTx = initialState.commitments.localCommit.publishableTxs.commitTx.tx @@ -421,7 +421,7 @@ class OfflineStateSpec extends TestkitBaseClass with StateTestsHelperMethods { sender.send(commandBuffer, CommandSend(htlc.channelId, htlc.id, CMD_FULFILL_HTLC(htlc.id, r, commit = true))) sender.send(bob, CurrentBlockCount((htlc.cltvExpiry - bob.underlyingActor.nodeParams.fulfillSafetyBeforeTimeoutBlocks).toLong)) - val ChannelErrorOccured(_, _, _, _, LocalError(err), isFatal) = listener.expectMsgType[ChannelErrorOccured] + val ChannelErrorOccurred(_, _, _, _, LocalError(err), isFatal) = listener.expectMsgType[ChannelErrorOccurred] assert(isFatal) assert(err.isInstanceOf[HtlcWillTimeoutUpstream]) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala index 49dd366358..7eb9b8d1b5 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala @@ -21,10 +21,10 @@ import java.util.UUID import fr.acinq.bitcoin.Transaction import fr.acinq.eclair._ import fr.acinq.eclair.channel.Channel.{LocalError, RemoteError} -import fr.acinq.eclair.channel.{AvailableBalanceChanged, ChannelErrorOccured, NetworkFeePaid} +import fr.acinq.eclair.channel.{AvailableBalanceChanged, ChannelErrorOccurred, NetworkFeePaid} import fr.acinq.eclair.db.sqlite.SqliteAuditDb import fr.acinq.eclair.db.sqlite.SqliteUtils.{getVersion, using} -import fr.acinq.eclair.payment.{PaymentReceived, PaymentRelayed, PaymentSent} +import fr.acinq.eclair.payment._ import fr.acinq.eclair.wire.{ChannelCodecs, ChannelCodecsSpec} import org.scalatest.FunSuite @@ -44,16 +44,21 @@ class SqliteAuditDbSpec extends FunSuite { val sqlite = TestConstants.sqliteInMemory() val db = new SqliteAuditDb(sqlite) - val e1 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, 42000 msat, 1000 msat, randomBytes32, randomBytes32, randomBytes32) - val e2 = PaymentReceived(42000 msat, randomBytes32, randomBytes32) + val e1 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, randomBytes32, randomBytes32, PaymentSent.PartialPayment(ChannelCodecs.UNKNOWN_UUID, 42000 msat, 1000 msat, randomBytes32, None) :: Nil) + val pp2a = PaymentReceived.PartialPayment(42000 msat, randomBytes32) + val pp2b = PaymentReceived.PartialPayment(42100 msat, randomBytes32) + val e2 = PaymentReceived(randomBytes32, pp2a :: pp2b :: Nil) val e3 = PaymentRelayed(42000 msat, 1000 msat, randomBytes32, randomBytes32, randomBytes32) val e4 = NetworkFeePaid(null, randomKey.publicKey, randomBytes32, Transaction(0, Seq.empty, Seq.empty, 0), 42 sat, "mutual") - val e5 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, 42000 msat, 1000 msat, randomBytes32, randomBytes32, randomBytes32, timestamp = 0) - val e6 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, 42000 msat, 1000 msat, randomBytes32, randomBytes32, randomBytes32, timestamp = (Platform.currentTime.milliseconds + 10.minutes).toMillis) + val pp5a = PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None, timestamp = 0) + val pp5b = PaymentSent.PartialPayment(UUID.randomUUID(), 42100 msat, 900 msat, randomBytes32, None, timestamp = 1) + val e5 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, randomBytes32, randomBytes32, pp5a :: pp5b :: Nil) + val pp6 = PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None, timestamp = (Platform.currentTime.milliseconds + 10.minutes).toMillis) + val e6 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, randomBytes32, randomBytes32, pp6 :: Nil) val e7 = AvailableBalanceChanged(null, randomBytes32, ShortChannelId(500000, 42, 1), 456123000 msat, ChannelCodecsSpec.commitments) - val e8 = ChannelLifecycleEvent(randomBytes32, randomKey.publicKey, 456123000 sat, true, false, "mutual") - val e9 = ChannelErrorOccured(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), true) - val e10 = ChannelErrorOccured(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), true) + val e8 = ChannelLifecycleEvent(randomBytes32, randomKey.publicKey, 456123000 sat, isFunder = true, isPrivate = false, "mutual") + val e9 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true) + val e10 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true) db.add(e1) db.add(e2) @@ -66,12 +71,12 @@ class SqliteAuditDbSpec extends FunSuite { db.add(e9) db.add(e10) - assert(db.listSent(from = 0L, to = (Platform.currentTime.milliseconds + 15.minute).toSeconds).toSet === Set(e1, e5, e6)) - assert(db.listSent(from = 100000L, to = (Platform.currentTime.milliseconds + 1.minute).toSeconds).toList === List(e1)) - assert(db.listReceived(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toSeconds).toList === List(e2)) - assert(db.listRelayed(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toSeconds).toList === List(e3)) - assert(db.listNetworkFees(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toSeconds).size === 1) - assert(db.listNetworkFees(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toSeconds).head.txType === "mutual") + assert(db.listSent(from = 0L, to = (Platform.currentTime.milliseconds + 15.minute).toMillis).toSet === Set(e1, e5.copy(id = pp5a.id, parts = pp5a :: Nil), e5.copy(id = pp5b.id, parts = pp5b :: Nil), e6.copy(id = pp6.id))) + assert(db.listSent(from = 100000L, to = (Platform.currentTime.milliseconds + 1.minute).toMillis).toList === List(e1)) + assert(db.listReceived(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toMillis).toList === List(e2.copy(parts = pp2a :: Nil), e2.copy(parts = pp2b :: Nil))) + assert(db.listRelayed(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toMillis).toList === List(e3)) + assert(db.listNetworkFees(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toMillis).size === 1) + assert(db.listNetworkFees(from = 0L, to = (Platform.currentTime.milliseconds + 1.minute).toMillis).head.txType === "mutual") } test("stats") { @@ -129,11 +134,12 @@ class SqliteAuditDbSpec extends FunSuite { assert(getVersion(statement, "audit", 3) == 1) // we expect version 1 } - val ps = PaymentSent(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, randomBytes32, randomBytes32) - val ps1 = PaymentSent(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, randomBytes32, randomBytes32) - val ps2 = PaymentSent(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, randomBytes32, randomBytes32) - val e1 = ChannelErrorOccured(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), true) - val e2 = ChannelErrorOccured(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), true) + val ps = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None) :: Nil) + val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, None) + val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, None) + val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, pp1 :: pp2 :: Nil) + val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true) + val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true) // add a row (no ID on sent) using(connection.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?)")) { statement => @@ -141,7 +147,7 @@ class SqliteAuditDbSpec extends FunSuite { statement.setLong(2, ps.feesPaid.toLong) statement.setBytes(3, ps.paymentHash.toArray) statement.setBytes(4, ps.paymentPreimage.toArray) - statement.setBytes(5, ps.toChannelId.toArray) + statement.setBytes(5, ps.parts.head.toChannelId.toArray) statement.setLong(6, ps.timestamp) statement.executeUpdate() } @@ -152,8 +158,8 @@ class SqliteAuditDbSpec extends FunSuite { assert(getVersion(statement, "audit", 3) == 3) // version changed from 1 -> 3 } - // existing rows will use 00000000-0000-0000-0000-000000000000 as default - assert(migratedDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) == Seq(ps.copy(id = ChannelCodecs.UNKNOWN_UUID))) + // existing rows in the 'sent' table will use id=00000000-0000-0000-0000-000000000000 as default + assert(migratedDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ChannelCodecs.UNKNOWN_UUID, parts = Seq(ps.parts.head.copy(id = ChannelCodecs.UNKNOWN_UUID))))) val postMigrationDb = new SqliteAuditDb(connection) @@ -162,12 +168,14 @@ class SqliteAuditDbSpec extends FunSuite { } postMigrationDb.add(ps1) - postMigrationDb.add(ps2) postMigrationDb.add(e1) postMigrationDb.add(e2) // the old record will have the UNKNOWN_UUID but the new ones will have their actual id - assert(postMigrationDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) == Seq(ps.copy(id = ChannelCodecs.UNKNOWN_UUID), ps1, ps2)) + assert(postMigrationDb.listSent(0, (Platform.currentTime.milliseconds + 1.minute).toMillis) === Seq( + ps.copy(id = ChannelCodecs.UNKNOWN_UUID, parts = Seq(ps.parts.head.copy(id = ChannelCodecs.UNKNOWN_UUID))), + ps1.copy(id = pp1.id, parts = pp1 :: Nil), + ps1.copy(id = pp2.id, parts = pp2 :: Nil))) } test("handle migration version 2 -> 3") { @@ -196,8 +204,8 @@ class SqliteAuditDbSpec extends FunSuite { assert(getVersion(statement, "audit", 3) == 2) // version 2 is deployed now } - val e1 = ChannelErrorOccured(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), true) - val e2 = ChannelErrorOccured(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), true) + val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true) + val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(wire.Error(randomBytes32, "remote oops")), isFatal = true) val migratedDb = new SqliteAuditDb(connection) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteChannelsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteChannelsDbSpec.scala index ac7d61ed37..993ebb4065 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteChannelsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteChannelsDbSpec.scala @@ -26,7 +26,6 @@ import org.scalatest.FunSuite import org.sqlite.SQLiteException import scodec.bits.ByteVector - class SqliteChannelsDbSpec extends FunSuite { test("init sqlite 2 times in a row") { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala index ed8993cc77..f70d18aaf0 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala @@ -18,29 +18,31 @@ package fr.acinq.eclair.db import java.util.UUID -import fr.acinq.bitcoin.{Block, ByteVector32} -import fr.acinq.eclair.TestConstants.Bob -import fr.acinq.eclair.db.OutgoingPaymentStatus._ +import fr.acinq.bitcoin.Crypto.PrivateKey +import fr.acinq.bitcoin.{Block, ByteVector32, Crypto} +import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.db.sqlite.SqlitePaymentsDb import fr.acinq.eclair.db.sqlite.SqliteUtils._ -import fr.acinq.eclair.payment.PaymentRequest -import fr.acinq.eclair.{LongToBtcAmount, TestConstants, randomBytes32} +import fr.acinq.eclair.payment._ +import fr.acinq.eclair.router.Hop +import fr.acinq.eclair.wire.{ChannelUpdate, UnknownNextPeer} +import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, ShortChannelId, TestConstants, randomBytes32, randomBytes64, randomKey} import org.scalatest.FunSuite -import scodec.bits._ import scala.compat.Platform import scala.concurrent.duration._ class SqlitePaymentsDbSpec extends FunSuite { + import SqlitePaymentsDbSpec._ + test("init sqlite 2 times in a row") { val sqlite = TestConstants.sqliteInMemory() val db1 = new SqlitePaymentsDb(sqlite) val db2 = new SqlitePaymentsDb(sqlite) } - test("handle version migration 1->2") { - + test("handle version migration 1->3") { val connection = TestConstants.sqliteInMemory() using(connection.createStatement()) { statement => @@ -52,137 +54,281 @@ class SqlitePaymentsDbSpec extends FunSuite { assert(getVersion(statement, "payments", 1) == 1) // version 1 is deployed now } - val oldReceivedPayment = IncomingPayment(ByteVector32(hex"0f059ef9b55bb70cc09069ee4df854bf0fab650eee6f2b87ba26d1ad08ab114f"), 123 msat, 1233322) - - // insert old type record + // Changes between version 1 and 2: + // - the monolithic payments table has been replaced by two tables, received_payments and sent_payments + // - old records from the payments table are ignored (not migrated to the new tables) using(connection.prepareStatement("INSERT INTO payments VALUES (?, ?, ?)")) { statement => - statement.setBytes(1, oldReceivedPayment.paymentHash.toArray) - statement.setLong(2, oldReceivedPayment.amount.toLong) - statement.setLong(3, oldReceivedPayment.receivedAt) + statement.setBytes(1, paymentHash1.toArray) + statement.setLong(2, (123 msat).toLong) + statement.setLong(3, 1000) // received_at statement.executeUpdate() } val preMigrationDb = new SqlitePaymentsDb(connection) using(connection.createStatement()) { statement => - assert(getVersion(statement, "payments", 1) == 2) // version has changed from 1 to 2! + assert(getVersion(statement, "payments", 1) == 3) // version changed from 1 -> 3 } // the existing received payment can NOT be queried anymore - assert(preMigrationDb.getIncomingPayment(oldReceivedPayment.paymentHash).isEmpty) + assert(preMigrationDb.getIncomingPayment(paymentHash1).isEmpty) // add a few rows - val ps1 = OutgoingPayment(id = UUID.randomUUID(), paymentHash = ByteVector32(hex"0f059ef9b55bb70cc09069ee4df854bf0fab650eee6f2b87ba26d1ad08ab114f"), None, amount = 12345 msat, createdAt = 12345, None, PENDING) - val i1 = PaymentRequest.read("lnbc10u1pw2t4phpp5ezwm2gdccydhnphfyepklc0wjkxhz0r4tctg9paunh2lxgeqhcmsdqlxycrqvpqwdshgueqvfjhggr0dcsry7qcqzpgfa4ecv7447p9t5hkujy9qgrxvkkf396p9zar9p87rv2htmeuunkhydl40r64n5s2k0u7uelzc8twxmp37nkcch6m0wg5tvvx69yjz8qpk94qf3") - val pr1 = IncomingPayment(i1.paymentHash, 12345678 msat, 1513871928275L) + val ps1 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), None, paymentHash1, 12345 msat, alice, 1000, None, OutgoingPaymentStatus.Pending) + val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(500 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 1) + val pr1 = IncomingPayment(i1, preimage1, i1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(550 msat, 1100)) - preMigrationDb.addPaymentRequest(i1, ByteVector32.Zeroes) - preMigrationDb.addIncomingPayment(pr1) preMigrationDb.addOutgoingPayment(ps1) + preMigrationDb.addIncomingPayment(i1, preimage1) + preMigrationDb.receiveIncomingPayment(i1.paymentHash, 550 msat, 1100) - assert(preMigrationDb.listIncomingPayments() == Seq(pr1)) - assert(preMigrationDb.listOutgoingPayments() == Seq(ps1)) - assert(preMigrationDb.listPaymentRequests(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) == Seq(i1)) + assert(preMigrationDb.listIncomingPayments(1, 1500) === Seq(pr1)) + assert(preMigrationDb.listOutgoingPayments(1, 1500) === Seq(ps1)) val postMigrationDb = new SqlitePaymentsDb(connection) using(connection.createStatement()) { statement => - assert(getVersion(statement, "payments", 2) == 2) // version still to 2 + assert(getVersion(statement, "payments", 3) == 3) // version still to 3 } - assert(postMigrationDb.listIncomingPayments() == Seq(pr1)) - assert(postMigrationDb.listOutgoingPayments() == Seq(ps1)) - assert(preMigrationDb.listPaymentRequests(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) == Seq(i1)) + assert(postMigrationDb.listIncomingPayments(1, 1500) === Seq(pr1)) + assert(postMigrationDb.listOutgoingPayments(1, 1500) === Seq(ps1)) } - test("add/list received payments/find 1 payment that exists/find 1 payment that does not exist") { - val sqlite = TestConstants.sqliteInMemory() - val db = new SqlitePaymentsDb(sqlite) + test("handle version migration 2->3") { + val connection = TestConstants.sqliteInMemory() - // can't receive a payment without an invoice associated with it - assertThrows[IllegalArgumentException](db.addIncomingPayment(IncomingPayment(ByteVector32(hex"6e7e8018f05e169cf1d99e77dc22cb372d09f10b6a81f1eae410718c56cad188"), 12345678 msat, 1513871928275L))) - - val i1 = PaymentRequest.read("lnbc5450n1pw2t4qdpp5vcrf6ylgpettyng4ac3vujsk0zpc25cj0q3zp7l7w44zvxmpzh8qdzz2pshjmt9de6zqen0wgsr2dp4ypcxj7r9d3ejqct5ypekzar0wd5xjuewwpkxzcm99cxqzjccqp2rzjqtspxelp67qc5l56p6999wkatsexzhs826xmupyhk6j8lxl038t27z9tsqqqgpgqqqqqqqlgqqqqqzsqpcz8z8hmy8g3ecunle4n3edn3zg2rly8g4klsk5md736vaqqy3ktxs30ht34rkfkqaffzxmjphvd0637dk2lp6skah2hq09z6lrjna3xqp3d4vyd") - val i2 = PaymentRequest.read("lnbc10u1pw2t4phpp5ezwm2gdccydhnphfyepklc0wjkxhz0r4tctg9paunh2lxgeqhcmsdqlxycrqvpqwdshgueqvfjhggr0dcsry7qcqzpgfa4ecv7447p9t5hkujy9qgrxvkkf396p9zar9p87rv2htmeuunkhydl40r64n5s2k0u7uelzc8twxmp37nkcch6m0wg5tvvx69yjz8qpk94qf3") - - db.addPaymentRequest(i1, ByteVector32.Zeroes) - db.addPaymentRequest(i2, ByteVector32.Zeroes) - - val p1 = IncomingPayment(i1.paymentHash, 12345678 msat, 1513871928275L) - val p2 = IncomingPayment(i2.paymentHash, 12345678 msat, 1513871928275L) - assert(db.listIncomingPayments() === Nil) - db.addIncomingPayment(p1) - db.addIncomingPayment(p2) - assert(db.listIncomingPayments().toList === List(p1, p2)) - assert(db.getIncomingPayment(p1.paymentHash) === Some(p1)) - assert(db.getIncomingPayment(ByteVector32(hex"6e7e8018f05e169cf1d99e77dc22cb372d09f10b6a81f1eae410718c56cad187")) === None) - } + using(connection.createStatement()) { statement => + getVersion(statement, "payments", 2) + statement.executeUpdate("CREATE TABLE IF NOT EXISTS received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, preimage BLOB NOT NULL, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER, received_at INTEGER)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent_payments (id TEXT NOT NULL PRIMARY KEY, payment_hash BLOB NOT NULL, preimage BLOB, amount_msat INTEGER NOT NULL, created_at INTEGER NOT NULL, completed_at INTEGER, status VARCHAR NOT NULL)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS payment_hash_idx ON sent_payments(payment_hash)") + } - test("add/retrieve/update sent payments") { + using(connection.createStatement()) { statement => + assert(getVersion(statement, "payments", 2) == 2) // version 2 is deployed now + } - val db = new SqlitePaymentsDb(TestConstants.sqliteInMemory()) + // Insert a bunch of old version 2 rows. + val id1 = UUID.randomUUID() + val id2 = UUID.randomUUID() + val id3 = UUID.randomUUID() + val ps1 = OutgoingPayment(id1, id1, None, randomBytes32, 561 msat, PrivateKey(ByteVector32.One).publicKey, 1000, None, OutgoingPaymentStatus.Pending) + val ps2 = OutgoingPayment(id2, id2, None, randomBytes32, 1105 msat, PrivateKey(ByteVector32.One).publicKey, 1010, None, OutgoingPaymentStatus.Failed(Nil, 1050)) + val ps3 = OutgoingPayment(id3, id3, None, paymentHash1, 1729 msat, PrivateKey(ByteVector32.One).publicKey, 1040, None, OutgoingPaymentStatus.Succeeded(preimage1, 0 msat, Nil, 1060)) + val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 1) + val pr1 = IncomingPayment(i1, preimage1, i1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(12345678 msat, 1090)) + val i2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(12345678 msat), paymentHash2, carolPriv, "Another invoice", expirySeconds = Some(30), timestamp = 1) + val pr2 = IncomingPayment(i2, preimage2, i2.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired) + + // Changes between version 2 and 3 to sent_payments: + // - removed the status column + // - added optional payment failures + // - added optional payment success details (fees paid and route) + // - added optional payment request + // - added target node ID + // - added externalID and parentID + + using(connection.prepareStatement("INSERT INTO sent_payments (id, payment_hash, amount_msat, created_at, status) VALUES (?, ?, ?, ?, ?)")) { statement => + statement.setString(1, ps1.id.toString) + statement.setBytes(2, ps1.paymentHash.toArray) + statement.setLong(3, ps1.amount.toLong) + statement.setLong(4, ps1.createdAt) + statement.setString(5, "PENDING") + statement.executeUpdate() + } - val s1 = OutgoingPayment(id = UUID.randomUUID(), paymentHash = ByteVector32(hex"0f059ef9b55bb70cc09069ee4df854bf0fab650eee6f2b87ba26d1ad08ab114f"), None, amount = 12345 msat, createdAt = 12345, None, PENDING) - val s2 = OutgoingPayment(id = UUID.randomUUID(), paymentHash = ByteVector32(hex"08d47d5f7164d4b696e8f6b62a03094d4f1c65f16e9d7b11c4a98854707e55cf"), None, amount = 12345 msat, createdAt = 12345, None, PENDING) + using(connection.prepareStatement("INSERT INTO sent_payments (id, payment_hash, amount_msat, created_at, completed_at, status) VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setString(1, ps2.id.toString) + statement.setBytes(2, ps2.paymentHash.toArray) + statement.setLong(3, ps2.amount.toLong) + statement.setLong(4, ps2.createdAt) + statement.setLong(5, ps2.status.asInstanceOf[OutgoingPaymentStatus.Failed].completedAt) + statement.setString(6, "FAILED") + statement.executeUpdate() + } - assert(db.listOutgoingPayments().isEmpty) - db.addOutgoingPayment(s1) - db.addOutgoingPayment(s2) + using(connection.prepareStatement("INSERT INTO sent_payments (id, payment_hash, preimage, amount_msat, created_at, completed_at, status) VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement => + statement.setString(1, ps3.id.toString) + statement.setBytes(2, ps3.paymentHash.toArray) + statement.setBytes(3, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].paymentPreimage.toArray) + statement.setLong(4, ps3.amount.toLong) + statement.setLong(5, ps3.createdAt) + statement.setLong(6, ps3.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].completedAt) + statement.setString(7, "SUCCEEDED") + statement.executeUpdate() + } - assert(db.listOutgoingPayments().toList == Seq(s1, s2)) - assert(db.getOutgoingPayment(s1.id) === Some(s1)) - assert(db.getOutgoingPayment(s1.id).get.completedAt.isEmpty) - assert(db.getOutgoingPayment(UUID.randomUUID()) === None) - assert(db.getOutgoingPayments(s2.paymentHash) === Seq(s2)) - assert(db.getOutgoingPayments(ByteVector32.Zeroes) === Seq.empty) + // Changes between version 2 and 3 to received_payments: + // - renamed the preimage column + // - made expire_at not null + + using(connection.prepareStatement("INSERT INTO received_payments (payment_hash, preimage, payment_request, received_msat, created_at, received_at) VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setBytes(1, i1.paymentHash.toArray) + statement.setBytes(2, pr1.paymentPreimage.toArray) + statement.setString(3, PaymentRequest.write(i1)) + statement.setLong(4, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].amount.toLong) + statement.setLong(5, pr1.createdAt) + statement.setLong(6, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].receivedAt) + statement.executeUpdate() + } - val s3 = s2.copy(id = UUID.randomUUID(), amount = 88776655 msat) - db.addOutgoingPayment(s3) + using(connection.prepareStatement("INSERT INTO received_payments (payment_hash, preimage, payment_request, created_at, expire_at) VALUES (?, ?, ?, ?, ?)")) { statement => + statement.setBytes(1, i2.paymentHash.toArray) + statement.setBytes(2, pr2.paymentPreimage.toArray) + statement.setString(3, PaymentRequest.write(i2)) + statement.setLong(4, pr2.createdAt) + statement.setLong(5, (i2.timestamp + i2.expiry.get).seconds.toMillis) + statement.executeUpdate() + } - db.updateOutgoingPayment(s3.id, FAILED) - assert(db.getOutgoingPayment(s3.id).get.status == FAILED) - assert(db.getOutgoingPayment(s3.id).get.preimage.isEmpty) // failed sent payments don't have a preimage - assert(db.getOutgoingPayment(s3.id).get.completedAt.isDefined) + val preMigrationDb = new SqlitePaymentsDb(connection) - // can't update again once it's in a final state - assertThrows[IllegalArgumentException](db.updateOutgoingPayment(s3.id, SUCCEEDED)) + using(connection.createStatement()) { statement => + assert(getVersion(statement, "payments", 2) == 3) // version changed from 2 -> 3 + } + + assert(preMigrationDb.getIncomingPayment(i1.paymentHash) === Some(pr1)) + assert(preMigrationDb.getIncomingPayment(i2.paymentHash) === Some(pr2)) + assert(preMigrationDb.listOutgoingPayments(1, 2000) === Seq(ps1, ps2, ps3)) - db.updateOutgoingPayment(s1.id, SUCCEEDED, Some(ByteVector32.One)) - assert(db.getOutgoingPayment(s1.id).get.preimage.isDefined) - assert(db.getOutgoingPayment(s1.id).get.completedAt.isDefined) + val postMigrationDb = new SqlitePaymentsDb(connection) + + using(connection.createStatement()) { statement => + assert(getVersion(statement, "payments", 3) == 3) // version still to 3 + } + + val i3 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), paymentHash3, alicePriv, "invoice #3", expirySeconds = Some(30)) + val pr3 = IncomingPayment(i3, preimage3, i3.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending) + postMigrationDb.addIncomingPayment(i3, pr3.paymentPreimage) + + val ps4 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("1"), randomBytes32, 123 msat, alice, 1100, Some(i3), OutgoingPaymentStatus.Pending) + val ps5 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("2"), randomBytes32, 456 msat, bob, 1150, Some(i2), OutgoingPaymentStatus.Succeeded(preimage1, 42 msat, Nil, 1180)) + val ps6 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), Some("3"), randomBytes32, 789 msat, bob, 1250, None, OutgoingPaymentStatus.Failed(Nil, 1300)) + postMigrationDb.addOutgoingPayment(ps4) + postMigrationDb.addOutgoingPayment(ps5.copy(status = OutgoingPaymentStatus.Pending)) + postMigrationDb.updateOutgoingPayment(PaymentSent(ps5.parentId, ps5.paymentHash, preimage1, Seq(PaymentSent.PartialPayment(ps5.id, ps5.amount, 42 msat, randomBytes32, None, 1180)))) + postMigrationDb.addOutgoingPayment(ps6.copy(status = OutgoingPaymentStatus.Pending)) + postMigrationDb.updateOutgoingPayment(PaymentFailed(ps6.id, ps6.paymentHash, Nil, 1300)) + + assert(postMigrationDb.listOutgoingPayments(1, 2000) === Seq(ps1, ps2, ps3, ps4, ps5, ps6)) + assert(postMigrationDb.listIncomingPayments(1, Platform.currentTime) === Seq(pr1, pr2, pr3)) + assert(postMigrationDb.listExpiredIncomingPayments(1, 2000) === Seq(pr2)) } - test("add/retrieve payment requests") { + test("add/retrieve/update incoming payments") { + val sqlite = TestConstants.sqliteInMemory() + val db = new SqlitePaymentsDb(sqlite) + + // can't receive a payment without an invoice associated with it + assertThrows[IllegalArgumentException](db.receiveIncomingPayment(randomBytes32, 12345678 msat)) + + val expiredInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32, alicePriv, "invoice #1", timestamp = 1) + val expiredInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32, bobPriv, "invoice #2", timestamp = 2, expirySeconds = Some(30)) + val expiredPayment1 = IncomingPayment(expiredInvoice1, randomBytes32, expiredInvoice1.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired) + val expiredPayment2 = IncomingPayment(expiredInvoice2, randomBytes32, expiredInvoice2.timestamp.seconds.toMillis, IncomingPaymentStatus.Expired) + + val pendingInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32, alicePriv, "invoice #3") + val pendingInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32, bobPriv, "invoice #4", expirySeconds = Some(30)) + val pendingPayment1 = IncomingPayment(pendingInvoice1, randomBytes32, pendingInvoice1.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending) + val pendingPayment2 = IncomingPayment(pendingInvoice2, randomBytes32, pendingInvoice2.timestamp.seconds.toMillis, IncomingPaymentStatus.Pending) + + val paidInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32, alicePriv, "invoice #5") + val paidInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32, bobPriv, "invoice #6", expirySeconds = Some(60)) + val receivedAt1 = Platform.currentTime + 1 + val payment1 = IncomingPayment(paidInvoice1, randomBytes32, paidInvoice1.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(561 msat, receivedAt1)) + val receivedAt2 = Platform.currentTime + 2 + val payment2 = IncomingPayment(paidInvoice2, randomBytes32, paidInvoice2.timestamp.seconds.toMillis, IncomingPaymentStatus.Received(1111 msat, receivedAt2)) + + db.addIncomingPayment(pendingInvoice1, pendingPayment1.paymentPreimage) + db.addIncomingPayment(pendingInvoice2, pendingPayment2.paymentPreimage) + db.addIncomingPayment(expiredInvoice1, expiredPayment1.paymentPreimage) + db.addIncomingPayment(expiredInvoice2, expiredPayment2.paymentPreimage) + db.addIncomingPayment(paidInvoice1, payment1.paymentPreimage) + db.addIncomingPayment(paidInvoice2, payment2.paymentPreimage) + + assert(db.getIncomingPayment(pendingInvoice1.paymentHash) === Some(pendingPayment1)) + assert(db.getIncomingPayment(expiredInvoice2.paymentHash) === Some(expiredPayment2)) + assert(db.getIncomingPayment(paidInvoice1.paymentHash) === Some(payment1.copy(status = IncomingPaymentStatus.Pending))) + + val now = Platform.currentTime + assert(db.listIncomingPayments(0, now) === Seq(expiredPayment1, expiredPayment2, pendingPayment1, pendingPayment2, payment1.copy(status = IncomingPaymentStatus.Pending), payment2.copy(status = IncomingPaymentStatus.Pending))) + assert(db.listExpiredIncomingPayments(0, now) === Seq(expiredPayment1, expiredPayment2)) + assert(db.listReceivedIncomingPayments(0, now) === Nil) + assert(db.listPendingIncomingPayments(0, now) === Seq(pendingPayment1, pendingPayment2, payment1.copy(status = IncomingPaymentStatus.Pending), payment2.copy(status = IncomingPaymentStatus.Pending))) + + db.receiveIncomingPayment(paidInvoice1.paymentHash, 561 msat, receivedAt1) + db.receiveIncomingPayment(paidInvoice2.paymentHash, 1111 msat, receivedAt2) + + assert(db.getIncomingPayment(paidInvoice1.paymentHash) === Some(payment1)) + assert(db.listIncomingPayments(0, now) === Seq(expiredPayment1, expiredPayment2, pendingPayment1, pendingPayment2, payment1, payment2)) + assert(db.listIncomingPayments(now - 60.seconds.toMillis, now) === Seq(pendingPayment1, pendingPayment2, payment1, payment2)) + assert(db.listPendingIncomingPayments(0, now) === Seq(pendingPayment1, pendingPayment2)) + assert(db.listReceivedIncomingPayments(0, now) === Seq(payment1, payment2)) + } - val someTimestamp = 12345 + test("add/retrieve/update outgoing payments") { val db = new SqlitePaymentsDb(TestConstants.sqliteInMemory()) - val bob = Bob.keyManager + val parentId = UUID.randomUUID() + val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(123 msat), paymentHash1, davePriv, "Some invoice", expirySeconds = None, timestamp = 0) + val s1 = OutgoingPayment(UUID.randomUUID(), parentId, None, paymentHash1, 123 msat, alice, 100, Some(i1), OutgoingPaymentStatus.Pending) + val s2 = OutgoingPayment(UUID.randomUUID(), parentId, Some("1"), paymentHash1, 456 msat, bob, 200, None, OutgoingPaymentStatus.Pending) - val (paymentHash1, paymentHash2) = (randomBytes32, randomBytes32) + assert(db.listOutgoingPayments(0, Platform.currentTime).isEmpty) + db.addOutgoingPayment(s1) + db.addOutgoingPayment(s2) - val i1 = PaymentRequest(chainHash = Block.TestnetGenesisBlock.hash, amount = Some(123 msat), paymentHash = paymentHash1, privateKey = bob.nodeKey.privateKey, description = "Some invoice", expirySeconds = None, timestamp = someTimestamp) - val i2 = PaymentRequest(chainHash = Block.TestnetGenesisBlock.hash, amount = None, paymentHash = paymentHash2, privateKey = bob.nodeKey.privateKey, description = "Some invoice", expirySeconds = Some(123456), timestamp = Platform.currentTime.milliseconds.toSeconds) + // can't add an outgoing payment in non-pending state + assertThrows[IllegalArgumentException](db.addOutgoingPayment(s1.copy(status = OutgoingPaymentStatus.Succeeded(randomBytes32, 0 msat, Nil, 110)))) - // i2 doesn't expire - assert(i1.expiry.isEmpty && i2.expiry.isDefined) - assert(i1.amount.isDefined && i2.amount.isEmpty) + assert(db.listOutgoingPayments(1, 300).toList == Seq(s1, s2)) + assert(db.listOutgoingPayments(1, 150).toList == Seq(s1)) + assert(db.listOutgoingPayments(150, 250).toList == Seq(s2)) + assert(db.getOutgoingPayment(s1.id) === Some(s1)) + assert(db.getOutgoingPayment(UUID.randomUUID()) === None) + assert(db.listOutgoingPayments(s2.paymentHash) === Seq(s1, s2)) + assert(db.listOutgoingPayments(s1.id) === Nil) + assert(db.listOutgoingPayments(parentId) === Seq(s1, s2)) + assert(db.listOutgoingPayments(ByteVector32.Zeroes) === Nil) - db.addPaymentRequest(i1, ByteVector32.Zeroes) - db.addPaymentRequest(i2, ByteVector32.One) + val s3 = s2.copy(id = UUID.randomUUID(), amount = 789 msat, createdAt = 300) + val s4 = s2.copy(id = UUID.randomUUID(), createdAt = 300) + db.addOutgoingPayment(s3) + db.addOutgoingPayment(s4) - // order matters, i2 has a more recent timestamp than i1 - assert(db.listPaymentRequests(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) == Seq(i2, i1)) - assert(db.getPaymentRequest(i1.paymentHash) == Some(i1)) - assert(db.getPaymentRequest(i2.paymentHash) == Some(i2)) + db.updateOutgoingPayment(PaymentFailed(s3.id, s3.paymentHash, Nil, 310)) + val ss3 = s3.copy(status = OutgoingPaymentStatus.Failed(Nil, 310)) + assert(db.getOutgoingPayment(s3.id) === Some(ss3)) + db.updateOutgoingPayment(PaymentFailed(s4.id, s4.paymentHash, Seq(LocalFailure(new RuntimeException("woops")), RemoteFailure(Seq(hop_ab, hop_bc), Sphinx.DecryptedFailurePacket(carol, UnknownNextPeer))), 320)) + val ss4 = s4.copy(status = OutgoingPaymentStatus.Failed(Seq(FailureSummary(FailureType.LOCAL, "woops", Nil), FailureSummary(FailureType.REMOTE, "processing node does not know the next peer in the route", List(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, Some(ShortChannelId(43)))))), 320)) + assert(db.getOutgoingPayment(s4.id) === Some(ss4)) - assert(db.listPendingPaymentRequests(0, (Platform.currentTime.milliseconds + 1.minute).toSeconds) == Seq(i2, i1)) - assert(db.getPendingPaymentRequestAndPreimage(paymentHash1) == Some((ByteVector32.Zeroes, i1))) - assert(db.getPendingPaymentRequestAndPreimage(paymentHash2) == Some((ByteVector32.One, i2))) + // can't update again once it's in a final state + assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentSent(parentId, s3.paymentHash, preimage1, Seq(PaymentSent.PartialPayment(s3.id, s3.amount, 42 msat, randomBytes32, None))))) + + val paymentSent = PaymentSent(parentId, paymentHash1, preimage1, Seq( + PaymentSent.PartialPayment(s1.id, s1.amount, 15 msat, randomBytes32, None, 400), + PaymentSent.PartialPayment(s2.id, s2.amount, 20 msat, randomBytes32, Some(Seq(hop_ab, hop_bc)), 410) + )) + val ss1 = s1.copy(status = OutgoingPaymentStatus.Succeeded(preimage1, 15 msat, Nil, 400)) + val ss2 = s2.copy(status = OutgoingPaymentStatus.Succeeded(preimage1, 20 msat, Seq(HopSummary(alice, bob, Some(ShortChannelId(42))), HopSummary(bob, carol, Some(ShortChannelId(43)))), 410)) + db.updateOutgoingPayment(paymentSent) + assert(db.getOutgoingPayment(s1.id) === Some(ss1)) + assert(db.getOutgoingPayment(s2.id) === Some(ss2)) + assert(db.listOutgoingPayments(parentId) === Seq(ss1, ss2, ss3, ss4)) - val from = (someTimestamp - 100).seconds.toSeconds - val to = (someTimestamp + 100).seconds.toSeconds - assert(db.listPaymentRequests(from, to) == Seq(i1)) + // can't update again once it's in a final state + assertThrows[IllegalArgumentException](db.updateOutgoingPayment(PaymentFailed(s1.id, s1.paymentHash, Nil))) } } + +object SqlitePaymentsDbSpec { + val (alicePriv, bobPriv, carolPriv, davePriv) = (randomKey, randomKey, randomKey, randomKey) + val (alice, bob, carol, dave) = (alicePriv.publicKey, bobPriv.publicKey, carolPriv.publicKey, davePriv.publicKey) + val hop_ab = Hop(alice, bob, ChannelUpdate(randomBytes64, randomBytes32, ShortChannelId(42), 1, 0, 0, CltvExpiryDelta(12), 1 msat, 1 msat, 1, None)) + val hop_bc = Hop(bob, carol, ChannelUpdate(randomBytes64, randomBytes32, ShortChannelId(43), 1, 0, 0, CltvExpiryDelta(12), 1 msat, 1 msat, 1, None)) + val (preimage1, preimage2, preimage3, preimage4) = (randomBytes32, randomBytes32, randomBytes32, randomBytes32) + val (paymentHash1, paymentHash2, paymentHash3, paymentHash4) = (Crypto.sha256(preimage1), Crypto.sha256(preimage2), Crypto.sha256(preimage3), Crypto.sha256(preimage4)) +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala new file mode 100644 index 0000000000..8bafb8fcb1 --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala @@ -0,0 +1,77 @@ +/* + * Copyright 2019 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.db + +import fr.acinq.eclair.TestConstants +import fr.acinq.eclair.db.sqlite.SqliteUtils.using +import org.scalatest.FunSuite +import org.sqlite.SQLiteException + +class SqliteUtilsSpec extends FunSuite { + + test("using with auto-commit disabled") { + val conn = TestConstants.sqliteInMemory() + + using(conn.createStatement()) { statement => + statement.executeUpdate("CREATE TABLE utils_test (id INTEGER NOT NULL PRIMARY KEY, updated_at INTEGER)") + statement.executeUpdate("INSERT INTO utils_test VALUES (1, 1)") + statement.executeUpdate("INSERT INTO utils_test VALUES (2, 2)") + } + + using(conn.createStatement()) { statement => + val results = statement.executeQuery("SELECT * FROM utils_test ORDER BY id") + assert(results.next()) + assert(results.getLong("id") === 1) + assert(results.next()) + assert(results.getLong("id") === 2) + assert(!results.next()) + } + + assertThrows[SQLiteException](using(conn.createStatement(), inTransaction = true) { statement => + statement.executeUpdate("INSERT INTO utils_test VALUES (3, 3)") + statement.executeUpdate("INSERT INTO utils_test VALUES (1, 3)") // should throw (primary key violation) + }) + + using(conn.createStatement()) { statement => + val results = statement.executeQuery("SELECT * FROM utils_test ORDER BY id") + assert(results.next()) + assert(results.getLong("id") === 1) + assert(results.next()) + assert(results.getLong("id") === 2) + assert(!results.next()) + } + + using(conn.createStatement(), inTransaction = true) { statement => + statement.executeUpdate("INSERT INTO utils_test VALUES (3, 3)") + statement.executeUpdate("INSERT INTO utils_test VALUES (4, 4)") + } + + using(conn.createStatement()) { statement => + val results = statement.executeQuery("SELECT * FROM utils_test ORDER BY id") + assert(results.next()) + assert(results.getLong("id") === 1) + assert(results.next()) + assert(results.getLong("id") === 2) + assert(results.next()) + assert(results.getLong("id") === 3) + assert(results.next()) + assert(results.getLong("id") === 4) + assert(!results.next()) + } + } + +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala index 92dc0bae4f..4e9fe9ffed 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala @@ -36,7 +36,7 @@ import fr.acinq.eclair.io.Peer import fr.acinq.eclair.io.Peer.{Disconnect, PeerRoutingMessage} import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest import fr.acinq.eclair.payment.PaymentLifecycle.{State => _, _} -import fr.acinq.eclair.payment.{LocalPaymentHandler, PaymentRequest} +import fr.acinq.eclair.payment._ import fr.acinq.eclair.router.Graph.WeightRatios import fr.acinq.eclair.router.Router.ROUTE_MAX_LENGTH import fr.acinq.eclair.router.{Announcements, AnnouncementsBatchValidationSpec, PublicChannel, RouteParams} @@ -51,11 +51,9 @@ import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} import scodec.bits.ByteVector import scala.collection.JavaConversions._ -import scala.compat.Platform import scala.concurrent.Await import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ -import scala.util.Try /** * Created by PM on 15/03/2017. @@ -267,7 +265,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService // then we make the actual payment sender.send(nodes("A").paymentInitiator, SendPaymentRequest(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 1)) val paymentId = sender.expectMsgType[UUID](5 seconds) - val ps = sender.expectMsgType[PaymentSucceeded](5 seconds) + val ps = sender.expectMsgType[PaymentSent](5 seconds) assert(ps.id == paymentId) } @@ -293,7 +291,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.send(nodes("A").paymentInitiator, sendReq) // A will receive an error from B that include the updated channel update, then will retry the payment val paymentId = sender.expectMsgType[UUID](5 seconds) - val ps = sender.expectMsgType[PaymentSucceeded](5 seconds) + val ps = sender.expectMsgType[PaymentSent](5 seconds) assert(ps.id == paymentId) def updateFor(n: PublicKey, pc: PublicChannel): Option[ChannelUpdate] = if (n == pc.ann.nodeId1) pc.update_1_opt else if (n == pc.ann.nodeId2) pc.update_2_opt else throw new IllegalArgumentException("this node is unrelated to this channel") @@ -333,7 +331,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.send(nodes("A").paymentInitiator, sendReq) // A will first receive an error from C, then retry and route around C: A->B->E->C->D sender.expectMsgType[UUID](5 seconds) - sender.expectMsgType[PaymentSucceeded] // the payment FSM will also reply to the sender after the payment is completed + sender.expectMsgType[PaymentSent] // the payment FSM will also reply to the sender after the payment is completed } test("send an HTLC A->D with an unknown payment hash") { @@ -414,7 +412,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val sendReq = SendPaymentRequest(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) sender.send(nodes("A").paymentInitiator, sendReq) sender.expectMsgType[UUID] - sender.expectMsgType[PaymentSucceeded] // the payment FSM will also reply to the sender after the payment is completed + sender.expectMsgType[PaymentSent] // the payment FSM will also reply to the sender after the payment is completed } } @@ -430,19 +428,16 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.expectMsgType[UUID](max = 60 seconds) awaitCond({ - sender.expectMsgType[PaymentResult](10 seconds) match { - case PaymentFailed(_, _, failures) => failures == Seq.empty // if something went wrong fail with a hint - case PaymentSucceeded(_, _, _, _, route) => route.exists(_.nodeId == nodes("G").nodeParams.nodeId) + sender.expectMsgType[PaymentEvent](10 seconds) match { + case PaymentFailed(_, _, failures, _) => failures == Seq.empty // if something went wrong fail with a hint + case PaymentSent(_, _, _, part :: Nil) => part.route.get.exists(_.nodeId == nodes("G").nodeParams.nodeId) + case _ => false } }, max = 30 seconds, interval = 10 seconds) } - /** * We currently use p2pkh script Helpers.getFinalScriptPubKey - * - * @param scriptPubKey - * @return */ def scriptPubKeyToAddress(scriptPubKey: ByteVector) = Script.parse(scriptPubKey) match { case OP_DUP :: OP_HASH160 :: OP_PUSHDATA(pubKeyHash, _) :: OP_EQUALVERIFY :: OP_CHECKSIG :: Nil => @@ -508,7 +503,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.send(bitcoincli, BitcoinReq("generate", 1)) sender.expectMsgType[JValue](10 seconds) // C will extract the preimage from the blockchain and fulfill the payment upstream - paymentSender.expectMsgType[PaymentSucceeded](30 seconds) + paymentSender.expectMsgType[PaymentSent](30 seconds) // at this point F should have 1 recv transactions: the redeemed htlc awaitCond({ sender.send(bitcoincli, BitcoinReq("listreceivedbyaddress", 0)) @@ -590,7 +585,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.send(bitcoincli, BitcoinReq("generate", 1)) sender.expectMsgType[JValue](10 seconds) // C will extract the preimage from the blockchain and fulfill the payment upstream - paymentSender.expectMsgType[PaymentSucceeded](30 seconds) + paymentSender.expectMsgType[PaymentSent](30 seconds) // at this point F should have 1 recv transactions: the redeemed htlc // we then generate enough blocks so that F gets its htlc-success delayed output sender.send(bitcoincli, BitcoinReq("generate", 145)) @@ -773,7 +768,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService forwardHandlerF.forward(paymentHandlerF) sigListener.expectMsgType[ChannelSignatureReceived] sigListener.expectMsgType[ChannelSignatureReceived] - sender.expectMsgType[PaymentSucceeded].id === paymentId + sender.expectMsgType[PaymentSent].id === paymentId // we now send a few htlcs C->F and F->C in order to obtain a commitments with multiple htlcs def send(amountMsat: MilliSatoshi, paymentHandler: ActorRef, paymentInitiator: ActorRef) = { @@ -813,19 +808,19 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService buffer.expectMsgType[UpdateAddHtlc] buffer.forward(paymentHandlerF) sigListener.expectMsgType[ChannelSignatureReceived] - val preimage1 = sender.expectMsgType[PaymentSucceeded].paymentPreimage + val preimage1 = sender.expectMsgType[PaymentSent].paymentPreimage buffer.expectMsgType[UpdateAddHtlc] buffer.forward(paymentHandlerF) sigListener.expectMsgType[ChannelSignatureReceived] - sender.expectMsgType[PaymentSucceeded].paymentPreimage + sender.expectMsgType[PaymentSent].paymentPreimage buffer.expectMsgType[UpdateAddHtlc] buffer.forward(paymentHandlerC) sigListener.expectMsgType[ChannelSignatureReceived] - sender.expectMsgType[PaymentSucceeded].paymentPreimage + sender.expectMsgType[PaymentSent].paymentPreimage buffer.expectMsgType[UpdateAddHtlc] buffer.forward(paymentHandlerC) sigListener.expectMsgType[ChannelSignatureReceived] - sender.expectMsgType[PaymentSucceeded].paymentPreimage + sender.expectMsgType[PaymentSent].paymentPreimage // this also allows us to get the channel id val channelId = commitmentsF.channelId // we also retrieve C's default final address diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentHandlerSpec.scala index b142e521f5..a66733a12b 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentHandlerSpec.scala @@ -19,10 +19,12 @@ package fr.acinq.eclair.payment import akka.actor.ActorSystem import akka.actor.Status.Failure import akka.testkit.{TestActorRef, TestKit, TestProbe} -import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.bitcoin.{ByteVector32, Crypto} import fr.acinq.eclair.TestConstants.Alice import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC} +import fr.acinq.eclair.db.IncomingPaymentStatus import fr.acinq.eclair.payment.PaymentLifecycle.ReceivePayment +import fr.acinq.eclair.payment.PaymentReceived.PartialPayment import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.wire.{IncorrectOrUnknownPaymentDetails, UpdateAddHtlc} import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, ShortChannelId, TestConstants, randomKey} @@ -50,42 +52,48 @@ class PaymentHandlerSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { sender.send(handler, ReceivePayment(Some(amountMsat), "1 coffee")) val pr = sender.expectMsgType[PaymentRequest] - assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).isEmpty) - assert(nodeParams.db.payments.getPendingPaymentRequestAndPreimage(pr.paymentHash).isDefined) - assert(!nodeParams.db.payments.getPendingPaymentRequestAndPreimage(pr.paymentHash).get._2.isExpired) + val incoming = nodeParams.db.payments.getIncomingPayment(pr.paymentHash) + assert(incoming.isDefined) + assert(incoming.get.status === IncomingPaymentStatus.Pending) + assert(!incoming.get.paymentRequest.isExpired) + assert(Crypto.sha256(incoming.get.paymentPreimage) === pr.paymentHash) val add = UpdateAddHtlc(ByteVector32(ByteVector.fill(32)(1)), 0, amountMsat, pr.paymentHash, expiry, TestConstants.emptyOnionPacket) sender.send(handler, add) sender.expectMsgType[CMD_FULFILL_HTLC] val paymentRelayed = eventListener.expectMsgType[PaymentReceived] - assert(paymentRelayed.copy(timestamp = 0) === PaymentReceived(amountMsat, add.paymentHash, add.channelId, timestamp = 0)) - assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).exists(_.paymentHash == pr.paymentHash)) + assert(paymentRelayed.copy(parts = paymentRelayed.parts.map(_.copy(timestamp = 0))) === PaymentReceived(add.paymentHash, PartialPayment(amountMsat, add.channelId, timestamp = 0) :: Nil)) + val received = nodeParams.db.payments.getIncomingPayment(pr.paymentHash) + assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0) === IncomingPaymentStatus.Received(amountMsat, 0)) } { sender.send(handler, ReceivePayment(Some(amountMsat), "another coffee")) val pr = sender.expectMsgType[PaymentRequest] - assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).isEmpty) + assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).get.status === IncomingPaymentStatus.Pending) val add = UpdateAddHtlc(ByteVector32(ByteVector.fill(32)(1)), 0, amountMsat, pr.paymentHash, expiry, TestConstants.emptyOnionPacket) sender.send(handler, add) sender.expectMsgType[CMD_FULFILL_HTLC] val paymentRelayed = eventListener.expectMsgType[PaymentReceived] - assert(paymentRelayed.copy(timestamp = 0) === PaymentReceived(amountMsat, add.paymentHash, add.channelId, timestamp = 0)) - assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).exists(_.paymentHash == pr.paymentHash)) + assert(paymentRelayed.copy(parts = paymentRelayed.parts.map(_.copy(timestamp = 0))) === PaymentReceived(add.paymentHash, PartialPayment(amountMsat, add.channelId, timestamp = 0) :: Nil)) + val received = nodeParams.db.payments.getIncomingPayment(pr.paymentHash) + assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0) === IncomingPaymentStatus.Received(amountMsat, 0)) } { sender.send(handler, ReceivePayment(Some(amountMsat), "bad expiry")) val pr = sender.expectMsgType[PaymentRequest] - assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).isEmpty) + assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).get.status === IncomingPaymentStatus.Pending) val add = UpdateAddHtlc(ByteVector32(ByteVector.fill(32)(1)), 0, amountMsat, pr.paymentHash, cltvExpiry = CltvExpiryDelta(3).toCltvExpiry(nodeParams.currentBlockHeight), TestConstants.emptyOnionPacket) sender.send(handler, add) assert(sender.expectMsgType[CMD_FAIL_HTLC].reason == Right(IncorrectOrUnknownPaymentDetails(amountMsat, nodeParams.currentBlockHeight))) eventListener.expectNoMsg(300 milliseconds) - assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).isEmpty) + assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).get.status === IncomingPaymentStatus.Pending) } } @@ -168,6 +176,7 @@ class PaymentHandlerSpec extends TestKit(ActorSystem("test")) with FunSuiteLike sender.send(handler, add) sender.expectMsgType[CMD_FAIL_HTLC] - assert(nodeParams.db.payments.getIncomingPayment(pr.paymentHash).isEmpty) + val Some(incoming) = nodeParams.db.payments.getIncomingPayment(pr.paymentHash) + assert(incoming.paymentRequest.isExpired && incoming.status === IncomingPaymentStatus.Expired) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index 7dcd6417e0..66ee7fecc4 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -22,15 +22,17 @@ import akka.actor.FSM.{CurrentState, SubscribeTransitionCallBack, Transition} import akka.actor.Status import akka.testkit.{TestFSMRef, TestProbe} import fr.acinq.bitcoin.Script.{pay2wsh, write} -import fr.acinq.bitcoin.{Block, ByteVector32, Transaction, TxOut} +import fr.acinq.bitcoin.{Block, ByteVector32, Crypto, Transaction, TxOut} import fr.acinq.eclair._ import fr.acinq.eclair.blockchain.{UtxoStatus, ValidateRequest, ValidateResult, WatchSpentBasic} import fr.acinq.eclair.channel.Register.ForwardShortId import fr.acinq.eclair.channel.{AddHtlcFailed, Channel, ChannelUnavailable} import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.db.OutgoingPaymentStatus +import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus} import fr.acinq.eclair.io.Peer.PeerRoutingMessage +import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest import fr.acinq.eclair.payment.PaymentLifecycle._ +import fr.acinq.eclair.payment.PaymentSent.PartialPayment import fr.acinq.eclair.router.Announcements.{makeChannelUpdate, makeNodeAnnouncement} import fr.acinq.eclair.router._ import fr.acinq.eclair.transactions.Scripts @@ -46,14 +48,17 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val defaultAmountMsat = 142000000 msat val defaultExpiryDelta = Channel.MIN_CLTV_EXPIRY_DELTA + val defaultPaymentHash = randomBytes32 + val defaultExternalId = UUID.randomUUID().toString + val defaultPaymentRequest = SendPaymentRequest(defaultAmountMsat, defaultPaymentHash, d, 1, externalId = Some(defaultExternalId)) test("send to route") { fixture => import fixture._ - val defaultPaymentHash = randomBytes32 val nodeParams = TestConstants.Alice.nodeParams.copy(keyManager = testKeyManager) - val paymentDb = nodeParams.db.payments val id = UUID.randomUUID() - val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, id, router, TestProbe().ref)) + val paymentDb = nodeParams.db.payments + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest, paymentDb) + val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, progressHandler, router, TestProbe().ref)) val monitor = TestProbe() val sender = TestProbe() val eventListener = TestProbe() @@ -68,21 +73,23 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val Transition(_, WAITING_FOR_ROUTE, WAITING_FOR_PAYMENT_COMPLETE) = monitor.expectMsgClass(classOf[Transition[_]]) - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.Pending)) + val Some(outgoing) = paymentDb.getOutgoingPayment(id) + assert(outgoing.copy(createdAt = 0) === OutgoingPayment(id, id, Some(defaultExternalId), defaultPaymentHash, defaultAmountMsat, d, 0, None, OutgoingPaymentStatus.Pending)) sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) - sender.expectMsgType[PaymentSucceeded] - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.SUCCEEDED)) + sender.expectMsgType[PaymentSent] + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Succeeded])) } test("payment failed (route not found)") { fixture => import fixture._ - val defaultPaymentHash = randomBytes32 val nodeParams = TestConstants.Alice.nodeParams.copy(keyManager = testKeyManager) val paymentDb = nodeParams.db.payments val id = UUID.randomUUID() + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest.copy(targetNodeId = f), paymentDb) val routerForwarder = TestProbe() - val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, id, routerForwarder.ref, TestProbe().ref)) + val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, progressHandler, routerForwarder.ref, TestProbe().ref)) val monitor = TestProbe() val sender = TestProbe() @@ -93,11 +100,11 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val routeRequest = routerForwarder.expectMsgType[RouteRequest] - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) routerForwarder.forward(router, routeRequest) - sender.expectMsg(PaymentFailed(id, request.paymentHash, LocalFailure(RouteNotFound) :: Nil)) - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.FAILED)) + assert(sender.expectMsgType[PaymentFailed].failures === LocalFailure(RouteNotFound) :: Nil) + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Failed])) } test("payment failed (route too expensive)") { fixture => @@ -105,30 +112,31 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val nodeParams = TestConstants.Alice.nodeParams.copy(keyManager = testKeyManager) val paymentDb = nodeParams.db.payments val id = UUID.randomUUID() - val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, id, router, TestProbe().ref)) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest, paymentDb) + val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, progressHandler, router, TestProbe().ref)) val monitor = TestProbe() val sender = TestProbe() paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 5, routeParams = Some(RouteParams(randomize = false, maxFeeBase = 100 msat, maxFeePct = 0.0, routeMaxLength = 20, routeMaxCltv = CltvExpiryDelta(2016), ratios = None))) + val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 5, routeParams = Some(RouteParams(randomize = false, maxFeeBase = 100 msat, maxFeePct = 0.0, routeMaxLength = 20, routeMaxCltv = CltvExpiryDelta(2016), ratios = None))) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val Seq(LocalFailure(RouteNotFound)) = sender.expectMsgType[PaymentFailed].failures - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.FAILED)) + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Failed])) } test("payment failed (unparsable failure)") { fixture => import fixture._ - val defaultPaymentHash = randomBytes32 val nodeParams = TestConstants.Alice.nodeParams.copy(keyManager = testKeyManager) val paymentDb = nodeParams.db.payments val relayer = TestProbe() val routerForwarder = TestProbe() val id = UUID.randomUUID() - val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, id, routerForwarder.ref, relayer.ref)) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest, paymentDb) + val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, progressHandler, routerForwarder.ref, relayer.ref)) val monitor = TestProbe() val sender = TestProbe() @@ -137,7 +145,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 2) sender.send(paymentFSM, request) - awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) + awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) @@ -160,8 +168,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, UpdateFailHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) // unparsable message // we allow 2 tries, so we send a 2nd request to the router - sender.expectMsg(PaymentFailed(id, request.paymentHash, UnreadableRemoteFailure(hops) :: UnreadableRemoteFailure(hops) :: Nil)) - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.FAILED)) // after last attempt the payment is failed + assert(sender.expectMsgType[PaymentFailed].failures === UnreadableRemoteFailure(hops) :: UnreadableRemoteFailure(hops) :: Nil) + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Failed])) // after last attempt the payment is failed } test("payment failed (local error)") { fixture => @@ -171,16 +179,17 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val relayer = TestProbe() val routerForwarder = TestProbe() val id = UUID.randomUUID() - val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, id, routerForwarder.ref, relayer.ref)) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest, paymentDb) + val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, progressHandler, routerForwarder.ref, relayer.ref)) val monitor = TestProbe() val sender = TestProbe() paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 2) + val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 2) sender.send(paymentFSM, request) - awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) + awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData routerForwarder.expectMsg(RouteRequest(nodeParams.nodeId, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) @@ -193,18 +202,18 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // then the payment lifecycle will ask for a new route excluding the channel routerForwarder.expectMsg(RouteRequest(nodeParams.nodeId, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set(ChannelDesc(channelId_ab, a, b)))) - awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) // payment is still pending because the error is recoverable + awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) // payment is still pending because the error is recoverable } test("payment failed (first hop returns an UpdateFailMalformedHtlc)") { fixture => import fixture._ - val defaultPaymentHash = randomBytes32 val nodeParams = TestConstants.Alice.nodeParams.copy(keyManager = testKeyManager) val paymentDb = nodeParams.db.payments val relayer = TestProbe() val routerForwarder = TestProbe() val id = UUID.randomUUID() - val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, id, routerForwarder.ref, relayer.ref)) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest, paymentDb) + val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, progressHandler, routerForwarder.ref, relayer.ref)) val monitor = TestProbe() val sender = TestProbe() @@ -213,7 +222,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 2) sender.send(paymentFSM, request) - awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) + awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData routerForwarder.expectMsg(RouteRequest(nodeParams.nodeId, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) @@ -226,7 +235,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // then the payment lifecycle will ask for a new route excluding the channel routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set(ChannelDesc(channelId_ab, a, b)))) - awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) + awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) } test("payment failed (TemporaryChannelFailure)") { fixture => @@ -235,14 +244,15 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val relayer = TestProbe() val routerForwarder = TestProbe() val id = UUID.randomUUID() - val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, id, routerForwarder.ref, relayer.ref)) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest, nodeParams.db.payments) + val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, progressHandler, routerForwarder.ref, relayer.ref)) val monitor = TestProbe() val sender = TestProbe() paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 2) + val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData @@ -265,7 +275,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { routerForwarder.expectMsg(RouteRequest(a, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) routerForwarder.forward(router) // we allow 2 tries, so we send a 2nd request to the router - sender.expectMsg(PaymentFailed(id, request.paymentHash, RemoteFailure(hops, Sphinx.DecryptedFailurePacket(b, failure)) :: LocalFailure(RouteNotFound) :: Nil)) + assert(sender.expectMsgType[PaymentFailed].failures === RemoteFailure(hops, Sphinx.DecryptedFailurePacket(b, failure)) :: LocalFailure(RouteNotFound) :: Nil) } test("payment failed (Update)") { fixture => @@ -275,16 +285,17 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val relayer = TestProbe() val routerForwarder = TestProbe() val id = UUID.randomUUID() - val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, id, routerForwarder.ref, relayer.ref)) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest, paymentDb) + val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, progressHandler, routerForwarder.ref, relayer.ref)) val monitor = TestProbe() val sender = TestProbe() paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 5) + val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 5) sender.send(paymentFSM, request) - awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) + awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData routerForwarder.expectMsg(RouteRequest(nodeParams.nodeId, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) @@ -301,7 +312,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { // payment lifecycle forwards the embedded channelUpdate to the router routerForwarder.expectMsg(channelUpdate_bc_modified) - awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) // 1 failure but not final, the payment is still PENDING + awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) // 1 failure but not final, the payment is still PENDING routerForwarder.expectMsg(RouteRequest(nodeParams.nodeId, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) routerForwarder.forward(router) @@ -326,8 +337,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { routerForwarder.forward(router) // this time the router can't find a route: game over - sender.expectMsg(PaymentFailed(id, request.paymentHash, RemoteFailure(hops, Sphinx.DecryptedFailurePacket(b, failure)) :: RemoteFailure(hops2, Sphinx.DecryptedFailurePacket(b, failure2)) :: LocalFailure(RouteNotFound) :: Nil)) - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.FAILED)) + assert(sender.expectMsgType[PaymentFailed].failures === RemoteFailure(hops, Sphinx.DecryptedFailurePacket(b, failure)) :: RemoteFailure(hops2, Sphinx.DecryptedFailurePacket(b, failure2)) :: LocalFailure(RouteNotFound) :: Nil) + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Failed])) } def testPermanentFailure(fixture: FixtureParam, failure: FailureMessage): Unit = { @@ -337,16 +348,17 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val relayer = TestProbe() val routerForwarder = TestProbe() val id = UUID.randomUUID() - val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, id, routerForwarder.ref, relayer.ref)) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest, paymentDb) + val paymentFSM = TestFSMRef(new PaymentLifecycle(nodeParams, progressHandler, routerForwarder.ref, relayer.ref)) val monitor = TestProbe() val sender = TestProbe() paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 2) + val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 2) sender.send(paymentFSM, request) - awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) + awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData routerForwarder.expectMsg(RouteRequest(nodeParams.nodeId, d, defaultAmountMsat, assistedRoutes = Nil, ignoreNodes = Set.empty, ignoreChannels = Set.empty)) @@ -363,8 +375,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { routerForwarder.forward(router) // we allow 2 tries, so we send a 2nd request to the router, which won't find another route - sender.expectMsg(PaymentFailed(id, request.paymentHash, RemoteFailure(hops, Sphinx.DecryptedFailurePacket(b, failure)) :: LocalFailure(RouteNotFound) :: Nil)) - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.FAILED)) + assert(sender.expectMsgType[PaymentFailed].failures === RemoteFailure(hops, Sphinx.DecryptedFailurePacket(b, failure)) :: LocalFailure(RouteNotFound) :: Nil) + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Failed])) } test("payment failed (PermanentChannelFailure)") { fixture => @@ -379,11 +391,13 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("payment succeeded") { fixture => import fixture._ - val defaultPaymentHash = randomBytes32 + val paymentPreimage = randomBytes32 + val paymentHash = Crypto.sha256(paymentPreimage) val nodeParams = TestConstants.Alice.nodeParams.copy(keyManager = testKeyManager) val paymentDb = nodeParams.db.payments val id = UUID.randomUUID() - val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, id, router, TestProbe().ref)) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(id, defaultPaymentRequest.copy(paymentHash = paymentHash), paymentDb) + val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, progressHandler, router, TestProbe().ref)) val monitor = TestProbe() val sender = TestProbe() val eventListener = TestProbe() @@ -392,24 +406,26 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 5) + val request = SendPayment(paymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)), maxAttempts = 5) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val Transition(_, WAITING_FOR_ROUTE, WAITING_FOR_PAYMENT_COMPLETE) = monitor.expectMsgClass(classOf[Transition[_]]) - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) - sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) - - val paymentOK = sender.expectMsgType[PaymentSucceeded] - val PaymentSent(_, request.finalPayload.amount, fee, request.paymentHash, paymentOK.paymentPreimage, _, _) = eventListener.expectMsgType[PaymentSent] - assert(fee > 0.msat) - assert(fee === paymentOK.amount - request.finalPayload.amount) - awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.SUCCEEDED)) + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status === OutgoingPaymentStatus.Pending)) + val Some(outgoing) = paymentDb.getOutgoingPayment(id) + assert(outgoing.copy(createdAt = 0) === OutgoingPayment(id, id, Some(defaultExternalId), paymentHash, defaultAmountMsat, d, 0, None, OutgoingPaymentStatus.Pending)) + sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, paymentPreimage)) + + val ps = eventListener.expectMsgType[PaymentSent] + assert(ps.feesPaid > 0.msat) + assert(ps.amount === defaultAmountMsat) + assert(ps.paymentHash === paymentHash) + assert(ps.paymentPreimage === paymentPreimage) + awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status.isInstanceOf[OutgoingPaymentStatus.Succeeded])) } test("payment succeeded to a channel with fees=0") { fixture => import fixture._ import fr.acinq.eclair.randomKey - val defaultPaymentHash = randomBytes32 val nodeParams = TestConstants.Alice.nodeParams.copy(keyManager = testKeyManager) // the network will be a --(1)--> b ---(2)--> c --(3)--> d and e --(4)--> f (we are a) and b -> g has fees=0 // \ @@ -432,7 +448,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { watcher.expectMsgType[WatchSpentBasic] // actual test begins - val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, UUID.randomUUID(), router, TestProbe().ref)) + val progressHandler = PaymentLifecycle.DefaultPaymentProgressHandler(UUID.randomUUID(), defaultPaymentRequest.copy(targetNodeId = g), nodeParams.db.payments) + val paymentFSM = system.actorOf(PaymentLifecycle.props(nodeParams, progressHandler, router, TestProbe().ref)) val monitor = TestProbe() val sender = TestProbe() val eventListener = TestProbe() @@ -451,8 +468,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) - val paymentOK = sender.expectMsgType[PaymentSucceeded] - val PaymentSent(_, request.finalPayload.amount, fee, request.paymentHash, paymentOK.paymentPreimage, _, _) = eventListener.expectMsgType[PaymentSent] + val paymentOK = sender.expectMsgType[PaymentSent] + val PaymentSent(_, request.paymentHash, paymentOK.paymentPreimage, PartialPayment(_, request.finalPayload.amount, fee, ByteVector32.Zeroes, _, _) :: Nil) = eventListener.expectMsgType[PaymentSent] // during the route computation the fees were treated as if they were 1msat but when sending the onion we actually put zero // NB: A -> B doesn't pay fees because it's our direct neighbor @@ -463,7 +480,8 @@ class PaymentLifecycleSpec extends BaseRouterSpec { test("filter errors properly") { _ => val failures = LocalFailure(RouteNotFound) :: RemoteFailure(Hop(a, b, channelUpdate_ab) :: Nil, Sphinx.DecryptedFailurePacket(a, TemporaryNodeFailure)) :: LocalFailure(AddHtlcFailed(ByteVector32.Zeroes, ByteVector32.Zeroes, ChannelUnavailable(ByteVector32.Zeroes), Local(UUID.randomUUID(), None), None, None)) :: LocalFailure(RouteNotFound) :: Nil - val filtered = PaymentLifecycle.transformForUser(failures) + val filtered = PaymentFailure.transformForUser(failures) assert(filtered == LocalFailure(RouteNotFound) :: RemoteFailure(Hop(a, b, channelUpdate_ab) :: Nil, Sphinx.DecryptedFailurePacket(a, TemporaryNodeFailure)) :: LocalFailure(ChannelUnavailable(ByteVector32.Zeroes)) :: Nil) } + } diff --git a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/FxApp.scala b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/FxApp.scala index 163d8c14a9..4958bbc634 100644 --- a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/FxApp.scala +++ b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/FxApp.scala @@ -25,7 +25,6 @@ import fr.acinq.eclair.blockchain.electrum.ElectrumClient.ElectrumEvent import fr.acinq.eclair.channel.ChannelEvent import fr.acinq.eclair.gui.controllers.{MainController, NotificationsController} import fr.acinq.eclair.payment.PaymentEvent -import fr.acinq.eclair.payment.PaymentLifecycle.PaymentResult import fr.acinq.eclair.router.NetworkEvent import grizzled.slf4j.Logging import javafx.application.Preloader.ErrorNotification @@ -42,8 +41,8 @@ import scala.util.{Failure, Success, Try} /** - * Created by PM on 16/08/2016. - */ + * Created by PM on 16/08/2016. + */ class FxApp extends Application with Logging { override def init = { @@ -99,7 +98,6 @@ class FxApp extends Application with Logging { system.eventStream.subscribe(guiUpdater, classOf[ChannelEvent]) system.eventStream.subscribe(guiUpdater, classOf[NetworkEvent]) system.eventStream.subscribe(guiUpdater, classOf[PaymentEvent]) - system.eventStream.subscribe(guiUpdater, classOf[PaymentResult]) system.eventStream.subscribe(guiUpdater, classOf[ZMQEvent]) system.eventStream.subscribe(guiUpdater, classOf[ElectrumEvent]) pKit.completeWith(setup.bootstrap) @@ -137,11 +135,11 @@ class FxApp extends Application with Logging { } /** - * Initialize the notification stage and assign it to the handler class. - * - * @param owner stage owning the notification stage - * @param notifhandlers Handles the notifications - */ + * Initialize the notification stage and assign it to the handler class. + * + * @param owner stage owning the notification stage + * @param notifhandlers Handles the notifications + */ private def initNotificationStage(owner: Stage, notifhandlers: Handlers) = { // get fxml/controller val notifFXML = new FXMLLoader(getClass.getResource("/gui/main/notifications.fxml")) diff --git a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/GUIUpdater.scala b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/GUIUpdater.scala index 2f996aa569..a295326330 100644 --- a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/GUIUpdater.scala +++ b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/GUIUpdater.scala @@ -21,12 +21,11 @@ import java.time.LocalDateTime import akka.actor.{Actor, ActorLogging, ActorRef, Terminated} import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin._ -import fr.acinq.eclair.{CoinUtils, MilliSatoshi} +import fr.acinq.eclair.CoinUtils import fr.acinq.eclair.blockchain.bitcoind.zmq.ZMQActor.{ZMQConnected, ZMQDisconnected} import fr.acinq.eclair.blockchain.electrum.ElectrumClient.{ElectrumDisconnected, ElectrumReady} import fr.acinq.eclair.channel._ import fr.acinq.eclair.gui.controllers._ -import fr.acinq.eclair.payment.PaymentLifecycle.{LocalFailure, PaymentFailed, PaymentSucceeded, RemoteFailure} import fr.acinq.eclair.payment._ import fr.acinq.eclair.router.{NORMAL => _, _} import javafx.application.Platform @@ -35,21 +34,21 @@ import javafx.scene.layout.VBox import scala.collection.JavaConversions._ - /** - * Created by PM on 16/08/2016. - */ + * Created by PM on 16/08/2016. + */ + class GUIUpdater(mainController: MainController) extends Actor with ActorLogging { val STATE_MUTUAL_CLOSE = Set(WAIT_FOR_INIT_INTERNAL, WAIT_FOR_OPEN_CHANNEL, WAIT_FOR_ACCEPT_CHANNEL, WAIT_FOR_FUNDING_INTERNAL, WAIT_FOR_FUNDING_CREATED, WAIT_FOR_FUNDING_SIGNED, NORMAL) val STATE_FORCE_CLOSE = Set(WAIT_FOR_FUNDING_CONFIRMED, WAIT_FOR_FUNDING_LOCKED, NORMAL, SHUTDOWN, NEGOTIATING, OFFLINE, SYNCING) /** - * Needed to stop JavaFX complaining about updates from non GUI thread - */ + * Needed to stop JavaFX complaining about updates from non GUI thread + */ private def runInGuiThread(f: () => Unit): Unit = { Platform.runLater(new Runnable() { - @Override def run() = f() + @Override def run(): Unit = f() }) } @@ -95,7 +94,7 @@ class GUIUpdater(mainController: MainController) extends Actor with ActorLogging }) context.become(main(m1)) - case ShortChannelIdAssigned(channel, channelId, shortChannelId) if m.contains(channel) => + case ShortChannelIdAssigned(channel, _, shortChannelId) if m.contains(channel) => val channelPaneController = m(channel) runInGuiThread(() => channelPaneController.shortChannelId.setText(shortChannelId.toString)) @@ -103,12 +102,12 @@ class GUIUpdater(mainController: MainController) extends Actor with ActorLogging val channelPaneController = m(channel) runInGuiThread(() => channelPaneController.channelId.setText(channelId.toHex)) - case ChannelStateChanged(channel, _, remoteNodeId, _, currentState, currentData) if m.contains(channel) => + case ChannelStateChanged(channel, _, _, _, currentState, currentData) if m.contains(channel) => val channelPaneController = m(channel) runInGuiThread { () => (currentState, currentData) match { case (WAIT_FOR_FUNDING_CONFIRMED, d: HasCommitments) => channelPaneController.txId.setText(d.commitments.commitInput.outPoint.txid.toHex) - case _ => {} + case _ => } channelPaneController.close.setVisible(STATE_MUTUAL_CLOSE.contains(currentState)) channelPaneController.forceclose.setVisible(STATE_FORCE_CLOSE.contains(currentState)) @@ -137,8 +136,8 @@ class GUIUpdater(mainController: MainController) extends Actor with ActorLogging case NodesDiscovered(nodeAnnouncements) => runInGuiThread { () => - nodeAnnouncements.foreach { nodeAnnouncement => - log.debug(s"peer node discovered with node id={}", nodeAnnouncement.nodeId) + nodeAnnouncements.foreach { nodeAnnouncement => + log.debug(s"peer node discovered with node id={}", nodeAnnouncement.nodeId) if (!mainController.networkNodesMap.containsKey(nodeAnnouncement.nodeId)) { mainController.networkNodesMap.put(nodeAnnouncement.nodeId, nodeAnnouncement) m.foreach(f => if (nodeAnnouncement.nodeId.toString.equals(f._2.peerNodeId)) { @@ -149,7 +148,7 @@ class GUIUpdater(mainController: MainController) extends Actor with ActorLogging } case NodeLost(nodeId) => - log.debug(s"peer node lost with node id=${nodeId}") + log.debug(s"peer node lost with node id=$nodeId") runInGuiThread { () => mainController.networkNodesMap.remove(nodeId) } @@ -165,46 +164,42 @@ class GUIUpdater(mainController: MainController) extends Actor with ActorLogging case ChannelsDiscovered(channelsDiscovered) => runInGuiThread { () => - channelsDiscovered.foreach { case SingleChannelDiscovered(channelAnnouncement, capacity) => - log.debug(s"peer channel discovered with channel id={}", channelAnnouncement.shortChannelId) - if (!mainController.networkChannelsMap.containsKey(channelAnnouncement.shortChannelId)) { - mainController.networkChannelsMap.put(channelAnnouncement.shortChannelId, new ChannelInfo(channelAnnouncement, None, None, None, None, capacity, None, None)) - } - } + channelsDiscovered.foreach { case SingleChannelDiscovered(channelAnnouncement, capacity) => + log.debug(s"peer channel discovered with channel id={}", channelAnnouncement.shortChannelId) + if (!mainController.networkChannelsMap.containsKey(channelAnnouncement.shortChannelId)) { + mainController.networkChannelsMap.put(channelAnnouncement.shortChannelId, ChannelInfo(channelAnnouncement, None, None, None, None, capacity, None, None)) } + } + } case ChannelLost(shortChannelId) => - log.debug(s"peer channel lost with channel id=${shortChannelId}") + log.debug(s"peer channel lost with channel id=$shortChannelId") runInGuiThread { () => mainController.networkChannelsMap.remove(shortChannelId) } case ChannelUpdatesReceived(channelUpdates) => runInGuiThread { () => - channelUpdates.foreach { case channelUpdate => - log.debug(s"peer channel with id={} has been updated - flags: {} fees: {} {}", channelUpdate.shortChannelId, channelUpdate.channelFlags, channelUpdate.feeBaseMsat, channelUpdate.feeProportionalMillionths) - if (mainController.networkChannelsMap.containsKey(channelUpdate.shortChannelId)) { - val c = mainController.networkChannelsMap.get(channelUpdate.shortChannelId) - if (Announcements.isNode1(channelUpdate.channelFlags)) { - c.isNode1Enabled = Some(Announcements.isEnabled(channelUpdate.channelFlags)) - c.feeBaseMsatNode1_opt = Some(channelUpdate.feeBaseMsat.toLong) - c.feeProportionalMillionthsNode1_opt = Some(channelUpdate.feeProportionalMillionths) - } else { - c.isNode2Enabled = Some(Announcements.isEnabled(channelUpdate.channelFlags)) - c.feeBaseMsatNode2_opt = Some(channelUpdate.feeBaseMsat.toLong) - c.feeProportionalMillionthsNode2_opt = Some(channelUpdate.feeProportionalMillionths) - } - mainController.networkChannelsMap.put(channelUpdate.shortChannelId, c) - } + channelUpdates.foreach { channelUpdate => + log.debug(s"peer channel with id={} has been updated - flags: {} fees: {} {}", channelUpdate.shortChannelId, channelUpdate.channelFlags, channelUpdate.feeBaseMsat, channelUpdate.feeProportionalMillionths) + if (mainController.networkChannelsMap.containsKey(channelUpdate.shortChannelId)) { + val c = mainController.networkChannelsMap.get(channelUpdate.shortChannelId) + if (Announcements.isNode1(channelUpdate.channelFlags)) { + c.isNode1Enabled = Some(Announcements.isEnabled(channelUpdate.channelFlags)) + c.feeBaseMsatNode1_opt = Some(channelUpdate.feeBaseMsat.toLong) + c.feeProportionalMillionthsNode1_opt = Some(channelUpdate.feeProportionalMillionths) + } else { + c.isNode2Enabled = Some(Announcements.isEnabled(channelUpdate.channelFlags)) + c.feeBaseMsatNode2_opt = Some(channelUpdate.feeBaseMsat.toLong) + c.feeProportionalMillionthsNode2_opt = Some(channelUpdate.feeProportionalMillionths) } + mainController.networkChannelsMap.put(channelUpdate.shortChannelId, c) } - - case p: PaymentSucceeded => - val message = CoinUtils.formatAmountInUnit(p.amount, FxApp.getUnit, withUnit = true) - mainController.handlers.notification("Payment Sent", message, NOTIFICATION_SUCCESS) + } + } case p: PaymentFailed => - val distilledFailures = PaymentLifecycle.transformForUser(p.failures) + val distilledFailures = PaymentFailure.transformForUser(p.failures) val message = s"${distilledFailures.size} attempts:\n${ distilledFailures.map { case LocalFailure(t) => s"- (local) ${t.getMessage}" @@ -216,15 +211,17 @@ class GUIUpdater(mainController: MainController) extends Actor with ActorLogging case p: PaymentSent => log.debug(s"payment sent with h=${p.paymentHash}, amount=${p.amount}, fees=${p.feesPaid}") - runInGuiThread(() => mainController.paymentSentList.prepend(new PaymentSentRecord(p, LocalDateTime.now()))) + val message = CoinUtils.formatAmountInUnit(p.amount + p.feesPaid, FxApp.getUnit, withUnit = true) + mainController.handlers.notification("Payment Sent", message, NOTIFICATION_SUCCESS) + runInGuiThread(() => mainController.paymentSentList.prepend(PaymentSentRecord(p, LocalDateTime.now()))) case p: PaymentReceived => log.debug(s"payment received with h=${p.paymentHash}, amount=${p.amount}") - runInGuiThread(() => mainController.paymentReceivedList.prepend(new PaymentReceivedRecord(p, LocalDateTime.now()))) + runInGuiThread(() => mainController.paymentReceivedList.prepend(PaymentReceivedRecord(p, LocalDateTime.now()))) case p: PaymentRelayed => log.debug(s"payment relayed with h=${p.paymentHash}, amount=${p.amountIn}, feesEarned=${p.amountOut}") - runInGuiThread(() => mainController.paymentRelayedList.prepend(new PaymentRelayedRecord(p, LocalDateTime.now()))) + runInGuiThread(() => mainController.paymentRelayedList.prepend(PaymentRelayedRecord(p, LocalDateTime.now()))) case ZMQConnected => log.debug("ZMQ connection UP") diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/JsonSerializers.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/JsonSerializers.scala index 73c85f0c82..f0cb3d4442 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/JsonSerializers.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/JsonSerializers.scala @@ -25,7 +25,6 @@ import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.{ByteVector32, ByteVector64, OutPoint, Satoshi, Transaction} import fr.acinq.eclair.channel.{ChannelVersion, State} import fr.acinq.eclair.crypto.ShaChain -import fr.acinq.eclair.db.OutgoingPaymentStatus import fr.acinq.eclair.payment.PaymentRequest import fr.acinq.eclair.router.RouteResponse import fr.acinq.eclair.transactions.Direction @@ -185,10 +184,6 @@ class JavaUUIDSerializer extends CustomSerializer[UUID](format => ({ null }, { case id: UUID => JString(id.toString) })) -class OutgoingPaymentStatusSerializer extends CustomSerializer[OutgoingPaymentStatus.Value](format => ({ null }, { - case el: OutgoingPaymentStatus.Value => JString(el.toString) -})) - object JsonSupport extends Json4sSupport { implicit val serialization = jackson.Serialization @@ -221,8 +216,7 @@ object JsonSupport extends Json4sSupport { new NodeAddressSerializer + new DirectionSerializer + new PaymentRequestSerializer + - new JavaUUIDSerializer + - new OutgoingPaymentStatusSerializer + new JavaUUIDSerializer case class CustomTypeHints(custom: Map[Class[_], String]) extends TypeHints { val reverse: Map[String, Class[_]] = custom.map(_.swap) diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala index 20d9609419..23feee8fe6 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/Service.scala @@ -36,8 +36,7 @@ import fr.acinq.bitcoin.{ByteVector32, Satoshi} import fr.acinq.eclair.api.FormParamExtractors._ import fr.acinq.eclair.api.JsonSupport.CustomTypeHints import fr.acinq.eclair.io.NodeURI -import fr.acinq.eclair.payment.PaymentLifecycle.PaymentFailed -import fr.acinq.eclair.payment.{PaymentReceived, PaymentRequest, _} +import fr.acinq.eclair.payment.{PaymentFailed, PaymentReceived, PaymentRequest, _} import fr.acinq.eclair.{CltvExpiryDelta, Eclair, MilliSatoshi} import grizzled.slf4j.Logging import scodec.bits.ByteVector @@ -117,7 +116,7 @@ trait Service extends ExtraDirectives with Logging { .map(TextMessage.apply) } - val timeoutResponse: HttpRequest => HttpResponse = { r => + val timeoutResponse: HttpRequest => HttpResponse = { _ => HttpResponse(StatusCodes.RequestTimeout).withEntity(ContentTypes.`application/json`, serialization.writePretty(ErrorResponse("request timed out"))) } @@ -134,7 +133,7 @@ trait Service extends ExtraDirectives with Logging { toStrictEntity(paramParsingTimeout) { formFields("timeoutSeconds".as[Timeout].?) { tm_opt => // this is the akka timeout - implicit val timeout = tm_opt.getOrElse(Timeout(30 seconds)) + implicit val timeout: Timeout = tm_opt.getOrElse(Timeout(30 seconds)) // we ensure that http timeout is greater than akka timeout withRequestTimeout(timeout.duration + 2.seconds) { withRequestTimeoutResponse(timeoutResponse) { @@ -224,22 +223,24 @@ trait Service extends ExtraDirectives with Logging { } } ~ path("payinvoice") { - formFields(invoiceFormParam, amountMsatFormParam.?, "maxAttempts".as[Int].?, "feeThresholdSat".as[Satoshi].?, "maxFeePct".as[Double].?) { - case (invoice@PaymentRequest(_, Some(amount), _, nodeId, _, _), None, maxAttempts, feeThresholdSat_opt, maxFeePct_opt) => - complete(eclairApi.send(nodeId, amount, invoice.paymentHash, Some(invoice), maxAttempts, feeThresholdSat_opt, maxFeePct_opt)) - case (invoice, Some(overrideAmount), maxAttempts, feeThresholdSat_opt, maxFeePct_opt) => - complete(eclairApi.send(invoice.nodeId, overrideAmount, invoice.paymentHash, Some(invoice), maxAttempts, feeThresholdSat_opt, maxFeePct_opt)) + formFields(invoiceFormParam, amountMsatFormParam.?, "maxAttempts".as[Int].?, "feeThresholdSat".as[Satoshi].?, "maxFeePct".as[Double].?, "externalId".?) { + case (invoice@PaymentRequest(_, Some(amount), _, nodeId, _, _), None, maxAttempts, feeThresholdSat_opt, maxFeePct_opt, externalId_opt) => + complete(eclairApi.send(externalId_opt, nodeId, amount, invoice.paymentHash, Some(invoice), maxAttempts, feeThresholdSat_opt, maxFeePct_opt)) + case (invoice, Some(overrideAmount), maxAttempts, feeThresholdSat_opt, maxFeePct_opt, externalId_opt) => + complete(eclairApi.send(externalId_opt, invoice.nodeId, overrideAmount, invoice.paymentHash, Some(invoice), maxAttempts, feeThresholdSat_opt, maxFeePct_opt)) case _ => reject(MalformedFormFieldRejection("invoice", "The invoice must have an amount or you need to specify one using the field 'amountMsat'")) } } ~ path("sendtonode") { - formFields(amountMsatFormParam, paymentHashFormParam, nodeIdFormParam, "maxAttempts".as[Int].?, "feeThresholdSat".as[Satoshi].?, "maxFeePct".as[Double].?) { (amountMsat, paymentHash, nodeId, maxAttempts_opt, feeThresholdSat_opt, maxFeePct_opt) => - complete(eclairApi.send(nodeId, amountMsat, paymentHash, maxAttempts_opt = maxAttempts_opt, feeThresholdSat_opt = feeThresholdSat_opt, maxFeePct_opt = maxFeePct_opt)) + formFields(amountMsatFormParam, paymentHashFormParam, nodeIdFormParam, "maxAttempts".as[Int].?, "feeThresholdSat".as[Satoshi].?, "maxFeePct".as[Double].?, "externalId".?) { + (amountMsat, paymentHash, nodeId, maxAttempts_opt, feeThresholdSat_opt, maxFeePct_opt, externalId_opt) => + complete(eclairApi.send(externalId_opt, nodeId, amountMsat, paymentHash, maxAttempts_opt = maxAttempts_opt, feeThresholdSat_opt = feeThresholdSat_opt, maxFeePct_opt = maxFeePct_opt)) } } ~ path("sendtoroute") { - formFields(amountMsatFormParam, paymentHashFormParam, "finalCltvExpiry".as[Int], "route".as[List[PublicKey]](pubkeyListUnmarshaller)) { (amountMsat, paymentHash, finalCltvExpiry, route) => - complete(eclairApi.sendToRoute(route, amountMsat, paymentHash, CltvExpiryDelta(finalCltvExpiry))) + formFields(amountMsatFormParam, paymentHashFormParam, "finalCltvExpiry".as[Int], "route".as[List[PublicKey]](pubkeyListUnmarshaller), "externalId".?) { + (amountMsat, paymentHash, finalCltvExpiry, route, externalId_opt) => + complete(eclairApi.sendToRoute(externalId_opt, route, amountMsat, paymentHash, CltvExpiryDelta(finalCltvExpiry))) } } ~ path("getsentinfo") { diff --git a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala index 330b821fd0..83fcc74f0b 100644 --- a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala +++ b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala @@ -32,8 +32,7 @@ import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair._ import fr.acinq.eclair.io.NodeURI import fr.acinq.eclair.io.Peer.PeerInfo -import fr.acinq.eclair.payment.PaymentLifecycle.PaymentFailed -import fr.acinq.eclair.payment._ +import fr.acinq.eclair.payment.{PaymentFailed, _} import fr.acinq.eclair.wire.NodeAddress import org.mockito.scalatest.IdiomaticMockito import org.scalatest.{FunSuite, Matchers} @@ -250,7 +249,7 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock test("'send' method should handle payment failures") { val eclair = mock[Eclair] - eclair.send(any, any, any, any, any, any, any)(any[Timeout]) returns Future.failed(new IllegalArgumentException("invoice has expired")) + eclair.send(any, any, any, any, any, any, any, any)(any[Timeout]) returns Future.failed(new IllegalArgumentException("invoice has expired")) val mockService = new MockService(eclair) val invoice = "lnbc12580n1pw2ywztpp554ganw404sh4yjkwnysgn3wjcxfcq7gtx53gxczkjr9nlpc3hzvqdq2wpskwctddyxqr4rqrzjqwryaup9lh50kkranzgcdnn2fgvx390wgj5jd07rwr3vxeje0glc7z9rtvqqwngqqqqqqqlgqqqqqeqqjqrrt8smgjvfj7sg38dwtr9kc9gg3era9k3t2hvq3cup0jvsrtrxuplevqgfhd3rzvhulgcxj97yjuj8gdx8mllwj4wzjd8gdjhpz3lpqqvk2plh" @@ -263,7 +262,7 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock assert(status == BadRequest) val resp = entityAs[ErrorResponse](Json4sSupport.unmarshaller, ClassTag(classOf[ErrorResponse])) assert(resp.error == "invoice has expired") - eclair.send(any, 1258000 msat, any, any, any, any, any)(any[Timeout]).wasCalled(once) + eclair.send(None, any, 1258000 msat, any, any, any, any, any)(any[Timeout]).wasCalled(once) } } @@ -271,7 +270,7 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock val invoice = "lnbc12580n1pw2ywztpp554ganw404sh4yjkwnysgn3wjcxfcq7gtx53gxczkjr9nlpc3hzvqdq2wpskwctddyxqr4rqrzjqwryaup9lh50kkranzgcdnn2fgvx390wgj5jd07rwr3vxeje0glc7z9rtvqqwngqqqqqqqlgqqqqqeqqjqrrt8smgjvfj7sg38dwtr9kc9gg3era9k3t2hvq3cup0jvsrtrxuplevqgfhd3rzvhulgcxj97yjuj8gdx8mllwj4wzjd8gdjhpz3lpqqvk2plh" val eclair = mock[Eclair] - eclair.send(any, any, any, any, any, any, any)(any[Timeout]) returns Future.successful(UUID.randomUUID()) + eclair.send(any, any, any, any, any, any, any, any)(any[Timeout]) returns Future.successful(UUID.randomUUID()) val mockService = new MockService(eclair) Post("/payinvoice", FormData("invoice" -> invoice).toEntity) ~> @@ -280,18 +279,17 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock check { assert(handled) assert(status == OK) - eclair.send(any, 1258000 msat, any, any, any, any, any)(any[Timeout]).wasCalled(once) + eclair.send(None, any, 1258000 msat, any, any, any, any, any)(any[Timeout]).wasCalled(once) } - Post("/payinvoice", FormData("invoice" -> invoice, "amountMsat" -> "123", "feeThresholdSat" -> "112233", "maxFeePct" -> "2.34").toEntity) ~> + Post("/payinvoice", FormData("invoice" -> invoice, "amountMsat" -> "123", "feeThresholdSat" -> "112233", "maxFeePct" -> "2.34", "externalId" -> "42").toEntity) ~> addCredentials(BasicHttpCredentials("", mockService.password)) ~> Route.seal(mockService.route) ~> check { assert(handled) assert(status == OK) - eclair.send(any, 123 msat, any, any, any, Some(112233 sat), Some(2.34))(any[Timeout]).wasCalled(once) + eclair.send(Some("42"), any, 123 msat, any, any, any, Some(112233 sat), Some(2.34))(any[Timeout]).wasCalled(once) } - } test("'getreceivedinfo' method should respond HTTP 404 with a JSON encoded response if the element is not found") { @@ -314,15 +312,16 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock test("'sendtoroute' method should accept a both a json-encoded AND comma separaterd list of pubkeys") { val rawUUID = "487da196-a4dc-4b1e-92b4-3e5e905e9f3f" val paymentUUID = UUID.fromString(rawUUID) + val externalId = UUID.randomUUID().toString val expectedRoute = List(PublicKey(hex"0217eb8243c95f5a3b7d4c5682d10de354b7007eb59b6807ae407823963c7547a9"), PublicKey(hex"0242a4ae0c5bef18048fbecf995094b74bfb0f7391418d71ed394784373f41e4f3"), PublicKey(hex"026ac9fcd64fb1aa1c491fc490634dc33da41d4a17b554e0adf1b32fee88ee9f28")) val csvNodes = "0217eb8243c95f5a3b7d4c5682d10de354b7007eb59b6807ae407823963c7547a9, 0242a4ae0c5bef18048fbecf995094b74bfb0f7391418d71ed394784373f41e4f3, 026ac9fcd64fb1aa1c491fc490634dc33da41d4a17b554e0adf1b32fee88ee9f28" val jsonNodes = serialization.write(expectedRoute) val eclair = mock[Eclair] - eclair.sendToRoute(any[List[PublicKey]], any[MilliSatoshi], any[ByteVector32], any[CltvExpiryDelta])(any[Timeout]) returns Future.successful(paymentUUID) + eclair.sendToRoute(any[Option[String]], any[List[PublicKey]], any[MilliSatoshi], any[ByteVector32], any[CltvExpiryDelta])(any[Timeout]) returns Future.successful(paymentUUID) val mockService = new MockService(eclair) - Post("/sendtoroute", FormData("route" -> jsonNodes, "amountMsat" -> "1234", "paymentHash" -> ByteVector32.Zeroes.toHex, "finalCltvExpiry" -> "190").toEntity) ~> + Post("/sendtoroute", FormData("route" -> jsonNodes, "amountMsat" -> "1234", "paymentHash" -> ByteVector32.Zeroes.toHex, "finalCltvExpiry" -> "190", "externalId" -> externalId.toString).toEntity) ~> addCredentials(BasicHttpCredentials("", mockService.password)) ~> addHeader("Content-Type", "application/json") ~> Route.seal(mockService.route) ~> @@ -330,7 +329,7 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock assert(handled) assert(status == OK) assert(entityAs[String] == "\"" + rawUUID + "\"") - eclair.sendToRoute(expectedRoute, 1234 msat, ByteVector32.Zeroes, CltvExpiryDelta(190))(any[Timeout]).wasCalled(once) + eclair.sendToRoute(Some(externalId), expectedRoute, 1234 msat, ByteVector32.Zeroes, CltvExpiryDelta(190))(any[Timeout]).wasCalled(once) } // this test uses CSV encoded route @@ -342,7 +341,7 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock assert(handled) assert(status == OK) assert(entityAs[String] == "\"" + rawUUID + "\"") - eclair.sendToRoute(expectedRoute, 1234 msat, ByteVector32.One, CltvExpiryDelta(190))(any[Timeout]).wasCalled(once) + eclair.sendToRoute(None, expectedRoute, 1234 msat, ByteVector32.One, CltvExpiryDelta(190))(any[Timeout]).wasCalled(once) } } @@ -357,14 +356,14 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock mockService.route ~> check { - val pf = PaymentFailed(fixedUUID, ByteVector32.Zeroes, failures = Seq.empty) - val expectedSerializedPf = """{"type":"payment-failed","id":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","failures":[]}""" + val pf = PaymentFailed(fixedUUID, ByteVector32.Zeroes, failures = Seq.empty, timestamp = 1553784963659L) + val expectedSerializedPf = """{"type":"payment-failed","id":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","failures":[],"timestamp":1553784963659}""" serialization.write(pf)(mockService.formatsWithTypeHint) === expectedSerializedPf system.eventStream.publish(pf) wsClient.expectMessage(expectedSerializedPf) - val ps = PaymentSent(fixedUUID, amount = 21 msat, feesPaid = 1 msat, paymentHash = ByteVector32.Zeroes, paymentPreimage = ByteVector32.One, toChannelId = ByteVector32.Zeroes, timestamp = 1553784337711L) - val expectedSerializedPs = """{"type":"payment-sent","id":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","amount":21,"feesPaid":1,"paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","paymentPreimage":"0100000000000000000000000000000000000000000000000000000000000000","toChannelId":"0000000000000000000000000000000000000000000000000000000000000000","timestamp":1553784337711}""" + val ps = PaymentSent(fixedUUID, ByteVector32.Zeroes, ByteVector32.One, Seq(PaymentSent.PartialPayment(fixedUUID, 21 msat, 1 msat, ByteVector32.Zeroes, None, 1553784337711L))) + val expectedSerializedPs = """{"type":"payment-sent","id":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","paymentPreimage":"0100000000000000000000000000000000000000000000000000000000000000","parts":[{"id":"487da196-a4dc-4b1e-92b4-3e5e905e9f3f","amount":21,"feesPaid":1,"toChannelId":"0000000000000000000000000000000000000000000000000000000000000000","timestamp":1553784337711}]}""" serialization.write(ps)(mockService.formatsWithTypeHint) === expectedSerializedPs system.eventStream.publish(ps) wsClient.expectMessage(expectedSerializedPs) @@ -375,8 +374,8 @@ class ApiServiceSpec extends FunSuite with ScalatestRouteTest with IdiomaticMock system.eventStream.publish(prel) wsClient.expectMessage(expectedSerializedPrel) - val precv = PaymentReceived(amount = 21 msat, paymentHash = ByteVector32.Zeroes, fromChannelId = ByteVector32.One, timestamp = 1553784963659L) - val expectedSerializedPrecv = """{"type":"payment-received","amount":21,"paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","fromChannelId":"0100000000000000000000000000000000000000000000000000000000000000","timestamp":1553784963659}""" + val precv = PaymentReceived(ByteVector32.Zeroes, Seq(PaymentReceived.PartialPayment(21 msat, ByteVector32.Zeroes, 1553784963659L))) + val expectedSerializedPrecv = """{"type":"payment-received","paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","parts":[{"amount":21,"fromChannelId":"0000000000000000000000000000000000000000000000000000000000000000","timestamp":1553784963659}]}""" serialization.write(precv)(mockService.formatsWithTypeHint) === expectedSerializedPrecv system.eventStream.publish(precv) wsClient.expectMessage(expectedSerializedPrecv)