From 84ec63da9114afd3522680b26ea61bf09c474898 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Fri, 12 Jul 2019 11:46:08 +0200 Subject: [PATCH 01/11] Add onion tlv types. Add onion payload decoder for both legacy and tlv. Rename fields for clarity. --- .../eclair/payment/PaymentLifecycle.scala | 10 +-- .../fr/acinq/eclair/payment/Relayer.scala | 10 +-- .../scala/fr/acinq/eclair/wire/Onion.scala | 78 +++++++++++++++---- .../fr/acinq/eclair/wire/TlvCodecs.scala | 8 ++ .../scala/fr/acinq/eclair/wire/TlvTypes.scala | 3 - .../eclair/payment/ChannelSelectionSpec.scala | 6 +- .../eclair/payment/HtlcGenerationSpec.scala | 28 +++---- .../acinq/eclair/wire/OnionCodecsSpec.scala | 62 +++++++++++---- .../fr/acinq/eclair/wire/TlvCodecsSpec.scala | 11 +++ 9 files changed, 159 insertions(+), 57 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala index 839859bf2b..132884d7fe 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala @@ -231,11 +231,11 @@ object PaymentLifecycle { // @formatter:on - def buildOnion(nodes: Seq[PublicKey], payloads: Seq[PerHopPayload], associatedData: ByteVector32): Sphinx.PacketAndSecrets = { + def buildOnion(nodes: Seq[PublicKey], payloads: Seq[OnionForwardInfo], associatedData: ByteVector32): Sphinx.PacketAndSecrets = { require(nodes.size == payloads.size) val sessionKey = randomKey val payloadsbin: Seq[ByteVector] = payloads - .map(OnionCodecs.perHopPayloadCodec.encode) + .map(OnionCodecs.legacyPerHopPayloadCodec.encode) .map { case Attempt.Successful(bitVector) => bitVector.toByteVector case Attempt.Failure(cause) => throw new RuntimeException(s"serialization error: $cause") @@ -253,11 +253,11 @@ object PaymentLifecycle { * - firstExpiry is the cltv expiry for the first htlc in the route * - a sequence of payloads that will be used to build the onion */ - def buildPayloads(finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, hops: Seq[Hop]): (MilliSatoshi, CltvExpiry, Seq[PerHopPayload]) = - hops.reverse.foldLeft((finalAmount, finalExpiry, PerHopPayload(ShortChannelId(0L), finalAmount, finalExpiry) :: Nil)) { + def buildPayloads(finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, hops: Seq[Hop]): (MilliSatoshi, CltvExpiry, Seq[OnionForwardInfo]) = + hops.reverse.foldLeft((finalAmount, finalExpiry, OnionForwardInfo(ShortChannelId(0L), finalAmount, finalExpiry) :: Nil)) { case ((msat, expiry, payloads), hop) => val nextFee = nodeFee(hop.lastUpdate.feeBaseMsat, hop.lastUpdate.feeProportionalMillionths, msat) - (msat + nextFee, expiry + hop.lastUpdate.cltvExpiryDelta, PerHopPayload(hop.lastUpdate.shortChannelId, msat, expiry) +: payloads) + (msat + nextFee, expiry + hop.lastUpdate.cltvExpiryDelta, OnionForwardInfo(hop.lastUpdate.shortChannelId, msat, expiry) +: payloads) } def buildCommand(id: UUID, finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, paymentHash: ByteVector32, hops: Seq[Hop]): (CMD_ADD_HTLC, Seq[(ByteVector32, PublicKey)]) = { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala index 673da84e34..0b58ab74ab 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala @@ -213,8 +213,8 @@ object Relayer extends Logging { // @formatter:off sealed trait NextPayload - case class FinalPayload(add: UpdateAddHtlc, payload: PerHopPayload) extends NextPayload - case class RelayPayload(add: UpdateAddHtlc, payload: PerHopPayload, nextPacket: OnionRoutingPacket) extends NextPayload { + case class FinalPayload(add: UpdateAddHtlc, payload: OnionForwardInfo) extends NextPayload + case class RelayPayload(add: UpdateAddHtlc, payload: OnionForwardInfo, nextPacket: OnionRoutingPacket) extends NextPayload { val relayFeeMsat: MilliSatoshi = add.amountMsat - payload.amtToForward val expiryDelta: CltvExpiryDelta = add.cltvExpiry - payload.outgoingCltvValue } @@ -231,7 +231,7 @@ object Relayer extends Logging { def decryptPacket(add: UpdateAddHtlc, privateKey: PrivateKey): Either[BadOnion, NextPayload] = Sphinx.PaymentPacket.peel(privateKey, add.paymentHash, add.onionRoutingPacket) match { case Right(p@Sphinx.DecryptedPacket(payload, nextPacket, _)) => - OnionCodecs.perHopPayloadCodec.decode(payload.bits) match { + OnionCodecs.legacyPerHopPayloadCodec.decode(payload.bits) match { case Attempt.Successful(DecodeResult(perHopPayload, remainder)) => if (remainder.nonEmpty) { logger.warn(s"${remainder.length} bits remaining after per-hop payload decoding: there might be an issue with the onion codec") @@ -259,9 +259,9 @@ object Relayer extends Logging { def handleFinal(finalPayload: FinalPayload): Either[CMD_FAIL_HTLC, UpdateAddHtlc] = { import finalPayload.add finalPayload.payload match { - case PerHopPayload(_, finalAmountToForward, _) if finalAmountToForward > add.amountMsat => + case OnionForwardInfo(_, finalAmountToForward, _) if finalAmountToForward > add.amountMsat => Left(CMD_FAIL_HTLC(add.id, Right(FinalIncorrectHtlcAmount(add.amountMsat)), commit = true)) - case PerHopPayload(_, _, finalOutgoingCltvValue) if finalOutgoingCltvValue != add.cltvExpiry => + case OnionForwardInfo(_, _, finalOutgoingCltvValue) if finalOutgoingCltvValue != add.cltvExpiry => Left(CMD_FAIL_HTLC(add.id, Right(FinalIncorrectCltvExpiry(add.cltvExpiry)), commit = true)) case _ => Right(add) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala index f1737eb606..0d4fa62d44 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala @@ -18,8 +18,11 @@ package fr.acinq.eclair.wire import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.{CltvExpiry, MilliSatoshi, ShortChannelId} -import scodec.bits.{BitVector, ByteVector} +import fr.acinq.eclair.wire.CommonCodecs._ +import fr.acinq.eclair.wire.OnionTlv._ +import fr.acinq.eclair.wire.TlvCodecs._ +import fr.acinq.eclair.{CltvExpiry, MilliSatoshi, ShortChannelId, UInt64} +import scodec.bits.{BitVector, ByteVector, HexStringSyntax} import scodec.codecs._ import scodec.{Codec, DecodeResult, Decoder} @@ -32,9 +35,38 @@ case class OnionRoutingPacket(version: Int, payload: ByteVector, hmac: ByteVector32) -case class PerHopPayload(shortChannelId: ShortChannelId, - amtToForward: MilliSatoshi, - outgoingCltvValue: CltvExpiry) +case class OnionForwardInfo(shortChannelId: ShortChannelId, + amtToForward: MilliSatoshi, + outgoingCltvValue: CltvExpiry) + +/** + * Tlv types used inside onion messages. + */ +sealed trait OnionTlv extends Tlv + +object OnionTlv { + + /** + * If this record is present in an onion payload, the current node is the final destination of the onion message. + */ + case class Destination() extends OnionTlv + + /** + * Amount to forward to the next node. + */ + case class AmountToForward(amount: MilliSatoshi) extends OnionTlv + + /** + * CLTV value to use for the HTLC offered to the next node. + */ + case class OutgoingCltv(cltv: CltvExpiry) extends OnionTlv + + /** + * Id of the channel to use to forward a payment to the next node. + */ + case class OutgoingChannelId(shortChannelId: ShortChannelId) extends OnionTlv + +} object OnionCodecs { @@ -42,17 +74,10 @@ object OnionCodecs { ("version" | uint8) :: ("publicKey" | bytes(33)) :: ("onionPayload" | bytes(payloadLength)) :: - ("hmac" | CommonCodecs.bytes32)).as[OnionRoutingPacket] + ("hmac" | bytes32)).as[OnionRoutingPacket] val paymentOnionPacketCodec: Codec[OnionRoutingPacket] = onionRoutingPacketCodec(Sphinx.PaymentPacket.PayloadLength) - val perHopPayloadCodec: Codec[PerHopPayload] = ( - ("realm" | constant(ByteVector.fromByte(0))) :: - ("short_channel_id" | CommonCodecs.shortchannelid) :: - ("amt_to_forward" | CommonCodecs.millisatoshi) :: - ("outgoing_cltv_value" | CommonCodecs.cltvExpiry) :: - ("unused_with_v0_version_on_header" | ignore(8 * 12))).as[PerHopPayload] - /** * The 1.1 BOLT spec changed the onion frame format to use variable-length per-hop payloads. * The first bytes contain a varint encoding the length of the payload data (not including the trailing mac). @@ -60,6 +85,31 @@ object OnionCodecs { * the varint prefix. */ val payloadLengthDecoder = Decoder[Long]((bits: BitVector) => - CommonCodecs.varintoverflow.decode(bits).map(d => DecodeResult(d.value + (bits.length - d.remainder.length) / 8, d.remainder))) + varintoverflow.decode(bits).map(d => DecodeResult(d.value + (bits.length - d.remainder.length) / 8, d.remainder))) + + private val destination: Codec[Destination] = ("length" | constant(hex"00")).xmap(_ => Destination(), _ => ()) + + private val amountToForward: Codec[AmountToForward] = ("amount_msat" | tu64overflow).xmap(amountMsat => AmountToForward(MilliSatoshi(amountMsat)), (a: AmountToForward) => a.amount.toLong) + + private val outgoingCltv: Codec[OutgoingCltv] = ("cltv" | tu32).xmap(cltv => OutgoingCltv(CltvExpiry(cltv)), (c: OutgoingCltv) => c.cltv.toLong) + + private val outgoingChannelId: Codec[OutgoingChannelId] = (("length" | constant(hex"08")) :: ("short_channel_id" | shortchannelid)).as[OutgoingChannelId] + + private val onionTlvCodec = discriminated[OnionTlv].by(varint) + .typecase(UInt64(0), destination) + .typecase(UInt64(2), amountToForward) + .typecase(UInt64(4), outgoingCltv) + .typecase(UInt64(6), outgoingChannelId) + + val tlvPerHopPayloadCodec: Codec[TlvStream[OnionTlv]] = TlvCodecs.lengthPrefixedTlvStream[OnionTlv](onionTlvCodec).complete + + val legacyPerHopPayloadCodec: Codec[OnionForwardInfo] = ( + ("realm" | constant(ByteVector.fromByte(0))) :: + ("short_channel_id" | shortchannelid) :: + ("amt_to_forward" | millisatoshi) :: + ("outgoing_cltv_value" | cltvExpiry) :: + ("unused_with_v0_version_on_header" | ignore(8 * 12))).as[OnionForwardInfo] + + val perHopPayloadCodec: Codec[Either[TlvStream[OnionTlv], OnionForwardInfo]] = fallback(tlvPerHopPayloadCodec, legacyPerHopPayloadCodec) } \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala index 664f7182ca..0f56ba666d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala @@ -44,6 +44,14 @@ object TlvCodecs { .\(0x07) { case i if i < 0x0100000000000000L => i }(variableSizeUInt64(7, 0x01000000000000L)) .\(0x08) { case i if i <= UInt64.MaxValue => i }(variableSizeUInt64(8, 0x0100000000000000L)) + /** + * Length-prefixed truncated long (1 to 9 bytes unsigned integer). + * This codec can be safely used for values < `2^63` and will fail otherwise. + */ + val tu64overflow: Codec[Long] = tu64.exmap( + u => if (u <= Long.MaxValue) Attempt.Successful(u.toBigInt.toLong) else Attempt.Failure(Err(s"overflow for value $u")), + l => if (l >= 0) Attempt.Successful(UInt64(l)) else Attempt.Failure(Err(s"uint64 must be positive (actual=$l)"))) + /** * Length-prefixed truncated uint32 (1 to 5 bytes unsigned integer). */ diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala index 54cb65c948..3ea3d85c21 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala @@ -25,10 +25,7 @@ import scala.reflect.ClassTag * Created by t-bast on 20/06/2019. */ -// @formatter:off trait Tlv -sealed trait OnionTlv extends Tlv -// @formatter:on /** * Generic tlv type we fallback to if we don't understand the incoming tlv. diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/ChannelSelectionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/ChannelSelectionSpec.scala index 0599412083..4238a7008c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/ChannelSelectionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/ChannelSelectionSpec.scala @@ -39,7 +39,7 @@ class ChannelSelectionSpec extends FunSuite { test("convert to CMD_FAIL_HTLC/CMD_ADD_HTLC") { val relayPayload = RelayPayload( add = UpdateAddHtlc(randomBytes32, 42, 1000000 msat, randomBytes32, CltvExpiry(70), TestConstants.emptyOnionPacket), - payload = PerHopPayload(ShortChannelId(12345), amtToForward = 998900 msat, outgoingCltvValue = CltvExpiry(60)), + payload = OnionForwardInfo(ShortChannelId(12345), amtToForward = 998900 msat, outgoingCltvValue = CltvExpiry(60)), nextPacket = TestConstants.emptyOnionPacket // just a placeholder ) @@ -52,7 +52,7 @@ class ChannelSelectionSpec extends FunSuite { // no channel_update assert(Relayer.relayOrFail(relayPayload, channelUpdate_opt = None) === RelayFailure(CMD_FAIL_HTLC(relayPayload.add.id, Right(UnknownNextPeer), commit = true))) // channel disabled - val channelUpdate_disabled = channelUpdate.copy(channelFlags = Announcements.makeChannelFlags(true, enable = false)) + val channelUpdate_disabled = channelUpdate.copy(channelFlags = Announcements.makeChannelFlags(isNode1 = true, enable = false)) assert(Relayer.relayOrFail(relayPayload, Some(channelUpdate_disabled)) === RelayFailure(CMD_FAIL_HTLC(relayPayload.add.id, Right(ChannelDisabled(channelUpdate_disabled.messageFlags, channelUpdate_disabled.channelFlags, channelUpdate_disabled)), commit = true))) // amount too low val relayPayload_toolow = relayPayload.copy(payload = relayPayload.payload.copy(amtToForward = 99 msat)) @@ -72,7 +72,7 @@ class ChannelSelectionSpec extends FunSuite { val relayPayload = RelayPayload( add = UpdateAddHtlc(randomBytes32, 42, 1000000 msat, randomBytes32, CltvExpiry(70), TestConstants.emptyOnionPacket), - payload = PerHopPayload(ShortChannelId(12345), amtToForward = 998900 msat, outgoingCltvValue = CltvExpiry(60)), + payload = OnionForwardInfo(ShortChannelId(12345), amtToForward = 998900 msat, outgoingCltvValue = CltvExpiry(60)), nextPacket = TestConstants.emptyOnionPacket // just a placeholder ) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala index 0a82253054..cf9e157137 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala @@ -25,7 +25,7 @@ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.crypto.Sphinx.{DecryptedPacket, PacketAndSecrets} import fr.acinq.eclair.payment.PaymentLifecycle._ import fr.acinq.eclair.router.Hop -import fr.acinq.eclair.wire.{ChannelUpdate, OnionCodecs, PerHopPayload} +import fr.acinq.eclair.wire.{ChannelUpdate, OnionCodecs, OnionForwardInfo} import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, MilliSatoshi, ShortChannelId, TestConstants, nodeFee, randomBytes32} import org.scalatest.FunSuite import scodec.bits.ByteVector @@ -55,10 +55,10 @@ class HtlcGenerationSpec extends FunSuite { assert(firstAmountMsat === amount_ab) assert(firstExpiry === expiry_ab) assert(payloads === - PerHopPayload(channelUpdate_bc.shortChannelId, amount_bc, expiry_bc) :: - PerHopPayload(channelUpdate_cd.shortChannelId, amount_cd, expiry_cd) :: - PerHopPayload(channelUpdate_de.shortChannelId, amount_de, expiry_de) :: - PerHopPayload(ShortChannelId(0L), finalAmountMsat, finalExpiry) :: Nil) + OnionForwardInfo(channelUpdate_bc.shortChannelId, amount_bc, expiry_bc) :: + OnionForwardInfo(channelUpdate_cd.shortChannelId, amount_cd, expiry_cd) :: + OnionForwardInfo(channelUpdate_de.shortChannelId, amount_de, expiry_de) :: + OnionForwardInfo(ShortChannelId(0L), finalAmountMsat, finalExpiry) :: Nil) } test("build onion") { @@ -70,25 +70,25 @@ class HtlcGenerationSpec extends FunSuite { // let's peel the onion val Right(DecryptedPacket(bin_b, packet_c, _)) = Sphinx.PaymentPacket.peel(priv_b.privateKey, paymentHash, packet_b) - val payload_b = OnionCodecs.perHopPayloadCodec.decode(bin_b.toBitVector).require.value + val payload_b = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_b.toBitVector).require.value assert(packet_c.payload.length === Sphinx.PaymentPacket.PayloadLength) assert(payload_b.amtToForward === amount_bc) assert(payload_b.outgoingCltvValue === expiry_bc) val Right(DecryptedPacket(bin_c, packet_d, _)) = Sphinx.PaymentPacket.peel(priv_c.privateKey, paymentHash, packet_c) - val payload_c = OnionCodecs.perHopPayloadCodec.decode(bin_c.toBitVector).require.value + val payload_c = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_c.toBitVector).require.value assert(packet_d.payload.length === Sphinx.PaymentPacket.PayloadLength) assert(payload_c.amtToForward === amount_cd) assert(payload_c.outgoingCltvValue === expiry_cd) val Right(DecryptedPacket(bin_d, packet_e, _)) = Sphinx.PaymentPacket.peel(priv_d.privateKey, paymentHash, packet_d) - val payload_d = OnionCodecs.perHopPayloadCodec.decode(bin_d.toBitVector).require.value + val payload_d = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_d.toBitVector).require.value assert(packet_e.payload.length === Sphinx.PaymentPacket.PayloadLength) assert(payload_d.amtToForward === amount_de) assert(payload_d.outgoingCltvValue === expiry_de) val Right(DecryptedPacket(bin_e, packet_random, _)) = Sphinx.PaymentPacket.peel(priv_e.privateKey, paymentHash, packet_e) - val payload_e = OnionCodecs.perHopPayloadCodec.decode(bin_e.toBitVector).require.value + val payload_e = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_e.toBitVector).require.value assert(packet_random.payload.length === Sphinx.PaymentPacket.PayloadLength) assert(payload_e.amtToForward === finalAmountMsat) assert(payload_e.outgoingCltvValue === finalExpiry) @@ -105,25 +105,25 @@ class HtlcGenerationSpec extends FunSuite { // let's peel the onion val Right(DecryptedPacket(bin_b, packet_c, _)) = Sphinx.PaymentPacket.peel(priv_b.privateKey, paymentHash, add.onion) - val payload_b = OnionCodecs.perHopPayloadCodec.decode(bin_b.toBitVector).require.value + val payload_b = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_b.toBitVector).require.value assert(packet_c.payload.length === Sphinx.PaymentPacket.PayloadLength) assert(payload_b.amtToForward === amount_bc) assert(payload_b.outgoingCltvValue === expiry_bc) val Right(DecryptedPacket(bin_c, packet_d, _)) = Sphinx.PaymentPacket.peel(priv_c.privateKey, paymentHash, packet_c) - val payload_c = OnionCodecs.perHopPayloadCodec.decode(bin_c.toBitVector).require.value + val payload_c = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_c.toBitVector).require.value assert(packet_d.payload.length === Sphinx.PaymentPacket.PayloadLength) assert(payload_c.amtToForward === amount_cd) assert(payload_c.outgoingCltvValue === expiry_cd) val Right(DecryptedPacket(bin_d, packet_e, _)) = Sphinx.PaymentPacket.peel(priv_d.privateKey, paymentHash, packet_d) - val payload_d = OnionCodecs.perHopPayloadCodec.decode(bin_d.toBitVector).require.value + val payload_d = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_d.toBitVector).require.value assert(packet_e.payload.length === Sphinx.PaymentPacket.PayloadLength) assert(payload_d.amtToForward === amount_de) assert(payload_d.outgoingCltvValue === expiry_de) val Right(DecryptedPacket(bin_e, packet_random, _)) = Sphinx.PaymentPacket.peel(priv_e.privateKey, paymentHash, packet_e) - val payload_e = OnionCodecs.perHopPayloadCodec.decode(bin_e.toBitVector).require.value + val payload_e = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_e.toBitVector).require.value assert(packet_random.payload.length === Sphinx.PaymentPacket.PayloadLength) assert(payload_e.amtToForward === finalAmountMsat) assert(payload_e.outgoingCltvValue === finalExpiry) @@ -139,7 +139,7 @@ class HtlcGenerationSpec extends FunSuite { // let's peel the onion val Right(DecryptedPacket(bin_b, packet_random, _)) = Sphinx.PaymentPacket.peel(priv_b.privateKey, paymentHash, add.onion) - val payload_b = OnionCodecs.perHopPayloadCodec.decode(bin_b.toBitVector).require.value + val payload_b = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_b.toBitVector).require.value assert(packet_random.payload.length === Sphinx.PaymentPacket.PayloadLength) assert(payload_b.amtToForward === finalAmountMsat) assert(payload_b.outgoingCltvValue === finalExpiry) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala index 272b85fc26..e885ff0cf8 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala @@ -17,8 +17,10 @@ package fr.acinq.eclair.wire import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.eclair.UInt64.Conversions._ import fr.acinq.eclair.wire.OnionCodecs._ -import fr.acinq.eclair.{CltvExpiry, LongToBtcAmount, ShortChannelId} +import fr.acinq.eclair.wire.OnionTlv._ +import fr.acinq.eclair.{CltvExpiry, LongToBtcAmount, MilliSatoshi, ShortChannelId} import org.scalatest.FunSuite import scodec.bits.HexStringSyntax @@ -39,18 +41,19 @@ class OnionCodecsSpec extends FunSuite { assert(encoded.toByteVector === bin) } - test("encode/decode per-hop payload") { - val payload = PerHopPayload(shortChannelId = ShortChannelId(42), amtToForward = 142000 msat, outgoingCltvValue = CltvExpiry(500000)) - val bin = perHopPayloadCodec.encode(payload).require - assert(bin.toByteVector.size === 33) - val payload1 = perHopPayloadCodec.decode(bin).require.value - assert(payload === payload1) - - // realm (the first byte) should be 0 - val bin1 = bin.toByteVector.update(0, 1) - intercept[IllegalArgumentException] { - val payload2 = perHopPayloadCodec.decode(bin1.bits).require.value - assert(payload2 === payload1) + test("encode/decode fixed-size (legacy) per-hop payload") { + val testCases = Map( + OnionForwardInfo(ShortChannelId(0), 0 msat, CltvExpiry(0)) -> hex"00 0000000000000000 0000000000000000 00000000 000000000000000000000000", + OnionForwardInfo(ShortChannelId(42), 142000 msat, CltvExpiry(500000)) -> hex"00 000000000000002a 0000000000022ab0 0007a120 000000000000000000000000", + OnionForwardInfo(ShortChannelId(561), 1105 msat, CltvExpiry(1729)) -> hex"00 0000000000000231 0000000000000451 000006c1 000000000000000000000000" + ) + + for ((expected, bin) <- testCases) { + val decoded = perHopPayloadCodec.decode(bin.bits).require.value + assert(decoded === Right(expected)) + + val encoded = perHopPayloadCodec.encode(Right(expected)).require.bytes + assert(encoded === bin) } } @@ -71,4 +74,37 @@ class OnionCodecsSpec extends FunSuite { } } + test("encode/decode variable-length (tlv) per-hop payload") { + val testCases = Map( + TlvStream[OnionTlv](Destination()) -> hex"02 0000", + TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), OutgoingChannelId(ShortChannelId(1105))) -> hex"11 02020231 04012a 06080000000000000451", + TlvStream[OnionTlv](Seq(Destination(), AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42))), Seq(GenericTlv(65535, hex"06c1"))) -> hex"0f 0000 02020231 04012a fdffff0206c1" + ) + + for ((expected, bin) <- testCases) { + val decoded = perHopPayloadCodec.decode(bin.bits).require.value + assert(decoded === Left(expected)) + + val encoded = perHopPayloadCodec.encode(Left(expected)).require.bytes + assert(encoded === bin) + } + } + + test("decode invalid per-hop payload") { + val testCases = Seq( + // Invalid fixed-size (legacy) payload. + hex"00 000000000000002a 000000000000002a", // invalid length + // Invalid variable-length (tlv) payload. + hex"01", // invalid length + hex"01 0000", // invalid length + hex"04 0000 2a00", // unknown even types + hex"04 0000 0000", // duplicate types + hex"04 0100 0000" // unordered types + ) + + for (testCase <- testCases) { + assert(perHopPayloadCodec.decode(testCase.bits).isFailure) + } + } + } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala index c21033085e..aaf5e4d1aa 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala @@ -116,6 +116,17 @@ class TlvCodecsSpec extends FunSuite { } } + test("encode/decode truncated uint64 overflow") { + assert(tu64overflow.encode(Long.MaxValue).require.toByteVector === hex"087fffffffffffffff") + assert(tu64overflow.decode(hex"087fffffffffffffff".bits).require.value === Long.MaxValue) + + assert(tu64overflow.encode(42L).require.toByteVector === hex"012a") + assert(tu64overflow.decode(hex"012a".bits).require.value === 42L) + + assert(tu64overflow.encode(-1L).isFailure) + assert(tu64overflow.decode(hex"088000000000000000".bits).isFailure) + } + test("decode invalid truncated integers") { val testCases = Seq( (tu16, hex"01 00"), // not minimal From 01b6cde245b595e25b4c7a52fa7033f4d29b718a Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Tue, 23 Jul 2019 15:34:24 +0200 Subject: [PATCH 02/11] Relay onions using tlv payloads. Handle tlv payloads at the final node. --- .../eclair/payment/PaymentLifecycle.scala | 38 ++++--- .../fr/acinq/eclair/payment/Relayer.scala | 22 ++-- .../scala/fr/acinq/eclair/wire/Onion.scala | 50 ++++++--- .../eclair/payment/HtlcGenerationSpec.scala | 45 ++++++-- .../fr/acinq/eclair/payment/RelayerSpec.scala | 104 +++++++++++++++++- .../acinq/eclair/wire/OnionCodecsSpec.scala | 56 ++++++++-- 6 files changed, 262 insertions(+), 53 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala index 132884d7fe..eab6c11fb5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala @@ -28,6 +28,7 @@ import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus} import fr.acinq.eclair.payment.PaymentLifecycle._ import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.router._ +import fr.acinq.eclair.wire.OnionPerHopPayload._ import fr.acinq.eclair.wire._ import scodec.Attempt import scodec.bits.ByteVector @@ -218,6 +219,10 @@ object PaymentLifecycle { case class UnreadableRemoteFailure(route: Seq[Hop]) extends PaymentFailure case class PaymentFailed(id: UUID, paymentHash: ByteVector32, failures: Seq[PaymentFailure]) extends PaymentResult + sealed trait PaymentOptions + case object LegacyPayload extends PaymentOptions + case object TlvPayload extends PaymentOptions + sealed trait Data case object WaitingForRequest extends Data case class WaitingForRoute(sender: ActorRef, c: SendPayment, failures: Seq[PaymentFailure]) extends Data @@ -227,41 +232,48 @@ object PaymentLifecycle { case object WAITING_FOR_REQUEST extends State case object WAITING_FOR_ROUTE extends State case object WAITING_FOR_PAYMENT_COMPLETE extends State - // @formatter:on - - def buildOnion(nodes: Seq[PublicKey], payloads: Seq[OnionForwardInfo], associatedData: ByteVector32): Sphinx.PacketAndSecrets = { + def buildOnion(nodes: Seq[PublicKey], payloads: Seq[OnionPerHopPayload], associatedData: ByteVector32): Sphinx.PacketAndSecrets = { require(nodes.size == payloads.size) val sessionKey = randomKey - val payloadsbin: Seq[ByteVector] = payloads - .map(OnionCodecs.legacyPerHopPayloadCodec.encode) + val payloadsBin: Seq[ByteVector] = payloads + .map(OnionCodecs.perHopPayloadCodec.encode) .map { case Attempt.Successful(bitVector) => bitVector.toByteVector case Attempt.Failure(cause) => throw new RuntimeException(s"serialization error: $cause") } - Sphinx.PaymentPacket.create(sessionKey, nodes, payloadsbin, associatedData) + Sphinx.PaymentPacket.create(sessionKey, nodes, payloadsBin, associatedData) } /** + * Build the onion payloads for each hop. * * @param finalAmount the final htlc amount in millisatoshis * @param finalExpiry the final htlc expiry in number of blocks * @param hops the hops as computed by the router + extra routes from payment request + * @param opts options to help build each hop's payload * @return a (firstAmountMsat, firstExpiry, payloads) tuple where: * - firstAmountMsat is the amount for the first htlc in the route * - firstExpiry is the cltv expiry for the first htlc in the route * - a sequence of payloads that will be used to build the onion */ - def buildPayloads(finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, hops: Seq[Hop]): (MilliSatoshi, CltvExpiry, Seq[OnionForwardInfo]) = - hops.reverse.foldLeft((finalAmount, finalExpiry, OnionForwardInfo(ShortChannelId(0L), finalAmount, finalExpiry) :: Nil)) { - case ((msat, expiry, payloads), hop) => - val nextFee = nodeFee(hop.lastUpdate.feeBaseMsat, hop.lastUpdate.feeProportionalMillionths, msat) - (msat + nextFee, expiry + hop.lastUpdate.cltvExpiryDelta, OnionForwardInfo(hop.lastUpdate.shortChannelId, msat, expiry) +: payloads) + def buildPayloads(finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, hops: Seq[Hop], opts: PaymentOptions = LegacyPayload): (MilliSatoshi, CltvExpiry, Seq[OnionPerHopPayload]) = { + val finalPayload: Seq[OnionPerHopPayload] = opts match { + case LegacyPayload => OnionForwardInfo(ShortChannelId(0L), finalAmount, finalExpiry) :: Nil + case TlvPayload => TlvStream[OnionTlv](OnionTlv.AmountToForward(finalAmount), OnionTlv.OutgoingCltv(finalExpiry)) :: Nil } + hops.reverse.foldLeft((finalAmount, finalExpiry, finalPayload)) { + case ((amount, expiry, payloads), hop) => + val nextFee = nodeFee(hop.lastUpdate.feeBaseMsat, hop.lastUpdate.feeProportionalMillionths, amount) + // Since we don't have any scenario where we add tlv data for intermediate hops, we use legacy payloads. + val payload: OnionPerHopPayload = OnionForwardInfo(hop.lastUpdate.shortChannelId, amount, expiry) + (amount + nextFee, expiry + hop.lastUpdate.cltvExpiryDelta, payload +: payloads) + } + } - def buildCommand(id: UUID, finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, paymentHash: ByteVector32, hops: Seq[Hop]): (CMD_ADD_HTLC, Seq[(ByteVector32, PublicKey)]) = { - val (firstAmount, firstExpiry, payloads) = buildPayloads(finalAmount, finalExpiry, hops.drop(1)) + def buildCommand(id: UUID, finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, paymentHash: ByteVector32, hops: Seq[Hop], opts: PaymentOptions = LegacyPayload): (CMD_ADD_HTLC, Seq[(ByteVector32, PublicKey)]) = { + val (firstAmount, firstExpiry, payloads) = buildPayloads(finalAmount, finalExpiry, hops.drop(1), opts) val nodes = hops.map(_.nextNodeId) // BOLT 2 requires that associatedData == paymentHash val onion = buildOnion(nodes, payloads, paymentHash) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala index 0b58ab74ab..f5e392bfde 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala @@ -213,7 +213,7 @@ object Relayer extends Logging { // @formatter:off sealed trait NextPayload - case class FinalPayload(add: UpdateAddHtlc, payload: OnionForwardInfo) extends NextPayload + case class FinalPayload(add: UpdateAddHtlc, payload: OnionPerHopPayload) extends NextPayload case class RelayPayload(add: UpdateAddHtlc, payload: OnionForwardInfo, nextPacket: OnionRoutingPacket) extends NextPayload { val relayFeeMsat: MilliSatoshi = add.amountMsat - payload.amtToForward val expiryDelta: CltvExpiryDelta = add.cltvExpiry - payload.outgoingCltvValue @@ -231,15 +231,21 @@ object Relayer extends Logging { def decryptPacket(add: UpdateAddHtlc, privateKey: PrivateKey): Either[BadOnion, NextPayload] = Sphinx.PaymentPacket.peel(privateKey, add.paymentHash, add.onionRoutingPacket) match { case Right(p@Sphinx.DecryptedPacket(payload, nextPacket, _)) => - OnionCodecs.legacyPerHopPayloadCodec.decode(payload.bits) match { + OnionCodecs.perHopPayloadCodec.decode(payload.bits) match { case Attempt.Successful(DecodeResult(perHopPayload, remainder)) => if (remainder.nonEmpty) { logger.warn(s"${remainder.length} bits remaining after per-hop payload decoding: there might be an issue with the onion codec") } if (p.isLastPacket) { - Right(FinalPayload(add, perHopPayload)) + perHopPayload.paymentInfo match { + case Some(_) => Right(FinalPayload(add, perHopPayload)) + case None => Left(InvalidOnionPayload(Sphinx.PaymentPacket.hash(add.onionRoutingPacket))) + } } else { - Right(RelayPayload(add, perHopPayload, nextPacket)) + perHopPayload.forwardInfo match { + case Some(forwardInfo) => Right(RelayPayload(add, forwardInfo, nextPacket)) + case None => Left(InvalidOnionPayload(Sphinx.PaymentPacket.hash(add.onionRoutingPacket))) + } } case Attempt.Failure(_) => // Onion is correctly encrypted but the content of the per-hop payload couldn't be parsed. @@ -258,11 +264,13 @@ object Relayer extends Logging { */ def handleFinal(finalPayload: FinalPayload): Either[CMD_FAIL_HTLC, UpdateAddHtlc] = { import finalPayload.add - finalPayload.payload match { - case OnionForwardInfo(_, finalAmountToForward, _) if finalAmountToForward > add.amountMsat => + finalPayload.payload.paymentInfo match { + case Some(OnionPaymentInfo(amountMsat, _)) if amountMsat > add.amountMsat => Left(CMD_FAIL_HTLC(add.id, Right(FinalIncorrectHtlcAmount(add.amountMsat)), commit = true)) - case OnionForwardInfo(_, _, finalOutgoingCltvValue) if finalOutgoingCltvValue != add.cltvExpiry => + case Some(OnionPaymentInfo(_, cltvExpiry)) if cltvExpiry != add.cltvExpiry => Left(CMD_FAIL_HTLC(add.id, Right(FinalIncorrectCltvExpiry(add.cltvExpiry)), commit = true)) + case None => + Left(CMD_FAIL_HTLC(add.id, Right(InvalidOnionPayload(Sphinx.PaymentPacket.hash(add.onionRoutingPacket))), commit = true)) case _ => Right(add) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala index 0d4fa62d44..7fad290594 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala @@ -30,6 +30,11 @@ import scodec.{Codec, DecodeResult, Decoder} * Created by t-bast on 05/07/2019. */ +/** + * Tlv types used inside onion messages. + */ +sealed trait OnionTlv extends Tlv + case class OnionRoutingPacket(version: Int, publicKey: ByteVector, payload: ByteVector, @@ -39,17 +44,39 @@ case class OnionForwardInfo(shortChannelId: ShortChannelId, amtToForward: MilliSatoshi, outgoingCltvValue: CltvExpiry) -/** - * Tlv types used inside onion messages. - */ -sealed trait OnionTlv extends Tlv +case class OnionPaymentInfo(amount: MilliSatoshi, cltvExpiry: CltvExpiry) -object OnionTlv { +case class OnionPerHopPayload(payload: Either[TlvStream[OnionTlv], OnionForwardInfo]) { - /** - * If this record is present in an onion payload, the current node is the final destination of the onion message. - */ - case class Destination() extends OnionTlv + lazy val paymentInfo: Option[OnionPaymentInfo] = payload match { + case Right(OnionForwardInfo(_, amount, cltv)) => Some(OnionPaymentInfo(amount, cltv)) + case Left(tlv) => for { + amount <- tlv.get[AmountToForward].map(_.amount) + cltv <- tlv.get[OutgoingCltv].map(_.cltv) + } yield OnionPaymentInfo(amount, cltv) + } + + lazy val forwardInfo: Option[OnionForwardInfo] = payload match { + case Right(onionForwardInfo) => Some(onionForwardInfo) + case Left(tlv) => for { + shortChannelId <- tlv.get[OutgoingChannelId].map(_.shortChannelId) + amount <- tlv.get[AmountToForward].map(_.amount) + cltv <- tlv.get[OutgoingCltv].map(_.cltv) + } yield OnionForwardInfo(shortChannelId, amount, cltv) + } + +} + +object OnionPerHopPayload { + + // @formatter:off + implicit def legacyToPerHopPayload(legacy: OnionForwardInfo): OnionPerHopPayload = OnionPerHopPayload(Right(legacy)) + implicit def tlvToPerHopPayload(tlv: TlvStream[OnionTlv]): OnionPerHopPayload = OnionPerHopPayload(Left(tlv)) + // @formatter:on + +} + +object OnionTlv { /** * Amount to forward to the next node. @@ -87,8 +114,6 @@ object OnionCodecs { val payloadLengthDecoder = Decoder[Long]((bits: BitVector) => varintoverflow.decode(bits).map(d => DecodeResult(d.value + (bits.length - d.remainder.length) / 8, d.remainder))) - private val destination: Codec[Destination] = ("length" | constant(hex"00")).xmap(_ => Destination(), _ => ()) - private val amountToForward: Codec[AmountToForward] = ("amount_msat" | tu64overflow).xmap(amountMsat => AmountToForward(MilliSatoshi(amountMsat)), (a: AmountToForward) => a.amount.toLong) private val outgoingCltv: Codec[OutgoingCltv] = ("cltv" | tu32).xmap(cltv => OutgoingCltv(CltvExpiry(cltv)), (c: OutgoingCltv) => c.cltv.toLong) @@ -96,7 +121,6 @@ object OnionCodecs { private val outgoingChannelId: Codec[OutgoingChannelId] = (("length" | constant(hex"08")) :: ("short_channel_id" | shortchannelid)).as[OutgoingChannelId] private val onionTlvCodec = discriminated[OnionTlv].by(varint) - .typecase(UInt64(0), destination) .typecase(UInt64(2), amountToForward) .typecase(UInt64(4), outgoingCltv) .typecase(UInt64(6), outgoingChannelId) @@ -110,6 +134,6 @@ object OnionCodecs { ("outgoing_cltv_value" | cltvExpiry) :: ("unused_with_v0_version_on_header" | ignore(8 * 12))).as[OnionForwardInfo] - val perHopPayloadCodec: Codec[Either[TlvStream[OnionTlv], OnionForwardInfo]] = fallback(tlvPerHopPayloadCodec, legacyPerHopPayloadCodec) + val perHopPayloadCodec: Codec[OnionPerHopPayload] = fallback(tlvPerHopPayloadCodec, legacyPerHopPayloadCodec).as[OnionPerHopPayload] } \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala index cf9e157137..4a44d531ea 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala @@ -25,7 +25,7 @@ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.crypto.Sphinx.{DecryptedPacket, PacketAndSecrets} import fr.acinq.eclair.payment.PaymentLifecycle._ import fr.acinq.eclair.router.Hop -import fr.acinq.eclair.wire.{ChannelUpdate, OnionCodecs, OnionForwardInfo} +import fr.acinq.eclair.wire._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, MilliSatoshi, ShortChannelId, TestConstants, nodeFee, randomBytes32} import org.scalatest.FunSuite import scodec.bits.ByteVector @@ -49,20 +49,19 @@ class HtlcGenerationSpec extends FunSuite { import HtlcGenerationSpec._ test("compute payloads with fees and expiry delta") { - val (firstAmountMsat, firstExpiry, payloads) = buildPayloads(finalAmountMsat, finalExpiry, hops.drop(1)) + val expectedPayloads = Seq[OnionPerHopPayload]( + OnionForwardInfo(channelUpdate_bc.shortChannelId, amount_bc, expiry_bc), + OnionForwardInfo(channelUpdate_cd.shortChannelId, amount_cd, expiry_cd), + OnionForwardInfo(channelUpdate_de.shortChannelId, amount_de, expiry_de), + OnionForwardInfo(ShortChannelId(0L), finalAmountMsat, finalExpiry)) assert(firstAmountMsat === amount_ab) assert(firstExpiry === expiry_ab) - assert(payloads === - OnionForwardInfo(channelUpdate_bc.shortChannelId, amount_bc, expiry_bc) :: - OnionForwardInfo(channelUpdate_cd.shortChannelId, amount_cd, expiry_cd) :: - OnionForwardInfo(channelUpdate_de.shortChannelId, amount_de, expiry_de) :: - OnionForwardInfo(ShortChannelId(0L), finalAmountMsat, finalExpiry) :: Nil) + assert(payloads === expectedPayloads) } test("build onion") { - val (_, _, payloads) = buildPayloads(finalAmountMsat, finalExpiry, hops.drop(1)) val nodes = hops.map(_.nextNodeId) val PacketAndSecrets(packet_b, _) = buildOnion(nodes, payloads, paymentHash) @@ -94,8 +93,36 @@ class HtlcGenerationSpec extends FunSuite { assert(payload_e.outgoingCltvValue === finalExpiry) } - test("build a command including the onion") { + test("build onion with final tlv payload") { + val (_, _, payloads) = buildPayloads(finalAmountMsat, finalExpiry, hops.drop(1), TlvPayload) + val nodes = hops.map(_.nextNodeId) + val PacketAndSecrets(packet_b, _) = buildOnion(nodes, payloads, paymentHash) + assert(packet_b.payload.length === Sphinx.PaymentPacket.PayloadLength) + + // let's peel the onion + val Right(DecryptedPacket(bin_b, packet_c, _)) = Sphinx.PaymentPacket.peel(priv_b.privateKey, paymentHash, packet_b) + val payload_b = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_b.toBitVector).require.value + assert(packet_c.payload.length === Sphinx.PaymentPacket.PayloadLength) + assert(payload_b === OnionForwardInfo(channelUpdate_bc.shortChannelId, amount_bc, expiry_bc)) + + val Right(DecryptedPacket(bin_c, packet_d, _)) = Sphinx.PaymentPacket.peel(priv_c.privateKey, paymentHash, packet_c) + val payload_c = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_c.toBitVector).require.value + assert(packet_d.payload.length === Sphinx.PaymentPacket.PayloadLength) + assert(payload_c === OnionForwardInfo(channelUpdate_cd.shortChannelId, amount_cd, expiry_cd)) + + val Right(DecryptedPacket(bin_d, packet_e, _)) = Sphinx.PaymentPacket.peel(priv_d.privateKey, paymentHash, packet_d) + val payload_d = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_d.toBitVector).require.value + assert(packet_e.payload.length === Sphinx.PaymentPacket.PayloadLength) + assert(payload_d === OnionForwardInfo(channelUpdate_de.shortChannelId, amount_de, expiry_de)) + + val Right(DecryptedPacket(bin_e, packet_random, _)) = Sphinx.PaymentPacket.peel(priv_e.privateKey, paymentHash, packet_e) + val payload_e = OnionCodecs.tlvPerHopPayloadCodec.decode(bin_e.toBitVector).require.value + val paymentInfo = OnionPerHopPayload(Left(payload_e)).paymentInfo + assert(packet_random.payload.length === Sphinx.PaymentPacket.PayloadLength) + assert(paymentInfo === Some(OnionPaymentInfo(finalAmountMsat, finalExpiry))) + } + test("build a command including the onion") { val (add, _) = buildCommand(UUID.randomUUID, finalAmountMsat, finalExpiry, paymentHash, hops) assert(add.amount > finalAmountMsat) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala index 534a6101b2..63ff2998cc 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala @@ -23,10 +23,10 @@ import akka.testkit.TestProbe import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.payment.PaymentLifecycle.buildCommand +import fr.acinq.eclair.payment.PaymentLifecycle.{buildCommand, buildOnion} import fr.acinq.eclair.router.Announcements import fr.acinq.eclair.wire._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, ShortChannelId, TestConstants, TestkitBaseClass, UInt64, randomBytes32} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, ShortChannelId, TestConstants, TestkitBaseClass, UInt64, nodeFee, randomBytes32} import org.scalatest.Outcome import scodec.bits.ByteVector @@ -71,6 +71,36 @@ class RelayerSpec extends TestkitBaseClass { val fwd = register.expectMsgType[Register.ForwardShortId[CMD_ADD_HTLC]] assert(fwd.shortChannelId === channelUpdate_bc.shortChannelId) + assert(fwd.message.amount === amount_bc) + assert(fwd.message.cltvExpiry === expiry_bc) + assert(fwd.message.upstream === Right(add_ab)) + + sender.expectNoMsg(100 millis) + paymentHandler.expectNoMsg(100 millis) + } + + test("relay an htlc-add with onion tlv payload") { f => + import f._ + import fr.acinq.eclair.wire.OnionTlv._ + val sender = TestProbe() + + val finalPayload: Seq[OnionPerHopPayload] = TlvStream[OnionTlv](AmountToForward(finalAmountMsat), OutgoingCltv(finalExpiry)) :: Nil + val (firstAmountMsat, firstExpiry, payloads) = hops.drop(1).reverse.foldLeft((finalAmountMsat, finalExpiry, finalPayload)) { + case ((amountMsat, expiry, currentPayloads), hop) => + val nextFee = nodeFee(hop.lastUpdate.feeBaseMsat, hop.lastUpdate.feeProportionalMillionths, amountMsat) + val payload: OnionPerHopPayload = TlvStream[OnionTlv](AmountToForward(amountMsat), OutgoingCltv(expiry), OutgoingChannelId(hop.lastUpdate.shortChannelId)) + (amountMsat + nextFee, expiry + hop.lastUpdate.cltvExpiryDelta, payload +: currentPayloads) + } + val Sphinx.PacketAndSecrets(onion, _) = buildOnion(hops.map(_.nextNodeId), payloads, paymentHash) + val add_ab = UpdateAddHtlc(channelId_ab, 123456, firstAmountMsat, paymentHash, firstExpiry, onion) + relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) + + sender.send(relayer, ForwardAdd(add_ab)) + + val fwd = register.expectMsgType[Register.ForwardShortId[CMD_ADD_HTLC]] + assert(fwd.shortChannelId === channelUpdate_bc.shortChannelId) + assert(fwd.message.amount === amount_bc) + assert(fwd.message.cltvExpiry === expiry_bc) assert(fwd.message.upstream === Right(add_ab)) sender.expectNoMsg(100 millis) @@ -122,6 +152,21 @@ class RelayerSpec extends TestkitBaseClass { paymentHandler.expectNoMsg(100 millis) } + test("relay an htlc-add at the final node to the payment handler") { f => + import f._ + val sender = TestProbe() + + val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops.take(1)) + val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) + sender.send(relayer, ForwardAdd(add_ab)) + + val htlc = paymentHandler.expectMsgType[UpdateAddHtlc] + assert(htlc === add_ab) + + sender.expectNoMsg(100 millis) + register.expectNoMsg(100 millis) + } + test("fail to relay an htlc-add when we have no channel_update for the next channel") { f => import f._ val sender = TestProbe() @@ -244,6 +289,35 @@ class RelayerSpec extends TestkitBaseClass { paymentHandler.expectNoMsg(100 millis) } + test("fail to relay an htlc-add when the onion payload is missing data") { f => + import f._ + import fr.acinq.eclair.wire.OnionTlv._ + + // B is not the last hop and receives an onion missing some routing information. + val invalidPayloads_bc = Seq( + TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_bc.shortChannelId), AmountToForward(amount_bc)), // Missing cltv expiry. + TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_bc.shortChannelId), OutgoingCltv(expiry_bc)), // Missing forwarding amount. + TlvStream[OnionTlv](AmountToForward(amount_bc), OutgoingCltv(expiry_bc))) // Missing channel id. + val payload_cd = TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_cd.shortChannelId), AmountToForward(amount_cd), OutgoingCltv(expiry_cd)) + + val sender = TestProbe() + relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) + + for (invalidPayload_bc <- invalidPayloads_bc) { + val Sphinx.PacketAndSecrets(onion, _) = buildOnion(Seq(b, c), Seq(invalidPayload_bc, payload_cd), paymentHash) + val add_ab = UpdateAddHtlc(channelId_ab, 123456, amount_ab, paymentHash, expiry_ab, onion) + sender.send(relayer, ForwardAdd(add_ab)) + + val fail = register.expectMsgType[Register.Forward[CMD_FAIL_MALFORMED_HTLC]].message + assert(fail.id === add_ab.id) + assert(fail.onionHash == Sphinx.PaymentPacket.hash(add_ab.onionRoutingPacket)) + assert(fail.failureCode === (FailureMessageCodecs.BADONION | FailureMessageCodecs.PERM)) + + register.expectNoMsg(100 millis) + paymentHandler.expectNoMsg(100 millis) + } + } + test("fail to relay an htlc-add when amount is below the next hop's requirements") { f => import f._ val sender = TestProbe() @@ -346,6 +420,32 @@ class RelayerSpec extends TestkitBaseClass { paymentHandler.expectNoMsg(100 millis) } + test("fail an htlc-add at the final node when the onion payload is missing data") { f => + import f._ + import fr.acinq.eclair.wire.OnionTlv._ + + // B is the last hop and receives an onion missing some payment information. + val invalidFinalPayloads = Seq( + TlvStream[OnionTlv](AmountToForward(amount_bc)), // Missing cltv expiry. + TlvStream[OnionTlv](OutgoingCltv(expiry_bc))) // Missing forwarding amount. + + val sender = TestProbe() + + for (invalidFinalPayload <- invalidFinalPayloads) { + val Sphinx.PacketAndSecrets(onion, _) = buildOnion(Seq(b), Seq(invalidFinalPayload), paymentHash) + val add_ab = UpdateAddHtlc(channelId_ab, 123456, amount_ab, paymentHash, expiry_ab, onion) + sender.send(relayer, ForwardAdd(add_ab)) + + val fail = register.expectMsgType[Register.Forward[CMD_FAIL_MALFORMED_HTLC]].message + assert(fail.id === add_ab.id) + assert(fail.onionHash == Sphinx.PaymentPacket.hash(add_ab.onionRoutingPacket)) + assert(fail.failureCode === (FailureMessageCodecs.BADONION | FailureMessageCodecs.PERM)) + + register.expectNoMsg(100 millis) + paymentHandler.expectNoMsg(100 millis) + } + } + test("correctly translates errors returned by channel when attempting to add an htlc") { f => import f._ val sender = TestProbe() diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala index e885ff0cf8..6b0703993e 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala @@ -19,8 +19,9 @@ package fr.acinq.eclair.wire import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.UInt64.Conversions._ import fr.acinq.eclair.wire.OnionCodecs._ +import fr.acinq.eclair.wire.OnionPerHopPayload._ import fr.acinq.eclair.wire.OnionTlv._ -import fr.acinq.eclair.{CltvExpiry, LongToBtcAmount, MilliSatoshi, ShortChannelId} +import fr.acinq.eclair.{CltvExpiry, LongToBtcAmount, ShortChannelId} import org.scalatest.FunSuite import scodec.bits.HexStringSyntax @@ -49,10 +50,10 @@ class OnionCodecsSpec extends FunSuite { ) for ((expected, bin) <- testCases) { - val decoded = perHopPayloadCodec.decode(bin.bits).require.value - assert(decoded === Right(expected)) + val OnionPerHopPayload(Right(decoded)) = perHopPayloadCodec.decode(bin.bits).require.value + assert(decoded === expected) - val encoded = perHopPayloadCodec.encode(Right(expected)).require.bytes + val encoded = perHopPayloadCodec.encode(expected).require.bytes assert(encoded === bin) } } @@ -76,16 +77,16 @@ class OnionCodecsSpec extends FunSuite { test("encode/decode variable-length (tlv) per-hop payload") { val testCases = Map( - TlvStream[OnionTlv](Destination()) -> hex"02 0000", + TlvStream[OnionTlv]() -> hex"00", TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), OutgoingChannelId(ShortChannelId(1105))) -> hex"11 02020231 04012a 06080000000000000451", - TlvStream[OnionTlv](Seq(Destination(), AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42))), Seq(GenericTlv(65535, hex"06c1"))) -> hex"0f 0000 02020231 04012a fdffff0206c1" + TlvStream[OnionTlv](Seq(AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42))), Seq(GenericTlv(65535, hex"06c1"))) -> hex"0d 02020231 04012a fdffff0206c1" ) for ((expected, bin) <- testCases) { - val decoded = perHopPayloadCodec.decode(bin.bits).require.value - assert(decoded === Left(expected)) + val OnionPerHopPayload(Left(decoded)) = perHopPayloadCodec.decode(bin.bits).require.value + assert(decoded === expected) - val encoded = perHopPayloadCodec.encode(Left(expected)).require.bytes + val encoded = perHopPayloadCodec.encode(expected).require.bytes assert(encoded === bin) } } @@ -107,4 +108,41 @@ class OnionCodecsSpec extends FunSuite { } } + test("get payment info") { + val legacyPayload: OnionPerHopPayload = OnionForwardInfo(ShortChannelId(550), 561 msat, CltvExpiry(1105)) + assert(legacyPayload.paymentInfo === Some(OnionPaymentInfo(561 msat, CltvExpiry(1105)))) + + val tlvPayload: OnionPerHopPayload = TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(1105))) + assert(tlvPayload.paymentInfo === Some(OnionPaymentInfo(561 msat, CltvExpiry(1105)))) + + val tlvPayloadUnknown: OnionPerHopPayload = TlvStream[OnionTlv](Seq(AmountToForward(561 msat), OutgoingCltv(CltvExpiry(1105))), Seq(GenericTlv(13, hex"2a"))) + assert(tlvPayloadUnknown.paymentInfo === Some(OnionPaymentInfo(561 msat, CltvExpiry(1105)))) + + val tlvPayloadNoCltv: OnionPerHopPayload = TlvStream[OnionTlv](AmountToForward(561 msat)) + assert(tlvPayloadNoCltv.paymentInfo === None) + + val tlvPayloadNoAmount: OnionPerHopPayload = TlvStream[OnionTlv](OutgoingCltv(CltvExpiry(1105))) + assert(tlvPayloadNoAmount.paymentInfo === None) + } + + test("get forward info") { + val legacyPayload: OnionPerHopPayload = OnionForwardInfo(ShortChannelId(550), 561 msat, CltvExpiry(1105)) + assert(legacyPayload.forwardInfo === Some(OnionForwardInfo(ShortChannelId(550), 561 msat, CltvExpiry(1105)))) + + val tlvPayload: OnionPerHopPayload = TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(1105)), OutgoingChannelId(ShortChannelId(550))) + assert(tlvPayload.forwardInfo === Some(OnionForwardInfo(ShortChannelId(550), 561 msat, CltvExpiry(1105)))) + + val tlvPayloadUnknown: OnionPerHopPayload = TlvStream[OnionTlv](Seq(AmountToForward(561 msat), OutgoingCltv(CltvExpiry(1105)), OutgoingChannelId(ShortChannelId(550))), Seq(GenericTlv(13, hex"2a"))) + assert(tlvPayloadUnknown.forwardInfo === Some(OnionForwardInfo(ShortChannelId(550), 561 msat, CltvExpiry(1105)))) + + val tlvPayloadNoCltv: OnionPerHopPayload = TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingChannelId(ShortChannelId(550))) + assert(tlvPayloadNoCltv.forwardInfo === None) + + val tlvPayloadNoAmount: OnionPerHopPayload = TlvStream[OnionTlv](OutgoingCltv(CltvExpiry(1105)), OutgoingChannelId(ShortChannelId(550))) + assert(tlvPayloadNoAmount.forwardInfo === None) + + val tlvPayloadNoChannelId: OnionPerHopPayload = TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(1105))) + assert(tlvPayloadNoChannelId.forwardInfo === None) + } + } From 09496f7903cc42103f8737c755744e0d1c320ef7 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Tue, 23 Jul 2019 18:12:04 +0200 Subject: [PATCH 03/11] Clean onion feature bit usage (move to conf). Refactored Features to use BitVector. --- eclair-core/src/main/resources/reference.conf | 2 +- .../main/scala/fr/acinq/eclair/Features.scala | 28 +++++----- .../fr/acinq/eclair/io/Switchboard.scala | 7 +-- .../fr/acinq/eclair/payment/Relayer.scala | 9 ++-- .../acinq/eclair/router/Announcements.scala | 53 +++++++++---------- .../scala/fr/acinq/eclair/router/Router.scala | 4 +- .../scala/fr/acinq/eclair/FeaturesSpec.scala | 14 ++--- .../scala/fr/acinq/eclair/TestConstants.scala | 7 +-- .../acinq/eclair/db/SqliteNetworkDbSpec.scala | 17 +++--- .../eclair/payment/PaymentLifecycleSpec.scala | 3 +- .../fr/acinq/eclair/payment/RelayerSpec.scala | 26 ++++++++- .../eclair/router/AnnouncementsSpec.scala | 2 +- .../acinq/eclair/router/BaseRouterSpec.scala | 14 ++--- .../acinq/eclair/router/RoutingSyncSpec.scala | 20 +++---- 14 files changed, 120 insertions(+), 86 deletions(-) diff --git a/eclair-core/src/main/resources/reference.conf b/eclair-core/src/main/resources/reference.conf index a6ac85de83..c83448ca69 100644 --- a/eclair-core/src/main/resources/reference.conf +++ b/eclair-core/src/main/resources/reference.conf @@ -35,7 +35,7 @@ eclair { node-alias = "eclair" node-color = "49daaa" - global-features = "" + global-features = "0200" // variable_length_onion local-features = "8a" // initial_routing_sync + option_data_loss_protect + option_channel_range_queries override-features = [ // optional per-node features # { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala index 41f8e358ff..bc43027728 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala @@ -16,10 +16,7 @@ package fr.acinq.eclair - -import java.util.BitSet - -import scodec.bits.ByteVector +import scodec.bits.{BitVector, ByteVector} /** * Created by PM on 13/02/2017. @@ -38,18 +35,24 @@ object Features { val VARIABLE_LENGTH_ONION_MANDATORY = 8 val VARIABLE_LENGTH_ONION_OPTIONAL = 9 - def hasFeature(features: BitSet, bit: Int): Boolean = features.get(bit) + // Note that BitVector indexes from left to right whereas the specification indexes from right to left. + // This is why we have to reverse the bits to check if a feature is set. + + def hasFeature(features: BitVector, bit: Int): Boolean = if (features.sizeLessThanOrEqual(bit)) false else features.reverse.get(bit) - def hasFeature(features: ByteVector, bit: Int): Boolean = hasFeature(BitSet.valueOf(features.reverse.toArray), bit) + def hasFeature(features: ByteVector, bit: Int): Boolean = hasFeature(features.bits, bit) + + def hasVariableLengthOnion(features: ByteVector): Boolean = hasFeature(features, VARIABLE_LENGTH_ONION_MANDATORY) || hasFeature(features, VARIABLE_LENGTH_ONION_OPTIONAL) /** * Check that the features that we understand are correctly specified, and that there are no mandatory features that - * we don't understand (even bits) + * we don't understand (even bits). */ - def areSupported(bitset: BitSet): Boolean = { - val supportedMandatoryFeatures = Set(OPTION_DATA_LOSS_PROTECT_MANDATORY) - for (i <- 0 until bitset.length() by 2) { - if (bitset.get(i) && !supportedMandatoryFeatures.contains(i)) return false + def areSupported(features: BitVector): Boolean = { + val supportedMandatoryFeatures = Set[Long](OPTION_DATA_LOSS_PROTECT_MANDATORY) + val reversed = features.reverse + for (i <- 0L until reversed.length by 2) { + if (reversed.get(i) && !supportedMandatoryFeatures.contains(i)) return false } true @@ -59,5 +62,6 @@ object Features { * A feature set is supported if all even bits are supported. * We just ignore unknown odd bits. */ - def areSupported(features: ByteVector): Boolean = areSupported(BitSet.valueOf(features.reverse.toArray)) + def areSupported(features: ByteVector): Boolean = areSupported(features.bits) + } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala index 7d29142d2e..2ee641c62f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala @@ -32,6 +32,7 @@ import fr.acinq.eclair.router.Rebroadcast import fr.acinq.eclair.transactions.{IN, OUT} import fr.acinq.eclair.wire.{TemporaryNodeFailure, UpdateAddHtlc} import grizzled.slf4j.Logging +import scodec.bits.ByteVector import scala.util.Success @@ -57,7 +58,7 @@ class Switchboard(nodeParams: NodeParams, authenticator: ActorRef, watcher: Acto }) val peers = nodeParams.db.peers.listPeers() - checkBrokenHtlcsLink(channels, nodeParams.privateKey) match { + checkBrokenHtlcsLink(channels, nodeParams.privateKey, nodeParams.globalFeatures) match { case Nil => () case brokenHtlcs => val brokenHtlcKiller = context.system.actorOf(Props[HtlcReaper], name = "htlc-reaper") @@ -165,7 +166,7 @@ object Switchboard extends Logging { * * This check will detect this and will allow us to fast-fail HTLCs and thus preserve channels. */ - def checkBrokenHtlcsLink(channels: Seq[HasCommitments], privateKey: PrivateKey): Seq[UpdateAddHtlc] = { + def checkBrokenHtlcsLink(channels: Seq[HasCommitments], privateKey: PrivateKey, features: ByteVector): Seq[UpdateAddHtlc] = { // We are interested in incoming HTLCs, that have been *cross-signed* (otherwise they wouldn't have been relayed). // They signed it first, so the HTLC will first appear in our commitment tx, and later on in their commitment when @@ -174,7 +175,7 @@ object Switchboard extends Logging { .flatMap(_.commitments.remoteCommit.spec.htlcs) .filter(_.direction == OUT) .map(_.add) - .map(Relayer.decryptPacket(_, privateKey)) + .map(Relayer.decryptPacket(_, privateKey, features)) .collect { case Right(RelayPayload(add, _, _)) => add } // we only consider htlcs that are relayed, not the ones for which we are the final node // Here we do it differently because we need the origin information. diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala index f5e392bfde..40c9c6ba28 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala @@ -28,8 +28,9 @@ import fr.acinq.eclair.db.OutgoingPaymentStatus import fr.acinq.eclair.payment.PaymentLifecycle.{PaymentFailed, PaymentSucceeded} import fr.acinq.eclair.router.Announcements import fr.acinq.eclair.wire._ -import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, MilliSatoshi, NodeParams, ShortChannelId, nodeFee} +import fr.acinq.eclair.{CltvExpiryDelta, Features, LongToBtcAmount, MilliSatoshi, NodeParams, ShortChannelId, nodeFee} import grizzled.slf4j.Logging +import scodec.bits.ByteVector import scodec.{Attempt, DecodeResult} import scala.collection.mutable @@ -99,7 +100,7 @@ class Relayer(nodeParams: NodeParams, register: ActorRef, paymentHandler: ActorR case ForwardAdd(add, previousFailures) => log.debug(s"received forwarding request for htlc #${add.id} paymentHash=${add.paymentHash} from channelId=${add.channelId}") - decryptPacket(add, nodeParams.privateKey) match { + decryptPacket(add, nodeParams.privateKey, nodeParams.globalFeatures) match { case Right(p: FinalPayload) => handleFinal(p) match { case Left(cmdFail) => @@ -228,10 +229,12 @@ object Relayer extends Logging { * @param privateKey this node's private key * @return the payload for the next hop or an error. */ - def decryptPacket(add: UpdateAddHtlc, privateKey: PrivateKey): Either[BadOnion, NextPayload] = + def decryptPacket(add: UpdateAddHtlc, privateKey: PrivateKey, features: ByteVector): Either[BadOnion, NextPayload] = Sphinx.PaymentPacket.peel(privateKey, add.paymentHash, add.onionRoutingPacket) match { case Right(p@Sphinx.DecryptedPacket(payload, nextPacket, _)) => OnionCodecs.perHopPayloadCodec.decode(payload.bits) match { + case Attempt.Successful(DecodeResult(OnionPerHopPayload(Left(_)), _)) if !Features.hasVariableLengthOnion(features) => + Left(InvalidOnionPayload(Sphinx.PaymentPacket.hash(add.onionRoutingPacket))) case Attempt.Successful(DecodeResult(perHopPayload, remainder)) => if (remainder.nonEmpty) { logger.warn(s"${remainder.length} bits remaining after per-hop payload decoding: there might be an issue with the onion codec") diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Announcements.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Announcements.scala index 2baa73e7a9..b1544ee311 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Announcements.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Announcements.scala @@ -19,7 +19,7 @@ package fr.acinq.eclair.router import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey, sha256, verifySignature} import fr.acinq.bitcoin.{ByteVector32, ByteVector64, Crypto, LexicographicalOrdering} import fr.acinq.eclair.wire._ -import fr.acinq.eclair.{CltvExpiryDelta, Features, MilliSatoshi, ShortChannelId, serializationResult} +import fr.acinq.eclair.{CltvExpiryDelta, MilliSatoshi, ShortChannelId, serializationResult} import scodec.bits.{BitVector, ByteVector} import shapeless.HNil @@ -27,8 +27,8 @@ import scala.compat.Platform import scala.concurrent.duration._ /** - * Created by PM on 03/02/2017. - */ + * Created by PM on 03/02/2017. + */ object Announcements { def channelAnnouncementWitnessEncode(chainHash: ByteVector32, shortChannelId: ShortChannelId, nodeId1: PublicKey, nodeId2: PublicKey, bitcoinKey1: PublicKey, bitcoinKey2: PublicKey, features: ByteVector, unknownFields: ByteVector): ByteVector = @@ -73,9 +73,8 @@ object Announcements { ) } - def makeNodeAnnouncement(nodeSecret: PrivateKey, alias: String, color: Color, nodeAddresses: List[NodeAddress], timestamp: Long = Platform.currentTime.milliseconds.toSeconds): NodeAnnouncement = { + def makeNodeAnnouncement(nodeSecret: PrivateKey, alias: String, color: Color, nodeAddresses: List[NodeAddress], features: ByteVector, timestamp: Long = Platform.currentTime.milliseconds.toSeconds): NodeAnnouncement = { require(alias.length <= 32) - val features = BitVector.fromLong(1 << Features.VARIABLE_LENGTH_ONION_OPTIONAL).bytes val witness = nodeAnnouncementWitnessEncode(timestamp, nodeSecret.publicKey, color, alias, features, nodeAddresses, unknownFields = ByteVector.empty) val sig = Crypto.sign(witness, nodeSecret) NodeAnnouncement( @@ -90,37 +89,37 @@ object Announcements { } /** - * BOLT 7: - * The creating node MUST set node-id-1 and node-id-2 to the public keys of the - * two nodes who are operating the channel, such that node-id-1 is the numerically-lesser - * of the two DER encoded keys sorted in ascending numerical order, - * - * @return true if localNodeId is node1 - */ + * BOLT 7: + * The creating node MUST set node-id-1 and node-id-2 to the public keys of the + * two nodes who are operating the channel, such that node-id-1 is the numerically-lesser + * of the two DER encoded keys sorted in ascending numerical order, + * + * @return true if localNodeId is node1 + */ def isNode1(localNodeId: PublicKey, remoteNodeId: PublicKey) = LexicographicalOrdering.isLessThan(localNodeId.value, remoteNodeId.value) /** - * BOLT 7: - * The creating node [...] MUST set the direction bit of flags to 0 if - * the creating node is node-id-1 in that message, otherwise 1. - * - * @return true if the node who sent these flags is node1 - */ + * BOLT 7: + * The creating node [...] MUST set the direction bit of flags to 0 if + * the creating node is node-id-1 in that message, otherwise 1. + * + * @return true if the node who sent these flags is node1 + */ def isNode1(channelFlags: Byte): Boolean = (channelFlags & 1) == 0 /** - * A node MAY create and send a channel_update with the disable bit set to - * signal the temporary unavailability of a channel - * - * @return - */ + * A node MAY create and send a channel_update with the disable bit set to + * signal the temporary unavailability of a channel + * + * @return + */ def isEnabled(channelFlags: Byte): Boolean = (channelFlags & 2) == 0 /** - * This method compares channel updates, ignoring fields that don't matter, like signature or timestamp - * - * @return true if channel updates are "equal" - */ + * This method compares channel updates, ignoring fields that don't matter, like signature or timestamp + * + * @return true if channel updates are "equal" + */ def areSame(u1: ChannelUpdate, u2: ChannelUpdate): Boolean = u1.copy(signature = ByteVector64.Zeroes, timestamp = 0) == u2.copy(signature = ByteVector64.Zeroes, timestamp = 0) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index 7001f3bb0f..575265f300 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -194,7 +194,7 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ // on restart we update our node announcement // note that if we don't currently have public channels, this will be ignored - val nodeAnn = Announcements.makeNodeAnnouncement(nodeParams.privateKey, nodeParams.alias, nodeParams.color, nodeParams.publicAddresses) + val nodeAnn = Announcements.makeNodeAnnouncement(nodeParams.privateKey, nodeParams.alias, nodeParams.color, nodeParams.publicAddresses, nodeParams.globalFeatures) self ! nodeAnn log.info(s"initialization completed, ready to process messages") @@ -289,7 +289,7 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ // in case we just validated our first local channel, we announce the local node if (!d0.nodes.contains(nodeParams.nodeId) && isRelatedTo(c, nodeParams.nodeId)) { log.info("first local channel validated, announcing local node") - val nodeAnn = Announcements.makeNodeAnnouncement(nodeParams.privateKey, nodeParams.alias, nodeParams.color, nodeParams.publicAddresses) + val nodeAnn = Announcements.makeNodeAnnouncement(nodeParams.privateKey, nodeParams.alias, nodeParams.color, nodeParams.publicAddresses, nodeParams.globalFeatures) self ! nodeAnn } Some(PublicChannel(c, tx.txid, capacity, None, None)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala index 102e50e532..71733773b3 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala @@ -16,9 +16,6 @@ package fr.acinq.eclair -import java.nio.ByteOrder - -import fr.acinq.bitcoin.Protocol import fr.acinq.eclair.Features._ import org.scalatest.FunSuite import scodec.bits._ @@ -45,14 +42,17 @@ class FeaturesSpec extends FunSuite { test("'variable_length_onion' feature") { assert(hasFeature(hex"0100", Features.VARIABLE_LENGTH_ONION_MANDATORY)) + assert(hasVariableLengthOnion(hex"0100")) assert(hasFeature(hex"0200", Features.VARIABLE_LENGTH_ONION_OPTIONAL)) + assert(hasVariableLengthOnion(hex"0200")) } test("features compatibility") { - assert(areSupported(Protocol.writeUInt64(1l << INITIAL_ROUTING_SYNC_BIT_OPTIONAL, ByteOrder.BIG_ENDIAN))) - assert(areSupported(Protocol.writeUInt64(1L << OPTION_DATA_LOSS_PROTECT_MANDATORY, ByteOrder.BIG_ENDIAN))) - assert(areSupported(Protocol.writeUInt64(1l << OPTION_DATA_LOSS_PROTECT_OPTIONAL, ByteOrder.BIG_ENDIAN))) - assert(areSupported(Protocol.writeUInt64(1l << VARIABLE_LENGTH_ONION_OPTIONAL, ByteOrder.BIG_ENDIAN))) + assert(areSupported(ByteVector.fromLong(1L << INITIAL_ROUTING_SYNC_BIT_OPTIONAL))) + assert(areSupported(ByteVector.fromLong(1L << OPTION_DATA_LOSS_PROTECT_MANDATORY))) + assert(areSupported(ByteVector.fromLong(1L << OPTION_DATA_LOSS_PROTECT_OPTIONAL))) + assert(areSupported(ByteVector.fromLong(1L << VARIABLE_LENGTH_ONION_OPTIONAL))) + assert(areSupported(hex"0b")) assert(!areSupported(hex"14")) assert(!areSupported(hex"0141")) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala index 1495bc9686..8a1845fd13 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala @@ -27,7 +27,7 @@ import fr.acinq.eclair.db._ import fr.acinq.eclair.io.Peer import fr.acinq.eclair.router.RouterConf import fr.acinq.eclair.wire.{Color, EncodingType, NodeAddress} -import scodec.bits.ByteVector +import scodec.bits.{ByteVector, HexStringSyntax} import scala.concurrent.duration._ @@ -36,6 +36,7 @@ import scala.concurrent.duration._ */ object TestConstants { + val globalFeatures = hex"0200" // variable_length_onion val fundingSatoshis = 1000000L sat val pushMsat = 200000000L msat val feeratePerKw = 10000L @@ -67,7 +68,7 @@ object TestConstants { alias = "alice", color = Color(1, 2, 3), publicAddresses = NodeAddress.fromParts("localhost", 9731).get :: Nil, - globalFeatures = ByteVector.empty, + globalFeatures = globalFeatures, localFeatures = ByteVector(0), overrideFeatures = Map.empty, syncWhitelist = Set.empty, @@ -140,7 +141,7 @@ object TestConstants { alias = "bob", color = Color(4, 5, 6), publicAddresses = NodeAddress.fromParts("localhost", 9732).get :: Nil, - globalFeatures = ByteVector.empty, + globalFeatures = globalFeatures, localFeatures = ByteVector.empty, // no announcement overrideFeatures = Map.empty, syncWhitelist = Set.empty, diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala index d2347b2769..faedbb1499 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala @@ -16,16 +16,17 @@ package fr.acinq.eclair.db -import java.sql.{Connection, DriverManager} +import java.sql.Connection import fr.acinq.bitcoin.Crypto.PrivateKey -import fr.acinq.bitcoin.{Block, Crypto, Satoshi} +import fr.acinq.bitcoin.{Block, Crypto} import fr.acinq.eclair.db.sqlite.SqliteNetworkDb import fr.acinq.eclair.db.sqlite.SqliteUtils._ import fr.acinq.eclair.router.{Announcements, PublicChannel} import fr.acinq.eclair.wire.{Color, NodeAddress, Tor2} import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, ShortChannelId, TestConstants, randomBytes32, randomKey} import org.scalatest.FunSuite +import scodec.bits.HexStringSyntax import scala.collection.SortedMap @@ -82,15 +83,15 @@ class SqliteNetworkDbSpec extends FunSuite { val sqlite = TestConstants.sqliteInMemory() val db = new SqliteNetworkDb(sqlite) - val node_1 = Announcements.makeNodeAnnouncement(randomKey, "node-alice", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil) - val node_2 = Announcements.makeNodeAnnouncement(randomKey, "node-bob", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil) - val node_3 = Announcements.makeNodeAnnouncement(randomKey, "node-charlie", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil) - val node_4 = Announcements.makeNodeAnnouncement(randomKey, "node-charlie", Color(100.toByte, 200.toByte, 300.toByte), Tor2("aaaqeayeaudaocaj", 42000) :: Nil) + val node_1 = Announcements.makeNodeAnnouncement(randomKey, "node-alice", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, hex"") + val node_2 = Announcements.makeNodeAnnouncement(randomKey, "node-bob", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, hex"0200") + val node_3 = Announcements.makeNodeAnnouncement(randomKey, "node-charlie", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, hex"0200") + val node_4 = Announcements.makeNodeAnnouncement(randomKey, "node-charlie", Color(100.toByte, 200.toByte, 300.toByte), Tor2("aaaqeayeaudaocaj", 42000) :: Nil, hex"00") assert(db.listNodes().toSet === Set.empty) db.addNode(node_1) db.addNode(node_1) // duplicate is ignored - assert(db.getNode(node_1.nodeId) == Some(node_1)) + assert(db.getNode(node_1.nodeId) === Some(node_1)) assert(db.listNodes().size === 1) db.addNode(node_2) db.addNode(node_3) @@ -110,7 +111,7 @@ class SqliteNetworkDbSpec extends FunSuite { def generatePubkeyHigherThan(priv: PrivateKey) = { var res = priv - while(!Announcements.isNode1(priv.publicKey, res.publicKey)) res = randomKey + while (!Announcements.isNode1(priv.publicKey, res.publicKey)) res = randomKey res } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index b36f29d777..684293518c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -35,6 +35,7 @@ import fr.acinq.eclair.router.Announcements.{makeChannelUpdate, makeNodeAnnounce import fr.acinq.eclair.router._ import fr.acinq.eclair.transactions.Scripts import fr.acinq.eclair.wire._ +import scodec.bits.HexStringSyntax /** * Created by PM on 29/08/2016. @@ -414,7 +415,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val (priv_g, priv_funding_g) = (randomKey, randomKey) val (g, funding_g) = (priv_g.publicKey, priv_funding_g.publicKey) - val ann_g = makeNodeAnnouncement(priv_g, "node-G", Color(-30, 10, -50), Nil) + val ann_g = makeNodeAnnouncement(priv_g, "node-G", Color(-30, 10, -50), Nil, hex"0200") val channelId_bg = ShortChannelId(420000, 5, 0) val chan_bg = channelAnnouncement(channelId_bg, priv_b, priv_g, priv_funding_b, priv_funding_g) val channelUpdate_bg = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, g, channelId_bg, CltvExpiryDelta(9), htlcMinimumMsat = 0 msat, feeBaseMsat = 0 msat, feeProportionalMillionths = 0, htlcMaximumMsat = 500000000 msat) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala index 63ff2998cc..8f7db6047d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala @@ -298,7 +298,7 @@ class RelayerSpec extends TestkitBaseClass { TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_bc.shortChannelId), AmountToForward(amount_bc)), // Missing cltv expiry. TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_bc.shortChannelId), OutgoingCltv(expiry_bc)), // Missing forwarding amount. TlvStream[OnionTlv](AmountToForward(amount_bc), OutgoingCltv(expiry_bc))) // Missing channel id. - val payload_cd = TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_cd.shortChannelId), AmountToForward(amount_cd), OutgoingCltv(expiry_cd)) + val payload_cd = TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_cd.shortChannelId), AmountToForward(amount_cd), OutgoingCltv(expiry_cd)) val sender = TestProbe() relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -318,6 +318,30 @@ class RelayerSpec extends TestkitBaseClass { } } + test("fail to relay an htlc-add when variable length onion is disabled") { f => + import f._ + import fr.acinq.eclair.wire.OnionTlv._ + + val relayer = system.actorOf(Relayer.props(TestConstants.Bob.nodeParams.copy(globalFeatures = ByteVector.empty), register.ref, paymentHandler.ref)) + val sender = TestProbe() + relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) + + val payload_bc = TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_bc.shortChannelId), AmountToForward(amount_bc), OutgoingCltv(expiry_bc)) + val payload_cd = TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_cd.shortChannelId), AmountToForward(amount_cd), OutgoingCltv(expiry_cd)) + + val Sphinx.PacketAndSecrets(onion, _) = buildOnion(Seq(b, c), Seq(payload_bc, payload_cd), paymentHash) + val add_ab = UpdateAddHtlc(channelId_ab, 123456, amount_ab, paymentHash, expiry_ab, onion) + sender.send(relayer, ForwardAdd(add_ab)) + + val fail = register.expectMsgType[Register.Forward[CMD_FAIL_MALFORMED_HTLC]].message + assert(fail.id === add_ab.id) + assert(fail.onionHash == Sphinx.PaymentPacket.hash(add_ab.onionRoutingPacket)) + assert(fail.failureCode === (FailureMessageCodecs.BADONION | FailureMessageCodecs.PERM)) + + register.expectNoMsg(100 millis) + paymentHandler.expectNoMsg(100 millis) + } + test("fail to relay an htlc-add when amount is below the next hop's requirements") { f => import f._ val sender = TestProbe() diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/AnnouncementsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/AnnouncementsSpec.scala index 2e01c4d516..4f628ede35 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/AnnouncementsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/AnnouncementsSpec.scala @@ -48,7 +48,7 @@ class AnnouncementsSpec extends FunSuite { } test("create valid signed node announcement") { - val ann = makeNodeAnnouncement(Alice.nodeParams.privateKey, Alice.nodeParams.alias, Alice.nodeParams.color, Alice.nodeParams.publicAddresses) + val ann = makeNodeAnnouncement(Alice.nodeParams.privateKey, Alice.nodeParams.alias, Alice.nodeParams.color, Alice.nodeParams.publicAddresses, Alice.nodeParams.globalFeatures) assert(Features.hasFeature(ann.features, Features.VARIABLE_LENGTH_ONION_OPTIONAL)) assert(checkSig(ann)) assert(checkSig(ann.copy(timestamp = 153)) === false) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala index 628cc489a3..c358490b47 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala @@ -30,7 +30,7 @@ import fr.acinq.eclair.transactions.Scripts import fr.acinq.eclair.wire._ import fr.acinq.eclair.{TestkitBaseClass, randomKey, _} import org.scalatest.Outcome -import scodec.bits.ByteVector +import scodec.bits.{ByteVector, HexStringSyntax} import scala.concurrent.duration._ @@ -55,12 +55,12 @@ abstract class BaseRouterSpec extends TestkitBaseClass { val (priv_funding_a, priv_funding_b, priv_funding_c, priv_funding_d, priv_funding_e, priv_funding_f) = (randomKey, randomKey, randomKey, randomKey, randomKey, randomKey) val (funding_a, funding_b, funding_c, funding_d, funding_e, funding_f) = (priv_funding_a.publicKey, priv_funding_b.publicKey, priv_funding_c.publicKey, priv_funding_d.publicKey, priv_funding_e.publicKey, priv_funding_f.publicKey) - val ann_a = makeNodeAnnouncement(priv_a, "node-A", Color(15, 10, -70), Nil) - val ann_b = makeNodeAnnouncement(priv_b, "node-B", Color(50, 99, -80), Nil) - val ann_c = makeNodeAnnouncement(priv_c, "node-C", Color(123, 100, -40), Nil) - val ann_d = makeNodeAnnouncement(priv_d, "node-D", Color(-120, -20, 60), Nil) - val ann_e = makeNodeAnnouncement(priv_e, "node-E", Color(-50, 0, 10), Nil) - val ann_f = makeNodeAnnouncement(priv_f, "node-F", Color(30, 10, -50), Nil) + val ann_a = makeNodeAnnouncement(priv_a, "node-A", Color(15, 10, -70), Nil, hex"0200") + val ann_b = makeNodeAnnouncement(priv_b, "node-B", Color(50, 99, -80), Nil, hex"") + val ann_c = makeNodeAnnouncement(priv_c, "node-C", Color(123, 100, -40), Nil, hex"0200") + val ann_d = makeNodeAnnouncement(priv_d, "node-D", Color(-120, -20, 60), Nil, hex"00") + val ann_e = makeNodeAnnouncement(priv_e, "node-E", Color(-50, 0, 10), Nil, hex"00") + val ann_f = makeNodeAnnouncement(priv_f, "node-F", Color(30, 10, -50), Nil, hex"00") val channelId_ab = ShortChannelId(420000, 1, 0) val channelId_bc = ShortChannelId(420000, 2, 0) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala index 9144657412..83d5c8c3e4 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala @@ -30,13 +30,13 @@ import fr.acinq.eclair.router.BaseRouterSpec.channelAnnouncement import fr.acinq.eclair.transactions.Scripts import fr.acinq.eclair.wire._ import org.scalatest.FunSuiteLike +import scodec.bits.HexStringSyntax import scala.collection.immutable.TreeMap import scala.collection.{SortedSet, immutable, mutable} import scala.compat.Platform import scala.concurrent.duration._ - class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { import RoutingSyncSpec._ @@ -120,7 +120,7 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { } def countUpdates(channels: Map[ShortChannelId, PublicChannel]) = channels.values.foldLeft(0) { - case (count, pc) => count + pc.update_1_opt.map(_ => 1).getOrElse(0) + pc.update_2_opt.map(_ => 1).getOrElse(0) + case (count, pc) => count + pc.update_1_opt.map(_ => 1).getOrElse(0) + pc.update_2_opt.map(_ => 1).getOrElse(0) } test("sync with standard channel queries") { @@ -151,7 +151,7 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { // add some updates to bob and resync fakeRoutingInfo.take(40).values.foreach { - case (pc, na1, na2) => + case (pc, _, _) => sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_2_opt.get)) } awaitCond(bob.stateData.channels.size === 40 && countUpdates(bob.stateData.channels) === 80) @@ -172,7 +172,7 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { awaitCond(alice.stateData.channels === bob.stateData.channels, max = 60 seconds) } - def syncWithExtendedQueries(requestNodeAnnouncements: Boolean) = { + def syncWithExtendedQueries(requestNodeAnnouncements: Boolean): Unit = { val watcher = system.actorOf(Props(new YesWatcher())) val alice = TestFSMRef(new Router(Alice.nodeParams.copy(routerConf = Alice.nodeParams.routerConf.copy(requestNodeAnnouncements = requestNodeAnnouncements)), watcher)) val bob = TestFSMRef(new Router(Bob.nodeParams, watcher)) @@ -200,7 +200,7 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { // add some updates to bob and resync fakeRoutingInfo.take(40).values.foreach { - case (pc, na1, na2) => + case (pc, _, _) => sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_2_opt.get)) } awaitCond(bob.stateData.channels.size === 40 && countUpdates(bob.stateData.channels) === 80) @@ -216,7 +216,7 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na1)) sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na2)) } - awaitCond(bob.stateData.channels.size === fakeRoutingInfo.size && countUpdates(bob.stateData.channels) === 2 * fakeRoutingInfo.size, max = 60 seconds) + awaitCond(bob.stateData.channels.size === fakeRoutingInfo.size && countUpdates(bob.stateData.channels) === 2 * fakeRoutingInfo.size, max = 60 seconds) assert(BasicSyncResult(ranges = 2, queries = 46, channels = fakeRoutingInfo.size - 40, updates = 2 * (fakeRoutingInfo.size - 40), nodes = if (requestNodeAnnouncements) 2 * (fakeRoutingInfo.size - 40) else 0) === sync(alice, bob, extendedQueryFlags_opt).counts) awaitCond(alice.stateData.channels === bob.stateData.channels, max = 60 seconds) @@ -226,7 +226,7 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { makeNewerChannelUpdate(c, if (side) u1 else u2) } - val bumpedUpdates = (List(0, 42, 147, 153, 654, 834, 4301).map(touchUpdate(_, true)) ++ List(1, 42, 150, 200).map(touchUpdate(_, false))).toSet + val bumpedUpdates = (List(0, 42, 147, 153, 654, 834, 4301).map(touchUpdate(_, side = true)) ++ List(1, 42, 150, 200).map(touchUpdate(_, side = false))).toSet bumpedUpdates.foreach(c => sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, c))) assert(BasicSyncResult(ranges = 2, queries = 2, channels = 0, updates = bumpedUpdates.size, nodes = if (requestNodeAnnouncements) 20 else 0) === sync(alice, bob, extendedQueryFlags_opt).counts) awaitCond(alice.stateData.channels === bob.stateData.channels, max = 60 seconds) @@ -319,8 +319,8 @@ object RoutingSyncSpec { val channelAnn_12 = channelAnnouncement(shortChannelId, priv1, priv2, priv_funding1, priv_funding2) val channelUpdate_12 = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv1, priv2.publicKey, shortChannelId, cltvExpiryDelta = CltvExpiryDelta(7), 0 msat, feeBaseMsat = 766000 msat, feeProportionalMillionths = 10, 500000000L msat, timestamp = timestamp) val channelUpdate_21 = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv2, priv1.publicKey, shortChannelId, cltvExpiryDelta = CltvExpiryDelta(7), 0 msat, feeBaseMsat = 766000 msat, feeProportionalMillionths = 10, 500000000L msat, timestamp = timestamp) - val nodeAnnouncement_1 = makeNodeAnnouncement(priv1, "", Color(0, 0, 0), List()) - val nodeAnnouncement_2 = makeNodeAnnouncement(priv2, "", Color(0, 0, 0), List()) + val nodeAnnouncement_1 = makeNodeAnnouncement(priv1, "a", Color(0, 0, 0), List(), hex"0200") + val nodeAnnouncement_2 = makeNodeAnnouncement(priv2, "b", Color(0, 0, 0), List(), hex"00") val publicChannel = PublicChannel(channelAnn_12, ByteVector32.Zeroes, Satoshi(0), Some(channelUpdate_12), Some(channelUpdate_21)) (publicChannel, nodeAnnouncement_1, nodeAnnouncement_2) } @@ -336,7 +336,7 @@ object RoutingSyncSpec { def makeFakeNodeAnnouncement(nodeId: PublicKey): NodeAnnouncement = { val priv = pub2priv(nodeId) - makeNodeAnnouncement(priv, "", Color(0, 0, 0), List()) + makeNodeAnnouncement(priv, "", Color(0, 0, 0), List(), hex"00") } } From 553ea64d1d80c4b55cd5be2ce6ad0569d1cb3b7d Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Tue, 23 Jul 2019 19:09:07 +0200 Subject: [PATCH 04/11] Update InvalidOnionPayload. It should simply be a PERM failure. The spec hasn't decided yet to assign it a specific failure code. --- .../fr/acinq/eclair/payment/Relayer.scala | 10 +++++++--- .../fr/acinq/eclair/wire/FailureMessage.scala | 6 +++--- .../fr/acinq/eclair/payment/RelayerSpec.scala | 18 +++--------------- .../eclair/wire/FailureMessageCodecsSpec.scala | 7 +++---- 4 files changed, 16 insertions(+), 25 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala index 40c9c6ba28..29c6a5598b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala @@ -119,11 +119,15 @@ class Relayer(nodeParams: NodeParams, register: ActorRef, paymentHandler: ActorR log.info(s"forwarding htlc #${add.id} paymentHash=${add.paymentHash} from channelId=${add.channelId} to shortChannelId=$selectedShortChannelId") register ! Register.ForwardShortId(selectedShortChannelId, cmdAdd) } - case Left(badOnion) => + case Left(badOnion: BadOnion) => log.warning(s"couldn't parse onion: reason=${badOnion.message}") val cmdFail = CMD_FAIL_MALFORMED_HTLC(add.id, badOnion.onionHash, badOnion.code, commit = true) log.info(s"rejecting htlc #${add.id} paymentHash=${add.paymentHash} from channelId=${add.channelId} reason=malformed onionHash=${cmdFail.onionHash} failureCode=${cmdFail.failureCode}") commandBuffer ! CommandBuffer.CommandSend(add.channelId, add.id, cmdFail) + case Left(failure) => + log.warning(s"couldn't process onion: reason=${failure.message}") + val cmdFail = CMD_FAIL_HTLC(add.id, Right(failure), commit = true) + commandBuffer ! CommandBuffer.CommandSend(add.channelId, add.id, cmdFail) } case Status.Failure(Register.ForwardShortIdFailure(Register.ForwardShortId(shortChannelId, CMD_ADD_HTLC(_, _, _, _, Right(add), _, _)))) => @@ -229,12 +233,12 @@ object Relayer extends Logging { * @param privateKey this node's private key * @return the payload for the next hop or an error. */ - def decryptPacket(add: UpdateAddHtlc, privateKey: PrivateKey, features: ByteVector): Either[BadOnion, NextPayload] = + def decryptPacket(add: UpdateAddHtlc, privateKey: PrivateKey, features: ByteVector): Either[FailureMessage, NextPayload] = Sphinx.PaymentPacket.peel(privateKey, add.paymentHash, add.onionRoutingPacket) match { case Right(p@Sphinx.DecryptedPacket(payload, nextPacket, _)) => OnionCodecs.perHopPayloadCodec.decode(payload.bits) match { case Attempt.Successful(DecodeResult(OnionPerHopPayload(Left(_)), _)) if !Features.hasVariableLengthOnion(features) => - Left(InvalidOnionPayload(Sphinx.PaymentPacket.hash(add.onionRoutingPacket))) + Left(InvalidRealm) case Attempt.Successful(DecodeResult(perHopPayload, remainder)) => if (remainder.nonEmpty) { logger.warn(s"${remainder.length} bits remaining after per-hop payload decoding: there might be an issue with the onion codec") diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala index b3afca1f24..0998cbf953 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala @@ -49,7 +49,6 @@ case object RequiredNodeFeatureMissing extends Perm with Node { def message = "p case class InvalidOnionVersion(onionHash: ByteVector32) extends BadOnion with Perm { def message = "onion version was not understood by the processing node" } case class InvalidOnionHmac(onionHash: ByteVector32) extends BadOnion with Perm { def message = "onion HMAC was incorrect when it reached the processing node" } case class InvalidOnionKey(onionHash: ByteVector32) extends BadOnion with Perm { def message = "ephemeral key was unparsable by the processing node" } -case class InvalidOnionPayload(onionHash: ByteVector32) extends BadOnion with Perm { def message = "onion per-hop payload could not be parsed" } case class TemporaryChannelFailure(update: ChannelUpdate) extends Update { def message = s"channel ${update.shortChannelId} is currently unavailable" } case object PermanentChannelFailure extends Perm { def message = "channel is permanently unavailable" } case object RequiredChannelFeatureMissing extends Perm { def message = "channel requires features not present in the onion" } @@ -63,6 +62,7 @@ case class ExpiryTooSoon(update: ChannelUpdate) extends Update { def message = " case class FinalIncorrectCltvExpiry(expiry: CltvExpiry) extends FailureMessage { def message = "payment expiry doesn't match the value in the onion" } case class FinalIncorrectHtlcAmount(amount: MilliSatoshi) extends FailureMessage { def message = "payment amount is incorrect in the final htlc" } case object ExpiryTooFar extends FailureMessage { def message = "payment expiry is too far in the future" } +case class InvalidOnionPayload(onionHash: ByteVector32) extends Perm { def message = "onion per-hop payload is invalid" } /** * We allow remote nodes to send us unknown failure codes (e.g. deprecated failure codes). @@ -97,7 +97,6 @@ object FailureMessageCodecs { .typecase(NODE | 2, provide(TemporaryNodeFailure)) .typecase(PERM | NODE | 2, provide(PermanentNodeFailure)) .typecase(PERM | NODE | 3, provide(RequiredNodeFeatureMissing)) - .typecase(BADONION | PERM, sha256.as[InvalidOnionPayload]) .typecase(BADONION | PERM | 4, sha256.as[InvalidOnionVersion]) .typecase(BADONION | PERM | 5, sha256.as[InvalidOnionHmac]) .typecase(BADONION | PERM | 6, sha256.as[InvalidOnionKey]) @@ -115,7 +114,8 @@ object FailureMessageCodecs { // PERM | 17 (final_expiry_too_soon) has been deprecated because it allowed probing attacks: IncorrectOrUnknownPaymentDetails should be used instead. .typecase(18, ("expiry" | cltvExpiry).as[FinalIncorrectCltvExpiry]) .typecase(19, ("amountMsat" | millisatoshi).as[FinalIncorrectHtlcAmount]) - .typecase(21, provide(ExpiryTooFar)), + .typecase(21, provide(ExpiryTooFar)) + .typecase(PERM, sha256.as[InvalidOnionPayload]), uint16.xmap(code => { val failureMessage = code match { // @formatter:off diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala index 8f7db6047d..dd04104112 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala @@ -308,11 +308,7 @@ class RelayerSpec extends TestkitBaseClass { val add_ab = UpdateAddHtlc(channelId_ab, 123456, amount_ab, paymentHash, expiry_ab, onion) sender.send(relayer, ForwardAdd(add_ab)) - val fail = register.expectMsgType[Register.Forward[CMD_FAIL_MALFORMED_HTLC]].message - assert(fail.id === add_ab.id) - assert(fail.onionHash == Sphinx.PaymentPacket.hash(add_ab.onionRoutingPacket)) - assert(fail.failureCode === (FailureMessageCodecs.BADONION | FailureMessageCodecs.PERM)) - + register.expectMsg(Register.Forward(channelId_ab, CMD_FAIL_HTLC(add_ab.id, Right(InvalidOnionPayload(Sphinx.PaymentPacket.hash(add_ab.onionRoutingPacket))), commit = true))) register.expectNoMsg(100 millis) paymentHandler.expectNoMsg(100 millis) } @@ -333,11 +329,7 @@ class RelayerSpec extends TestkitBaseClass { val add_ab = UpdateAddHtlc(channelId_ab, 123456, amount_ab, paymentHash, expiry_ab, onion) sender.send(relayer, ForwardAdd(add_ab)) - val fail = register.expectMsgType[Register.Forward[CMD_FAIL_MALFORMED_HTLC]].message - assert(fail.id === add_ab.id) - assert(fail.onionHash == Sphinx.PaymentPacket.hash(add_ab.onionRoutingPacket)) - assert(fail.failureCode === (FailureMessageCodecs.BADONION | FailureMessageCodecs.PERM)) - + register.expectMsg(Register.Forward(channelId_ab, CMD_FAIL_HTLC(add_ab.id, Right(InvalidRealm), commit = true))) register.expectNoMsg(100 millis) paymentHandler.expectNoMsg(100 millis) } @@ -460,11 +452,7 @@ class RelayerSpec extends TestkitBaseClass { val add_ab = UpdateAddHtlc(channelId_ab, 123456, amount_ab, paymentHash, expiry_ab, onion) sender.send(relayer, ForwardAdd(add_ab)) - val fail = register.expectMsgType[Register.Forward[CMD_FAIL_MALFORMED_HTLC]].message - assert(fail.id === add_ab.id) - assert(fail.onionHash == Sphinx.PaymentPacket.hash(add_ab.onionRoutingPacket)) - assert(fail.failureCode === (FailureMessageCodecs.BADONION | FailureMessageCodecs.PERM)) - + register.expectMsg(Register.Forward(channelId_ab, CMD_FAIL_HTLC(add_ab.id, Right(InvalidOnionPayload(Sphinx.PaymentPacket.hash(add_ab.onionRoutingPacket))), commit = true))) register.expectNoMsg(100 millis) paymentHandler.expectNoMsg(100 millis) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/FailureMessageCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/FailureMessageCodecsSpec.scala index 30ea734322..687a054fe8 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/FailureMessageCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/FailureMessageCodecsSpec.scala @@ -44,10 +44,10 @@ class FailureMessageCodecsSpec extends FunSuite { test("encode/decode all failure messages") { val msgs: List[FailureMessage] = InvalidRealm :: TemporaryNodeFailure :: PermanentNodeFailure :: RequiredNodeFeatureMissing :: - InvalidOnionVersion(randomBytes32) :: InvalidOnionHmac(randomBytes32) :: InvalidOnionKey(randomBytes32) :: InvalidOnionPayload(randomBytes32) :: + InvalidOnionVersion(randomBytes32) :: InvalidOnionHmac(randomBytes32) :: InvalidOnionKey(randomBytes32) :: TemporaryChannelFailure(channelUpdate) :: PermanentChannelFailure :: RequiredChannelFeatureMissing :: UnknownNextPeer :: AmountBelowMinimum(123456 msat, channelUpdate) :: FeeInsufficient(546463 msat, channelUpdate) :: IncorrectCltvExpiry(CltvExpiry(1211), channelUpdate) :: ExpiryTooSoon(channelUpdate) :: - IncorrectOrUnknownPaymentDetails(123456 msat, 1105) :: FinalIncorrectCltvExpiry(CltvExpiry(1234)) :: ChannelDisabled(0, 1, channelUpdate) :: ExpiryTooFar :: Nil + IncorrectOrUnknownPaymentDetails(123456 msat, 1105) :: FinalIncorrectCltvExpiry(CltvExpiry(1234)) :: ChannelDisabled(0, 1, channelUpdate) :: ExpiryTooFar :: InvalidOnionPayload(randomBytes32) :: Nil msgs.foreach { msg => { @@ -83,8 +83,7 @@ class FailureMessageCodecsSpec extends FunSuite { val msgs = Map( (BADONION | PERM | 4) -> InvalidOnionVersion(randomBytes32), (BADONION | PERM | 5) -> InvalidOnionHmac(randomBytes32), - (BADONION | PERM | 6) -> InvalidOnionKey(randomBytes32), - (BADONION | PERM) -> InvalidOnionPayload(randomBytes32) + (BADONION | PERM | 6) -> InvalidOnionKey(randomBytes32) ) for ((code, message) <- msgs) { From 10143ad80c8c6755ac043b19b59cd25f525b3709 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Thu, 25 Jul 2019 11:57:34 +0200 Subject: [PATCH 05/11] Refactor Payment classes. The PaymentOptions embeds the amount and expiry and additional tlvs. The PaymentInitiator translates payment parameters before forwarding to PaymentLifecycle. This make it easier to add new payment capabilities before the PaymentLifecycle layer (AMP, Loop, Trampoline, etc). --- .../main/scala/fr/acinq/eclair/Eclair.scala | 11 +-- .../fr/acinq/eclair/payment/Autoprobe.scala | 11 +-- .../eclair/payment/PaymentInitiator.scala | 32 ++++++-- .../eclair/payment/PaymentLifecycle.scala | 79 ++++++++++--------- .../fr/acinq/eclair/EclairImplSpec.scala | 22 +++--- .../fr/acinq/eclair/channel/FuzzySpec.scala | 8 +- .../states/StateTestsHelperMethods.scala | 9 ++- .../channel/states/f/ShutdownStateSpec.scala | 4 +- .../eclair/integration/IntegrationSpec.scala | 33 ++++---- .../eclair/payment/HtlcGenerationSpec.scala | 24 +++--- .../eclair/payment/PaymentInitiatorSpec.scala | 61 ++++++++++++++ .../eclair/payment/PaymentLifecycleSpec.scala | 33 ++++---- .../fr/acinq/eclair/payment/RelayerSpec.scala | 34 ++++---- .../scala/fr/acinq/eclair/gui/Handlers.scala | 43 +++++----- 14 files changed, 247 insertions(+), 157 deletions(-) create mode 100644 eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index 50dde47091..7b485e0461 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -29,7 +29,8 @@ import fr.acinq.eclair.channel._ import fr.acinq.eclair.db.{IncomingPayment, NetworkFee, OutgoingPayment, Stats} import fr.acinq.eclair.io.Peer.{GetPeerInfo, PeerInfo} import fr.acinq.eclair.io.{NodeURI, Peer} -import fr.acinq.eclair.payment.PaymentLifecycle._ +import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest +import fr.acinq.eclair.payment.PaymentLifecycle.ReceivePayment import fr.acinq.eclair.payment._ import fr.acinq.eclair.router.{ChannelDesc, RouteRequest, RouteResponse, Router} import fr.acinq.eclair.wire.{ChannelAnnouncement, ChannelUpdate, NodeAddress, NodeAnnouncement} @@ -186,7 +187,7 @@ class EclairImpl(appKit: Kit) extends Eclair { } override def sendToRoute(route: Seq[PublicKey], amount: MilliSatoshi, paymentHash: ByteVector32, finalCltvExpiryDelta: CltvExpiryDelta)(implicit timeout: Timeout): Future[UUID] = { - (appKit.paymentInitiator ? SendPaymentToRoute(amount, paymentHash, route, finalCltvExpiryDelta)).mapTo[UUID] + (appKit.paymentInitiator ? SendPaymentRequest(amount, paymentHash, route.last, 1, finalCltvExpiryDelta, route)).mapTo[UUID] } override def send(recipientNodeId: PublicKey, amount: MilliSatoshi, paymentHash: ByteVector32, invoice_opt: Option[PaymentRequest], maxAttempts_opt: Option[Int], feeThreshold_opt: Option[Satoshi], maxFeePct_opt: Option[Double])(implicit timeout: Timeout): Future[UUID] = { @@ -202,12 +203,12 @@ class EclairImpl(appKit: Kit) extends Eclair { case Some(invoice) if invoice.isExpired => Future.failed(new IllegalArgumentException("invoice has expired")) case Some(invoice) => val sendPayment = invoice.minFinalCltvExpiryDelta match { - case Some(minFinalCltvExpiryDelta) => SendPayment(amount, paymentHash, recipientNodeId, invoice.routingInfo, minFinalCltvExpiryDelta, maxAttempts = maxAttempts, routeParams = Some(routeParams)) - case None => SendPayment(amount, paymentHash, recipientNodeId, invoice.routingInfo, maxAttempts = maxAttempts, routeParams = Some(routeParams)) + case Some(minFinalCltvExpiryDelta) => SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts, minFinalCltvExpiryDelta, assistedRoutes = invoice.routingInfo, routeParams = Some(routeParams)) + case None => SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts, assistedRoutes = invoice.routingInfo, routeParams = Some(routeParams)) } (appKit.paymentInitiator ? sendPayment).mapTo[UUID] case None => - val sendPayment = SendPayment(amount, paymentHash, recipientNodeId, maxAttempts = maxAttempts, routeParams = Some(routeParams)) + val sendPayment = SendPaymentRequest(amount, paymentHash, recipientNodeId, maxAttempts = maxAttempts, routeParams = Some(routeParams)) (appKit.paymentInitiator ? sendPayment).mapTo[UUID] } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala index 2494ad3ffc..40ca336c44 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala @@ -19,7 +19,8 @@ package fr.acinq.eclair.payment import akka.actor.{Actor, ActorLogging, ActorRef, Props} import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.crypto.Sphinx.DecryptedFailurePacket -import fr.acinq.eclair.payment.PaymentLifecycle.{PaymentFailed, PaymentResult, RemoteFailure, SendPayment} +import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest +import fr.acinq.eclair.payment.PaymentLifecycle.{PaymentFailed, PaymentResult, RemoteFailure} import fr.acinq.eclair.router.{Announcements, Data, PublicChannel} import fr.acinq.eclair.wire.IncorrectOrUnknownPaymentDetails import fr.acinq.eclair.{LongToBtcAmount, NodeParams, randomBytes32, secureRandom} @@ -27,9 +28,9 @@ import fr.acinq.eclair.{LongToBtcAmount, NodeParams, randomBytes32, secureRandom import scala.concurrent.duration._ /** - * This actor periodically probes the network by sending payments to random nodes. The payments will eventually fail - * because the recipient doesn't know the preimage, but it allows us to test channels and improve routing for real payments. - */ + * This actor periodically probes the network by sending payments to random nodes. The payments will eventually fail + * because the recipient doesn't know the preimage, but it allows us to test channels and improve routing for real payments. + */ class Autoprobe(nodeParams: NodeParams, router: ActorRef, paymentInitiator: ActorRef) extends Actor with ActorLogging { import Autoprobe._ @@ -54,7 +55,7 @@ class Autoprobe(nodeParams: NodeParams, router: ActorRef, paymentInitiator: Acto case Some(targetNodeId) => val paymentHash = randomBytes32 // we don't even know the preimage (this needs to be a secure random!) log.info(s"sending payment probe to node=$targetNodeId payment_hash=$paymentHash") - paymentInitiator ! SendPayment(PAYMENT_AMOUNT_MSAT, paymentHash, targetNodeId, maxAttempts = 1) + paymentInitiator ! SendPaymentRequest(PAYMENT_AMOUNT_MSAT, paymentHash, targetNodeId, maxAttempts = 1) case None => log.info(s"could not find a destination, re-scheduling") scheduleProbe() diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala index de9bef237a..94f9fa566d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala @@ -17,25 +17,45 @@ package fr.acinq.eclair.payment import java.util.UUID + import akka.actor.{Actor, ActorLogging, ActorRef, Props} -import fr.acinq.eclair.NodeParams -import fr.acinq.eclair.payment.PaymentLifecycle.GenericSendPayment +import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.bitcoin.Crypto.PublicKey +import fr.acinq.eclair.channel.Channel +import fr.acinq.eclair.payment.PaymentLifecycle.{LegacyPayload, SendPayment, SendPaymentToRoute} +import fr.acinq.eclair.payment.PaymentRequest.ExtraHop +import fr.acinq.eclair.router.RouteParams +import fr.acinq.eclair.{CltvExpiryDelta, MilliSatoshi, NodeParams} /** - * Created by PM on 29/08/2016. - */ + * Created by PM on 29/08/2016. + */ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: ActorRef) extends Actor with ActorLogging { override def receive: Receive = { - case c: GenericSendPayment => + case p: PaymentInitiator.SendPaymentRequest => val paymentId = UUID.randomUUID() val payFsm = context.actorOf(PaymentLifecycle.props(nodeParams, paymentId, router, register)) - payFsm forward c + p.predefinedRoute match { + case Nil => payFsm forward SendPayment(p.paymentHash, p.targetNodeId, LegacyPayload(p.amount, p.finalExpiryDelta.toCltvExpiry), p.maxAttempts, p.assistedRoutes, p.routeParams) + case hops => payFsm forward SendPaymentToRoute(p.paymentHash, hops, LegacyPayload(p.amount, p.finalExpiryDelta.toCltvExpiry)) + } sender ! paymentId } } object PaymentInitiator { + def props(nodeParams: NodeParams, router: ActorRef, register: ActorRef) = Props(classOf[PaymentInitiator], nodeParams, router, register) + + case class SendPaymentRequest(amount: MilliSatoshi, + paymentHash: ByteVector32, + targetNodeId: PublicKey, + maxAttempts: Int, + finalExpiryDelta: CltvExpiryDelta = Channel.MIN_CLTV_EXPIRY_DELTA, + predefinedRoute: Seq[PublicKey] = Nil, + assistedRoutes: Seq[Seq[ExtraHop]] = Nil, + routeParams: Option[RouteParams] = None) + } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala index eab6c11fb5..c72b9f8213 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala @@ -22,7 +22,7 @@ import akka.actor.{ActorRef, FSM, Props, Status} import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair._ -import fr.acinq.eclair.channel.{AddHtlcFailed, CMD_ADD_HTLC, Channel, Register} +import fr.acinq.eclair.channel.{AddHtlcFailed, CMD_ADD_HTLC, Register} import fr.acinq.eclair.crypto.{Sphinx, TransportHandler} import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus} import fr.acinq.eclair.payment.PaymentLifecycle._ @@ -47,14 +47,14 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis when(WAITING_FOR_REQUEST) { case Event(c: SendPaymentToRoute, WaitingForRequest) => - val send = SendPayment(c.amount, c.paymentHash, c.hops.last, finalCltvExpiryDelta = c.finalCltvExpiryDelta, maxAttempts = 1) - paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.amount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING)) + val send = SendPayment(c.paymentHash, c.hops.last, c.paymentOptions, maxAttempts = 1) + paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.paymentOptions.finalAmount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING)) router ! FinalizeRoute(c.hops) goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, send, failures = Nil) case Event(c: SendPayment, WaitingForRequest) => - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.amount, c.assistedRoutes, routeParams = c.routeParams) - paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.amount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING)) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, routeParams = c.routeParams) + paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.paymentOptions.finalAmount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING)) goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, c, failures = Nil) } @@ -62,10 +62,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis case Event(RouteResponse(hops, ignoreNodes, ignoreChannels), WaitingForRoute(s, c, failures)) => log.info(s"route found: attempt=${failures.size + 1}/${c.maxAttempts} route=${hops.map(_.nextNodeId).mkString("->")} channels=${hops.map(_.lastUpdate.shortChannelId).mkString("->")}") val firstHop = hops.head - // we add one block in order to not have our htlc fail when a new block has just been found - val finalExpiry = (c.finalCltvExpiryDelta + 1).toCltvExpiry - - val (cmd, sharedSecrets) = buildCommand(id, c.amount, finalExpiry, c.paymentHash, hops) + val (cmd, sharedSecrets) = buildCommand(id, c.paymentHash, hops, c.paymentOptions) register ! Register.ForwardShortId(firstHop.lastUpdate.shortChannelId, cmd) goto(WAITING_FOR_PAYMENT_COMPLETE) using WaitingForComplete(s, c, cmd, failures, sharedSecrets, ignoreNodes, ignoreChannels, hops) @@ -81,7 +78,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis case Event(fulfill: UpdateFulfillHtlc, WaitingForComplete(s, c, cmd, _, _, _, _, hops)) => paymentsDb.updateOutgoingPayment(id, OutgoingPaymentStatus.SUCCEEDED, preimage = Some(fulfill.paymentPreimage)) reply(s, PaymentSucceeded(id, cmd.amount, c.paymentHash, fulfill.paymentPreimage, hops)) - context.system.eventStream.publish(PaymentSent(id, c.amount, cmd.amount - c.amount, cmd.paymentHash, fulfill.paymentPreimage, fulfill.channelId)) + context.system.eventStream.publish(PaymentSent(id, c.paymentOptions.finalAmount, cmd.amount - c.paymentOptions.finalAmount, cmd.paymentHash, fulfill.paymentPreimage, fulfill.channelId)) stop(FSM.Normal) case Event(fail: UpdateFailHtlc, WaitingForComplete(s, c, _, failures, sharedSecrets, ignoreNodes, ignoreChannels, hops)) => @@ -111,12 +108,12 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis // in that case we don't know which node is sending garbage, let's try to blacklist all nodes except the one we are directly connected to and the destination node val blacklist = hops.map(_.nextNodeId).drop(1).dropRight(1) log.warning(s"blacklisting intermediate nodes=${blacklist.mkString(",")}") - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.amount, c.assistedRoutes, ignoreNodes ++ blacklist, ignoreChannels, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes ++ blacklist, ignoreChannels, c.routeParams) goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ UnreadableRemoteFailure(hops)) case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Node)) => log.info(s"received 'Node' type error message from nodeId=$nodeId, trying to route around it (failure=$failureMessage)") // let's try to route around this node - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.amount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams) goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(hops, e)) case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Update)) => log.info(s"received 'Update' type error message from nodeId=$nodeId, retrying payment (failure=$failureMessage)") @@ -144,18 +141,18 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis // in any case, we forward the update to the router router ! failureMessage.update // let's try again, router will have updated its state - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.amount, c.assistedRoutes, ignoreNodes, ignoreChannels, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes, ignoreChannels, c.routeParams) } else { // this node is fishy, it gave us a bad sig!! let's filter it out log.warning(s"got bad signature from node=$nodeId update=${failureMessage.update}") - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.amount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams) } goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(hops, e)) case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) => log.info(s"received an error message from nodeId=$nodeId, trying to use a different channel (failure=$failureMessage)") // let's try again without the channel outgoing from nodeId val faultyChannel = hops.find(_.nodeId == nodeId).map(hop => ChannelDesc(hop.lastUpdate.shortChannelId, hop.nodeId, hop.nextNodeId)) - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.amount, c.assistedRoutes, ignoreNodes, ignoreChannels ++ faultyChannel.toSet, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes, ignoreChannels ++ faultyChannel.toSet, c.routeParams) goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(hops, e)) } @@ -175,7 +172,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis } else { log.info(s"received an error message from local, trying to use a different channel (failure=${t.getMessage})") val faultyChannel = ChannelDesc(hops.head.lastUpdate.shortChannelId, hops.head.nodeId, hops.head.nextNodeId) - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.amount, c.assistedRoutes, ignoreNodes, ignoreChannels + faultyChannel, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes, ignoreChannels + faultyChannel, c.routeParams) goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ LocalFailure(t)) } @@ -199,16 +196,14 @@ object PaymentLifecycle { // @formatter:off case class ReceivePayment(amount_opt: Option[MilliSatoshi], description: String, expirySeconds_opt: Option[Long] = None, extraHops: List[List[ExtraHop]] = Nil, fallbackAddress: Option[String] = None, paymentPreimage: Option[ByteVector32] = None) - sealed trait GenericSendPayment - case class SendPaymentToRoute(amount: MilliSatoshi, paymentHash: ByteVector32, hops: Seq[PublicKey], finalCltvExpiryDelta: CltvExpiryDelta = Channel.MIN_CLTV_EXPIRY_DELTA) extends GenericSendPayment - case class SendPayment(amount: MilliSatoshi, - paymentHash: ByteVector32, + case class SendPaymentToRoute(paymentHash: ByteVector32, hops: Seq[PublicKey], paymentOptions: PaymentOptions) + case class SendPayment(paymentHash: ByteVector32, targetNodeId: PublicKey, - assistedRoutes: Seq[Seq[ExtraHop]] = Nil, - finalCltvExpiryDelta: CltvExpiryDelta = Channel.MIN_CLTV_EXPIRY_DELTA, + paymentOptions: PaymentOptions, maxAttempts: Int, - routeParams: Option[RouteParams] = None) extends GenericSendPayment { - require(amount > 0.msat, s"amountMsat must be > 0") + assistedRoutes: Seq[Seq[ExtraHop]] = Nil, + routeParams: Option[RouteParams] = None) { + require(paymentOptions.finalAmount > 0.msat, s"amount must be > 0") } sealed trait PaymentResult @@ -219,9 +214,17 @@ object PaymentLifecycle { case class UnreadableRemoteFailure(route: Seq[Hop]) extends PaymentFailure case class PaymentFailed(id: UUID, paymentHash: ByteVector32, failures: Seq[PaymentFailure]) extends PaymentResult - sealed trait PaymentOptions - case object LegacyPayload extends PaymentOptions - case object TlvPayload extends PaymentOptions + /** + * Options to help build the final payload of the payment route. + */ + sealed trait PaymentOptions { + // The final htlc amount in millisatoshis. + val finalAmount: MilliSatoshi + // The final htlc expiry in number of blocks. + val finalExpiry: CltvExpiry + } + case class LegacyPayload(finalAmount: MilliSatoshi, finalExpiry: CltvExpiry) extends PaymentOptions + case class TlvPayload(finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, records: Seq[OnionTlv] = Nil) extends PaymentOptions sealed trait Data case object WaitingForRequest extends Data @@ -249,21 +252,19 @@ object PaymentLifecycle { /** * Build the onion payloads for each hop. * - * @param finalAmount the final htlc amount in millisatoshis - * @param finalExpiry the final htlc expiry in number of blocks - * @param hops the hops as computed by the router + extra routes from payment request - * @param opts options to help build each hop's payload - * @return a (firstAmountMsat, firstExpiry, payloads) tuple where: - * - firstAmountMsat is the amount for the first htlc in the route + * @param hops the hops as computed by the router + extra routes from payment request + * @param opts options to help build each hop's payload (final amount, expiry, additional tlv records, etc) + * @return a (firstAmount, firstExpiry, payloads) tuple where: + * - firstAmount is the amount for the first htlc in the route * - firstExpiry is the cltv expiry for the first htlc in the route * - a sequence of payloads that will be used to build the onion */ - def buildPayloads(finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, hops: Seq[Hop], opts: PaymentOptions = LegacyPayload): (MilliSatoshi, CltvExpiry, Seq[OnionPerHopPayload]) = { + def buildPayloads(hops: Seq[Hop], opts: PaymentOptions): (MilliSatoshi, CltvExpiry, Seq[OnionPerHopPayload]) = { val finalPayload: Seq[OnionPerHopPayload] = opts match { - case LegacyPayload => OnionForwardInfo(ShortChannelId(0L), finalAmount, finalExpiry) :: Nil - case TlvPayload => TlvStream[OnionTlv](OnionTlv.AmountToForward(finalAmount), OnionTlv.OutgoingCltv(finalExpiry)) :: Nil + case p: LegacyPayload => OnionForwardInfo(ShortChannelId(0L), p.finalAmount, p.finalExpiry) :: Nil + case p: TlvPayload => TlvStream[OnionTlv](OnionTlv.AmountToForward(p.finalAmount) +: OnionTlv.OutgoingCltv(p.finalExpiry) +: p.records) :: Nil } - hops.reverse.foldLeft((finalAmount, finalExpiry, finalPayload)) { + hops.reverse.foldLeft((opts.finalAmount, opts.finalExpiry, finalPayload)) { case ((amount, expiry, payloads), hop) => val nextFee = nodeFee(hop.lastUpdate.feeBaseMsat, hop.lastUpdate.feeProportionalMillionths, amount) // Since we don't have any scenario where we add tlv data for intermediate hops, we use legacy payloads. @@ -272,8 +273,8 @@ object PaymentLifecycle { } } - def buildCommand(id: UUID, finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, paymentHash: ByteVector32, hops: Seq[Hop], opts: PaymentOptions = LegacyPayload): (CMD_ADD_HTLC, Seq[(ByteVector32, PublicKey)]) = { - val (firstAmount, firstExpiry, payloads) = buildPayloads(finalAmount, finalExpiry, hops.drop(1), opts) + def buildCommand(id: UUID, paymentHash: ByteVector32, hops: Seq[Hop], opts: PaymentOptions): (CMD_ADD_HTLC, Seq[(ByteVector32, PublicKey)]) = { + val (firstAmount, firstExpiry, payloads) = buildPayloads(hops.drop(1), opts) val nodes = hops.map(_.nextNodeId) // BOLT 2 requires that associatedData == paymentHash val onion = buildOnion(nodes, payloads, paymentHash) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala index 1802fc5b6e..3e5f5a8dab 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala @@ -26,7 +26,8 @@ import fr.acinq.eclair.blockchain.TestWallet import fr.acinq.eclair.channel.{CMD_FORCECLOSE, Register, _} import fr.acinq.eclair.db._ import fr.acinq.eclair.io.Peer.OpenChannel -import fr.acinq.eclair.payment.PaymentLifecycle.{ReceivePayment, SendPayment, SendPaymentToRoute} +import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest +import fr.acinq.eclair.payment.PaymentLifecycle.ReceivePayment import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.payment.{LocalPaymentHandler, PaymentRequest} import fr.acinq.eclair.router.RouteCalculationSpec.makeUpdate @@ -94,7 +95,7 @@ class EclairImplSpec extends TestKit(ActorSystem("mySystem")) with fixture.FunSu val nodeId = PublicKey(hex"030bb6a5e0c6b203c7e2180fb78c7ba4bdce46126761d8201b91ddac089cdecc87") eclair.send(nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = None) - val send = paymentInitiator.expectMsgType[SendPayment] + val send = paymentInitiator.expectMsgType[SendPaymentRequest] assert(send.targetNodeId == nodeId) assert(send.amount == 123.msat) assert(send.paymentHash == ByteVector32.Zeroes) @@ -104,7 +105,7 @@ class EclairImplSpec extends TestKit(ActorSystem("mySystem")) with fixture.FunSu val hints = List(List(ExtraHop(Bob.nodeParams.nodeId, ShortChannelId("569178x2331x1"), feeBase = 10 msat, feeProportionalMillionths = 1, cltvExpiryDelta = CltvExpiryDelta(12)))) val invoice1 = PaymentRequest(Block.RegtestGenesisBlock.hash, Some(123 msat), ByteVector32.Zeroes, randomKey, "description", None, None, hints) eclair.send(nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(invoice1)) - val send1 = paymentInitiator.expectMsgType[SendPayment] + val send1 = paymentInitiator.expectMsgType[SendPaymentRequest] assert(send1.targetNodeId == nodeId) assert(send1.amount == 123.msat) assert(send1.paymentHash == ByteVector32.Zeroes) @@ -113,15 +114,15 @@ class EclairImplSpec extends TestKit(ActorSystem("mySystem")) with fixture.FunSu // with finalCltvExpiry val invoice2 = PaymentRequest("lntb", Some(123 msat), System.currentTimeMillis() / 1000L, nodeId, List(PaymentRequest.MinFinalCltvExpiry(96), PaymentRequest.PaymentHash(ByteVector32.Zeroes), PaymentRequest.Description("description")), ByteVector.empty) eclair.send(nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = Some(invoice2)) - val send2 = paymentInitiator.expectMsgType[SendPayment] + val send2 = paymentInitiator.expectMsgType[SendPaymentRequest] assert(send2.targetNodeId == nodeId) assert(send2.amount == 123.msat) assert(send2.paymentHash == ByteVector32.Zeroes) - assert(send2.finalCltvExpiryDelta == CltvExpiryDelta(96)) + assert(send2.finalExpiryDelta == CltvExpiryDelta(96)) // with custom route fees parameters eclair.send(nodeId, 123 msat, ByteVector32.Zeroes, invoice_opt = None, feeThreshold_opt = Some(123 sat), maxFeePct_opt = Some(4.20)) - val send3 = paymentInitiator.expectMsgType[SendPayment] + val send3 = paymentInitiator.expectMsgType[SendPaymentRequest] assert(send3.targetNodeId == nodeId) assert(send3.amount == 123.msat) assert(send3.paymentHash == ByteVector32.Zeroes) @@ -252,12 +253,11 @@ class EclairImplSpec extends TestKit(ActorSystem("mySystem")) with fixture.FunSu val eclair = new EclairImpl(kit) eclair.sendToRoute(route, 1234 msat, ByteVector32.One, CltvExpiryDelta(123)) - val send = paymentInitiator.expectMsgType[SendPaymentToRoute] - - assert(send.hops === route) + val send = paymentInitiator.expectMsgType[SendPaymentRequest] + assert(send.predefinedRoute == route) assert(send.amount === 1234.msat) - assert(send.finalCltvExpiryDelta === CltvExpiryDelta(123)) - assert(send.paymentHash === ByteVector32.One) + assert(send.finalExpiryDelta === CltvExpiryDelta(123)) + assert(send.paymentHash == ByteVector32.One) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala index 53a0ab5914..924d30227f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala @@ -27,7 +27,7 @@ import fr.acinq.eclair.TestConstants.{Alice, Bob} import fr.acinq.eclair._ import fr.acinq.eclair.blockchain._ import fr.acinq.eclair.channel.states.StateTestsHelperMethods -import fr.acinq.eclair.payment.PaymentLifecycle.ReceivePayment +import fr.acinq.eclair.payment.PaymentLifecycle.{LegacyPayload, ReceivePayment} import fr.acinq.eclair.payment._ import fr.acinq.eclair.router.Hop import fr.acinq.eclair.wire._ @@ -39,8 +39,8 @@ import scala.concurrent.duration._ import scala.util.Random /** - * Created by PM on 05/07/2016. - */ + * Created by PM on 05/07/2016. + */ class FuzzySpec extends TestkitBaseClass with StateTestsHelperMethods with Logging { @@ -95,7 +95,7 @@ class FuzzySpec extends TestkitBaseClass with StateTestsHelperMethods with Loggi // allow overpaying (no more than 2 times the required amount) val amount = MilliSatoshi(requiredAmount + Random.nextInt(requiredAmount)) val expiry = (Channel.MIN_CLTV_EXPIRY_DELTA + 1).toCltvExpiry - PaymentLifecycle.buildCommand(UUID.randomUUID(), amount, expiry, paymentHash, Hop(null, dest, null) :: Nil)._1 + PaymentLifecycle.buildCommand(UUID.randomUUID(), paymentHash, Hop(null, dest, null) :: Nil, LegacyPayload(amount, expiry))._1 } def initiatePayment(stopping: Boolean) = diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/StateTestsHelperMethods.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/StateTestsHelperMethods.scala index cdcbab0d8b..d37e174e09 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/StateTestsHelperMethods.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/StateTestsHelperMethods.scala @@ -27,6 +27,7 @@ import fr.acinq.eclair.blockchain.fee.FeeTargets import fr.acinq.eclair.channel._ import fr.acinq.eclair.io.Peer import fr.acinq.eclair.payment.PaymentLifecycle +import fr.acinq.eclair.payment.PaymentLifecycle.LegacyPayload import fr.acinq.eclair.router.Hop import fr.acinq.eclair.wire._ import fr.acinq.eclair.{NodeParams, TestConstants, randomBytes32, _} @@ -109,7 +110,7 @@ trait StateTestsHelperMethods extends TestKitBase { val payment_preimage: ByteVector32 = randomBytes32 val payment_hash: ByteVector32 = Crypto.sha256(payment_preimage) val expiry = CltvExpiryDelta(144).toCltvExpiry - val cmd = PaymentLifecycle.buildCommand(UUID.randomUUID, amount, expiry, payment_hash, Hop(null, destination, null) :: Nil)._1.copy(commit = false) + val cmd = PaymentLifecycle.buildCommand(UUID.randomUUID, payment_hash, Hop(null, destination, null) :: Nil, LegacyPayload(amount, expiry))._1.copy(commit = false) (payment_preimage, cmd) } @@ -124,7 +125,7 @@ trait StateTestsHelperMethods extends TestKitBase { (payment_preimage, htlc) } - def fulfillHtlc(id: Long, R: ByteVector32, s: TestFSMRef[State, Data, Channel], r: TestFSMRef[State, Data, Channel], s2r: TestProbe, r2s: TestProbe) = { + def fulfillHtlc(id: Long, R: ByteVector32, s: TestFSMRef[State, Data, Channel], r: TestFSMRef[State, Data, Channel], s2r: TestProbe, r2s: TestProbe): Unit = { val sender = TestProbe() sender.send(s, CMD_FULFILL_HTLC(id, R)) sender.expectMsg("ok") @@ -133,7 +134,7 @@ trait StateTestsHelperMethods extends TestKitBase { awaitCond(r.stateData.asInstanceOf[HasCommitments].commitments.remoteChanges.proposed.contains(fulfill)) } - def crossSign(s: TestFSMRef[State, Data, Channel], r: TestFSMRef[State, Data, Channel], s2r: TestProbe, r2s: TestProbe) = { + def crossSign(s: TestFSMRef[State, Data, Channel], r: TestFSMRef[State, Data, Channel], s2r: TestProbe, r2s: TestProbe): Unit = { val sender = TestProbe() val sCommitIndex = s.stateData.asInstanceOf[HasCommitments].commitments.localCommit.index val rCommitIndex = r.stateData.asInstanceOf[HasCommitments].commitments.localCommit.index @@ -170,12 +171,14 @@ trait StateTestsHelperMethods extends TestKitBase { implicit class ChannelWithTestFeeConf(a: TestFSMRef[State, Data, Channel]) { def feeEstimator: TestFeeEstimator = a.underlyingActor.nodeParams.onChainFeeConf.feeEstimator.asInstanceOf[TestFeeEstimator] + def feeTargets: FeeTargets = a.underlyingActor.nodeParams.onChainFeeConf.feeTargets } implicit class PeerWithTestFeeConf(a: TestFSMRef[Peer.State, Peer.Data, Peer]) { def feeEstimator: TestFeeEstimator = a.underlyingActor.nodeParams.onChainFeeConf.feeEstimator.asInstanceOf[TestFeeEstimator] + def feeTargets: FeeTargets = a.underlyingActor.nodeParams.onChainFeeConf.feeTargets } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala index 5b425eac08..7ee2bd80ee 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala @@ -56,7 +56,7 @@ class ShutdownStateSpec extends TestkitBaseClass with StateTestsHelperMethods { val h1 = Crypto.sha256(r1) val amount1 = 300000000 msat val expiry1 = CltvExpiryDelta(144).toCltvExpiry - val cmd1 = PaymentLifecycle.buildCommand(UUID.randomUUID, amount1, expiry1, h1, Hop(null, TestConstants.Bob.nodeParams.nodeId, null) :: Nil)._1.copy(commit = false) + val cmd1 = PaymentLifecycle.buildCommand(UUID.randomUUID, h1, Hop(null, TestConstants.Bob.nodeParams.nodeId, null) :: Nil, PaymentLifecycle.LegacyPayload(amount1, expiry1))._1.copy(commit = false) sender.send(alice, cmd1) sender.expectMsg("ok") val htlc1 = alice2bob.expectMsgType[UpdateAddHtlc] @@ -66,7 +66,7 @@ class ShutdownStateSpec extends TestkitBaseClass with StateTestsHelperMethods { val h2 = Crypto.sha256(r2) val amount2 = 200000000 msat val expiry2 = CltvExpiryDelta(144).toCltvExpiry - val cmd2 = PaymentLifecycle.buildCommand(UUID.randomUUID, amount2, expiry2, h2, Hop(null, TestConstants.Bob.nodeParams.nodeId, null) :: Nil)._1.copy(commit = false) + val cmd2 = PaymentLifecycle.buildCommand(UUID.randomUUID, h2, Hop(null, TestConstants.Bob.nodeParams.nodeId, null) :: Nil, PaymentLifecycle.LegacyPayload(amount2, expiry2))._1.copy(commit = false) sender.send(alice, cmd2) sender.expectMsg("ok") val htlc2 = alice2bob.expectMsgType[UpdateAddHtlc] diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala index 2a6db6c66e..dd01ce607e 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala @@ -34,6 +34,7 @@ import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx.DecryptedFailurePacket import fr.acinq.eclair.io.Peer import fr.acinq.eclair.io.Peer.{Disconnect, PeerRoutingMessage} +import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest import fr.acinq.eclair.payment.PaymentLifecycle.{State => _, _} import fr.acinq.eclair.payment.{LocalPaymentHandler, PaymentRequest} import fr.acinq.eclair.router.Graph.WeightRatios @@ -262,8 +263,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.send(nodes("D").paymentHandler, ReceivePayment(Some(amountMsat), "1 coffee")) val pr = sender.expectMsgType[PaymentRequest] // then we make the actual payment - sender.send(nodes("A").paymentInitiator, - SendPayment(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 1)) + sender.send(nodes("A").paymentInitiator, SendPaymentRequest(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 1)) val paymentId = sender.expectMsgType[UUID](5 seconds) val ps = sender.expectMsgType[PaymentSucceeded](5 seconds) assert(ps.id == paymentId) @@ -287,7 +287,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.send(nodes("D").paymentHandler, ReceivePayment(Some(amountMsat), "1 coffee")) val pr = sender.expectMsgType[PaymentRequest] // then we make the actual payment, do not randomize the route to make sure we route through node B - val sendReq = SendPayment(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) + val sendReq = SendPaymentRequest(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) sender.send(nodes("A").paymentInitiator, sendReq) // A will receive an error from B that include the updated channel update, then will retry the payment val paymentId = sender.expectMsgType[UUID](5 seconds) @@ -327,7 +327,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.send(nodes("D").paymentHandler, ReceivePayment(Some(amountMsat), "1 coffee")) val pr = sender.expectMsgType[PaymentRequest] // then we make the payment (B-C has a smaller capacity than A-B and C-D) - val sendReq = SendPayment(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) + val sendReq = SendPaymentRequest(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) sender.send(nodes("A").paymentInitiator, sendReq) // A will first receive an error from C, then retry and route around C: A->B->E->C->D sender.expectMsgType[UUID](5 seconds) @@ -336,7 +336,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService test("send an HTLC A->D with an unknown payment hash") { val sender = TestProbe() - val pr = SendPayment(100000000 msat, randomBytes32, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) + val pr = SendPaymentRequest(100000000 msat, randomBytes32, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) sender.send(nodes("A").paymentInitiator, pr) // A will receive an error from D and won't retry @@ -356,7 +356,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val pr = sender.expectMsgType[PaymentRequest] // A send payment of only 1 mBTC - val sendReq = SendPayment(100000000 msat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) + val sendReq = SendPaymentRequest(100000000 msat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) sender.send(nodes("A").paymentInitiator, sendReq) // A will first receive an IncorrectPaymentAmount error from D @@ -376,7 +376,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val pr = sender.expectMsgType[PaymentRequest] // A send payment of 6 mBTC - val sendReq = SendPayment(600000000 msat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) + val sendReq = SendPaymentRequest(600000000 msat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) sender.send(nodes("A").paymentInitiator, sendReq) // A will first receive an IncorrectPaymentAmount error from D @@ -396,7 +396,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val pr = sender.expectMsgType[PaymentRequest] // A send payment of 3 mBTC, more than asked but it should still be accepted - val sendReq = SendPayment(300000000 msat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) + val sendReq = SendPaymentRequest(300000000 msat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) sender.send(nodes("A").paymentInitiator, sendReq) sender.expectMsgType[UUID] } @@ -409,7 +409,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService sender.send(nodes("D").paymentHandler, ReceivePayment(Some(amountMsat), "1 payment")) val pr = sender.expectMsgType[PaymentRequest] - val sendReq = SendPayment(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) + val sendReq = SendPaymentRequest(amountMsat, pr.paymentHash, nodes("D").nodeParams.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 5) sender.send(nodes("A").paymentInitiator, sendReq) sender.expectMsgType[UUID] sender.expectMsgType[PaymentSucceeded] // the payment FSM will also reply to the sender after the payment is completed @@ -424,8 +424,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val pr = sender.expectMsgType[PaymentRequest](30 seconds) // the payment is requesting to use a capacity-optimized route which will select node G even though it's a bit more expensive - sender.send(nodes("A").paymentInitiator, - SendPayment(amountMsat, pr.paymentHash, nodes("C").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams.map(_.copy(ratios = Some(WeightRatios(0, 0, 1)))))) + sender.send(nodes("A").paymentInitiator, SendPaymentRequest(amountMsat, pr.paymentHash, nodes("C").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams.map(_.copy(ratios = Some(WeightRatios(0, 0, 1)))))) sender.expectMsgType[UUID](max = 60 seconds) awaitCond({ @@ -469,7 +468,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val preimage = randomBytes32 val paymentHash = Crypto.sha256(preimage) // A sends a payment to F - val paymentReq = SendPayment(100000000 msat, paymentHash, nodes("F1").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams) + val paymentReq = SendPaymentRequest(100000000 msat, paymentHash, nodes("F1").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams) val paymentSender = TestProbe() paymentSender.send(nodes("A").paymentInitiator, paymentReq) paymentSender.expectMsgType[UUID](30 seconds) @@ -549,7 +548,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val preimage = randomBytes32 val paymentHash = Crypto.sha256(preimage) // A sends a payment to F - val paymentReq = SendPayment(100000000 msat, paymentHash, nodes("F2").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams) + val paymentReq = SendPaymentRequest(100000000 msat, paymentHash, nodes("F2").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams) val paymentSender = TestProbe() paymentSender.send(nodes("A").paymentInitiator, paymentReq) paymentSender.expectMsgType[UUID](30 seconds) @@ -626,7 +625,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val preimage: ByteVector = randomBytes32 val paymentHash = Crypto.sha256(preimage) // A sends a payment to F - val paymentReq = SendPayment(100000000 msat, paymentHash, nodes("F3").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams) + val paymentReq = SendPaymentRequest(100000000 msat, paymentHash, nodes("F3").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams) val paymentSender = TestProbe() paymentSender.send(nodes("A").paymentInitiator, paymentReq) val paymentId = paymentSender.expectMsgType[UUID] @@ -688,7 +687,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val preimage: ByteVector = randomBytes32 val paymentHash = Crypto.sha256(preimage) // A sends a payment to F - val paymentReq = SendPayment(100000000 msat, paymentHash, nodes("F4").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams) + val paymentReq = SendPaymentRequest(100000000 msat, paymentHash, nodes("F4").nodeParams.nodeId, maxAttempts = 1, routeParams = integrationTestRouteParams) val paymentSender = TestProbe() paymentSender.send(nodes("A").paymentInitiator, paymentReq) val paymentId = paymentSender.expectMsgType[UUID](30 seconds) @@ -762,7 +761,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val amountMsat = 300000000.msat sender.send(paymentHandlerF, ReceivePayment(Some(amountMsat), "1 coffee")) val pr = sender.expectMsgType[PaymentRequest] - val sendReq = SendPayment(300000000 msat, pr.paymentHash, pr.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 1) + val sendReq = SendPaymentRequest(300000000 msat, pr.paymentHash, pr.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 1) sender.send(nodes("A").paymentInitiator, sendReq) val paymentId = sender.expectMsgType[UUID] // we forward the htlc to the payment handler @@ -776,7 +775,7 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService def send(amountMsat: MilliSatoshi, paymentHandler: ActorRef, paymentInitiator: ActorRef) = { sender.send(paymentHandler, ReceivePayment(Some(amountMsat), "1 coffee")) val pr = sender.expectMsgType[PaymentRequest] - val sendReq = SendPayment(amountMsat, pr.paymentHash, pr.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 1) + val sendReq = SendPaymentRequest(amountMsat, pr.paymentHash, pr.nodeId, routeParams = integrationTestRouteParams, maxAttempts = 1) sender.send(paymentInitiator, sendReq) sender.expectMsgType[UUID] } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala index 4a44d531ea..22ff0ca88d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala @@ -26,15 +26,19 @@ import fr.acinq.eclair.crypto.Sphinx.{DecryptedPacket, PacketAndSecrets} import fr.acinq.eclair.payment.PaymentLifecycle._ import fr.acinq.eclair.router.Hop import fr.acinq.eclair.wire._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, MilliSatoshi, ShortChannelId, TestConstants, nodeFee, randomBytes32} -import org.scalatest.FunSuite +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Globals, LongToBtcAmount, MilliSatoshi, ShortChannelId, TestConstants, nodeFee, randomBytes32} +import org.scalatest.{BeforeAndAfterAll, FunSuite} import scodec.bits.ByteVector /** * Created by PM on 31/05/2016. */ -class HtlcGenerationSpec extends FunSuite { +class HtlcGenerationSpec extends FunSuite with BeforeAndAfterAll { + + override def beforeAll { + Globals.blockCount.set(HtlcGenerationSpec.currentBlockCount) + } test("compute fees") { val feeBaseMsat = 150000 msat @@ -49,7 +53,7 @@ class HtlcGenerationSpec extends FunSuite { import HtlcGenerationSpec._ test("compute payloads with fees and expiry delta") { - val (firstAmountMsat, firstExpiry, payloads) = buildPayloads(finalAmountMsat, finalExpiry, hops.drop(1)) + val (firstAmountMsat, firstExpiry, payloads) = buildPayloads(hops.drop(1), LegacyPayload(finalAmountMsat, finalExpiry)) val expectedPayloads = Seq[OnionPerHopPayload]( OnionForwardInfo(channelUpdate_bc.shortChannelId, amount_bc, expiry_bc), OnionForwardInfo(channelUpdate_cd.shortChannelId, amount_cd, expiry_cd), @@ -62,7 +66,7 @@ class HtlcGenerationSpec extends FunSuite { } test("build onion") { - val (_, _, payloads) = buildPayloads(finalAmountMsat, finalExpiry, hops.drop(1)) + val (_, _, payloads) = buildPayloads(hops.drop(1), LegacyPayload(finalAmountMsat, finalExpiry)) val nodes = hops.map(_.nextNodeId) val PacketAndSecrets(packet_b, _) = buildOnion(nodes, payloads, paymentHash) assert(packet_b.payload.length === Sphinx.PaymentPacket.PayloadLength) @@ -94,7 +98,7 @@ class HtlcGenerationSpec extends FunSuite { } test("build onion with final tlv payload") { - val (_, _, payloads) = buildPayloads(finalAmountMsat, finalExpiry, hops.drop(1), TlvPayload) + val (_, _, payloads) = buildPayloads(hops.drop(1), TlvPayload(finalAmountMsat, finalExpiry)) val nodes = hops.map(_.nextNodeId) val PacketAndSecrets(packet_b, _) = buildOnion(nodes, payloads, paymentHash) assert(packet_b.payload.length === Sphinx.PaymentPacket.PayloadLength) @@ -123,7 +127,7 @@ class HtlcGenerationSpec extends FunSuite { } test("build a command including the onion") { - val (add, _) = buildCommand(UUID.randomUUID, finalAmountMsat, finalExpiry, paymentHash, hops) + val (add, _) = buildCommand(UUID.randomUUID, paymentHash, hops, LegacyPayload(finalAmountMsat, finalExpiry)) assert(add.amount > finalAmountMsat) assert(add.cltvExpiry === finalExpiry + channelUpdate_de.cltvExpiryDelta + channelUpdate_cd.cltvExpiryDelta + channelUpdate_bc.cltvExpiryDelta) @@ -157,7 +161,7 @@ class HtlcGenerationSpec extends FunSuite { } test("build a command with no hops") { - val (add, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops.take(1)) + val (add, _) = buildCommand(UUID.randomUUID(), paymentHash, hops.take(1), LegacyPayload(finalAmountMsat, finalExpiry)) assert(add.amount === finalAmountMsat) assert(add.cltvExpiry === finalExpiry) @@ -202,12 +206,12 @@ object HtlcGenerationSpec { Hop(d, e, channelUpdate_de) :: Nil val finalAmountMsat = 42000000 msat - val currentBlockCount = 420000 + val currentBlockCount = 400000 val finalExpiry = CltvExpiry(currentBlockCount) + Channel.MIN_CLTV_EXPIRY_DELTA val paymentPreimage = randomBytes32 val paymentHash = Crypto.sha256(paymentPreimage) - val expiry_de = CltvExpiry(currentBlockCount) + Channel.MIN_CLTV_EXPIRY_DELTA + val expiry_de = finalExpiry val amount_de = finalAmountMsat val fee_d = nodeFee(channelUpdate_de.feeBaseMsat, channelUpdate_de.feeProportionalMillionths, amount_de) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala new file mode 100644 index 0000000000..a059dcc19c --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala @@ -0,0 +1,61 @@ +/* + * Copyright 2019 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.payment + +import java.util.UUID + +import akka.actor.ActorSystem +import akka.testkit.{TestKit, TestProbe} +import fr.acinq.eclair.payment.HtlcGenerationSpec._ +import fr.acinq.eclair.payment.PaymentRequest.ExtraHop +import fr.acinq.eclair.router.{FinalizeRoute, RouteParams, RouteRequest} +import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, TestConstants} +import org.scalatest.FunSuiteLike + +/** + * Created by t-bast on 25/07/2019. + */ + +class PaymentInitiatorSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { + + test("forward payment with pre-defined route") { + val sender = TestProbe() + val router = TestProbe() + val paymentInitiator = system.actorOf(PaymentInitiator.props(TestConstants.Alice.nodeParams, router.ref, TestProbe().ref)) + + sender.send(paymentInitiator, PaymentInitiator.SendPaymentRequest(finalAmountMsat, paymentHash, c, 1, predefinedRoute = Seq(a, b, c))) + sender.expectMsgType[UUID] + router.expectMsg(FinalizeRoute(Seq(a, b, c))) + } + + test("forward legacy payment") { + val sender = TestProbe() + val router = TestProbe() + val paymentInitiator = system.actorOf(PaymentInitiator.props(TestConstants.Alice.nodeParams, router.ref, TestProbe().ref)) + + val hints = Seq(Seq(ExtraHop(b, channelUpdate_bc.shortChannelId, feeBase = 10 msat, feeProportionalMillionths = 1, cltvExpiryDelta = CltvExpiryDelta(12)))) + val routeParams = RouteParams(randomize = true, 15 msat, 1.5, 5, CltvExpiryDelta(561), None) + sender.send(paymentInitiator, PaymentInitiator.SendPaymentRequest(finalAmountMsat, paymentHash, c, 1, CltvExpiryDelta(42), assistedRoutes = hints, routeParams = Some(routeParams))) + sender.expectMsgType[UUID] + router.expectMsg(RouteRequest(TestConstants.Alice.nodeParams.nodeId, c, finalAmountMsat, assistedRoutes = hints, routeParams = Some(routeParams))) + + sender.send(paymentInitiator, PaymentInitiator.SendPaymentRequest(finalAmountMsat, paymentHash, e, 3)) + sender.expectMsgType[UUID] + router.expectMsg(RouteRequest(TestConstants.Alice.nodeParams.nodeId, e, finalAmountMsat)) + } + +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index 684293518c..7bb56dd77b 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -26,7 +26,7 @@ import fr.acinq.bitcoin.{Block, ByteVector32, Transaction, TxOut} import fr.acinq.eclair._ import fr.acinq.eclair.blockchain.{UtxoStatus, ValidateRequest, ValidateResult, WatchSpentBasic} import fr.acinq.eclair.channel.Register.ForwardShortId -import fr.acinq.eclair.channel.{AddHtlcFailed, ChannelUnavailable} +import fr.acinq.eclair.channel.{AddHtlcFailed, Channel, ChannelUnavailable} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.db.OutgoingPaymentStatus import fr.acinq.eclair.io.Peer.PeerRoutingMessage @@ -44,6 +44,7 @@ import scodec.bits.HexStringSyntax class PaymentLifecycleSpec extends BaseRouterSpec { val defaultAmountMsat = 142000000 msat + val defaultExpiry = Channel.MIN_CLTV_EXPIRY_DELTA.toCltvExpiry test("send to route") { fixture => import fixture._ @@ -61,7 +62,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) // pre-computed route going from A to D - val request = SendPaymentToRoute(defaultAmountMsat, defaultPaymentHash, Seq(a, b, c, d)) + val request = SendPaymentToRoute(defaultPaymentHash, Seq(a, b, c, d), LegacyPayload(defaultAmountMsat, defaultExpiry)) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) @@ -87,7 +88,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultAmountMsat, defaultPaymentHash, f, maxAttempts = 5) + val request = SendPayment(defaultPaymentHash, f, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val routeRequest = routerForwarder.expectMsgType[RouteRequest] @@ -110,7 +111,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultAmountMsat, randomBytes32, d, routeParams = Some(RouteParams(randomize = false, maxFeeBase = 100 msat, maxFeePct = 0.0, routeMaxLength = 20, routeMaxCltv = CltvExpiryDelta(2016), ratios = None)), maxAttempts = 5) + val request = SendPayment(randomBytes32, d, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5, routeParams = Some(RouteParams(randomize = false, maxFeeBase = 100 msat, maxFeePct = 0.0, routeMaxLength = 20, routeMaxCltv = CltvExpiryDelta(2016), ratios = None))) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) @@ -133,7 +134,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultAmountMsat, defaultPaymentHash, d, maxAttempts = 2) + val request = SendPayment(defaultPaymentHash, d, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) @@ -176,7 +177,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultAmountMsat, randomBytes32, d, maxAttempts = 2) + val request = SendPayment(randomBytes32, d, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) @@ -209,7 +210,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultAmountMsat, defaultPaymentHash, d, maxAttempts = 2) + val request = SendPayment(defaultPaymentHash, d, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) @@ -240,7 +241,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultAmountMsat, randomBytes32, d, maxAttempts = 2) + val request = SendPayment(randomBytes32, d, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData @@ -280,7 +281,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultAmountMsat, randomBytes32, d, maxAttempts = 5) + val request = SendPayment(randomBytes32, d, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) @@ -342,7 +343,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultAmountMsat, randomBytes32, d, maxAttempts = 2) + val request = SendPayment(randomBytes32, d, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) @@ -390,7 +391,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultAmountMsat, defaultPaymentHash, d, maxAttempts = 5) + val request = SendPayment(defaultPaymentHash, d, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val Transition(_, WAITING_FOR_ROUTE, WAITING_FOR_PAYMENT_COMPLETE) = monitor.expectMsgClass(classOf[Transition[_]]) @@ -398,9 +399,9 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) val paymentOK = sender.expectMsgType[PaymentSucceeded] - val PaymentSent(_, request.amount, fee, request.paymentHash, paymentOK.paymentPreimage, _, _) = eventListener.expectMsgType[PaymentSent] + val PaymentSent(_, request.paymentOptions.finalAmount, fee, request.paymentHash, paymentOK.paymentPreimage, _, _) = eventListener.expectMsgType[PaymentSent] assert(fee > 0.msat) - assert(fee === paymentOK.amount - request.amount) + assert(fee === paymentOK.amount - request.paymentOptions.finalAmount) awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.SUCCEEDED)) } @@ -440,7 +441,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) // we send a payment to G which is just after the - val request = SendPayment(defaultAmountMsat, defaultPaymentHash, g, maxAttempts = 5) + val request = SendPayment(defaultPaymentHash, g, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5) sender.send(paymentFSM, request) // the route will be A -> B -> G where B -> G has a channel_update with fees=0 @@ -450,13 +451,13 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) val paymentOK = sender.expectMsgType[PaymentSucceeded] - val PaymentSent(_, request.amount, fee, request.paymentHash, paymentOK.paymentPreimage, _, _) = eventListener.expectMsgType[PaymentSent] + val PaymentSent(_, request.paymentOptions.finalAmount, fee, request.paymentHash, paymentOK.paymentPreimage, _, _) = eventListener.expectMsgType[PaymentSent] // during the route computation the fees were treated as if they were 1msat but when sending the onion we actually put zero // NB: A -> B doesn't pay fees because it's our direct neighbor // NB: B -> G doesn't asks for fees at all assert(fee === 0.msat) - assert(fee === paymentOK.amount - request.amount) + assert(fee === paymentOK.amount - request.paymentOptions.finalAmount) } test("filter errors properly") { _ => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala index dd04104112..6cc7051b1d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala @@ -23,7 +23,7 @@ import akka.testkit.TestProbe import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.payment.PaymentLifecycle.{buildCommand, buildOnion} +import fr.acinq.eclair.payment.PaymentLifecycle.{LegacyPayload, buildCommand, buildOnion} import fr.acinq.eclair.router.Announcements import fr.acinq.eclair.wire._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, ShortChannelId, TestConstants, TestkitBaseClass, UInt64, nodeFee, randomBytes32} @@ -62,7 +62,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, LegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -88,7 +88,7 @@ class RelayerSpec extends TestkitBaseClass { val (firstAmountMsat, firstExpiry, payloads) = hops.drop(1).reverse.foldLeft((finalAmountMsat, finalExpiry, finalPayload)) { case ((amountMsat, expiry, currentPayloads), hop) => val nextFee = nodeFee(hop.lastUpdate.feeBaseMsat, hop.lastUpdate.feeProportionalMillionths, amountMsat) - val payload: OnionPerHopPayload = TlvStream[OnionTlv](AmountToForward(amountMsat), OutgoingCltv(expiry), OutgoingChannelId(hop.lastUpdate.shortChannelId)) + val payload: OnionPerHopPayload = TlvStream[OnionTlv](AmountToForward(amountMsat), OutgoingCltv(expiry), OutgoingChannelId(hop.lastUpdate.shortChannelId)) (amountMsat + nextFee, expiry + hop.lastUpdate.cltvExpiryDelta, payload +: currentPayloads) } val Sphinx.PacketAndSecrets(onion, _) = buildOnion(hops.map(_.nextNodeId), payloads, paymentHash) @@ -112,7 +112,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, LegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) @@ -156,7 +156,7 @@ class RelayerSpec extends TestkitBaseClass { import f._ val sender = TestProbe() - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops.take(1)) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops.take(1), LegacyPayload(finalAmountMsat, finalExpiry)) val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) sender.send(relayer, ForwardAdd(add_ab)) @@ -172,7 +172,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, LegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) @@ -192,7 +192,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, LegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -219,7 +219,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // check that payments are sent properly - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, LegacyPayload(finalAmountMsat, finalExpiry)) val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -235,7 +235,7 @@ class RelayerSpec extends TestkitBaseClass { // now tell the relayer that the channel is down and try again relayer ! LocalChannelDown(sender.ref, channelId = channelId_bc, shortChannelId = channelUpdate_bc.shortChannelId, remoteNodeId = TestConstants.Bob.nodeParams.nodeId) - val (cmd1, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, randomBytes32, hops) + val (cmd1, _) = buildCommand(UUID.randomUUID(), randomBytes32, hops, LegacyPayload(finalAmountMsat, finalExpiry)) val add_ab1 = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd1.amount, cmd1.paymentHash, cmd1.cltvExpiry, cmd1.onion) sender.send(relayer, ForwardAdd(add_ab1)) @@ -252,7 +252,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, LegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) val channelUpdate_bc_disabled = channelUpdate_bc.copy(channelFlags = Announcements.makeChannelFlags(Announcements.isNode1(channelUpdate_bc.channelFlags), enable = false)) @@ -273,7 +273,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, LegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc with an invalid onion (hmac) val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion.copy(hmac = cmd.onion.hmac.reverse)) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -339,7 +339,9 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), channelUpdate_bc.htlcMinimumMsat - (1 msat), finalExpiry, paymentHash, hops.map(hop => hop.copy(lastUpdate = hop.lastUpdate.copy(feeBaseMsat = 0 msat, feeProportionalMillionths = 0)))) + val paymentOptions = LegacyPayload(channelUpdate_bc.htlcMinimumMsat - (1 msat), finalExpiry) + val zeroFeeHops = hops.map(hop => hop.copy(lastUpdate = hop.lastUpdate.copy(feeBaseMsat = 0 msat, feeProportionalMillionths = 0))) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, zeroFeeHops, paymentOptions) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -359,7 +361,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() val hops1 = hops.updated(1, hops(1).copy(lastUpdate = hops(1).lastUpdate.copy(cltvExpiryDelta = CltvExpiryDelta(0)))) - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops1) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops1, LegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -379,7 +381,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() val hops1 = hops.updated(1, hops(1).copy(lastUpdate = hops(1).lastUpdate.copy(feeBaseMsat = hops(1).lastUpdate.feeBaseMsat / 2))) - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops1) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops1, LegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -400,7 +402,7 @@ class RelayerSpec extends TestkitBaseClass { // to simulate this we use a zero-hop route A->B where A is the 'attacker' val hops1 = hops.head :: Nil - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops1) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops1, LegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc with a wrong expiry val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount - (1 msat), cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -421,7 +423,7 @@ class RelayerSpec extends TestkitBaseClass { // to simulate this we use a zero-hop route A->B where A is the 'attacker' val hops1 = hops.head :: Nil - val (cmd, _) = buildCommand(UUID.randomUUID(), finalAmountMsat, finalExpiry, paymentHash, hops1) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops1, LegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc with a wrong expiry val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry - CltvExpiryDelta(1), cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) diff --git a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/Handlers.scala b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/Handlers.scala index e0bb1daf59..b11aeb5888 100644 --- a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/Handlers.scala +++ b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/Handlers.scala @@ -20,11 +20,11 @@ import java.util.UUID import akka.pattern.{AskTimeoutException, ask} import akka.util.Timeout -import fr.acinq.eclair.MilliSatoshi -import fr.acinq.eclair._ +import fr.acinq.eclair.{MilliSatoshi, _} import fr.acinq.eclair.gui.controllers._ import fr.acinq.eclair.io.{NodeURI, Peer} -import fr.acinq.eclair.payment.PaymentLifecycle.{PaymentResult, ReceivePayment, SendPayment} +import fr.acinq.eclair.payment.PaymentInitiator.SendPaymentRequest +import fr.acinq.eclair.payment.PaymentLifecycle.ReceivePayment import fr.acinq.eclair.payment._ import grizzled.slf4j.Logging @@ -33,26 +33,23 @@ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} /** - * Created by PM on 16/08/2016. - */ + * Created by PM on 16/08/2016. + */ class Handlers(fKit: Future[Kit])(implicit ec: ExecutionContext = ExecutionContext.Implicits.global) extends Logging { implicit val timeout = Timeout(60 seconds) private var notifsController: Option[NotificationsController] = None - def initNotifications(controller: NotificationsController) = { + def initNotifications(controller: NotificationsController): Unit = { notifsController = Option(controller) } /** - * Opens a connection to a node. If the channel option exists this will also open a channel with the node, with a - * `fundingSatoshis` capacity and `pushMsat` amount. - * - * @param nodeUri - * @param channel - */ - def open(nodeUri: NodeURI, channel: Option[Peer.OpenChannel]) = { + * Opens a connection to a node. If the channel option exists this will also open a channel with the node, with a + * `fundingSatoshis` capacity and `pushMsat` amount. + */ + def open(nodeUri: NodeURI, channel: Option[Peer.OpenChannel]): Unit = { logger.info(s"opening a connection to nodeUri=$nodeUri") (for { kit <- fKit @@ -88,8 +85,8 @@ class Handlers(fKit: Future[Kit])(implicit ec: ExecutionContext = ExecutionConte (for { kit <- fKit sendPayment = req.minFinalCltvExpiryDelta match { - case None => SendPayment(MilliSatoshi(amountMsat), req.paymentHash, req.nodeId, req.routingInfo, maxAttempts = kit.nodeParams.maxPaymentAttempts) - case Some(minFinalCltvExpiry) => SendPayment(MilliSatoshi(amountMsat), req.paymentHash, req.nodeId, req.routingInfo, finalCltvExpiryDelta = minFinalCltvExpiry, maxAttempts = kit.nodeParams.maxPaymentAttempts) + case None => SendPaymentRequest(MilliSatoshi(amountMsat), req.paymentHash, req.nodeId, kit.nodeParams.maxPaymentAttempts, assistedRoutes = req.routingInfo) + case Some(minFinalCltvExpiry) => SendPaymentRequest(MilliSatoshi(amountMsat), req.paymentHash, req.nodeId, kit.nodeParams.maxPaymentAttempts, assistedRoutes = req.routingInfo, finalExpiryDelta = minFinalCltvExpiry) } res <- (kit.paymentInitiator ? sendPayment).mapTo[UUID] } yield res).recover { @@ -108,14 +105,14 @@ class Handlers(fKit: Future[Kit])(implicit ec: ExecutionContext = ExecutionConte } yield res /** - * Displays a system notification if the system supports it. - * - * @param title Title of the notification - * @param message main message of the notification, will not wrap - * @param notificationType type of the message, default to NONE - * @param showAppName true if you want the notification title to be preceded by "Eclair - ". True by default - */ - def notification(title: String, message: String, notificationType: NotificationType = NOTIFICATION_NONE, showAppName: Boolean = true) = { + * Displays a system notification if the system supports it. + * + * @param title Title of the notification + * @param message main message of the notification, will not wrap + * @param notificationType type of the message, default to NONE + * @param showAppName true if you want the notification title to be preceded by "Eclair - ". True by default + */ + def notification(title: String, message: String, notificationType: NotificationType = NOTIFICATION_NONE, showAppName: Boolean = true): Unit = { notifsController.foreach(_.addNotification(if (showAppName) s"Eclair - $title" else title, message, notificationType)) } } From fae550e5176a85e3dc28d869d66994b74ba6a3b3 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Tue, 3 Sep 2019 10:50:20 +0200 Subject: [PATCH 06/11] Implement InvalidOnionPayload final spec decision. We don't provide the offset hint though. --- .../fr/acinq/eclair/payment/Relayer.scala | 21 +++++----- .../fr/acinq/eclair/wire/FailureMessage.scala | 8 ++-- .../scala/fr/acinq/eclair/wire/Onion.scala | 40 +++++++++++++------ .../fr/acinq/eclair/crypto/SphinxSpec.scala | 4 +- .../eclair/payment/HtlcGenerationSpec.scala | 2 +- .../fr/acinq/eclair/payment/RelayerSpec.scala | 18 ++++----- .../wire/FailureMessageCodecsSpec.scala | 4 +- .../acinq/eclair/wire/OnionCodecsSpec.scala | 24 +++++------ 8 files changed, 68 insertions(+), 53 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala index 29c6a5598b..e912bb48f8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala @@ -28,7 +28,7 @@ import fr.acinq.eclair.db.OutgoingPaymentStatus import fr.acinq.eclair.payment.PaymentLifecycle.{PaymentFailed, PaymentSucceeded} import fr.acinq.eclair.router.Announcements import fr.acinq.eclair.wire._ -import fr.acinq.eclair.{CltvExpiryDelta, Features, LongToBtcAmount, MilliSatoshi, NodeParams, ShortChannelId, nodeFee} +import fr.acinq.eclair.{CltvExpiryDelta, Features, LongToBtcAmount, MilliSatoshi, NodeParams, ShortChannelId, UInt64, nodeFee} import grizzled.slf4j.Logging import scodec.bits.ByteVector import scodec.{Attempt, DecodeResult} @@ -245,18 +245,19 @@ object Relayer extends Logging { } if (p.isLastPacket) { perHopPayload.paymentInfo match { - case Some(_) => Right(FinalPayload(add, perHopPayload)) - case None => Left(InvalidOnionPayload(Sphinx.PaymentPacket.hash(add.onionRoutingPacket))) + case Right(_) => Right(FinalPayload(add, perHopPayload)) + case Left(err) => Left(err) } } else { perHopPayload.forwardInfo match { - case Some(forwardInfo) => Right(RelayPayload(add, forwardInfo, nextPacket)) - case None => Left(InvalidOnionPayload(Sphinx.PaymentPacket.hash(add.onionRoutingPacket))) + case Right(forwardInfo) => Right(RelayPayload(add, forwardInfo, nextPacket)) + case Left(err) => Left(err) } } case Attempt.Failure(_) => // Onion is correctly encrypted but the content of the per-hop payload couldn't be parsed. - Left(InvalidOnionPayload(Sphinx.PaymentPacket.hash(add.onionRoutingPacket))) + // It's hard to provide tag and offset information from scodec failures, so we currently don't do it. + Left(InvalidOnionPayload(UInt64(0), 0)) } case Left(badOnion) => Left(badOnion) } @@ -272,12 +273,12 @@ object Relayer extends Logging { def handleFinal(finalPayload: FinalPayload): Either[CMD_FAIL_HTLC, UpdateAddHtlc] = { import finalPayload.add finalPayload.payload.paymentInfo match { - case Some(OnionPaymentInfo(amountMsat, _)) if amountMsat > add.amountMsat => + case Right(OnionPaymentInfo(amountMsat, _)) if amountMsat > add.amountMsat => Left(CMD_FAIL_HTLC(add.id, Right(FinalIncorrectHtlcAmount(add.amountMsat)), commit = true)) - case Some(OnionPaymentInfo(_, cltvExpiry)) if cltvExpiry != add.cltvExpiry => + case Right(OnionPaymentInfo(_, cltvExpiry)) if cltvExpiry != add.cltvExpiry => Left(CMD_FAIL_HTLC(add.id, Right(FinalIncorrectCltvExpiry(add.cltvExpiry)), commit = true)) - case None => - Left(CMD_FAIL_HTLC(add.id, Right(InvalidOnionPayload(Sphinx.PaymentPacket.hash(add.onionRoutingPacket))), commit = true)) + case Left(err) => + Left(CMD_FAIL_HTLC(add.id, Right(err), commit = true)) case _ => Right(add) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala index 0998cbf953..442545bfa3 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala @@ -18,10 +18,10 @@ package fr.acinq.eclair.wire import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.crypto.Mac32 -import fr.acinq.eclair.wire.CommonCodecs.{cltvExpiry, discriminatorWithDefault, millisatoshi, sha256} +import fr.acinq.eclair.wire.CommonCodecs._ import fr.acinq.eclair.wire.FailureMessageCodecs.failureMessageCodec import fr.acinq.eclair.wire.LightningMessageCodecs.{channelUpdateCodec, lightningMessageCodec} -import fr.acinq.eclair.{CltvExpiry, LongToBtcAmount, MilliSatoshi} +import fr.acinq.eclair.{CltvExpiry, LongToBtcAmount, MilliSatoshi, UInt64} import scodec.codecs._ import scodec.{Attempt, Codec} @@ -62,7 +62,7 @@ case class ExpiryTooSoon(update: ChannelUpdate) extends Update { def message = " case class FinalIncorrectCltvExpiry(expiry: CltvExpiry) extends FailureMessage { def message = "payment expiry doesn't match the value in the onion" } case class FinalIncorrectHtlcAmount(amount: MilliSatoshi) extends FailureMessage { def message = "payment amount is incorrect in the final htlc" } case object ExpiryTooFar extends FailureMessage { def message = "payment expiry is too far in the future" } -case class InvalidOnionPayload(onionHash: ByteVector32) extends Perm { def message = "onion per-hop payload is invalid" } +case class InvalidOnionPayload(tag: UInt64, offset: Int) extends Perm { def message = "onion per-hop payload is invalid" } /** * We allow remote nodes to send us unknown failure codes (e.g. deprecated failure codes). @@ -115,7 +115,7 @@ object FailureMessageCodecs { .typecase(18, ("expiry" | cltvExpiry).as[FinalIncorrectCltvExpiry]) .typecase(19, ("amountMsat" | millisatoshi).as[FinalIncorrectHtlcAmount]) .typecase(21, provide(ExpiryTooFar)) - .typecase(PERM, sha256.as[InvalidOnionPayload]), + .typecase(PERM | 22, (("tag" | varint) :: ("offset" | uint16)).as[InvalidOnionPayload]), uint16.xmap(code => { val failureMessage = code match { // @formatter:off diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala index 7fad290594..fe84b366e8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala @@ -48,21 +48,35 @@ case class OnionPaymentInfo(amount: MilliSatoshi, cltvExpiry: CltvExpiry) case class OnionPerHopPayload(payload: Either[TlvStream[OnionTlv], OnionForwardInfo]) { - lazy val paymentInfo: Option[OnionPaymentInfo] = payload match { - case Right(OnionForwardInfo(_, amount, cltv)) => Some(OnionPaymentInfo(amount, cltv)) - case Left(tlv) => for { - amount <- tlv.get[AmountToForward].map(_.amount) - cltv <- tlv.get[OutgoingCltv].map(_.cltv) - } yield OnionPaymentInfo(amount, cltv) + lazy val paymentInfo: Either[InvalidOnionPayload, OnionPaymentInfo] = payload match { + case Right(OnionForwardInfo(_, amount, cltv)) => Right(OnionPaymentInfo(amount, cltv)) + case Left(tlv) => + val amount = tlv.get[AmountToForward].map(_.amount) + val cltv = tlv.get[OutgoingCltv].map(_.cltv) + if (amount.isEmpty) { + Left(InvalidOnionPayload(UInt64(2), 0)) + } else if (cltv.isEmpty) { + Left(InvalidOnionPayload(UInt64(4), 0)) + } else { + Right(OnionPaymentInfo(amount.get, cltv.get)) + } } - lazy val forwardInfo: Option[OnionForwardInfo] = payload match { - case Right(onionForwardInfo) => Some(onionForwardInfo) - case Left(tlv) => for { - shortChannelId <- tlv.get[OutgoingChannelId].map(_.shortChannelId) - amount <- tlv.get[AmountToForward].map(_.amount) - cltv <- tlv.get[OutgoingCltv].map(_.cltv) - } yield OnionForwardInfo(shortChannelId, amount, cltv) + lazy val forwardInfo: Either[InvalidOnionPayload, OnionForwardInfo] = payload match { + case Right(onionForwardInfo) => Right(onionForwardInfo) + case Left(tlv) => + val shortChannelId = tlv.get[OutgoingChannelId].map(_.shortChannelId) + val amount = tlv.get[AmountToForward].map(_.amount) + val cltv = tlv.get[OutgoingCltv].map(_.cltv) + if (amount.isEmpty) { + Left(InvalidOnionPayload(UInt64(2), 0)) + } else if (cltv.isEmpty) { + Left(InvalidOnionPayload(UInt64(4), 0)) + } else if (shortChannelId.isEmpty) { + Left(InvalidOnionPayload(UInt64(6), 0)) + } else { + Right(OnionForwardInfo(shortChannelId.get, amount.get, cltv.get)) + } } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala index 59a622b48b..1dbdecffbe 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala @@ -18,7 +18,7 @@ package fr.acinq.eclair.crypto import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} -import fr.acinq.eclair.wire +import fr.acinq.eclair.{UInt64, wire} import fr.acinq.eclair.wire._ import org.scalatest.FunSuite import scodec.bits._ @@ -249,7 +249,7 @@ class SphinxSpec extends FunSuite { val packet = FailurePacket.wrap( FailurePacket.wrap( - FailurePacket.create(sharedSecrets.head, InvalidOnionPayload(ByteVector32.Zeroes)), + FailurePacket.create(sharedSecrets.head, InvalidOnionPayload(UInt64(0), 0)), sharedSecrets(1)), sharedSecrets(2)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala index 22ff0ca88d..d460876180 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala @@ -123,7 +123,7 @@ class HtlcGenerationSpec extends FunSuite with BeforeAndAfterAll { val payload_e = OnionCodecs.tlvPerHopPayloadCodec.decode(bin_e.toBitVector).require.value val paymentInfo = OnionPerHopPayload(Left(payload_e)).paymentInfo assert(packet_random.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(paymentInfo === Some(OnionPaymentInfo(finalAmountMsat, finalExpiry))) + assert(paymentInfo === Right(OnionPaymentInfo(finalAmountMsat, finalExpiry))) } test("build a command including the onion") { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala index 6cc7051b1d..cb89981103 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala @@ -295,20 +295,20 @@ class RelayerSpec extends TestkitBaseClass { // B is not the last hop and receives an onion missing some routing information. val invalidPayloads_bc = Seq( - TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_bc.shortChannelId), AmountToForward(amount_bc)), // Missing cltv expiry. - TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_bc.shortChannelId), OutgoingCltv(expiry_bc)), // Missing forwarding amount. - TlvStream[OnionTlv](AmountToForward(amount_bc), OutgoingCltv(expiry_bc))) // Missing channel id. + (InvalidOnionPayload(UInt64(2), 0), TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_bc.shortChannelId), OutgoingCltv(expiry_bc))), // Missing forwarding amount. + (InvalidOnionPayload(UInt64(4), 0), TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_bc.shortChannelId), AmountToForward(amount_bc))), // Missing cltv expiry. + (InvalidOnionPayload(UInt64(6), 0), TlvStream[OnionTlv](AmountToForward(amount_bc), OutgoingCltv(expiry_bc)))) // Missing channel id. val payload_cd = TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_cd.shortChannelId), AmountToForward(amount_cd), OutgoingCltv(expiry_cd)) val sender = TestProbe() relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) - for (invalidPayload_bc <- invalidPayloads_bc) { + for ((expectedErr, invalidPayload_bc) <- invalidPayloads_bc) { val Sphinx.PacketAndSecrets(onion, _) = buildOnion(Seq(b, c), Seq(invalidPayload_bc, payload_cd), paymentHash) val add_ab = UpdateAddHtlc(channelId_ab, 123456, amount_ab, paymentHash, expiry_ab, onion) sender.send(relayer, ForwardAdd(add_ab)) - register.expectMsg(Register.Forward(channelId_ab, CMD_FAIL_HTLC(add_ab.id, Right(InvalidOnionPayload(Sphinx.PaymentPacket.hash(add_ab.onionRoutingPacket))), commit = true))) + register.expectMsg(Register.Forward(channelId_ab, CMD_FAIL_HTLC(add_ab.id, Right(expectedErr), commit = true))) register.expectNoMsg(100 millis) paymentHandler.expectNoMsg(100 millis) } @@ -444,17 +444,17 @@ class RelayerSpec extends TestkitBaseClass { // B is the last hop and receives an onion missing some payment information. val invalidFinalPayloads = Seq( - TlvStream[OnionTlv](AmountToForward(amount_bc)), // Missing cltv expiry. - TlvStream[OnionTlv](OutgoingCltv(expiry_bc))) // Missing forwarding amount. + (InvalidOnionPayload(UInt64(2), 0), TlvStream[OnionTlv](OutgoingCltv(expiry_bc))), // Missing forwarding amount. + (InvalidOnionPayload(UInt64(4), 0), TlvStream[OnionTlv](AmountToForward(amount_bc)))) // Missing cltv expiry. val sender = TestProbe() - for (invalidFinalPayload <- invalidFinalPayloads) { + for ((expectedErr, invalidFinalPayload) <- invalidFinalPayloads) { val Sphinx.PacketAndSecrets(onion, _) = buildOnion(Seq(b), Seq(invalidFinalPayload), paymentHash) val add_ab = UpdateAddHtlc(channelId_ab, 123456, amount_ab, paymentHash, expiry_ab, onion) sender.send(relayer, ForwardAdd(add_ab)) - register.expectMsg(Register.Forward(channelId_ab, CMD_FAIL_HTLC(add_ab.id, Right(InvalidOnionPayload(Sphinx.PaymentPacket.hash(add_ab.onionRoutingPacket))), commit = true))) + register.expectMsg(Register.Forward(channelId_ab, CMD_FAIL_HTLC(add_ab.id, Right(expectedErr), commit = true))) register.expectNoMsg(100 millis) paymentHandler.expectNoMsg(100 millis) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/FailureMessageCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/FailureMessageCodecsSpec.scala index 687a054fe8..3925b55044 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/FailureMessageCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/FailureMessageCodecsSpec.scala @@ -19,7 +19,7 @@ package fr.acinq.eclair.wire import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64} import fr.acinq.eclair.crypto.Hmac256 import fr.acinq.eclair.wire.FailureMessageCodecs._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, MilliSatoshi, ShortChannelId, randomBytes32, randomBytes64} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, MilliSatoshi, ShortChannelId, UInt64, randomBytes32, randomBytes64} import org.scalatest.FunSuite import scodec.bits._ @@ -47,7 +47,7 @@ class FailureMessageCodecsSpec extends FunSuite { InvalidOnionVersion(randomBytes32) :: InvalidOnionHmac(randomBytes32) :: InvalidOnionKey(randomBytes32) :: TemporaryChannelFailure(channelUpdate) :: PermanentChannelFailure :: RequiredChannelFeatureMissing :: UnknownNextPeer :: AmountBelowMinimum(123456 msat, channelUpdate) :: FeeInsufficient(546463 msat, channelUpdate) :: IncorrectCltvExpiry(CltvExpiry(1211), channelUpdate) :: ExpiryTooSoon(channelUpdate) :: - IncorrectOrUnknownPaymentDetails(123456 msat, 1105) :: FinalIncorrectCltvExpiry(CltvExpiry(1234)) :: ChannelDisabled(0, 1, channelUpdate) :: ExpiryTooFar :: InvalidOnionPayload(randomBytes32) :: Nil + IncorrectOrUnknownPaymentDetails(123456 msat, 1105) :: FinalIncorrectCltvExpiry(CltvExpiry(1234)) :: ChannelDisabled(0, 1, channelUpdate) :: ExpiryTooFar :: InvalidOnionPayload(UInt64(561), 1105) :: Nil msgs.foreach { msg => { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala index 6b0703993e..be9bfff386 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala @@ -21,7 +21,7 @@ import fr.acinq.eclair.UInt64.Conversions._ import fr.acinq.eclair.wire.OnionCodecs._ import fr.acinq.eclair.wire.OnionPerHopPayload._ import fr.acinq.eclair.wire.OnionTlv._ -import fr.acinq.eclair.{CltvExpiry, LongToBtcAmount, ShortChannelId} +import fr.acinq.eclair.{CltvExpiry, LongToBtcAmount, ShortChannelId, UInt64} import org.scalatest.FunSuite import scodec.bits.HexStringSyntax @@ -110,39 +110,39 @@ class OnionCodecsSpec extends FunSuite { test("get payment info") { val legacyPayload: OnionPerHopPayload = OnionForwardInfo(ShortChannelId(550), 561 msat, CltvExpiry(1105)) - assert(legacyPayload.paymentInfo === Some(OnionPaymentInfo(561 msat, CltvExpiry(1105)))) + assert(legacyPayload.paymentInfo === Right(OnionPaymentInfo(561 msat, CltvExpiry(1105)))) val tlvPayload: OnionPerHopPayload = TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(1105))) - assert(tlvPayload.paymentInfo === Some(OnionPaymentInfo(561 msat, CltvExpiry(1105)))) + assert(tlvPayload.paymentInfo === Right(OnionPaymentInfo(561 msat, CltvExpiry(1105)))) val tlvPayloadUnknown: OnionPerHopPayload = TlvStream[OnionTlv](Seq(AmountToForward(561 msat), OutgoingCltv(CltvExpiry(1105))), Seq(GenericTlv(13, hex"2a"))) - assert(tlvPayloadUnknown.paymentInfo === Some(OnionPaymentInfo(561 msat, CltvExpiry(1105)))) + assert(tlvPayloadUnknown.paymentInfo === Right(OnionPaymentInfo(561 msat, CltvExpiry(1105)))) val tlvPayloadNoCltv: OnionPerHopPayload = TlvStream[OnionTlv](AmountToForward(561 msat)) - assert(tlvPayloadNoCltv.paymentInfo === None) + assert(tlvPayloadNoCltv.paymentInfo === Left(InvalidOnionPayload(UInt64(4), 0))) val tlvPayloadNoAmount: OnionPerHopPayload = TlvStream[OnionTlv](OutgoingCltv(CltvExpiry(1105))) - assert(tlvPayloadNoAmount.paymentInfo === None) + assert(tlvPayloadNoAmount.paymentInfo === Left(InvalidOnionPayload(UInt64(2), 0))) } test("get forward info") { val legacyPayload: OnionPerHopPayload = OnionForwardInfo(ShortChannelId(550), 561 msat, CltvExpiry(1105)) - assert(legacyPayload.forwardInfo === Some(OnionForwardInfo(ShortChannelId(550), 561 msat, CltvExpiry(1105)))) + assert(legacyPayload.forwardInfo === Right(OnionForwardInfo(ShortChannelId(550), 561 msat, CltvExpiry(1105)))) val tlvPayload: OnionPerHopPayload = TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(1105)), OutgoingChannelId(ShortChannelId(550))) - assert(tlvPayload.forwardInfo === Some(OnionForwardInfo(ShortChannelId(550), 561 msat, CltvExpiry(1105)))) + assert(tlvPayload.forwardInfo === Right(OnionForwardInfo(ShortChannelId(550), 561 msat, CltvExpiry(1105)))) val tlvPayloadUnknown: OnionPerHopPayload = TlvStream[OnionTlv](Seq(AmountToForward(561 msat), OutgoingCltv(CltvExpiry(1105)), OutgoingChannelId(ShortChannelId(550))), Seq(GenericTlv(13, hex"2a"))) - assert(tlvPayloadUnknown.forwardInfo === Some(OnionForwardInfo(ShortChannelId(550), 561 msat, CltvExpiry(1105)))) + assert(tlvPayloadUnknown.forwardInfo === Right(OnionForwardInfo(ShortChannelId(550), 561 msat, CltvExpiry(1105)))) val tlvPayloadNoCltv: OnionPerHopPayload = TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingChannelId(ShortChannelId(550))) - assert(tlvPayloadNoCltv.forwardInfo === None) + assert(tlvPayloadNoCltv.forwardInfo === Left(InvalidOnionPayload(UInt64(4), 0))) val tlvPayloadNoAmount: OnionPerHopPayload = TlvStream[OnionTlv](OutgoingCltv(CltvExpiry(1105)), OutgoingChannelId(ShortChannelId(550))) - assert(tlvPayloadNoAmount.forwardInfo === None) + assert(tlvPayloadNoAmount.forwardInfo === Left(InvalidOnionPayload(UInt64(2), 0))) val tlvPayloadNoChannelId: OnionPerHopPayload = TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(1105))) - assert(tlvPayloadNoChannelId.forwardInfo === None) + assert(tlvPayloadNoChannelId.forwardInfo === Left(InvalidOnionPayload(UInt64(6), 0))) } } From 3783d979fdeea6b3b5fbf54e6a854b6b5a6cd048 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Tue, 3 Sep 2019 17:08:41 +0200 Subject: [PATCH 07/11] Add 1 block to payment expiry to account for potential new block --- .../scala/fr/acinq/eclair/payment/PaymentInitiator.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala index 94f9fa566d..04e516918b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala @@ -35,10 +35,12 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: Actor override def receive: Receive = { case p: PaymentInitiator.SendPaymentRequest => val paymentId = UUID.randomUUID() + // We add one block in order to not have our htlc fail when a new block has just been found. + val finalExpiry = (p.finalExpiryDelta + 1).toCltvExpiry val payFsm = context.actorOf(PaymentLifecycle.props(nodeParams, paymentId, router, register)) p.predefinedRoute match { - case Nil => payFsm forward SendPayment(p.paymentHash, p.targetNodeId, LegacyPayload(p.amount, p.finalExpiryDelta.toCltvExpiry), p.maxAttempts, p.assistedRoutes, p.routeParams) - case hops => payFsm forward SendPaymentToRoute(p.paymentHash, hops, LegacyPayload(p.amount, p.finalExpiryDelta.toCltvExpiry)) + case Nil => payFsm forward SendPayment(p.paymentHash, p.targetNodeId, LegacyPayload(p.amount, finalExpiry), p.maxAttempts, p.assistedRoutes, p.routeParams) + case hops => payFsm forward SendPaymentToRoute(p.paymentHash, hops, LegacyPayload(p.amount, finalExpiry)) } sender ! paymentId } From 589690ba46f41901ea3753a8eec43689faf4b2b9 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Tue, 3 Sep 2019 17:20:03 +0200 Subject: [PATCH 08/11] Clarify that we don't interpret variable_length_onion_mandatory strictly --- eclair-core/src/main/scala/fr/acinq/eclair/Features.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala index bc43027728..12e9997955 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala @@ -42,6 +42,11 @@ object Features { def hasFeature(features: ByteVector, bit: Int): Boolean = hasFeature(features.bits, bit) + /** + * We currently don't distinguish mandatory and optional. Interpreting VARIABLE_LENGTH_ONION_MANDATORY strictly would + * be very restrictive and probably fork us out of the network. + * We may implement this distinction later, but for now both flags are interpreted as an optional support. + */ def hasVariableLengthOnion(features: ByteVector): Boolean = hasFeature(features, VARIABLE_LENGTH_ONION_MANDATORY) || hasFeature(features, VARIABLE_LENGTH_ONION_OPTIONAL) /** From 9c81ef5e8e2a3bd3a389fd5d767345ad444bce93 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Wed, 4 Sep 2019 14:27:54 +0200 Subject: [PATCH 09/11] Change type architecture for onion per-hop payload. Explicitly expand the matrix of possible types (relay/final, legacy/tlv). --- .../eclair/payment/PaymentInitiator.scala | 7 +- .../eclair/payment/PaymentLifecycle.scala | 69 +++----- .../fr/acinq/eclair/payment/Relayer.scala | 82 ++++----- .../scala/fr/acinq/eclair/wire/Onion.scala | 164 ++++++++++-------- .../fr/acinq/eclair/channel/FuzzySpec.scala | 7 +- .../states/StateTestsHelperMethods.scala | 4 +- .../channel/states/f/ShutdownStateSpec.scala | 14 +- .../eclair/payment/ChannelSelectionSpec.scala | 40 ++--- .../eclair/payment/HtlcGenerationSpec.scala | 118 +++++-------- .../eclair/payment/PaymentLifecycleSpec.scala | 31 ++-- .../fr/acinq/eclair/payment/RelayerSpec.scala | 66 ++++--- .../acinq/eclair/wire/OnionCodecsSpec.scala | 131 ++++++++------ 12 files changed, 374 insertions(+), 359 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala index 04e516918b..13035aa3ef 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala @@ -22,9 +22,10 @@ import akka.actor.{Actor, ActorLogging, ActorRef, Props} import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.channel.Channel -import fr.acinq.eclair.payment.PaymentLifecycle.{LegacyPayload, SendPayment, SendPaymentToRoute} +import fr.acinq.eclair.payment.PaymentLifecycle.{SendPayment, SendPaymentToRoute} import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.router.RouteParams +import fr.acinq.eclair.wire.Onion.FinalLegacyPayload import fr.acinq.eclair.{CltvExpiryDelta, MilliSatoshi, NodeParams} /** @@ -39,8 +40,8 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: Actor val finalExpiry = (p.finalExpiryDelta + 1).toCltvExpiry val payFsm = context.actorOf(PaymentLifecycle.props(nodeParams, paymentId, router, register)) p.predefinedRoute match { - case Nil => payFsm forward SendPayment(p.paymentHash, p.targetNodeId, LegacyPayload(p.amount, finalExpiry), p.maxAttempts, p.assistedRoutes, p.routeParams) - case hops => payFsm forward SendPaymentToRoute(p.paymentHash, hops, LegacyPayload(p.amount, finalExpiry)) + case Nil => payFsm forward SendPayment(p.paymentHash, p.targetNodeId, FinalLegacyPayload(p.amount, finalExpiry), p.maxAttempts, p.assistedRoutes, p.routeParams) + case hops => payFsm forward SendPaymentToRoute(p.paymentHash, hops, FinalLegacyPayload(p.amount, finalExpiry)) } sender ! paymentId } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala index c72b9f8213..0de607d64f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala @@ -28,7 +28,7 @@ import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus} import fr.acinq.eclair.payment.PaymentLifecycle._ import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.router._ -import fr.acinq.eclair.wire.OnionPerHopPayload._ +import fr.acinq.eclair.wire.Onion._ import fr.acinq.eclair.wire._ import scodec.Attempt import scodec.bits.ByteVector @@ -47,14 +47,14 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis when(WAITING_FOR_REQUEST) { case Event(c: SendPaymentToRoute, WaitingForRequest) => - val send = SendPayment(c.paymentHash, c.hops.last, c.paymentOptions, maxAttempts = 1) - paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.paymentOptions.finalAmount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING)) + val send = SendPayment(c.paymentHash, c.hops.last, c.finalPayload, maxAttempts = 1) + paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.finalPayload.amount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING)) router ! FinalizeRoute(c.hops) goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, send, failures = Nil) case Event(c: SendPayment, WaitingForRequest) => - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, routeParams = c.routeParams) - paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.paymentOptions.finalAmount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING)) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, routeParams = c.routeParams) + paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.finalPayload.amount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING)) goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, c, failures = Nil) } @@ -62,7 +62,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis case Event(RouteResponse(hops, ignoreNodes, ignoreChannels), WaitingForRoute(s, c, failures)) => log.info(s"route found: attempt=${failures.size + 1}/${c.maxAttempts} route=${hops.map(_.nextNodeId).mkString("->")} channels=${hops.map(_.lastUpdate.shortChannelId).mkString("->")}") val firstHop = hops.head - val (cmd, sharedSecrets) = buildCommand(id, c.paymentHash, hops, c.paymentOptions) + val (cmd, sharedSecrets) = buildCommand(id, c.paymentHash, hops, c.finalPayload) register ! Register.ForwardShortId(firstHop.lastUpdate.shortChannelId, cmd) goto(WAITING_FOR_PAYMENT_COMPLETE) using WaitingForComplete(s, c, cmd, failures, sharedSecrets, ignoreNodes, ignoreChannels, hops) @@ -78,7 +78,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis case Event(fulfill: UpdateFulfillHtlc, WaitingForComplete(s, c, cmd, _, _, _, _, hops)) => paymentsDb.updateOutgoingPayment(id, OutgoingPaymentStatus.SUCCEEDED, preimage = Some(fulfill.paymentPreimage)) reply(s, PaymentSucceeded(id, cmd.amount, c.paymentHash, fulfill.paymentPreimage, hops)) - context.system.eventStream.publish(PaymentSent(id, c.paymentOptions.finalAmount, cmd.amount - c.paymentOptions.finalAmount, cmd.paymentHash, fulfill.paymentPreimage, fulfill.channelId)) + context.system.eventStream.publish(PaymentSent(id, c.finalPayload.amount, cmd.amount - c.finalPayload.amount, cmd.paymentHash, fulfill.paymentPreimage, fulfill.channelId)) stop(FSM.Normal) case Event(fail: UpdateFailHtlc, WaitingForComplete(s, c, _, failures, sharedSecrets, ignoreNodes, ignoreChannels, hops)) => @@ -108,12 +108,12 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis // in that case we don't know which node is sending garbage, let's try to blacklist all nodes except the one we are directly connected to and the destination node val blacklist = hops.map(_.nextNodeId).drop(1).dropRight(1) log.warning(s"blacklisting intermediate nodes=${blacklist.mkString(",")}") - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes ++ blacklist, ignoreChannels, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes ++ blacklist, ignoreChannels, c.routeParams) goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ UnreadableRemoteFailure(hops)) case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Node)) => log.info(s"received 'Node' type error message from nodeId=$nodeId, trying to route around it (failure=$failureMessage)") // let's try to route around this node - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams) goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(hops, e)) case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Update)) => log.info(s"received 'Update' type error message from nodeId=$nodeId, retrying payment (failure=$failureMessage)") @@ -141,18 +141,18 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis // in any case, we forward the update to the router router ! failureMessage.update // let's try again, router will have updated its state - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes, ignoreChannels, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes, ignoreChannels, c.routeParams) } else { // this node is fishy, it gave us a bad sig!! let's filter it out log.warning(s"got bad signature from node=$nodeId update=${failureMessage.update}") - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams) } goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(hops, e)) case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) => log.info(s"received an error message from nodeId=$nodeId, trying to use a different channel (failure=$failureMessage)") // let's try again without the channel outgoing from nodeId val faultyChannel = hops.find(_.nodeId == nodeId).map(hop => ChannelDesc(hop.lastUpdate.shortChannelId, hop.nodeId, hop.nextNodeId)) - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes, ignoreChannels ++ faultyChannel.toSet, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes, ignoreChannels ++ faultyChannel.toSet, c.routeParams) goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(hops, e)) } @@ -172,7 +172,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis } else { log.info(s"received an error message from local, trying to use a different channel (failure=${t.getMessage})") val faultyChannel = ChannelDesc(hops.head.lastUpdate.shortChannelId, hops.head.nodeId, hops.head.nextNodeId) - router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes, ignoreChannels + faultyChannel, c.routeParams) + router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes, ignoreChannels + faultyChannel, c.routeParams) goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ LocalFailure(t)) } @@ -196,14 +196,14 @@ object PaymentLifecycle { // @formatter:off case class ReceivePayment(amount_opt: Option[MilliSatoshi], description: String, expirySeconds_opt: Option[Long] = None, extraHops: List[List[ExtraHop]] = Nil, fallbackAddress: Option[String] = None, paymentPreimage: Option[ByteVector32] = None) - case class SendPaymentToRoute(paymentHash: ByteVector32, hops: Seq[PublicKey], paymentOptions: PaymentOptions) + case class SendPaymentToRoute(paymentHash: ByteVector32, hops: Seq[PublicKey], finalPayload: FinalPayload) case class SendPayment(paymentHash: ByteVector32, targetNodeId: PublicKey, - paymentOptions: PaymentOptions, + finalPayload: FinalPayload, maxAttempts: Int, assistedRoutes: Seq[Seq[ExtraHop]] = Nil, routeParams: Option[RouteParams] = None) { - require(paymentOptions.finalAmount > 0.msat, s"amount must be > 0") + require(finalPayload.amount > 0.msat, s"amount must be > 0") } sealed trait PaymentResult @@ -214,18 +214,6 @@ object PaymentLifecycle { case class UnreadableRemoteFailure(route: Seq[Hop]) extends PaymentFailure case class PaymentFailed(id: UUID, paymentHash: ByteVector32, failures: Seq[PaymentFailure]) extends PaymentResult - /** - * Options to help build the final payload of the payment route. - */ - sealed trait PaymentOptions { - // The final htlc amount in millisatoshis. - val finalAmount: MilliSatoshi - // The final htlc expiry in number of blocks. - val finalExpiry: CltvExpiry - } - case class LegacyPayload(finalAmount: MilliSatoshi, finalExpiry: CltvExpiry) extends PaymentOptions - case class TlvPayload(finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, records: Seq[OnionTlv] = Nil) extends PaymentOptions - sealed trait Data case object WaitingForRequest extends Data case class WaitingForRoute(sender: ActorRef, c: SendPayment, failures: Seq[PaymentFailure]) extends Data @@ -237,11 +225,14 @@ object PaymentLifecycle { case object WAITING_FOR_PAYMENT_COMPLETE extends State // @formatter:on - def buildOnion(nodes: Seq[PublicKey], payloads: Seq[OnionPerHopPayload], associatedData: ByteVector32): Sphinx.PacketAndSecrets = { + def buildOnion(nodes: Seq[PublicKey], payloads: Seq[PerHopPayload], associatedData: ByteVector32): Sphinx.PacketAndSecrets = { require(nodes.size == payloads.size) val sessionKey = randomKey val payloadsBin: Seq[ByteVector] = payloads - .map(OnionCodecs.perHopPayloadCodec.encode) + .map({ + case p: FinalPayload => OnionCodecs.finalPerHopPayloadCodec.encode(p) + case p: RelayPayload => OnionCodecs.relayPerHopPayloadCodec.encode(p) + }) .map { case Attempt.Successful(bitVector) => bitVector.toByteVector case Attempt.Failure(cause) => throw new RuntimeException(s"serialization error: $cause") @@ -252,29 +243,25 @@ object PaymentLifecycle { /** * Build the onion payloads for each hop. * - * @param hops the hops as computed by the router + extra routes from payment request - * @param opts options to help build each hop's payload (final amount, expiry, additional tlv records, etc) + * @param hops the hops as computed by the router + extra routes from payment request + * @param finalPayload payload data for the final node (amount, expiry, additional tlv records, etc) * @return a (firstAmount, firstExpiry, payloads) tuple where: * - firstAmount is the amount for the first htlc in the route * - firstExpiry is the cltv expiry for the first htlc in the route * - a sequence of payloads that will be used to build the onion */ - def buildPayloads(hops: Seq[Hop], opts: PaymentOptions): (MilliSatoshi, CltvExpiry, Seq[OnionPerHopPayload]) = { - val finalPayload: Seq[OnionPerHopPayload] = opts match { - case p: LegacyPayload => OnionForwardInfo(ShortChannelId(0L), p.finalAmount, p.finalExpiry) :: Nil - case p: TlvPayload => TlvStream[OnionTlv](OnionTlv.AmountToForward(p.finalAmount) +: OnionTlv.OutgoingCltv(p.finalExpiry) +: p.records) :: Nil - } - hops.reverse.foldLeft((opts.finalAmount, opts.finalExpiry, finalPayload)) { + def buildPayloads(hops: Seq[Hop], finalPayload: FinalPayload): (MilliSatoshi, CltvExpiry, Seq[PerHopPayload]) = { + hops.reverse.foldLeft((finalPayload.amount, finalPayload.expiry, Seq[PerHopPayload](finalPayload))) { case ((amount, expiry, payloads), hop) => val nextFee = nodeFee(hop.lastUpdate.feeBaseMsat, hop.lastUpdate.feeProportionalMillionths, amount) // Since we don't have any scenario where we add tlv data for intermediate hops, we use legacy payloads. - val payload: OnionPerHopPayload = OnionForwardInfo(hop.lastUpdate.shortChannelId, amount, expiry) + val payload = RelayLegacyPayload(hop.lastUpdate.shortChannelId, amount, expiry) (amount + nextFee, expiry + hop.lastUpdate.cltvExpiryDelta, payload +: payloads) } } - def buildCommand(id: UUID, paymentHash: ByteVector32, hops: Seq[Hop], opts: PaymentOptions): (CMD_ADD_HTLC, Seq[(ByteVector32, PublicKey)]) = { - val (firstAmount, firstExpiry, payloads) = buildPayloads(hops.drop(1), opts) + def buildCommand(id: UUID, paymentHash: ByteVector32, hops: Seq[Hop], finalPayload: FinalPayload): (CMD_ADD_HTLC, Seq[(ByteVector32, PublicKey)]) = { + val (firstAmount, firstExpiry, payloads) = buildPayloads(hops.drop(1), finalPayload) val nodes = hops.map(_.nextNodeId) // BOLT 2 requires that associatedData == paymentHash val onion = buildOnion(nodes, payloads, paymentHash) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala index e912bb48f8..3e747a1eee 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala @@ -36,7 +36,6 @@ import scodec.{Attempt, DecodeResult} import scala.collection.mutable // @formatter:off - sealed trait Origin case class Local(id: UUID, sender: Option[ActorRef]) extends Origin // we don't persist reference to local actors case class Relayed(originChannelId: ByteVector32, originHtlcId: Long, amountIn: MilliSatoshi, amountOut: MilliSatoshi) extends Origin @@ -49,10 +48,8 @@ case class ForwardFailMalformed(fail: UpdateFailMalformedHtlc, to: Origin, htlc: case object GetUsableBalances case class UsableBalances(remoteNodeId: PublicKey, shortChannelId: ShortChannelId, canSend: MilliSatoshi, canReceive: MilliSatoshi, isPublic: Boolean) - // @formatter:on - /** * Created by PM on 01/02/2017. */ @@ -113,7 +110,7 @@ class Relayer(nodeParams: NodeParams, register: ActorRef, paymentHandler: ActorR case Right(r: RelayPayload) => handleRelay(r, channelUpdates, node2channels, previousFailures, nodeParams.chainHash) match { case RelayFailure(cmdFail) => - log.info(s"rejecting htlc #${add.id} paymentHash=${add.paymentHash} from channelId=${add.channelId} to shortChannelId=${r.payload.shortChannelId} reason=${cmdFail.reason}") + log.info(s"rejecting htlc #${add.id} paymentHash=${add.paymentHash} from channelId=${add.channelId} to shortChannelId=${r.payload.outgoingChannelId} reason=${cmdFail.reason}") commandBuffer ! CommandBuffer.CommandSend(add.channelId, add.id, cmdFail) case RelaySuccess(selectedShortChannelId, cmdAdd) => log.info(s"forwarding htlc #${add.id} paymentHash=${add.paymentHash} from channelId=${add.channelId} to shortChannelId=$selectedShortChannelId") @@ -218,16 +215,16 @@ object Relayer extends Logging { // @formatter:off sealed trait NextPayload - case class FinalPayload(add: UpdateAddHtlc, payload: OnionPerHopPayload) extends NextPayload - case class RelayPayload(add: UpdateAddHtlc, payload: OnionForwardInfo, nextPacket: OnionRoutingPacket) extends NextPayload { - val relayFeeMsat: MilliSatoshi = add.amountMsat - payload.amtToForward - val expiryDelta: CltvExpiryDelta = add.cltvExpiry - payload.outgoingCltvValue + case class FinalPayload(add: UpdateAddHtlc, payload: Onion.FinalPayload) extends NextPayload + case class RelayPayload(add: UpdateAddHtlc, payload: Onion.RelayPayload, nextPacket: OnionRoutingPacket) extends NextPayload { + val relayFeeMsat: MilliSatoshi = add.amountMsat - payload.amountToForward + val expiryDelta: CltvExpiryDelta = add.cltvExpiry - payload.outgoingCltv } // @formatter:on /** - * Decrypt the onion of a received htlc, and find out if the payment is to be relayed, - * or if our node is the last one in the route + * Decrypt the onion of a received htlc, and find out if the payment is to be relayed, or if our node is the last one + * in the route. * * @param add incoming htlc * @param privateKey this node's private key @@ -236,28 +233,21 @@ object Relayer extends Logging { def decryptPacket(add: UpdateAddHtlc, privateKey: PrivateKey, features: ByteVector): Either[FailureMessage, NextPayload] = Sphinx.PaymentPacket.peel(privateKey, add.paymentHash, add.onionRoutingPacket) match { case Right(p@Sphinx.DecryptedPacket(payload, nextPacket, _)) => - OnionCodecs.perHopPayloadCodec.decode(payload.bits) match { - case Attempt.Successful(DecodeResult(OnionPerHopPayload(Left(_)), _)) if !Features.hasVariableLengthOnion(features) => - Left(InvalidRealm) + val codec = if (p.isLastPacket) OnionCodecs.finalPerHopPayloadCodec else OnionCodecs.relayPerHopPayloadCodec + codec.decode(payload.bits) match { + case Attempt.Successful(DecodeResult(_: Onion.TlvPayload, _)) if !Features.hasVariableLengthOnion(features) => Left(InvalidRealm) case Attempt.Successful(DecodeResult(perHopPayload, remainder)) => if (remainder.nonEmpty) { logger.warn(s"${remainder.length} bits remaining after per-hop payload decoding: there might be an issue with the onion codec") } - if (p.isLastPacket) { - perHopPayload.paymentInfo match { - case Right(_) => Right(FinalPayload(add, perHopPayload)) - case Left(err) => Left(err) - } - } else { - perHopPayload.forwardInfo match { - case Right(forwardInfo) => Right(RelayPayload(add, forwardInfo, nextPacket)) - case Left(err) => Left(err) - } + perHopPayload match { + case finalPayload: Onion.FinalPayload => Right(FinalPayload(add, finalPayload)) + case relayPayload: Onion.RelayPayload => Right(RelayPayload(add, relayPayload, nextPacket)) } - case Attempt.Failure(_) => - // Onion is correctly encrypted but the content of the per-hop payload couldn't be parsed. - // It's hard to provide tag and offset information from scodec failures, so we currently don't do it. - Left(InvalidOnionPayload(UInt64(0), 0)) + case Attempt.Failure(e: OnionCodecs.MissingRequiredTlv) => Left(e.failureMessage) + // Onion is correctly encrypted but the content of the per-hop payload couldn't be parsed. + // It's hard to provide tag and offset information from scodec failures, so we currently don't do it. + case Attempt.Failure(_) => Left(InvalidOnionPayload(UInt64(0), 0)) } case Left(badOnion) => Left(badOnion) } @@ -265,22 +255,18 @@ object Relayer extends Logging { /** * Handle an incoming htlc when we are the last node * - * @param finalPayload payload + * @param p final payload * @return either: * - a CMD_FAIL_HTLC to be sent back upstream * - an UpdateAddHtlc to forward */ - def handleFinal(finalPayload: FinalPayload): Either[CMD_FAIL_HTLC, UpdateAddHtlc] = { - import finalPayload.add - finalPayload.payload.paymentInfo match { - case Right(OnionPaymentInfo(amountMsat, _)) if amountMsat > add.amountMsat => - Left(CMD_FAIL_HTLC(add.id, Right(FinalIncorrectHtlcAmount(add.amountMsat)), commit = true)) - case Right(OnionPaymentInfo(_, cltvExpiry)) if cltvExpiry != add.cltvExpiry => - Left(CMD_FAIL_HTLC(add.id, Right(FinalIncorrectCltvExpiry(add.cltvExpiry)), commit = true)) - case Left(err) => - Left(CMD_FAIL_HTLC(add.id, Right(err), commit = true)) - case _ => - Right(add) + def handleFinal(p: FinalPayload): Either[CMD_FAIL_HTLC, UpdateAddHtlc] = { + if (p.add.amountMsat < p.payload.amount) { + Left(CMD_FAIL_HTLC(p.add.id, Right(FinalIncorrectHtlcAmount(p.add.amountMsat)), commit = true)) + } else if (p.add.cltvExpiry != p.payload.expiry) { + Left(CMD_FAIL_HTLC(p.add.id, Right(FinalIncorrectCltvExpiry(p.add.cltvExpiry)), commit = true)) + } else { + Right(p.add) } } @@ -300,7 +286,7 @@ object Relayer extends Logging { */ def handleRelay(relayPayload: RelayPayload, channelUpdates: Map[ShortChannelId, OutgoingChannel], node2channels: mutable.Map[PublicKey, mutable.Set[ShortChannelId]] with mutable.MultiMap[PublicKey, ShortChannelId], previousFailures: Seq[AddHtlcFailed], chainHash: ByteVector32)(implicit log: LoggingAdapter): RelayResult = { import relayPayload._ - log.info(s"relaying htlc #${add.id} paymentHash={} from channelId={} to requestedShortChannelId={} previousAttempts={}", add.paymentHash, add.channelId, relayPayload.payload.shortChannelId, previousFailures.size) + log.info(s"relaying htlc #${add.id} paymentHash={} from channelId={} to requestedShortChannelId={} previousAttempts={}", add.paymentHash, add.channelId, relayPayload.payload.outgoingChannelId, previousFailures.size) val alreadyTried = previousFailures.flatMap(_.channelUpdate).map(_.shortChannelId) selectPreferredChannel(relayPayload, channelUpdates, node2channels, alreadyTried) .flatMap(selectedShortChannelId => channelUpdates.get(selectedShortChannelId).map(_.channelUpdate)) match { @@ -308,7 +294,7 @@ object Relayer extends Logging { // no more channels to try val error = previousFailures // we return the error for the initially requested channel if it exists - .find(_.channelUpdate.map(_.shortChannelId).contains(relayPayload.payload.shortChannelId)) + .find(_.channelUpdate.map(_.shortChannelId).contains(relayPayload.payload.outgoingChannelId)) // otherwise we return the error for the first channel tried .getOrElse(previousFailures.head) RelayFailure(CMD_FAIL_HTLC(add.id, Right(translateError(error)), commit = true)) @@ -325,7 +311,7 @@ object Relayer extends Logging { */ def selectPreferredChannel(relayPayload: RelayPayload, channelUpdates: Map[ShortChannelId, OutgoingChannel], node2channels: mutable.Map[PublicKey, mutable.Set[ShortChannelId]] with mutable.MultiMap[PublicKey, ShortChannelId], alreadyTried: Seq[ShortChannelId])(implicit log: LoggingAdapter): Option[ShortChannelId] = { import relayPayload.add - val requestedShortChannelId = relayPayload.payload.shortChannelId + val requestedShortChannelId = relayPayload.payload.outgoingChannelId log.debug(s"selecting next channel for htlc #${add.id} paymentHash={} from channelId={} to requestedShortChannelId={} previousAttempts={}", add.paymentHash, add.channelId, requestedShortChannelId, alreadyTried.size) // first we find out what is the next node val nextNodeId_opt = channelUpdates.get(requestedShortChannelId) match { @@ -350,7 +336,7 @@ object Relayer extends Logging { (shortChannelId, channelInfo_opt, relayResult) } .collect { case (shortChannelId, Some(channelInfo), _: RelaySuccess) => (shortChannelId, channelInfo.commitments.availableBalanceForSend) } - .filter(_._2 > relayPayload.payload.amtToForward) // we only keep channels that have enough balance to handle this payment + .filter(_._2 > relayPayload.payload.amountToForward) // we only keep channels that have enough balance to handle this payment .toList // needed for ordering .sortBy(_._2) // we want to use the channel with the lowest available balance that can process the payment .headOption match { @@ -383,14 +369,14 @@ object Relayer extends Logging { RelayFailure(CMD_FAIL_HTLC(add.id, Right(UnknownNextPeer), commit = true)) case Some(channelUpdate) if !Announcements.isEnabled(channelUpdate.channelFlags) => RelayFailure(CMD_FAIL_HTLC(add.id, Right(ChannelDisabled(channelUpdate.messageFlags, channelUpdate.channelFlags, channelUpdate)), commit = true)) - case Some(channelUpdate) if payload.amtToForward < channelUpdate.htlcMinimumMsat => - RelayFailure(CMD_FAIL_HTLC(add.id, Right(AmountBelowMinimum(payload.amtToForward, channelUpdate)), commit = true)) + case Some(channelUpdate) if payload.amountToForward < channelUpdate.htlcMinimumMsat => + RelayFailure(CMD_FAIL_HTLC(add.id, Right(AmountBelowMinimum(payload.amountToForward, channelUpdate)), commit = true)) case Some(channelUpdate) if relayPayload.expiryDelta != channelUpdate.cltvExpiryDelta => - RelayFailure(CMD_FAIL_HTLC(add.id, Right(IncorrectCltvExpiry(payload.outgoingCltvValue, channelUpdate)), commit = true)) - case Some(channelUpdate) if relayPayload.relayFeeMsat < nodeFee(channelUpdate.feeBaseMsat, channelUpdate.feeProportionalMillionths, payload.amtToForward) => + RelayFailure(CMD_FAIL_HTLC(add.id, Right(IncorrectCltvExpiry(payload.outgoingCltv, channelUpdate)), commit = true)) + case Some(channelUpdate) if relayPayload.relayFeeMsat < nodeFee(channelUpdate.feeBaseMsat, channelUpdate.feeProportionalMillionths, payload.amountToForward) => RelayFailure(CMD_FAIL_HTLC(add.id, Right(FeeInsufficient(add.amountMsat, channelUpdate)), commit = true)) case Some(channelUpdate) => - RelaySuccess(channelUpdate.shortChannelId, CMD_ADD_HTLC(payload.amtToForward, add.paymentHash, payload.outgoingCltvValue, nextPacket, upstream = Right(add), commit = true, previousFailures = previousFailures)) + RelaySuccess(channelUpdate.shortChannelId, CMD_ADD_HTLC(payload.amountToForward, add.paymentHash, payload.outgoingCltv, nextPacket, upstream = Right(add), commit = true, previousFailures = previousFailures)) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala index fe84b366e8..ee37cdf296 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala @@ -19,98 +19,87 @@ package fr.acinq.eclair.wire import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.wire.CommonCodecs._ -import fr.acinq.eclair.wire.OnionTlv._ import fr.acinq.eclair.wire.TlvCodecs._ import fr.acinq.eclair.{CltvExpiry, MilliSatoshi, ShortChannelId, UInt64} import scodec.bits.{BitVector, ByteVector, HexStringSyntax} -import scodec.codecs._ -import scodec.{Codec, DecodeResult, Decoder} /** * Created by t-bast on 05/07/2019. */ -/** - * Tlv types used inside onion messages. - */ +case class OnionRoutingPacket(version: Int, publicKey: ByteVector, payload: ByteVector, hmac: ByteVector32) + +/** Tlv types used inside onion messages. */ sealed trait OnionTlv extends Tlv -case class OnionRoutingPacket(version: Int, - publicKey: ByteVector, - payload: ByteVector, - hmac: ByteVector32) - -case class OnionForwardInfo(shortChannelId: ShortChannelId, - amtToForward: MilliSatoshi, - outgoingCltvValue: CltvExpiry) - -case class OnionPaymentInfo(amount: MilliSatoshi, cltvExpiry: CltvExpiry) - -case class OnionPerHopPayload(payload: Either[TlvStream[OnionTlv], OnionForwardInfo]) { - - lazy val paymentInfo: Either[InvalidOnionPayload, OnionPaymentInfo] = payload match { - case Right(OnionForwardInfo(_, amount, cltv)) => Right(OnionPaymentInfo(amount, cltv)) - case Left(tlv) => - val amount = tlv.get[AmountToForward].map(_.amount) - val cltv = tlv.get[OutgoingCltv].map(_.cltv) - if (amount.isEmpty) { - Left(InvalidOnionPayload(UInt64(2), 0)) - } else if (cltv.isEmpty) { - Left(InvalidOnionPayload(UInt64(4), 0)) - } else { - Right(OnionPaymentInfo(amount.get, cltv.get)) - } - } +object OnionTlv { - lazy val forwardInfo: Either[InvalidOnionPayload, OnionForwardInfo] = payload match { - case Right(onionForwardInfo) => Right(onionForwardInfo) - case Left(tlv) => - val shortChannelId = tlv.get[OutgoingChannelId].map(_.shortChannelId) - val amount = tlv.get[AmountToForward].map(_.amount) - val cltv = tlv.get[OutgoingCltv].map(_.cltv) - if (amount.isEmpty) { - Left(InvalidOnionPayload(UInt64(2), 0)) - } else if (cltv.isEmpty) { - Left(InvalidOnionPayload(UInt64(4), 0)) - } else if (shortChannelId.isEmpty) { - Left(InvalidOnionPayload(UInt64(6), 0)) - } else { - Right(OnionForwardInfo(shortChannelId.get, amount.get, cltv.get)) - } - } + /** Amount to forward to the next node. */ + case class AmountToForward(amount: MilliSatoshi) extends OnionTlv + + /** CLTV value to use for the HTLC offered to the next node. */ + case class OutgoingCltv(cltv: CltvExpiry) extends OnionTlv + + /** Id of the channel to use to forward a payment to the next node. */ + case class OutgoingChannelId(shortChannelId: ShortChannelId) extends OnionTlv } -object OnionPerHopPayload { +object Onion { - // @formatter:off - implicit def legacyToPerHopPayload(legacy: OnionForwardInfo): OnionPerHopPayload = OnionPerHopPayload(Right(legacy)) - implicit def tlvToPerHopPayload(tlv: TlvStream[OnionTlv]): OnionPerHopPayload = OnionPerHopPayload(Left(tlv)) - // @formatter:on + import OnionTlv._ -} + /** Per-hop payload from an HTLC's payment onion (after decryption and decoding). */ + sealed trait PerHopPayload -object OnionTlv { + /** Legacy fixed-size 65-bytes onion payload. */ + sealed trait LegacyPayload extends PerHopPayload - /** - * Amount to forward to the next node. - */ - case class AmountToForward(amount: MilliSatoshi) extends OnionTlv + /** Variable-length onion payload with optional additional tlv records. */ + sealed trait TlvPayload extends PerHopPayload { + val records: TlvStream[OnionTlv] + } - /** - * CLTV value to use for the HTLC offered to the next node. - */ - case class OutgoingCltv(cltv: CltvExpiry) extends OnionTlv + /** Per-hop payload for an intermediate node. */ + sealed trait RelayPayload extends PerHopPayload { + /** Amount to forward to the next node. */ + val amountToForward: MilliSatoshi + /** CLTV value to use for the HTLC offered to the next node. */ + val outgoingCltv: CltvExpiry + /** Id of the channel to use to forward a payment to the next node. */ + val outgoingChannelId: ShortChannelId + } - /** - * Id of the channel to use to forward a payment to the next node. - */ - case class OutgoingChannelId(shortChannelId: ShortChannelId) extends OnionTlv + /** Per-hop payload for a final node. */ + sealed trait FinalPayload extends PerHopPayload { + val amount: MilliSatoshi + val expiry: CltvExpiry + } + + case class RelayLegacyPayload(outgoingChannelId: ShortChannelId, amountToForward: MilliSatoshi, outgoingCltv: CltvExpiry) extends LegacyPayload with RelayPayload + + case class FinalLegacyPayload(amount: MilliSatoshi, expiry: CltvExpiry) extends LegacyPayload with FinalPayload + + case class RelayTlvPayload(records: TlvStream[OnionTlv]) extends TlvPayload with RelayPayload { + override val amountToForward = records.get[AmountToForward].get.amount + override val outgoingCltv = records.get[OutgoingCltv].get.cltv + override val outgoingChannelId = records.get[OutgoingChannelId].get.shortChannelId + } + + case class FinalTlvPayload(records: TlvStream[OnionTlv]) extends TlvPayload with FinalPayload { + override val amount = records.get[AmountToForward].get.amount + override val expiry = records.get[OutgoingCltv].get.cltv + } } object OnionCodecs { + import Onion._ + import OnionTlv._ + import scodec.codecs._ + import scodec.{Attempt, Codec, DecodeResult, Decoder, Err} + def onionRoutingPacketCodec(payloadLength: Int): Codec[OnionRoutingPacket] = ( ("version" | uint8) :: ("publicKey" | bytes(33)) :: @@ -141,13 +130,48 @@ object OnionCodecs { val tlvPerHopPayloadCodec: Codec[TlvStream[OnionTlv]] = TlvCodecs.lengthPrefixedTlvStream[OnionTlv](onionTlvCodec).complete - val legacyPerHopPayloadCodec: Codec[OnionForwardInfo] = ( + private val legacyRelayPerHopPayloadCodec: Codec[RelayLegacyPayload] = ( ("realm" | constant(ByteVector.fromByte(0))) :: ("short_channel_id" | shortchannelid) :: ("amt_to_forward" | millisatoshi) :: ("outgoing_cltv_value" | cltvExpiry) :: - ("unused_with_v0_version_on_header" | ignore(8 * 12))).as[OnionForwardInfo] + ("unused_with_v0_version_on_header" | ignore(8 * 12))).as[RelayLegacyPayload] + + private val legacyFinalPerHopPayloadCodec: Codec[FinalLegacyPayload] = ( + ("realm" | constant(ByteVector.fromByte(0))) :: + ("short_channel_id" | ignore(8 * 8)) :: + ("amount" | millisatoshi) :: + ("expiry" | cltvExpiry) :: + ("unused_with_v0_version_on_header" | ignore(8 * 12))).as[FinalLegacyPayload] + + case class MissingRequiredTlv(tag: UInt64) extends Err { + // @formatter:off + val failureMessage: FailureMessage = InvalidOnionPayload(tag, 0) + override def message = failureMessage.message + override def context: List[String] = Nil + override def pushContext(ctx: String): Err = this + // @formatter:on + } - val perHopPayloadCodec: Codec[OnionPerHopPayload] = fallback(tlvPerHopPayloadCodec, legacyPerHopPayloadCodec).as[OnionPerHopPayload] + val relayPerHopPayloadCodec: Codec[RelayPayload] = fallback(tlvPerHopPayloadCodec, legacyRelayPerHopPayloadCodec).narrow({ + case Left(tlvs) if tlvs.get[AmountToForward].isEmpty => Attempt.failure(MissingRequiredTlv(UInt64(2))) + case Left(tlvs) if tlvs.get[OutgoingCltv].isEmpty => Attempt.failure(MissingRequiredTlv(UInt64(4))) + case Left(tlvs) if tlvs.get[OutgoingChannelId].isEmpty => Attempt.failure(MissingRequiredTlv(UInt64(6))) + case Left(tlvs) => Attempt.successful(RelayTlvPayload(tlvs)) + case Right(legacy) => Attempt.successful(legacy) + }, { + case legacy: RelayLegacyPayload => Right(legacy) + case RelayTlvPayload(tlvs) => Left(tlvs) + }) + + val finalPerHopPayloadCodec: Codec[FinalPayload] = fallback(tlvPerHopPayloadCodec, legacyFinalPerHopPayloadCodec).narrow({ + case Left(tlvs) if tlvs.get[AmountToForward].isEmpty => Attempt.failure(MissingRequiredTlv(UInt64(2))) + case Left(tlvs) if tlvs.get[OutgoingCltv].isEmpty => Attempt.failure(MissingRequiredTlv(UInt64(4))) + case Left(tlvs) => Attempt.successful(FinalTlvPayload(tlvs)) + case Right(legacy) => Attempt.successful(legacy) + }, { + case legacy: FinalLegacyPayload => Right(legacy) + case FinalTlvPayload(tlvs) => Left(tlvs) + }) } \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala index 924d30227f..f1f464ee1c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala @@ -27,9 +27,10 @@ import fr.acinq.eclair.TestConstants.{Alice, Bob} import fr.acinq.eclair._ import fr.acinq.eclair.blockchain._ import fr.acinq.eclair.channel.states.StateTestsHelperMethods -import fr.acinq.eclair.payment.PaymentLifecycle.{LegacyPayload, ReceivePayment} +import fr.acinq.eclair.payment.PaymentLifecycle.ReceivePayment import fr.acinq.eclair.payment._ import fr.acinq.eclair.router.Hop +import fr.acinq.eclair.wire.Onion.FinalLegacyPayload import fr.acinq.eclair.wire._ import grizzled.slf4j.Logging import org.scalatest.{Outcome, Tag} @@ -95,10 +96,10 @@ class FuzzySpec extends TestkitBaseClass with StateTestsHelperMethods with Loggi // allow overpaying (no more than 2 times the required amount) val amount = MilliSatoshi(requiredAmount + Random.nextInt(requiredAmount)) val expiry = (Channel.MIN_CLTV_EXPIRY_DELTA + 1).toCltvExpiry - PaymentLifecycle.buildCommand(UUID.randomUUID(), paymentHash, Hop(null, dest, null) :: Nil, LegacyPayload(amount, expiry))._1 + PaymentLifecycle.buildCommand(UUID.randomUUID(), paymentHash, Hop(null, dest, null) :: Nil, FinalLegacyPayload(amount, expiry))._1 } - def initiatePayment(stopping: Boolean) = + def initiatePayment(stopping: Boolean): Unit = if (stopping) { context stop self } else { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/StateTestsHelperMethods.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/StateTestsHelperMethods.scala index d37e174e09..9ffb4c8634 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/StateTestsHelperMethods.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/StateTestsHelperMethods.scala @@ -27,8 +27,8 @@ import fr.acinq.eclair.blockchain.fee.FeeTargets import fr.acinq.eclair.channel._ import fr.acinq.eclair.io.Peer import fr.acinq.eclair.payment.PaymentLifecycle -import fr.acinq.eclair.payment.PaymentLifecycle.LegacyPayload import fr.acinq.eclair.router.Hop +import fr.acinq.eclair.wire.Onion.FinalLegacyPayload import fr.acinq.eclair.wire._ import fr.acinq.eclair.{NodeParams, TestConstants, randomBytes32, _} @@ -110,7 +110,7 @@ trait StateTestsHelperMethods extends TestKitBase { val payment_preimage: ByteVector32 = randomBytes32 val payment_hash: ByteVector32 = Crypto.sha256(payment_preimage) val expiry = CltvExpiryDelta(144).toCltvExpiry - val cmd = PaymentLifecycle.buildCommand(UUID.randomUUID, payment_hash, Hop(null, destination, null) :: Nil, LegacyPayload(amount, expiry))._1.copy(commit = false) + val cmd = PaymentLifecycle.buildCommand(UUID.randomUUID, payment_hash, Hop(null, destination, null) :: Nil, FinalLegacyPayload(amount, expiry))._1.copy(commit = false) (payment_preimage, cmd) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala index 7ee2bd80ee..0757f56338 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala @@ -28,6 +28,7 @@ import fr.acinq.eclair.channel._ import fr.acinq.eclair.channel.states.StateTestsHelperMethods import fr.acinq.eclair.payment._ import fr.acinq.eclair.router.Hop +import fr.acinq.eclair.wire.Onion.FinalLegacyPayload import fr.acinq.eclair.wire.{CommitSig, Error, FailureMessageCodecs, PermanentChannelFailure, RevokeAndAck, Shutdown, UpdateAddHtlc, UpdateFailHtlc, UpdateFailMalformedHtlc, UpdateFee, UpdateFulfillHtlc} import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, TestConstants, TestkitBaseClass, randomBytes32} import org.scalatest.Outcome @@ -56,7 +57,7 @@ class ShutdownStateSpec extends TestkitBaseClass with StateTestsHelperMethods { val h1 = Crypto.sha256(r1) val amount1 = 300000000 msat val expiry1 = CltvExpiryDelta(144).toCltvExpiry - val cmd1 = PaymentLifecycle.buildCommand(UUID.randomUUID, h1, Hop(null, TestConstants.Bob.nodeParams.nodeId, null) :: Nil, PaymentLifecycle.LegacyPayload(amount1, expiry1))._1.copy(commit = false) + val cmd1 = PaymentLifecycle.buildCommand(UUID.randomUUID, h1, Hop(null, TestConstants.Bob.nodeParams.nodeId, null) :: Nil, FinalLegacyPayload(amount1, expiry1))._1.copy(commit = false) sender.send(alice, cmd1) sender.expectMsg("ok") val htlc1 = alice2bob.expectMsgType[UpdateAddHtlc] @@ -66,7 +67,7 @@ class ShutdownStateSpec extends TestkitBaseClass with StateTestsHelperMethods { val h2 = Crypto.sha256(r2) val amount2 = 200000000 msat val expiry2 = CltvExpiryDelta(144).toCltvExpiry - val cmd2 = PaymentLifecycle.buildCommand(UUID.randomUUID, h2, Hop(null, TestConstants.Bob.nodeParams.nodeId, null) :: Nil, PaymentLifecycle.LegacyPayload(amount2, expiry2))._1.copy(commit = false) + val cmd2 = PaymentLifecycle.buildCommand(UUID.randomUUID, h2, Hop(null, TestConstants.Bob.nodeParams.nodeId, null) :: Nil, FinalLegacyPayload(amount2, expiry2))._1.copy(commit = false) sender.send(alice, cmd2) sender.expectMsg("ok") val htlc2 = alice2bob.expectMsgType[UpdateAddHtlc] @@ -216,7 +217,6 @@ class ShutdownStateSpec extends TestkitBaseClass with StateTestsHelperMethods { test("recv CMD_FAIL_HTLC (acknowledge in case of failure)") { f => import f._ val sender = TestProbe() - val r = randomBytes32 val initialState = bob.stateData.asInstanceOf[DATA_SHUTDOWN] sender.send(bob, CMD_FAIL_HTLC(42, Right(PermanentChannelFailure))) // this will fail sender.expectMsg(Failure(UnknownHtlcId(channelId(bob), 42))) @@ -239,7 +239,7 @@ class ShutdownStateSpec extends TestkitBaseClass with StateTestsHelperMethods { import f._ val sender = TestProbe() val initialState = bob.stateData.asInstanceOf[DATA_SHUTDOWN] - sender.send(bob, CMD_FAIL_MALFORMED_HTLC(42, ByteVector32.Zeroes, FailureMessageCodecs.BADONION)) + sender.send(bob, CMD_FAIL_MALFORMED_HTLC(42, randomBytes32, FailureMessageCodecs.BADONION)) sender.expectMsg(Failure(UnknownHtlcId(channelId(bob), 42))) assert(initialState == bob.stateData) } @@ -248,7 +248,7 @@ class ShutdownStateSpec extends TestkitBaseClass with StateTestsHelperMethods { import f._ val sender = TestProbe() val initialState = bob.stateData.asInstanceOf[DATA_SHUTDOWN] - sender.send(bob, CMD_FAIL_MALFORMED_HTLC(42, ByteVector32.Zeroes, 42)) + sender.send(bob, CMD_FAIL_MALFORMED_HTLC(42, randomBytes32, 42)) sender.expectMsg(Failure(InvalidFailureCode(channelId(bob)))) assert(initialState == bob.stateData) } @@ -256,10 +256,8 @@ class ShutdownStateSpec extends TestkitBaseClass with StateTestsHelperMethods { test("recv CMD_FAIL_MALFORMED_HTLC (acknowledge in case of failure)") { f => import f._ val sender = TestProbe() - val r = randomBytes32 val initialState = bob.stateData.asInstanceOf[DATA_SHUTDOWN] - - sender.send(bob, CMD_FAIL_MALFORMED_HTLC(42, ByteVector32.Zeroes, FailureMessageCodecs.BADONION)) // this will fail + sender.send(bob, CMD_FAIL_MALFORMED_HTLC(42, randomBytes32, FailureMessageCodecs.BADONION)) // this will fail sender.expectMsg(Failure(UnknownHtlcId(channelId(bob), 42))) relayerB.expectMsg(CommandBuffer.CommandAck(initialState.channelId, 42)) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/ChannelSelectionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/ChannelSelectionSpec.scala index 4238a7008c..e2bbb33201 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/ChannelSelectionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/ChannelSelectionSpec.scala @@ -22,6 +22,7 @@ import fr.acinq.eclair.channel.{CMD_ADD_HTLC, CMD_FAIL_HTLC} import fr.acinq.eclair.payment.HtlcGenerationSpec.makeCommitments import fr.acinq.eclair.payment.Relayer.{OutgoingChannel, RelayFailure, RelayPayload, RelaySuccess} import fr.acinq.eclair.router.Announcements +import fr.acinq.eclair.wire.Onion.RelayLegacyPayload import fr.acinq.eclair.wire._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, MilliSatoshi, ShortChannelId, TestConstants, randomBytes32, randomKey} import org.scalatest.FunSuite @@ -30,6 +31,8 @@ import scala.collection.mutable class ChannelSelectionSpec extends FunSuite { + implicit val log: akka.event.LoggingAdapter = akka.event.NoLogging + /** * This is just a simplified helper function with random values for fields we are not using here */ @@ -37,42 +40,41 @@ class ChannelSelectionSpec extends FunSuite { Announcements.makeChannelUpdate(Block.RegtestGenesisBlock.hash, randomKey, randomKey.publicKey, shortChannelId, cltvExpiryDelta, htlcMinimumMsat, feeBaseMsat, feeProportionalMillionths, htlcMaximumMsat, enable) test("convert to CMD_FAIL_HTLC/CMD_ADD_HTLC") { + val onionPayload = RelayLegacyPayload(ShortChannelId(12345), 998900 msat, CltvExpiry(60)) val relayPayload = RelayPayload( add = UpdateAddHtlc(randomBytes32, 42, 1000000 msat, randomBytes32, CltvExpiry(70), TestConstants.emptyOnionPacket), - payload = OnionForwardInfo(ShortChannelId(12345), amtToForward = 998900 msat, outgoingCltvValue = CltvExpiry(60)), + payload = onionPayload, nextPacket = TestConstants.emptyOnionPacket // just a placeholder ) val channelUpdate = dummyUpdate(ShortChannelId(12345), CltvExpiryDelta(10), 100 msat, 1000 msat, 100, 10000000 msat, true) - implicit val log = akka.event.NoLogging - // nominal case - assert(Relayer.relayOrFail(relayPayload, Some(channelUpdate)) === RelaySuccess(ShortChannelId(12345), CMD_ADD_HTLC(relayPayload.payload.amtToForward, relayPayload.add.paymentHash, relayPayload.payload.outgoingCltvValue, relayPayload.nextPacket, upstream = Right(relayPayload.add), commit = true))) + assert(Relayer.relayOrFail(relayPayload, Some(channelUpdate)) === RelaySuccess(ShortChannelId(12345), CMD_ADD_HTLC(relayPayload.payload.amountToForward, relayPayload.add.paymentHash, relayPayload.payload.outgoingCltv, relayPayload.nextPacket, upstream = Right(relayPayload.add), commit = true))) // no channel_update assert(Relayer.relayOrFail(relayPayload, channelUpdate_opt = None) === RelayFailure(CMD_FAIL_HTLC(relayPayload.add.id, Right(UnknownNextPeer), commit = true))) // channel disabled val channelUpdate_disabled = channelUpdate.copy(channelFlags = Announcements.makeChannelFlags(isNode1 = true, enable = false)) assert(Relayer.relayOrFail(relayPayload, Some(channelUpdate_disabled)) === RelayFailure(CMD_FAIL_HTLC(relayPayload.add.id, Right(ChannelDisabled(channelUpdate_disabled.messageFlags, channelUpdate_disabled.channelFlags, channelUpdate_disabled)), commit = true))) // amount too low - val relayPayload_toolow = relayPayload.copy(payload = relayPayload.payload.copy(amtToForward = 99 msat)) - assert(Relayer.relayOrFail(relayPayload_toolow, Some(channelUpdate)) === RelayFailure(CMD_FAIL_HTLC(relayPayload.add.id, Right(AmountBelowMinimum(relayPayload_toolow.payload.amtToForward, channelUpdate)), commit = true))) + val relayPayload_toolow = relayPayload.copy(payload = onionPayload.copy(amountToForward = 99 msat)) + assert(Relayer.relayOrFail(relayPayload_toolow, Some(channelUpdate)) === RelayFailure(CMD_FAIL_HTLC(relayPayload.add.id, Right(AmountBelowMinimum(relayPayload_toolow.payload.amountToForward, channelUpdate)), commit = true))) // incorrect cltv expiry - val relayPayload_incorrectcltv = relayPayload.copy(payload = relayPayload.payload.copy(outgoingCltvValue = CltvExpiry(42))) - assert(Relayer.relayOrFail(relayPayload_incorrectcltv, Some(channelUpdate)) === RelayFailure(CMD_FAIL_HTLC(relayPayload.add.id, Right(IncorrectCltvExpiry(relayPayload_incorrectcltv.payload.outgoingCltvValue, channelUpdate)), commit = true))) + val relayPayload_incorrectcltv = relayPayload.copy(payload = onionPayload.copy(outgoingCltv = CltvExpiry(42))) + assert(Relayer.relayOrFail(relayPayload_incorrectcltv, Some(channelUpdate)) === RelayFailure(CMD_FAIL_HTLC(relayPayload.add.id, Right(IncorrectCltvExpiry(relayPayload_incorrectcltv.payload.outgoingCltv, channelUpdate)), commit = true))) // insufficient fee - val relayPayload_insufficientfee = relayPayload.copy(payload = relayPayload.payload.copy(amtToForward = 998910 msat)) + val relayPayload_insufficientfee = relayPayload.copy(payload = onionPayload.copy(amountToForward = 998910 msat)) assert(Relayer.relayOrFail(relayPayload_insufficientfee, Some(channelUpdate)) === RelayFailure(CMD_FAIL_HTLC(relayPayload.add.id, Right(FeeInsufficient(relayPayload_insufficientfee.add.amountMsat, channelUpdate)), commit = true))) // note that a generous fee is ok! - val relayPayload_highfee = relayPayload.copy(payload = relayPayload.payload.copy(amtToForward = 900000 msat)) - assert(Relayer.relayOrFail(relayPayload_highfee, Some(channelUpdate)) === RelaySuccess(ShortChannelId(12345), CMD_ADD_HTLC(relayPayload_highfee.payload.amtToForward, relayPayload_highfee.add.paymentHash, relayPayload_highfee.payload.outgoingCltvValue, relayPayload_highfee.nextPacket, upstream = Right(relayPayload.add), commit = true))) + val relayPayload_highfee = relayPayload.copy(payload = onionPayload.copy(amountToForward = 900000 msat)) + assert(Relayer.relayOrFail(relayPayload_highfee, Some(channelUpdate)) === RelaySuccess(ShortChannelId(12345), CMD_ADD_HTLC(relayPayload_highfee.payload.amountToForward, relayPayload_highfee.add.paymentHash, relayPayload_highfee.payload.outgoingCltv, relayPayload_highfee.nextPacket, upstream = Right(relayPayload.add), commit = true))) } test("channel selection") { - + val onionPayload = RelayLegacyPayload(ShortChannelId(12345), 998900 msat, CltvExpiry(60)) val relayPayload = RelayPayload( add = UpdateAddHtlc(randomBytes32, 42, 1000000 msat, randomBytes32, CltvExpiry(70), TestConstants.emptyOnionPacket), - payload = OnionForwardInfo(ShortChannelId(12345), amtToForward = 998900 msat, outgoingCltvValue = CltvExpiry(60)), + payload = onionPayload, nextPacket = TestConstants.emptyOnionPacket // just a placeholder ) @@ -91,10 +93,6 @@ class ChannelSelectionSpec extends FunSuite { node2channels.put(a, mutable.Set(ShortChannelId(12345), ShortChannelId(11111), ShortChannelId(22222), ShortChannelId(33333))) node2channels.put(b, mutable.Set(ShortChannelId(44444))) - implicit val log = akka.event.NoLogging - - import com.softwaremill.quicklens._ - // select the channel to the same node, with the lowest balance but still high enough to handle the payment assert(Relayer.selectPreferredChannel(relayPayload, channelUpdates, node2channels, Seq.empty) === Some(ShortChannelId(22222))) // select 2nd-to-best channel @@ -104,13 +102,13 @@ class ChannelSelectionSpec extends FunSuite { // all the suitable channels have been tried assert(Relayer.selectPreferredChannel(relayPayload, channelUpdates, node2channels, Seq(ShortChannelId(22222), ShortChannelId(12345), ShortChannelId(11111))) === None) // higher amount payment (have to increased incoming htlc amount for fees to be sufficient) - assert(Relayer.selectPreferredChannel(relayPayload.modify(_.add.amountMsat).setTo(60000000 msat).modify(_.payload.amtToForward).setTo(50000000 msat), channelUpdates, node2channels, Seq.empty) === Some(ShortChannelId(11111))) + assert(Relayer.selectPreferredChannel(relayPayload.copy(add = relayPayload.add.copy(amountMsat = 60000000 msat), payload = onionPayload.copy(amountToForward = 50000000 msat)), channelUpdates, node2channels, Seq.empty) === Some(ShortChannelId(11111))) // lower amount payment - assert(Relayer.selectPreferredChannel(relayPayload.modify(_.payload.amtToForward).setTo(1000 msat), channelUpdates, node2channels, Seq.empty) === Some(ShortChannelId(33333))) + assert(Relayer.selectPreferredChannel(relayPayload.copy(payload = onionPayload.copy(amountToForward = 1000 msat)), channelUpdates, node2channels, Seq.empty) === Some(ShortChannelId(33333))) // payment too high, no suitable channel found - assert(Relayer.selectPreferredChannel(relayPayload.modify(_.payload.amtToForward).setTo(1000000000 msat), channelUpdates, node2channels, Seq.empty) === Some(ShortChannelId(12345))) + assert(Relayer.selectPreferredChannel(relayPayload.copy(payload = onionPayload.copy(amountToForward = 1000000000 msat)), channelUpdates, node2channels, Seq.empty) === Some(ShortChannelId(12345))) // invalid cltv expiry, no suitable channel, we keep the requested one - assert(Relayer.selectPreferredChannel(relayPayload.modify(_.payload.outgoingCltvValue).setTo(CltvExpiry(40)), channelUpdates, node2channels, Seq.empty) === Some(ShortChannelId(12345))) + assert(Relayer.selectPreferredChannel(relayPayload.copy(payload = onionPayload.copy(outgoingCltv = CltvExpiry(40))), channelUpdates, node2channels, Seq.empty) === Some(ShortChannelId(12345))) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala index d460876180..6361c3813c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/HtlcGenerationSpec.scala @@ -25,6 +25,8 @@ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.crypto.Sphinx.{DecryptedPacket, PacketAndSecrets} import fr.acinq.eclair.payment.PaymentLifecycle._ import fr.acinq.eclair.router.Hop +import fr.acinq.eclair.wire.Onion.{FinalLegacyPayload, FinalTlvPayload, PerHopPayload, RelayLegacyPayload} +import fr.acinq.eclair.wire.OnionTlv.{AmountToForward, OutgoingCltv} import fr.acinq.eclair.wire._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Globals, LongToBtcAmount, MilliSatoshi, ShortChannelId, TestConstants, nodeFee, randomBytes32} import org.scalatest.{BeforeAndAfterAll, FunSuite} @@ -53,116 +55,80 @@ class HtlcGenerationSpec extends FunSuite with BeforeAndAfterAll { import HtlcGenerationSpec._ test("compute payloads with fees and expiry delta") { - val (firstAmountMsat, firstExpiry, payloads) = buildPayloads(hops.drop(1), LegacyPayload(finalAmountMsat, finalExpiry)) - val expectedPayloads = Seq[OnionPerHopPayload]( - OnionForwardInfo(channelUpdate_bc.shortChannelId, amount_bc, expiry_bc), - OnionForwardInfo(channelUpdate_cd.shortChannelId, amount_cd, expiry_cd), - OnionForwardInfo(channelUpdate_de.shortChannelId, amount_de, expiry_de), - OnionForwardInfo(ShortChannelId(0L), finalAmountMsat, finalExpiry)) + val (firstAmountMsat, firstExpiry, payloads) = buildPayloads(hops.drop(1), FinalLegacyPayload(finalAmountMsat, finalExpiry)) + val expectedPayloads = Seq[PerHopPayload]( + RelayLegacyPayload(channelUpdate_bc.shortChannelId, amount_bc, expiry_bc), + RelayLegacyPayload(channelUpdate_cd.shortChannelId, amount_cd, expiry_cd), + RelayLegacyPayload(channelUpdate_de.shortChannelId, amount_de, expiry_de), + FinalLegacyPayload(finalAmountMsat, finalExpiry)) assert(firstAmountMsat === amount_ab) assert(firstExpiry === expiry_ab) assert(payloads === expectedPayloads) } - test("build onion") { - val (_, _, payloads) = buildPayloads(hops.drop(1), LegacyPayload(finalAmountMsat, finalExpiry)) + def testBuildOnion(legacy: Boolean): Unit = { + val finalPayload = if (legacy) { + FinalLegacyPayload(finalAmountMsat, finalExpiry) + } else { + FinalTlvPayload(TlvStream[OnionTlv](AmountToForward(finalAmountMsat), OutgoingCltv(finalExpiry))) + } + val (_, _, payloads) = buildPayloads(hops.drop(1), finalPayload) val nodes = hops.map(_.nextNodeId) val PacketAndSecrets(packet_b, _) = buildOnion(nodes, payloads, paymentHash) assert(packet_b.payload.length === Sphinx.PaymentPacket.PayloadLength) // let's peel the onion + testPeelOnion(packet_b) + } + + def testPeelOnion(packet_b: OnionRoutingPacket): Unit = { val Right(DecryptedPacket(bin_b, packet_c, _)) = Sphinx.PaymentPacket.peel(priv_b.privateKey, paymentHash, packet_b) - val payload_b = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_b.toBitVector).require.value + val payload_b = OnionCodecs.relayPerHopPayloadCodec.decode(bin_b.toBitVector).require.value assert(packet_c.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_b.amtToForward === amount_bc) - assert(payload_b.outgoingCltvValue === expiry_bc) + assert(payload_b.amountToForward === amount_bc) + assert(payload_b.outgoingCltv === expiry_bc) val Right(DecryptedPacket(bin_c, packet_d, _)) = Sphinx.PaymentPacket.peel(priv_c.privateKey, paymentHash, packet_c) - val payload_c = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_c.toBitVector).require.value + val payload_c = OnionCodecs.relayPerHopPayloadCodec.decode(bin_c.toBitVector).require.value assert(packet_d.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_c.amtToForward === amount_cd) - assert(payload_c.outgoingCltvValue === expiry_cd) + assert(payload_c.amountToForward === amount_cd) + assert(payload_c.outgoingCltv === expiry_cd) val Right(DecryptedPacket(bin_d, packet_e, _)) = Sphinx.PaymentPacket.peel(priv_d.privateKey, paymentHash, packet_d) - val payload_d = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_d.toBitVector).require.value + val payload_d = OnionCodecs.relayPerHopPayloadCodec.decode(bin_d.toBitVector).require.value assert(packet_e.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_d.amtToForward === amount_de) - assert(payload_d.outgoingCltvValue === expiry_de) + assert(payload_d.amountToForward === amount_de) + assert(payload_d.outgoingCltv === expiry_de) val Right(DecryptedPacket(bin_e, packet_random, _)) = Sphinx.PaymentPacket.peel(priv_e.privateKey, paymentHash, packet_e) - val payload_e = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_e.toBitVector).require.value + val payload_e = OnionCodecs.finalPerHopPayloadCodec.decode(bin_e.toBitVector).require.value assert(packet_random.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_e.amtToForward === finalAmountMsat) - assert(payload_e.outgoingCltvValue === finalExpiry) + assert(payload_e.amount === finalAmountMsat) + assert(payload_e.expiry === finalExpiry) } - test("build onion with final tlv payload") { - val (_, _, payloads) = buildPayloads(hops.drop(1), TlvPayload(finalAmountMsat, finalExpiry)) - val nodes = hops.map(_.nextNodeId) - val PacketAndSecrets(packet_b, _) = buildOnion(nodes, payloads, paymentHash) - assert(packet_b.payload.length === Sphinx.PaymentPacket.PayloadLength) - - // let's peel the onion - val Right(DecryptedPacket(bin_b, packet_c, _)) = Sphinx.PaymentPacket.peel(priv_b.privateKey, paymentHash, packet_b) - val payload_b = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_b.toBitVector).require.value - assert(packet_c.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_b === OnionForwardInfo(channelUpdate_bc.shortChannelId, amount_bc, expiry_bc)) - - val Right(DecryptedPacket(bin_c, packet_d, _)) = Sphinx.PaymentPacket.peel(priv_c.privateKey, paymentHash, packet_c) - val payload_c = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_c.toBitVector).require.value - assert(packet_d.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_c === OnionForwardInfo(channelUpdate_cd.shortChannelId, amount_cd, expiry_cd)) - - val Right(DecryptedPacket(bin_d, packet_e, _)) = Sphinx.PaymentPacket.peel(priv_d.privateKey, paymentHash, packet_d) - val payload_d = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_d.toBitVector).require.value - assert(packet_e.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_d === OnionForwardInfo(channelUpdate_de.shortChannelId, amount_de, expiry_de)) + test("build onion with final legacy payload") { + testBuildOnion(legacy = true) + } - val Right(DecryptedPacket(bin_e, packet_random, _)) = Sphinx.PaymentPacket.peel(priv_e.privateKey, paymentHash, packet_e) - val payload_e = OnionCodecs.tlvPerHopPayloadCodec.decode(bin_e.toBitVector).require.value - val paymentInfo = OnionPerHopPayload(Left(payload_e)).paymentInfo - assert(packet_random.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(paymentInfo === Right(OnionPaymentInfo(finalAmountMsat, finalExpiry))) + test("build onion with final tlv payload") { + testBuildOnion(legacy = false) } test("build a command including the onion") { - val (add, _) = buildCommand(UUID.randomUUID, paymentHash, hops, LegacyPayload(finalAmountMsat, finalExpiry)) - + val (add, _) = buildCommand(UUID.randomUUID, paymentHash, hops, FinalLegacyPayload(finalAmountMsat, finalExpiry)) assert(add.amount > finalAmountMsat) assert(add.cltvExpiry === finalExpiry + channelUpdate_de.cltvExpiryDelta + channelUpdate_cd.cltvExpiryDelta + channelUpdate_bc.cltvExpiryDelta) assert(add.paymentHash === paymentHash) assert(add.onion.payload.length === Sphinx.PaymentPacket.PayloadLength) // let's peel the onion - val Right(DecryptedPacket(bin_b, packet_c, _)) = Sphinx.PaymentPacket.peel(priv_b.privateKey, paymentHash, add.onion) - val payload_b = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_b.toBitVector).require.value - assert(packet_c.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_b.amtToForward === amount_bc) - assert(payload_b.outgoingCltvValue === expiry_bc) - - val Right(DecryptedPacket(bin_c, packet_d, _)) = Sphinx.PaymentPacket.peel(priv_c.privateKey, paymentHash, packet_c) - val payload_c = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_c.toBitVector).require.value - assert(packet_d.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_c.amtToForward === amount_cd) - assert(payload_c.outgoingCltvValue === expiry_cd) - - val Right(DecryptedPacket(bin_d, packet_e, _)) = Sphinx.PaymentPacket.peel(priv_d.privateKey, paymentHash, packet_d) - val payload_d = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_d.toBitVector).require.value - assert(packet_e.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_d.amtToForward === amount_de) - assert(payload_d.outgoingCltvValue === expiry_de) - - val Right(DecryptedPacket(bin_e, packet_random, _)) = Sphinx.PaymentPacket.peel(priv_e.privateKey, paymentHash, packet_e) - val payload_e = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_e.toBitVector).require.value - assert(packet_random.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_e.amtToForward === finalAmountMsat) - assert(payload_e.outgoingCltvValue === finalExpiry) + testPeelOnion(add.onion) } test("build a command with no hops") { - val (add, _) = buildCommand(UUID.randomUUID(), paymentHash, hops.take(1), LegacyPayload(finalAmountMsat, finalExpiry)) - + val (add, _) = buildCommand(UUID.randomUUID(), paymentHash, hops.take(1), FinalLegacyPayload(finalAmountMsat, finalExpiry)) assert(add.amount === finalAmountMsat) assert(add.cltvExpiry === finalExpiry) assert(add.paymentHash === paymentHash) @@ -170,10 +136,10 @@ class HtlcGenerationSpec extends FunSuite with BeforeAndAfterAll { // let's peel the onion val Right(DecryptedPacket(bin_b, packet_random, _)) = Sphinx.PaymentPacket.peel(priv_b.privateKey, paymentHash, add.onion) - val payload_b = OnionCodecs.legacyPerHopPayloadCodec.decode(bin_b.toBitVector).require.value + val payload_b = OnionCodecs.relayPerHopPayloadCodec.decode(bin_b.toBitVector).require.value assert(packet_random.payload.length === Sphinx.PaymentPacket.PayloadLength) - assert(payload_b.amtToForward === finalAmountMsat) - assert(payload_b.outgoingCltvValue === finalExpiry) + assert(payload_b.amountToForward === finalAmountMsat) + assert(payload_b.outgoingCltv === finalExpiry) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala index 7bb56dd77b..957dbdef11 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentLifecycleSpec.scala @@ -34,6 +34,7 @@ import fr.acinq.eclair.payment.PaymentLifecycle._ import fr.acinq.eclair.router.Announcements.{makeChannelUpdate, makeNodeAnnouncement} import fr.acinq.eclair.router._ import fr.acinq.eclair.transactions.Scripts +import fr.acinq.eclair.wire.Onion.FinalLegacyPayload import fr.acinq.eclair.wire._ import scodec.bits.HexStringSyntax @@ -62,7 +63,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) // pre-computed route going from A to D - val request = SendPaymentToRoute(defaultPaymentHash, Seq(a, b, c, d), LegacyPayload(defaultAmountMsat, defaultExpiry)) + val request = SendPaymentToRoute(defaultPaymentHash, Seq(a, b, c, d), FinalLegacyPayload(defaultAmountMsat, defaultExpiry)) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) @@ -88,7 +89,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultPaymentHash, f, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5) + val request = SendPayment(defaultPaymentHash, f, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val routeRequest = routerForwarder.expectMsgType[RouteRequest] @@ -111,7 +112,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(randomBytes32, d, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5, routeParams = Some(RouteParams(randomize = false, maxFeeBase = 100 msat, maxFeePct = 0.0, routeMaxLength = 20, routeMaxCltv = CltvExpiryDelta(2016), ratios = None))) + val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5, routeParams = Some(RouteParams(randomize = false, maxFeeBase = 100 msat, maxFeePct = 0.0, routeMaxLength = 20, routeMaxCltv = CltvExpiryDelta(2016), ratios = None))) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) @@ -134,7 +135,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultPaymentHash, d, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) + val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) @@ -177,7 +178,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(randomBytes32, d, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) + val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) @@ -210,7 +211,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultPaymentHash, d, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) + val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) @@ -241,7 +242,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(randomBytes32, d, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) + val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE) val WaitingForRoute(_, _, Nil) = paymentFSM.stateData @@ -281,7 +282,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(randomBytes32, d, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5) + val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) @@ -343,7 +344,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(randomBytes32, d, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) + val request = SendPayment(randomBytes32, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 2) sender.send(paymentFSM, request) awaitCond(paymentFSM.stateName == WAITING_FOR_ROUTE && paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.PENDING)) @@ -391,7 +392,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { paymentFSM ! SubscribeTransitionCallBack(monitor.ref) val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) - val request = SendPayment(defaultPaymentHash, d, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5) + val request = SendPayment(defaultPaymentHash, d, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5) sender.send(paymentFSM, request) val Transition(_, WAITING_FOR_REQUEST, WAITING_FOR_ROUTE) = monitor.expectMsgClass(classOf[Transition[_]]) val Transition(_, WAITING_FOR_ROUTE, WAITING_FOR_PAYMENT_COMPLETE) = monitor.expectMsgClass(classOf[Transition[_]]) @@ -399,9 +400,9 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) val paymentOK = sender.expectMsgType[PaymentSucceeded] - val PaymentSent(_, request.paymentOptions.finalAmount, fee, request.paymentHash, paymentOK.paymentPreimage, _, _) = eventListener.expectMsgType[PaymentSent] + val PaymentSent(_, request.finalPayload.amount, fee, request.paymentHash, paymentOK.paymentPreimage, _, _) = eventListener.expectMsgType[PaymentSent] assert(fee > 0.msat) - assert(fee === paymentOK.amount - request.paymentOptions.finalAmount) + assert(fee === paymentOK.amount - request.finalPayload.amount) awaitCond(paymentDb.getOutgoingPayment(id).exists(_.status == OutgoingPaymentStatus.SUCCEEDED)) } @@ -441,7 +442,7 @@ class PaymentLifecycleSpec extends BaseRouterSpec { val CurrentState(_, WAITING_FOR_REQUEST) = monitor.expectMsgClass(classOf[CurrentState[_]]) // we send a payment to G which is just after the - val request = SendPayment(defaultPaymentHash, g, LegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5) + val request = SendPayment(defaultPaymentHash, g, FinalLegacyPayload(defaultAmountMsat, defaultExpiry), maxAttempts = 5) sender.send(paymentFSM, request) // the route will be A -> B -> G where B -> G has a channel_update with fees=0 @@ -451,13 +452,13 @@ class PaymentLifecycleSpec extends BaseRouterSpec { sender.send(paymentFSM, UpdateFulfillHtlc(ByteVector32.Zeroes, 0, defaultPaymentHash)) val paymentOK = sender.expectMsgType[PaymentSucceeded] - val PaymentSent(_, request.paymentOptions.finalAmount, fee, request.paymentHash, paymentOK.paymentPreimage, _, _) = eventListener.expectMsgType[PaymentSent] + val PaymentSent(_, request.finalPayload.amount, fee, request.paymentHash, paymentOK.paymentPreimage, _, _) = eventListener.expectMsgType[PaymentSent] // during the route computation the fees were treated as if they were 1msat but when sending the onion we actually put zero // NB: A -> B doesn't pay fees because it's our direct neighbor // NB: B -> G doesn't asks for fees at all assert(fee === 0.msat) - assert(fee === paymentOK.amount - request.paymentOptions.finalAmount) + assert(fee === paymentOK.amount - request.finalPayload.amount) } test("filter errors properly") { _ => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala index cb89981103..299ad3058f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/RelayerSpec.scala @@ -21,13 +21,16 @@ import java.util.UUID import akka.actor.{ActorRef, Status} import akka.testkit.TestProbe import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.payment.PaymentLifecycle.{LegacyPayload, buildCommand, buildOnion} +import fr.acinq.eclair.payment.PaymentLifecycle.{buildCommand, buildOnion} import fr.acinq.eclair.router.Announcements +import fr.acinq.eclair.wire.Onion.{FinalLegacyPayload, FinalTlvPayload, PerHopPayload, RelayTlvPayload} import fr.acinq.eclair.wire._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, ShortChannelId, TestConstants, TestkitBaseClass, UInt64, nodeFee, randomBytes32} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, LongToBtcAmount, ShortChannelId, TestConstants, TestkitBaseClass, UInt64, nodeFee, randomBytes32, randomKey} import org.scalatest.Outcome +import scodec.Attempt import scodec.bits.ByteVector import scala.concurrent.duration._ @@ -41,6 +44,7 @@ class RelayerSpec extends TestkitBaseClass { // let's reuse the existing test data import HtlcGenerationSpec._ + import RelayerSpec._ case class FixtureParam(relayer: ActorRef, register: TestProbe, paymentHandler: TestProbe) @@ -62,7 +66,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, LegacyPayload(finalAmountMsat, finalExpiry)) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -84,11 +88,12 @@ class RelayerSpec extends TestkitBaseClass { import fr.acinq.eclair.wire.OnionTlv._ val sender = TestProbe() - val finalPayload: Seq[OnionPerHopPayload] = TlvStream[OnionTlv](AmountToForward(finalAmountMsat), OutgoingCltv(finalExpiry)) :: Nil + // Use tlv payloads for all hops (final and intermediate) + val finalPayload: Seq[PerHopPayload] = FinalTlvPayload(TlvStream[OnionTlv](AmountToForward(finalAmountMsat), OutgoingCltv(finalExpiry))) :: Nil val (firstAmountMsat, firstExpiry, payloads) = hops.drop(1).reverse.foldLeft((finalAmountMsat, finalExpiry, finalPayload)) { case ((amountMsat, expiry, currentPayloads), hop) => val nextFee = nodeFee(hop.lastUpdate.feeBaseMsat, hop.lastUpdate.feeProportionalMillionths, amountMsat) - val payload: OnionPerHopPayload = TlvStream[OnionTlv](AmountToForward(amountMsat), OutgoingCltv(expiry), OutgoingChannelId(hop.lastUpdate.shortChannelId)) + val payload = RelayTlvPayload(TlvStream[OnionTlv](AmountToForward(amountMsat), OutgoingCltv(expiry), OutgoingChannelId(hop.lastUpdate.shortChannelId))) (amountMsat + nextFee, expiry + hop.lastUpdate.cltvExpiryDelta, payload +: currentPayloads) } val Sphinx.PacketAndSecrets(onion, _) = buildOnion(hops.map(_.nextNodeId), payloads, paymentHash) @@ -112,7 +117,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, LegacyPayload(finalAmountMsat, finalExpiry)) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) @@ -156,7 +161,7 @@ class RelayerSpec extends TestkitBaseClass { import f._ val sender = TestProbe() - val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops.take(1), LegacyPayload(finalAmountMsat, finalExpiry)) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops.take(1), FinalLegacyPayload(finalAmountMsat, finalExpiry)) val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) sender.send(relayer, ForwardAdd(add_ab)) @@ -172,7 +177,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, LegacyPayload(finalAmountMsat, finalExpiry)) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) @@ -192,7 +197,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, LegacyPayload(finalAmountMsat, finalExpiry)) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -219,7 +224,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // check that payments are sent properly - val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, LegacyPayload(finalAmountMsat, finalExpiry)) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, FinalLegacyPayload(finalAmountMsat, finalExpiry)) val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -235,7 +240,7 @@ class RelayerSpec extends TestkitBaseClass { // now tell the relayer that the channel is down and try again relayer ! LocalChannelDown(sender.ref, channelId = channelId_bc, shortChannelId = channelUpdate_bc.shortChannelId, remoteNodeId = TestConstants.Bob.nodeParams.nodeId) - val (cmd1, _) = buildCommand(UUID.randomUUID(), randomBytes32, hops, LegacyPayload(finalAmountMsat, finalExpiry)) + val (cmd1, _) = buildCommand(UUID.randomUUID(), randomBytes32, hops, FinalLegacyPayload(finalAmountMsat, finalExpiry)) val add_ab1 = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd1.amount, cmd1.paymentHash, cmd1.cltvExpiry, cmd1.onion) sender.send(relayer, ForwardAdd(add_ab1)) @@ -252,7 +257,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, LegacyPayload(finalAmountMsat, finalExpiry)) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) val channelUpdate_bc_disabled = channelUpdate_bc.copy(channelFlags = Announcements.makeChannelFlags(Announcements.isNode1(channelUpdate_bc.channelFlags), enable = false)) @@ -273,7 +278,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // we use this to build a valid onion - val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, LegacyPayload(finalAmountMsat, finalExpiry)) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc with an invalid onion (hmac) val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion.copy(hmac = cmd.onion.hmac.reverse)) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -304,7 +309,7 @@ class RelayerSpec extends TestkitBaseClass { relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) for ((expectedErr, invalidPayload_bc) <- invalidPayloads_bc) { - val Sphinx.PacketAndSecrets(onion, _) = buildOnion(Seq(b, c), Seq(invalidPayload_bc, payload_cd), paymentHash) + val onion = buildTlvOnion(Seq(b, c), Seq(invalidPayload_bc, payload_cd), paymentHash) val add_ab = UpdateAddHtlc(channelId_ab, 123456, amount_ab, paymentHash, expiry_ab, onion) sender.send(relayer, ForwardAdd(add_ab)) @@ -325,7 +330,7 @@ class RelayerSpec extends TestkitBaseClass { val payload_bc = TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_bc.shortChannelId), AmountToForward(amount_bc), OutgoingCltv(expiry_bc)) val payload_cd = TlvStream[OnionTlv](OutgoingChannelId(channelUpdate_cd.shortChannelId), AmountToForward(amount_cd), OutgoingCltv(expiry_cd)) - val Sphinx.PacketAndSecrets(onion, _) = buildOnion(Seq(b, c), Seq(payload_bc, payload_cd), paymentHash) + val onion = buildTlvOnion(Seq(b, c), Seq(payload_bc, payload_cd), paymentHash) val add_ab = UpdateAddHtlc(channelId_ab, 123456, amount_ab, paymentHash, expiry_ab, onion) sender.send(relayer, ForwardAdd(add_ab)) @@ -339,9 +344,9 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() // we use this to build a valid onion - val paymentOptions = LegacyPayload(channelUpdate_bc.htlcMinimumMsat - (1 msat), finalExpiry) + val finalPayload = FinalLegacyPayload(channelUpdate_bc.htlcMinimumMsat - (1 msat), finalExpiry) val zeroFeeHops = hops.map(hop => hop.copy(lastUpdate = hop.lastUpdate.copy(feeBaseMsat = 0 msat, feeProportionalMillionths = 0))) - val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, zeroFeeHops, paymentOptions) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, zeroFeeHops, finalPayload) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -361,7 +366,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() val hops1 = hops.updated(1, hops(1).copy(lastUpdate = hops(1).lastUpdate.copy(cltvExpiryDelta = CltvExpiryDelta(0)))) - val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops1, LegacyPayload(finalAmountMsat, finalExpiry)) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops1, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -381,7 +386,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() val hops1 = hops.updated(1, hops(1).copy(lastUpdate = hops(1).lastUpdate.copy(feeBaseMsat = hops(1).lastUpdate.feeBaseMsat / 2))) - val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops1, LegacyPayload(finalAmountMsat, finalExpiry)) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops1, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -402,7 +407,7 @@ class RelayerSpec extends TestkitBaseClass { // to simulate this we use a zero-hop route A->B where A is the 'attacker' val hops1 = hops.head :: Nil - val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops1, LegacyPayload(finalAmountMsat, finalExpiry)) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops1, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc with a wrong expiry val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount - (1 msat), cmd.paymentHash, cmd.cltvExpiry, cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -423,7 +428,7 @@ class RelayerSpec extends TestkitBaseClass { // to simulate this we use a zero-hop route A->B where A is the 'attacker' val hops1 = hops.head :: Nil - val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops1, LegacyPayload(finalAmountMsat, finalExpiry)) + val (cmd, _) = buildCommand(UUID.randomUUID(), paymentHash, hops1, FinalLegacyPayload(finalAmountMsat, finalExpiry)) // and then manually build an htlc with a wrong expiry val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 123456, cmd.amount, cmd.paymentHash, cmd.cltvExpiry - CltvExpiryDelta(1), cmd.onion) relayer ! LocalChannelUpdate(null, channelId_bc, channelUpdate_bc.shortChannelId, c, None, channelUpdate_bc, makeCommitments(channelId_bc)) @@ -450,7 +455,7 @@ class RelayerSpec extends TestkitBaseClass { val sender = TestProbe() for ((expectedErr, invalidFinalPayload) <- invalidFinalPayloads) { - val Sphinx.PacketAndSecrets(onion, _) = buildOnion(Seq(b), Seq(invalidFinalPayload), paymentHash) + val onion = buildTlvOnion(Seq(b), Seq(invalidFinalPayload), paymentHash) val add_ab = UpdateAddHtlc(channelId_ab, 123456, amount_ab, paymentHash, expiry_ab, onion) sender.send(relayer, ForwardAdd(add_ab)) @@ -559,3 +564,20 @@ class RelayerSpec extends TestkitBaseClass { assert(usableBalances5.size === 1) } } + +object RelayerSpec { + + /** Build onion from arbitrary tlv stream (potentially invalid). */ + def buildTlvOnion(nodes: Seq[PublicKey], payloads: Seq[TlvStream[OnionTlv]], associatedData: ByteVector32): OnionRoutingPacket = { + require(nodes.size == payloads.size) + val sessionKey = randomKey + val payloadsBin: Seq[ByteVector] = payloads + .map(OnionCodecs.tlvPerHopPayloadCodec.encode) + .map { + case Attempt.Successful(bitVector) => bitVector.toByteVector + case Attempt.Failure(cause) => throw new RuntimeException(s"serialization error: $cause") + } + Sphinx.PaymentPacket.create(sessionKey, nodes, payloadsBin, associatedData).packet + } + +} \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala index be9bfff386..c2077d0ebb 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/OnionCodecsSpec.scala @@ -18,11 +18,12 @@ package fr.acinq.eclair.wire import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.UInt64.Conversions._ +import fr.acinq.eclair.wire.Onion.{FinalLegacyPayload, FinalTlvPayload, RelayLegacyPayload, RelayTlvPayload} import fr.acinq.eclair.wire.OnionCodecs._ -import fr.acinq.eclair.wire.OnionPerHopPayload._ import fr.acinq.eclair.wire.OnionTlv._ import fr.acinq.eclair.{CltvExpiry, LongToBtcAmount, ShortChannelId, UInt64} import org.scalatest.FunSuite +import scodec.Attempt import scodec.bits.HexStringSyntax /** @@ -42,18 +43,34 @@ class OnionCodecsSpec extends FunSuite { assert(encoded.toByteVector === bin) } - test("encode/decode fixed-size (legacy) per-hop payload") { + test("encode/decode fixed-size (legacy) relay per-hop payload") { val testCases = Map( - OnionForwardInfo(ShortChannelId(0), 0 msat, CltvExpiry(0)) -> hex"00 0000000000000000 0000000000000000 00000000 000000000000000000000000", - OnionForwardInfo(ShortChannelId(42), 142000 msat, CltvExpiry(500000)) -> hex"00 000000000000002a 0000000000022ab0 0007a120 000000000000000000000000", - OnionForwardInfo(ShortChannelId(561), 1105 msat, CltvExpiry(1729)) -> hex"00 0000000000000231 0000000000000451 000006c1 000000000000000000000000" + RelayLegacyPayload(ShortChannelId(0), 0 msat, CltvExpiry(0)) -> hex"00 0000000000000000 0000000000000000 00000000 000000000000000000000000", + RelayLegacyPayload(ShortChannelId(42), 142000 msat, CltvExpiry(500000)) -> hex"00 000000000000002a 0000000000022ab0 0007a120 000000000000000000000000", + RelayLegacyPayload(ShortChannelId(561), 1105 msat, CltvExpiry(1729)) -> hex"00 0000000000000231 0000000000000451 000006c1 000000000000000000000000" ) for ((expected, bin) <- testCases) { - val OnionPerHopPayload(Right(decoded)) = perHopPayloadCodec.decode(bin.bits).require.value + val decoded = relayPerHopPayloadCodec.decode(bin.bits).require.value assert(decoded === expected) - val encoded = perHopPayloadCodec.encode(expected).require.bytes + val encoded = relayPerHopPayloadCodec.encode(expected).require.bytes + assert(encoded === bin) + } + } + + test("encode/decode fixed-size (legacy) final per-hop payload") { + val testCases = Map( + FinalLegacyPayload(0 msat, CltvExpiry(0)) -> hex"00 0000000000000000 0000000000000000 00000000 000000000000000000000000", + FinalLegacyPayload(142000 msat, CltvExpiry(500000)) -> hex"00 0000000000000000 0000000000022ab0 0007a120 000000000000000000000000", + FinalLegacyPayload(1105 msat, CltvExpiry(1729)) -> hex"00 0000000000000000 0000000000000451 000006c1 000000000000000000000000" + ) + + for ((expected, bin) <- testCases) { + val decoded = finalPerHopPayloadCodec.decode(bin.bits).require.value + assert(decoded === expected) + + val encoded = finalPerHopPayloadCodec.encode(expected).require.bytes assert(encoded === bin) } } @@ -75,27 +92,77 @@ class OnionCodecsSpec extends FunSuite { } } - test("encode/decode variable-length (tlv) per-hop payload") { + test("encode/decode variable-length (tlv) relay per-hop payload") { val testCases = Map( - TlvStream[OnionTlv]() -> hex"00", + TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), OutgoingChannelId(ShortChannelId(1105))) -> hex"11 02020231 04012a 06080000000000000451", + TlvStream[OnionTlv](Seq(AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), OutgoingChannelId(ShortChannelId(1105))), Seq(GenericTlv(65535, hex"06c1"))) -> hex"17 02020231 04012a 06080000000000000451 fdffff0206c1" + ) + + for ((expected, bin) <- testCases) { + val decoded = relayPerHopPayloadCodec.decode(bin.bits).require.value + assert(decoded === RelayTlvPayload(expected)) + assert(decoded.amountToForward === 561.msat) + assert(decoded.outgoingCltv === CltvExpiry(42)) + assert(decoded.outgoingChannelId === ShortChannelId(1105)) + + val encoded = relayPerHopPayloadCodec.encode(RelayTlvPayload(expected)).require.bytes + assert(encoded === bin) + } + } + + test("encode/decode variable-length (tlv) final per-hop payload") { + val testCases = Map( + TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42))) -> hex"07 02020231 04012a", TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), OutgoingChannelId(ShortChannelId(1105))) -> hex"11 02020231 04012a 06080000000000000451", TlvStream[OnionTlv](Seq(AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42))), Seq(GenericTlv(65535, hex"06c1"))) -> hex"0d 02020231 04012a fdffff0206c1" ) for ((expected, bin) <- testCases) { - val OnionPerHopPayload(Left(decoded)) = perHopPayloadCodec.decode(bin.bits).require.value - assert(decoded === expected) + val decoded = finalPerHopPayloadCodec.decode(bin.bits).require.value + assert(decoded === FinalTlvPayload(expected)) + assert(decoded.amount === 561.msat) + assert(decoded.expiry === CltvExpiry(42)) - val encoded = perHopPayloadCodec.encode(expected).require.bytes + val encoded = finalPerHopPayloadCodec.encode(FinalTlvPayload(expected)).require.bytes assert(encoded === bin) } } + test("decode variable-length (tlv) relay per-hop payload missing information") { + val testCases = Seq( + (InvalidOnionPayload(UInt64(2), 0), hex"0d 04012a 06080000000000000451"), // missing amount + (InvalidOnionPayload(UInt64(4), 0), hex"0e 02020231 06080000000000000451"), // missing cltv + (InvalidOnionPayload(UInt64(6), 0), hex"07 02020231 04012a") // missing channel id + ) + + for ((expectedErr, bin) <- testCases) { + val decoded = relayPerHopPayloadCodec.decode(bin.bits) + assert(decoded.isFailure) + val Attempt.Failure(err: MissingRequiredTlv) = decoded + assert(err.failureMessage === expectedErr) + } + } + + test("decode variable-length (tlv) final per-hop payload missing information") { + val testCases = Seq( + (InvalidOnionPayload(UInt64(2), 0), hex"03 04012a"), // missing amount + (InvalidOnionPayload(UInt64(4), 0), hex"04 02020231") // missing cltv + ) + + for ((expectedErr, bin) <- testCases) { + val decoded = finalPerHopPayloadCodec.decode(bin.bits) + assert(decoded.isFailure) + val Attempt.Failure(err: MissingRequiredTlv) = decoded + assert(err.failureMessage === expectedErr) + } + } + test("decode invalid per-hop payload") { val testCases = Seq( // Invalid fixed-size (legacy) payload. hex"00 000000000000002a 000000000000002a", // invalid length // Invalid variable-length (tlv) payload. + hex"00", // empty payload is missing required information hex"01", // invalid length hex"01 0000", // invalid length hex"04 0000 2a00", // unknown even types @@ -104,45 +171,9 @@ class OnionCodecsSpec extends FunSuite { ) for (testCase <- testCases) { - assert(perHopPayloadCodec.decode(testCase.bits).isFailure) + assert(relayPerHopPayloadCodec.decode(testCase.bits).isFailure) + assert(finalPerHopPayloadCodec.decode(testCase.bits).isFailure) } } - test("get payment info") { - val legacyPayload: OnionPerHopPayload = OnionForwardInfo(ShortChannelId(550), 561 msat, CltvExpiry(1105)) - assert(legacyPayload.paymentInfo === Right(OnionPaymentInfo(561 msat, CltvExpiry(1105)))) - - val tlvPayload: OnionPerHopPayload = TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(1105))) - assert(tlvPayload.paymentInfo === Right(OnionPaymentInfo(561 msat, CltvExpiry(1105)))) - - val tlvPayloadUnknown: OnionPerHopPayload = TlvStream[OnionTlv](Seq(AmountToForward(561 msat), OutgoingCltv(CltvExpiry(1105))), Seq(GenericTlv(13, hex"2a"))) - assert(tlvPayloadUnknown.paymentInfo === Right(OnionPaymentInfo(561 msat, CltvExpiry(1105)))) - - val tlvPayloadNoCltv: OnionPerHopPayload = TlvStream[OnionTlv](AmountToForward(561 msat)) - assert(tlvPayloadNoCltv.paymentInfo === Left(InvalidOnionPayload(UInt64(4), 0))) - - val tlvPayloadNoAmount: OnionPerHopPayload = TlvStream[OnionTlv](OutgoingCltv(CltvExpiry(1105))) - assert(tlvPayloadNoAmount.paymentInfo === Left(InvalidOnionPayload(UInt64(2), 0))) - } - - test("get forward info") { - val legacyPayload: OnionPerHopPayload = OnionForwardInfo(ShortChannelId(550), 561 msat, CltvExpiry(1105)) - assert(legacyPayload.forwardInfo === Right(OnionForwardInfo(ShortChannelId(550), 561 msat, CltvExpiry(1105)))) - - val tlvPayload: OnionPerHopPayload = TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(1105)), OutgoingChannelId(ShortChannelId(550))) - assert(tlvPayload.forwardInfo === Right(OnionForwardInfo(ShortChannelId(550), 561 msat, CltvExpiry(1105)))) - - val tlvPayloadUnknown: OnionPerHopPayload = TlvStream[OnionTlv](Seq(AmountToForward(561 msat), OutgoingCltv(CltvExpiry(1105)), OutgoingChannelId(ShortChannelId(550))), Seq(GenericTlv(13, hex"2a"))) - assert(tlvPayloadUnknown.forwardInfo === Right(OnionForwardInfo(ShortChannelId(550), 561 msat, CltvExpiry(1105)))) - - val tlvPayloadNoCltv: OnionPerHopPayload = TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingChannelId(ShortChannelId(550))) - assert(tlvPayloadNoCltv.forwardInfo === Left(InvalidOnionPayload(UInt64(4), 0))) - - val tlvPayloadNoAmount: OnionPerHopPayload = TlvStream[OnionTlv](OutgoingCltv(CltvExpiry(1105)), OutgoingChannelId(ShortChannelId(550))) - assert(tlvPayloadNoAmount.forwardInfo === Left(InvalidOnionPayload(UInt64(2), 0))) - - val tlvPayloadNoChannelId: OnionPerHopPayload = TlvStream[OnionTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(1105))) - assert(tlvPayloadNoChannelId.forwardInfo === Left(InvalidOnionPayload(UInt64(6), 0))) - } - } From 75e14f0f7c00b41ef3b40bac067dbd1e81848eb6 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Thu, 5 Sep 2019 10:08:33 +0200 Subject: [PATCH 10/11] Add variable length onion to supported features. Add more comments. Separate per-hop payload content trait from its encoding. --- .../main/scala/fr/acinq/eclair/Features.scala | 2 +- .../eclair/payment/PaymentInitiator.scala | 1 + .../eclair/payment/PaymentLifecycle.scala | 4 ++-- .../scala/fr/acinq/eclair/wire/Onion.scala | 24 ++++++++++--------- .../scala/fr/acinq/eclair/FeaturesSpec.scala | 5 ++-- 5 files changed, 20 insertions(+), 16 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala index 12e9997955..145961614c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala @@ -54,7 +54,7 @@ object Features { * we don't understand (even bits). */ def areSupported(features: BitVector): Boolean = { - val supportedMandatoryFeatures = Set[Long](OPTION_DATA_LOSS_PROTECT_MANDATORY) + val supportedMandatoryFeatures = Set[Long](OPTION_DATA_LOSS_PROTECT_MANDATORY, VARIABLE_LENGTH_ONION_MANDATORY) val reversed = features.reverse for (i <- 0L until reversed.length by 2) { if (reversed.get(i) && !supportedMandatoryFeatures.contains(i)) return false diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala index 13035aa3ef..a3db9e6dc8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentInitiator.scala @@ -39,6 +39,7 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: Actor // We add one block in order to not have our htlc fail when a new block has just been found. val finalExpiry = (p.finalExpiryDelta + 1).toCltvExpiry val payFsm = context.actorOf(PaymentLifecycle.props(nodeParams, paymentId, router, register)) + // NB: we only generate legacy payment onions for now for maximum compatibility. p.predefinedRoute match { case Nil => payFsm forward SendPayment(p.paymentHash, p.targetNodeId, FinalLegacyPayload(p.amount, finalExpiry), p.maxAttempts, p.assistedRoutes, p.routeParams) case hops => payFsm forward SendPaymentToRoute(p.paymentHash, hops, FinalLegacyPayload(p.amount, finalExpiry)) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala index 0de607d64f..d06566ca7e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentLifecycle.scala @@ -229,10 +229,10 @@ object PaymentLifecycle { require(nodes.size == payloads.size) val sessionKey = randomKey val payloadsBin: Seq[ByteVector] = payloads - .map({ + .map { case p: FinalPayload => OnionCodecs.finalPerHopPayloadCodec.encode(p) case p: RelayPayload => OnionCodecs.relayPerHopPayloadCodec.encode(p) - }) + } .map { case Attempt.Successful(bitVector) => bitVector.toByteVector case Attempt.Failure(cause) => throw new RuntimeException(s"serialization error: $cause") diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala index ee37cdf296..f717249802 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala @@ -49,19 +49,21 @@ object Onion { import OnionTlv._ - /** Per-hop payload from an HTLC's payment onion (after decryption and decoding). */ - sealed trait PerHopPayload + sealed trait PerHopPayloadEncoding /** Legacy fixed-size 65-bytes onion payload. */ - sealed trait LegacyPayload extends PerHopPayload + sealed trait LegacyPayload extends PerHopPayloadEncoding /** Variable-length onion payload with optional additional tlv records. */ - sealed trait TlvPayload extends PerHopPayload { - val records: TlvStream[OnionTlv] + sealed trait TlvPayload extends PerHopPayloadEncoding { + def records: TlvStream[OnionTlv] } + /** Per-hop payload from an HTLC's payment onion (after decryption and decoding). */ + sealed trait PerHopPayload + /** Per-hop payload for an intermediate node. */ - sealed trait RelayPayload extends PerHopPayload { + sealed trait RelayPayload extends PerHopPayload with PerHopPayloadEncoding { /** Amount to forward to the next node. */ val amountToForward: MilliSatoshi /** CLTV value to use for the HTLC offered to the next node. */ @@ -71,22 +73,22 @@ object Onion { } /** Per-hop payload for a final node. */ - sealed trait FinalPayload extends PerHopPayload { + sealed trait FinalPayload extends PerHopPayload with PerHopPayloadEncoding { val amount: MilliSatoshi val expiry: CltvExpiry } - case class RelayLegacyPayload(outgoingChannelId: ShortChannelId, amountToForward: MilliSatoshi, outgoingCltv: CltvExpiry) extends LegacyPayload with RelayPayload + case class RelayLegacyPayload(outgoingChannelId: ShortChannelId, amountToForward: MilliSatoshi, outgoingCltv: CltvExpiry) extends RelayPayload with LegacyPayload - case class FinalLegacyPayload(amount: MilliSatoshi, expiry: CltvExpiry) extends LegacyPayload with FinalPayload + case class FinalLegacyPayload(amount: MilliSatoshi, expiry: CltvExpiry) extends FinalPayload with LegacyPayload - case class RelayTlvPayload(records: TlvStream[OnionTlv]) extends TlvPayload with RelayPayload { + case class RelayTlvPayload(records: TlvStream[OnionTlv]) extends RelayPayload with TlvPayload { override val amountToForward = records.get[AmountToForward].get.amount override val outgoingCltv = records.get[OutgoingCltv].get.cltv override val outgoingChannelId = records.get[OutgoingChannelId].get.shortChannelId } - case class FinalTlvPayload(records: TlvStream[OnionTlv]) extends TlvPayload with FinalPayload { + case class FinalTlvPayload(records: TlvStream[OnionTlv]) extends FinalPayload with TlvPayload { override val amount = records.get[AmountToForward].get.amount override val expiry = records.get[OutgoingCltv].get.cltv } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala index 71733773b3..fa1c4b5ed0 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala @@ -35,8 +35,8 @@ class FeaturesSpec extends FunSuite { assert(hasFeature(hex"02", Features.OPTION_DATA_LOSS_PROTECT_OPTIONAL)) } - test("'initial_routing_sync' and 'data_loss_protect' feature") { - val features = hex"0a" + test("'initial_routing_sync', 'data_loss_protect' and 'variable_length_onion' features") { + val features = hex"010a" assert(areSupported(features) && hasFeature(features, OPTION_DATA_LOSS_PROTECT_OPTIONAL) && hasFeature(features, INITIAL_ROUTING_SYNC_BIT_OPTIONAL)) } @@ -52,6 +52,7 @@ class FeaturesSpec extends FunSuite { assert(areSupported(ByteVector.fromLong(1L << OPTION_DATA_LOSS_PROTECT_MANDATORY))) assert(areSupported(ByteVector.fromLong(1L << OPTION_DATA_LOSS_PROTECT_OPTIONAL))) assert(areSupported(ByteVector.fromLong(1L << VARIABLE_LENGTH_ONION_OPTIONAL))) + assert(areSupported(ByteVector.fromLong(1L << VARIABLE_LENGTH_ONION_MANDATORY))) assert(areSupported(hex"0b")) assert(!areSupported(hex"14")) assert(!areSupported(hex"0141")) From 3cf75a98c264df0716a7d7230558cbfe4a4911e9 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Thu, 5 Sep 2019 16:55:02 +0200 Subject: [PATCH 11/11] fixup! Add variable length onion to supported features. Add more comments. Separate per-hop payload content trait from its encoding. --- .../fr/acinq/eclair/payment/Relayer.scala | 2 +- .../scala/fr/acinq/eclair/wire/Onion.scala | 18 +++++++++--------- .../scala/fr/acinq/eclair/FeaturesSpec.scala | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala index 3e747a1eee..1ee391d05e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Relayer.scala @@ -235,7 +235,7 @@ object Relayer extends Logging { case Right(p@Sphinx.DecryptedPacket(payload, nextPacket, _)) => val codec = if (p.isLastPacket) OnionCodecs.finalPerHopPayloadCodec else OnionCodecs.relayPerHopPayloadCodec codec.decode(payload.bits) match { - case Attempt.Successful(DecodeResult(_: Onion.TlvPayload, _)) if !Features.hasVariableLengthOnion(features) => Left(InvalidRealm) + case Attempt.Successful(DecodeResult(_: Onion.TlvFormat, _)) if !Features.hasVariableLengthOnion(features) => Left(InvalidRealm) case Attempt.Successful(DecodeResult(perHopPayload, remainder)) => if (remainder.nonEmpty) { logger.warn(s"${remainder.length} bits remaining after per-hop payload decoding: there might be an issue with the onion codec") diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala index f717249802..d1ff115951 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala @@ -49,13 +49,13 @@ object Onion { import OnionTlv._ - sealed trait PerHopPayloadEncoding + sealed trait PerHopPayloadFormat /** Legacy fixed-size 65-bytes onion payload. */ - sealed trait LegacyPayload extends PerHopPayloadEncoding + sealed trait LegacyFormat extends PerHopPayloadFormat /** Variable-length onion payload with optional additional tlv records. */ - sealed trait TlvPayload extends PerHopPayloadEncoding { + sealed trait TlvFormat extends PerHopPayloadFormat { def records: TlvStream[OnionTlv] } @@ -63,7 +63,7 @@ object Onion { sealed trait PerHopPayload /** Per-hop payload for an intermediate node. */ - sealed trait RelayPayload extends PerHopPayload with PerHopPayloadEncoding { + sealed trait RelayPayload extends PerHopPayload with PerHopPayloadFormat { /** Amount to forward to the next node. */ val amountToForward: MilliSatoshi /** CLTV value to use for the HTLC offered to the next node. */ @@ -73,22 +73,22 @@ object Onion { } /** Per-hop payload for a final node. */ - sealed trait FinalPayload extends PerHopPayload with PerHopPayloadEncoding { + sealed trait FinalPayload extends PerHopPayload with PerHopPayloadFormat { val amount: MilliSatoshi val expiry: CltvExpiry } - case class RelayLegacyPayload(outgoingChannelId: ShortChannelId, amountToForward: MilliSatoshi, outgoingCltv: CltvExpiry) extends RelayPayload with LegacyPayload + case class RelayLegacyPayload(outgoingChannelId: ShortChannelId, amountToForward: MilliSatoshi, outgoingCltv: CltvExpiry) extends RelayPayload with LegacyFormat - case class FinalLegacyPayload(amount: MilliSatoshi, expiry: CltvExpiry) extends FinalPayload with LegacyPayload + case class FinalLegacyPayload(amount: MilliSatoshi, expiry: CltvExpiry) extends FinalPayload with LegacyFormat - case class RelayTlvPayload(records: TlvStream[OnionTlv]) extends RelayPayload with TlvPayload { + case class RelayTlvPayload(records: TlvStream[OnionTlv]) extends RelayPayload with TlvFormat { override val amountToForward = records.get[AmountToForward].get.amount override val outgoingCltv = records.get[OutgoingCltv].get.cltv override val outgoingChannelId = records.get[OutgoingChannelId].get.shortChannelId } - case class FinalTlvPayload(records: TlvStream[OnionTlv]) extends FinalPayload with TlvPayload { + case class FinalTlvPayload(records: TlvStream[OnionTlv]) extends FinalPayload with TlvFormat { override val amount = records.get[AmountToForward].get.amount override val expiry = records.get[OutgoingCltv].get.cltv } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala index fa1c4b5ed0..30f06c385d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/FeaturesSpec.scala @@ -37,7 +37,7 @@ class FeaturesSpec extends FunSuite { test("'initial_routing_sync', 'data_loss_protect' and 'variable_length_onion' features") { val features = hex"010a" - assert(areSupported(features) && hasFeature(features, OPTION_DATA_LOSS_PROTECT_OPTIONAL) && hasFeature(features, INITIAL_ROUTING_SYNC_BIT_OPTIONAL)) + assert(areSupported(features) && hasFeature(features, OPTION_DATA_LOSS_PROTECT_OPTIONAL) && hasFeature(features, INITIAL_ROUTING_SYNC_BIT_OPTIONAL) && hasFeature(features, VARIABLE_LENGTH_ONION_MANDATORY)) } test("'variable_length_onion' feature") {