diff --git a/eclair-core/src/main/resources/reference.conf b/eclair-core/src/main/resources/reference.conf index df16913553..ef4b8bc38e 100644 --- a/eclair-core/src/main/resources/reference.conf +++ b/eclair-core/src/main/resources/reference.conf @@ -54,10 +54,21 @@ eclair { trampoline_payment = disabled keysend = disabled } + channel-types { + // The following parameter contains the list of all supported lightning transaction formats (order by preference). + // You can reorder this list or remove entries to change what types of channels can be created. + commitment-format = [ + // standard lightning channels with option_static_remotekey applied (simplified funds recovery in case of data loss) + "static_remotekey", + // standard lightning channels (as defined in v1.0 of the LN specification) + "standard" + ] + } override-features = [ // optional per-node features # { # nodeid = "02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", # features { } + # channel-types { } # } ] sync-whitelist = [] // a list of public keys; if non-empty, we will only do the initial sync with those peers @@ -321,6 +332,6 @@ akka { backend.min-nr-of-members = 1 frontend.min-nr-of-members = 0 } - seed-nodes = [ "akka://eclair-node@127.0.0.1:25520" ] + seed-nodes = ["akka://eclair-node@127.0.0.1:25520"] } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala index 4601e65cb4..df49a6befd 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala @@ -21,7 +21,7 @@ import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.{Block, ByteVector32, Crypto, Satoshi} import fr.acinq.eclair.Setup.Seeds import fr.acinq.eclair.blockchain.fee._ -import fr.acinq.eclair.channel.Channel +import fr.acinq.eclair.channel.{Channel, ChannelType} import fr.acinq.eclair.crypto.Noise.KeyPair import fr.acinq.eclair.crypto.keymanager.{ChannelKeyManager, NodeKeyManager} import fr.acinq.eclair.db._ @@ -52,7 +52,8 @@ case class NodeParams(nodeKeyManager: NodeKeyManager, color: Color, publicAddresses: List[NodeAddress], features: Features, - private val overrideFeatures: Map[PublicKey, Features], + channelTypes: List[ChannelType], + private val overrideFeatures: Map[PublicKey, (Features, List[ChannelType])], syncWhitelist: Set[PublicKey], pluginParams: Seq[PluginParams], dustLimit: Satoshi, @@ -100,7 +101,16 @@ case class NodeParams(nodeKeyManager: NodeKeyManager, def currentBlockHeight: Long = blockCount.get - def featuresFor(nodeId: PublicKey): Features = overrideFeatures.getOrElse(nodeId, features) + def featuresFor(nodeId: PublicKey): Features = overrideFeatures.get(nodeId) match { + case Some((featuresOverride, _)) if featuresOverride.activated.nonEmpty => featuresOverride + case _ => features + } + + def channelTypesFor(nodeId: PublicKey): List[ChannelType] = overrideFeatures.get(nodeId) match { + case Some((_, channelTypesOverride)) if channelTypesOverride.nonEmpty => channelTypesOverride + case _ => channelTypes + } + } object NodeParams extends Logging { @@ -247,6 +257,28 @@ object NodeParams extends Logging { val features = Features.fromConfiguration(config) validateFeatures(features) + def parseChannelTypes(config: Config): List[ChannelType] = { + if (!config.hasPath("channel-types")) { + Nil + } else { + config.getStringList("channel-types.commitment-format").asScala.toList.map { + case "standard" => ChannelType(Features.empty) + case "static_remotekey" => ChannelType(Features(Features.StaticRemoteKey -> FeatureSupport.Optional)) + case "anchor_outputs" => ChannelType(Features(Features.StaticRemoteKey -> FeatureSupport.Optional, Features.AnchorOutputs -> FeatureSupport.Optional)) + case unknown => throw new RuntimeException(s"unsupported channel type: $unknown") + } + } + } + + def validateChannelTypes(features: Features, channelTypes: List[ChannelType]): Unit = { + channelTypes.foreach(channelType => channelType.features.activated.keys.foreach(f => + require(features.hasFeature(f), s"feature $f is necessary for channel-type $channelType: you must either enable $f or disable $channelType") + )) + } + + val channelTypes = parseChannelTypes(config) + validateChannelTypes(features, channelTypes) + require(pluginMessageParams.forall(_.feature.mandatory > 128), "Plugin mandatory feature bit is too low, must be > 128") require(pluginMessageParams.forall(_.feature.mandatory % 2 == 0), "Plugin mandatory feature bit is odd, must be even") require(pluginMessageParams.flatMap(_.messageTags).forall(_ > 32768), "Plugin messages tags must be > 32768") @@ -256,11 +288,19 @@ object NodeParams extends Logging { val coreAndPluginFeatures = features.copy(unknown = features.unknown ++ pluginMessageParams.map(_.pluginFeature)) - val overrideFeatures: Map[PublicKey, Features] = config.getConfigList("override-features").asScala.map { e => + val overrideFeatures: Map[PublicKey, (Features, List[ChannelType])] = config.getConfigList("override-features").asScala.map { e => val p = PublicKey(ByteVector.fromValidHex(e.getString("nodeid"))) - val f = Features.fromConfiguration(e) - validateFeatures(f) - p -> f.copy(unknown = f.unknown ++ pluginMessageParams.map(_.pluginFeature)) + val featuresOverride = Features.fromConfiguration(e) match { + case f if f.activated.nonEmpty => f + case _ => features + } + validateFeatures(featuresOverride) + val channelTypesOverride = parseChannelTypes(e) match { + case ct if ct.nonEmpty => ct + case _ => channelTypes + } + validateChannelTypes(featuresOverride, channelTypesOverride) + p -> (featuresOverride.copy(unknown = featuresOverride.unknown ++ pluginMessageParams.map(_.pluginFeature)), channelTypesOverride) }.toMap val syncWhitelist: Set[PublicKey] = config.getStringList("sync-whitelist").asScala.map(s => PublicKey(ByteVector.fromValidHex(s))).toSet @@ -309,6 +349,7 @@ object NodeParams extends Logging { color = Color(color(0), color(1), color(2)), publicAddresses = addresses, features = coreAndPluginFeatures, + channelTypes = channelTypes, pluginParams = pluginParams, overrideFeatures = overrideFeatures, syncWhitelist = syncWhitelist, diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala index 7756577489..66ba585cb8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala @@ -201,6 +201,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId txPublisher ! SetChannelId(remoteNodeId, temporaryChannelId) val fundingPubKey = keyManager.fundingPublicKey(localParams.fundingKeyPath).publicKey val channelKeyPath = keyManager.keyPath(localParams, channelVersion) + val channelTypes = channelVersion.filterChannelTypes(nodeParams.channelTypesFor(remoteNodeId)) val open = OpenChannel(nodeParams.chainHash, temporaryChannelId = temporaryChannelId, fundingSatoshis = fundingSatoshis, @@ -221,7 +222,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId channelFlags = channelFlags, // In order to allow TLV extensions and keep backwards-compatibility, we include an empty upfront_shutdown_script. // See https://github.com/lightningnetwork/lightning-rfc/pull/714. - tlvStream = TlvStream(ChannelTlv.UpfrontShutdownScript(ByteVector.empty))) + tlvStream = TlvStream(ChannelTlv.UpfrontShutdownScript(ByteVector.empty), OpenChannelTlv.ChannelTypes(channelTypes.map(_.features).toList))) goto(WAIT_FOR_ACCEPT_CHANNEL) using DATA_WAIT_FOR_ACCEPT_CHANNEL(initFunder, open) sending open case Event(inputFundee@INPUT_INIT_FUNDEE(_, localParams, remote, _, _), Nothing) if !localParams.isFunder => @@ -362,7 +363,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId firstPerCommitmentPoint = keyManager.commitmentPoint(channelKeyPath, 0), // In order to allow TLV extensions and keep backwards-compatibility, we include an empty upfront_shutdown_script. // See https://github.com/lightningnetwork/lightning-rfc/pull/714. - tlvStream = TlvStream(ChannelTlv.UpfrontShutdownScript(ByteVector.empty))) + tlvStream = TlvStream(ChannelTlv.UpfrontShutdownScript(ByteVector.empty), AcceptChannelTlv.ChannelType(channelVersion.channelType.features))) val remoteParams = RemoteParams( nodeId = remoteNodeId, dustLimit = open.dustLimitSatoshis, @@ -389,11 +390,11 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId }) when(WAIT_FOR_ACCEPT_CHANNEL)(handleExceptions { - case Event(accept: AcceptChannel, d@DATA_WAIT_FOR_ACCEPT_CHANNEL(INPUT_INIT_FUNDER(temporaryChannelId, fundingSatoshis, pushMsat, initialFeeratePerKw, fundingTxFeeratePerKw, initialRelayFees_opt, localParams, _, remoteInit, _, channelVersion), open)) => + case Event(accept: AcceptChannel, d@DATA_WAIT_FOR_ACCEPT_CHANNEL(INPUT_INIT_FUNDER(temporaryChannelId, fundingSatoshis, pushMsat, initialFeeratePerKw, fundingTxFeeratePerKw, initialRelayFees_opt, localParams, _, remoteInit, _, initialChannelVersion), open)) => log.info(s"received AcceptChannel=$accept") - Helpers.validateParamsFunder(nodeParams, open, accept) match { + Helpers.validateParamsFunder(nodeParams, open, accept, initialChannelVersion) match { case Left(t) => handleLocalError(t, d, Some(accept)) - case _ => + case Right(finalChannelVersion) => val remoteParams = RemoteParams( nodeId = remoteNodeId, dustLimit = accept.dustLimitSatoshis, @@ -412,7 +413,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId val localFundingPubkey = keyManager.fundingPublicKey(localParams.fundingKeyPath) val fundingPubkeyScript = Script.write(Script.pay2wsh(Scripts.multiSig2of2(localFundingPubkey.publicKey, remoteParams.fundingPubKey))) wallet.makeFundingTx(fundingPubkeyScript, fundingSatoshis, fundingTxFeeratePerKw).pipeTo(self) - goto(WAIT_FOR_FUNDING_INTERNAL) using DATA_WAIT_FOR_FUNDING_INTERNAL(temporaryChannelId, localParams, remoteParams, fundingSatoshis, pushMsat, initialFeeratePerKw, initialRelayFees_opt, accept.firstPerCommitmentPoint, channelVersion, open) + goto(WAIT_FOR_FUNDING_INTERNAL) using DATA_WAIT_FOR_FUNDING_INTERNAL(temporaryChannelId, localParams, remoteParams, fundingSatoshis, pushMsat, initialFeeratePerKw, initialRelayFees_opt, accept.firstPerCommitmentPoint, finalChannelVersion, open) } case Event(c: CloseCommand, d: DATA_WAIT_FOR_ACCEPT_CHANNEL) => diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelExceptions.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelExceptions.scala index f04f0a2b2d..6e20991e8c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelExceptions.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelExceptions.scala @@ -20,7 +20,7 @@ import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin.{ByteVector32, Satoshi, Transaction} import fr.acinq.eclair.blockchain.fee.FeeratePerKw import fr.acinq.eclair.wire.protocol.{AnnouncementSignatures, Error, UpdateAddHtlc} -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, MilliSatoshi, UInt64} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, UInt64} /** * Created by PM on 11/04/2017. @@ -40,6 +40,7 @@ case class InvalidChainHash (override val channelId: Byte case class InvalidFundingAmount (override val channelId: ByteVector32, fundingAmount: Satoshi, min: Satoshi, max: Satoshi) extends ChannelException(channelId, s"invalid funding_satoshis=$fundingAmount (min=$min max=$max)") case class InvalidPushAmount (override val channelId: ByteVector32, pushAmount: MilliSatoshi, max: MilliSatoshi) extends ChannelException(channelId, s"invalid pushAmount=$pushAmount (max=$max)") case class InvalidMaxAcceptedHtlcs (override val channelId: ByteVector32, maxAcceptedHtlcs: Int, max: Int) extends ChannelException(channelId, s"invalid max_accepted_htlcs=$maxAcceptedHtlcs (max=$max)") +case class IncompatibleChannelTypes (override val channelId: ByteVector32, supportedChannelTypes: Seq[Features]) extends ChannelException(channelId, s"incompatible channel types (we support ${supportedChannelTypes.map(_.toByteVector.toHex).mkString(" or ")})") case class DustLimitTooSmall (override val channelId: ByteVector32, dustLimit: Satoshi, min: Satoshi) extends ChannelException(channelId, s"dustLimit=$dustLimit is too small (min=$min)") case class DustLimitTooLarge (override val channelId: ByteVector32, dustLimit: Satoshi, max: Satoshi) extends ChannelException(channelId, s"dustLimit=$dustLimit is too large (max=$max)") case class DustLimitAboveOurChannelReserve (override val channelId: ByteVector32, dustLimit: Satoshi, channelReserve: Satoshi) extends ChannelException(channelId, s"dustLimit=$dustLimit is above our channelReserve=$channelReserve") diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelTypes.scala index 254b8ff8f1..3871db6c03 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelTypes.scala @@ -25,7 +25,7 @@ import fr.acinq.eclair.router.Announcements import fr.acinq.eclair.transactions.CommitmentSpec import fr.acinq.eclair.transactions.Transactions._ import fr.acinq.eclair.wire.protocol.{AcceptChannel, ChannelAnnouncement, ChannelReestablish, ChannelUpdate, ClosingSigned, FailureMessage, FundingCreated, FundingLocked, FundingSigned, Init, OnionRoutingPacket, OpenChannel, Shutdown, UpdateAddHtlc, UpdateFailHtlc, UpdateFailMalformedHtlc, UpdateFulfillHtlc} -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, ShortChannelId, UInt64} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, FeatureSupport, Features, MilliSatoshi, ShortChannelId, UInt64} import scodec.bits.{BitVector, ByteVector} import java.util.UUID @@ -483,6 +483,12 @@ object ChannelFlags { val Empty = 0x00.toByte } +/** Each channel type is a specific combination of features listed in the RFC (Bolt 2). */ +case class ChannelType(features: Features) { + /** True if our main output in the remote commitment is directly sent (without any delay) to one of our wallet addresses. */ + def paysDirectlyToWallet: Boolean = features.hasFeature(Features.StaticRemoteKey) && !features.hasFeature(Features.AnchorOutputs) +} + case class ChannelVersion(bits: BitVector) { import ChannelVersion._ @@ -494,6 +500,23 @@ case class ChannelVersion(bits: BitVector) { DefaultCommitmentFormat } + val channelType: ChannelType = { + if (hasAnchorOutputs) { + ChannelType(Features(Features.StaticRemoteKey -> FeatureSupport.Optional, Features.AnchorOutputs -> FeatureSupport.Optional)) + } else if (hasStaticRemotekey) { + ChannelType(Features(Features.StaticRemoteKey -> FeatureSupport.Optional)) + } else { + ChannelType(Features.empty) + } + } + + /** Filter channel types to keep only those compatible with the current channel version. */ + def filterChannelTypes(channelTypes: Seq[ChannelType]): Seq[ChannelType] = { + // We ensure we don't mix channel types that pay to our wallet with channel types that don't, since they use + // different methods to obtain the payment basepoint. + channelTypes.filter(_.paysDirectlyToWallet == paysDirectlyToWallet) + } + def |(other: ChannelVersion) = ChannelVersion(bits | other.bits) def &(other: ChannelVersion) = ChannelVersion(bits & other.bits) def ^(other: ChannelVersion) = ChannelVersion(bits ^ other.bits) @@ -503,8 +526,7 @@ case class ChannelVersion(bits: BitVector) { def hasPubkeyKeyPath: Boolean = isSet(USE_PUBKEY_KEYPATH_BIT) def hasStaticRemotekey: Boolean = isSet(USE_STATIC_REMOTEKEY_BIT) def hasAnchorOutputs: Boolean = isSet(USE_ANCHOR_OUTPUTS_BIT) - /** True if our main output in the remote commitment is directly sent (without any delay) to one of our wallet addresses. */ - def paysDirectlyToWallet: Boolean = hasStaticRemotekey && !hasAnchorOutputs + def paysDirectlyToWallet: Boolean = channelType.paysDirectlyToWallet } object ChannelVersion { @@ -518,6 +540,10 @@ object ChannelVersion { def fromBit(bit: Int): ChannelVersion = ChannelVersion(BitVector.low(LENGTH_BITS).set(bit).reverse) + /** + * Pick the channel version that should be applied based on features alone (in case our peer doesn't support explicit + * channel type negotiation). + */ def pickChannelVersion(localFeatures: Features, remoteFeatures: Features): ChannelVersion = { if (Features.canUseFeature(localFeatures, remoteFeatures, Features.AnchorOutputs)) { ANCHOR_OUTPUTS @@ -528,6 +554,17 @@ object ChannelVersion { } } + /** Pick a channel version that matches the negotiated channel type. */ + def pickChannelVersion(channelType: ChannelType): ChannelVersion = { + if (channelType.features.hasFeature(Features.AnchorOutputs)) { + ANCHOR_OUTPUTS + } else if (channelType.features.hasFeature(Features.StaticRemoteKey)) { + STATIC_REMOTEKEY + } else { + STANDARD + } + } + val ZEROES = ChannelVersion(bin"00000000000000000000000000000000") val STANDARD = ZEROES | fromBit(USE_PUBKEY_KEYPATH_BIT) val STATIC_REMOTEKEY = STANDARD | fromBit(USE_STATIC_REMOTEKEY_BIT) // PUBKEY_KEYPATH + STATIC_REMOTEKEY diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala index 6cc1d64e45..1fd3e29cda 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala @@ -135,7 +135,7 @@ object Helpers { /** * Called by the funder */ - def validateParamsFunder(nodeParams: NodeParams, open: OpenChannel, accept: AcceptChannel): Either[ChannelException, Unit] = { + def validateParamsFunder(nodeParams: NodeParams, open: OpenChannel, accept: AcceptChannel, proposedChannelVersion: ChannelVersion): Either[ChannelException, ChannelVersion] = { if (accept.maxAcceptedHtlcs > Channel.MAX_ACCEPTED_HTLCS) return Left(InvalidMaxAcceptedHtlcs(accept.temporaryChannelId, accept.maxAcceptedHtlcs, Channel.MAX_ACCEPTED_HTLCS)) // only enforce dust limit check on mainnet if (nodeParams.chainHash == Block.LivenetGenesisBlock.hash) { @@ -162,7 +162,14 @@ object Helpers { val reserveToFundingRatio = accept.channelReserveSatoshis.toLong.toDouble / Math.max(open.fundingSatoshis.toLong, 1) if (reserveToFundingRatio > nodeParams.maxReserveToFundingRatio) return Left(ChannelReserveTooHigh(open.temporaryChannelId, accept.channelReserveSatoshis, reserveToFundingRatio, nodeParams.maxReserveToFundingRatio)) - Right() + accept.channelType_opt match { + case Some(channelType) if !open.channelTypes.contains(channelType) => + Left(IncompatibleChannelTypes(accept.temporaryChannelId, open.channelTypes)) + case Some(channelType) => + Right(ChannelVersion.pickChannelVersion(ChannelType(channelType))) + case None => + Right(proposedChannelVersion) + } } /** diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala index 509980a442..d1058e3d0b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala @@ -120,6 +120,7 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: EclairWa stay case Event(c: Peer.OpenChannel, d: ConnectedData) => + val channelVersion = ChannelVersion.pickChannelVersion(d.localFeatures, d.remoteFeatures) if (c.fundingSatoshis >= Channel.MAX_FUNDING && !d.localFeatures.hasFeature(Wumbo)) { sender ! Status.Failure(new RuntimeException(s"fundingSatoshis=${c.fundingSatoshis} is too big, you must enable large channels support in 'eclair.features' to use funding above ${Channel.MAX_FUNDING} (see eclair.conf)")) stay @@ -129,8 +130,10 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: EclairWa } else if (c.fundingSatoshis > nodeParams.maxFundingSatoshis) { sender ! Status.Failure(new RuntimeException(s"fundingSatoshis=${c.fundingSatoshis} is too big for the current settings, increase 'eclair.max-funding-satoshis' (see eclair.conf)")) stay + } else if (channelVersion.filterChannelTypes(nodeParams.channelTypesFor(remoteNodeId)).isEmpty) { + sender ! Status.Failure(new RuntimeException(s"cannot find a suitable channel type with $remoteNodeId, make sure that 'channel-types' and 'features' are properly configured (see eclair.conf)")) + stay } else { - val channelVersion = ChannelVersion.pickChannelVersion(d.localFeatures, d.remoteFeatures) val (channel, localParams) = createNewChannel(nodeParams, d.localFeatures, funder = true, c.fundingSatoshis, origin_opt = Some(sender), channelVersion) c.timeout_opt.map(openTimeout => context.system.scheduler.scheduleOnce(openTimeout.duration, channel, Channel.TickChannelOpenTimeout)(context.dispatcher)) val temporaryChannelId = randomBytes32() @@ -144,13 +147,26 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: EclairWa case Event(msg: protocol.OpenChannel, d: ConnectedData) => d.channels.get(TemporaryChannelId(msg.temporaryChannelId)) match { case None => - val channelVersion = ChannelVersion.pickChannelVersion(d.localFeatures, d.remoteFeatures) - val (channel, localParams) = createNewChannel(nodeParams, d.localFeatures, funder = false, fundingAmount = msg.fundingSatoshis, origin_opt = None, channelVersion) - val temporaryChannelId = msg.temporaryChannelId - log.info(s"accepting a new channel with temporaryChannelId=$temporaryChannelId localParams=$localParams") - channel ! INPUT_INIT_FUNDEE(temporaryChannelId, localParams, d.peerConnection, d.remoteInit, channelVersion) - channel ! msg - stay using d.copy(channels = d.channels + (TemporaryChannelId(temporaryChannelId) -> channel)) + val channelVersion_opt = msg.channelTypes match { + case proposedChannelTypes if proposedChannelTypes.nonEmpty => + // We select our preferred channel version based on their proposed channel types. + nodeParams.channelTypesFor(remoteNodeId).find(ct => msg.channelTypes.contains(ct.features)).map(chosenChannelType => ChannelVersion.pickChannelVersion(chosenChannelType)) + case _ => + Some(ChannelVersion.pickChannelVersion(d.localFeatures, d.remoteFeatures)) + } + channelVersion_opt match { + case Some(channelVersion) => + val (channel, localParams) = createNewChannel(nodeParams, d.localFeatures, funder = false, fundingAmount = msg.fundingSatoshis, origin_opt = None, channelVersion) + val temporaryChannelId = msg.temporaryChannelId + log.info(s"accepting a new channel with temporaryChannelId=$temporaryChannelId localParams=$localParams") + channel ! INPUT_INIT_FUNDEE(temporaryChannelId, localParams, d.peerConnection, d.remoteInit, channelVersion) + channel ! msg + stay using d.copy(channels = d.channels + (TemporaryChannelId(temporaryChannelId) -> channel)) + case None => + log.warning(s"rejecting channel: the proposed channel types are not supported: ${msg.channelTypes.mkString(", ")}") + d.peerConnection ! Error(msg.temporaryChannelId, IncompatibleChannelTypes(msg.temporaryChannelId, msg.channelTypes).getMessage) + stay + } case Some(_) => log.warning(s"ignoring open_channel with duplicate temporaryChannelId=${msg.temporaryChannelId}") stay diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/ChannelTlv.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/ChannelTlv.scala index 9f98a00d28..d06ad1741c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/ChannelTlv.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/ChannelTlv.scala @@ -16,9 +16,10 @@ package fr.acinq.eclair.wire.protocol -import fr.acinq.eclair.UInt64 -import fr.acinq.eclair.wire.protocol.TlvCodecs.tlvStream import fr.acinq.eclair.wire.protocol.CommonCodecs._ +import fr.acinq.eclair.wire.protocol.LightningMessageCodecs.featuresCodec +import fr.acinq.eclair.wire.protocol.TlvCodecs.tlvStream +import fr.acinq.eclair.{Features, UInt64} import scodec.Codec import scodec.bits.ByteVector import scodec.codecs._ @@ -40,8 +41,11 @@ object OpenChannelTlv { import ChannelTlv._ + case class ChannelTypes(proposed: List[Features]) extends OpenChannelTlv + val openTlvCodec: Codec[TlvStream[OpenChannelTlv]] = tlvStream(discriminated[OpenChannelTlv].by(varint) .typecase(UInt64(0), variableSizeBytesLong(varintoverflow, bytes).as[UpfrontShutdownScript]) + .typecase(UInt64(1), variableSizeBytesLong(varintoverflow, list(featuresCodec)).as[ChannelTypes]) ) } @@ -50,7 +54,11 @@ object AcceptChannelTlv { import ChannelTlv._ + case class ChannelType(features: Features) extends AcceptChannelTlv + val acceptTlvCodec: Codec[TlvStream[AcceptChannelTlv]] = tlvStream(discriminated[AcceptChannelTlv].by(varint) .typecase(UInt64(0), variableSizeBytesLong(varintoverflow, bytes).as[UpfrontShutdownScript]) + .typecase(UInt64(1), variableSizeBytesLong(varintoverflow, featuresCodec).as[ChannelType]) ) + } \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala index 7fda142976..8041c44a62 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala @@ -87,7 +87,9 @@ case class OpenChannel(chainHash: ByteVector32, htlcBasepoint: PublicKey, firstPerCommitmentPoint: PublicKey, channelFlags: Byte, - tlvStream: TlvStream[OpenChannelTlv] = TlvStream.empty) extends ChannelMessage with HasTemporaryChannelId with HasChainHash + tlvStream: TlvStream[OpenChannelTlv] = TlvStream.empty) extends ChannelMessage with HasTemporaryChannelId with HasChainHash { + val channelTypes: List[Features] = tlvStream.get[OpenChannelTlv.ChannelTypes].map(_.proposed).getOrElse(Nil) +} case class AcceptChannel(temporaryChannelId: ByteVector32, dustLimitSatoshis: Satoshi, @@ -103,7 +105,9 @@ case class AcceptChannel(temporaryChannelId: ByteVector32, delayedPaymentBasepoint: PublicKey, htlcBasepoint: PublicKey, firstPerCommitmentPoint: PublicKey, - tlvStream: TlvStream[AcceptChannelTlv] = TlvStream.empty) extends ChannelMessage with HasTemporaryChannelId + tlvStream: TlvStream[AcceptChannelTlv] = TlvStream.empty) extends ChannelMessage with HasTemporaryChannelId { + val channelType_opt: Option[Features] = tlvStream.get[AcceptChannelTlv.ChannelType].map(_.features) +} case class FundingCreated(temporaryChannelId: ByteVector32, fundingTxid: ByteVector32, diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/StartupSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/StartupSpec.scala index 3ed044e2fe..f43de97053 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/StartupSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/StartupSpec.scala @@ -22,6 +22,7 @@ import fr.acinq.bitcoin.{Block, SatoshiLong} import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.Features._ import fr.acinq.eclair.blockchain.fee.{FeeratePerByte, FeeratePerKw, FeerateTolerance} +import fr.acinq.eclair.channel.ChannelType import fr.acinq.eclair.crypto.keymanager.{LocalChannelKeyManager, LocalNodeKeyManager} import org.scalatest.funsuite.AnyFunSuite import scodec.bits.{ByteVector, HexStringSyntax} @@ -86,9 +87,19 @@ class StartupSpec extends AnyFunSuite { } test("NodeParams should fail if features are inconsistent") { + // We don't want to test inconsistencies between channel types and features in this test, so we only keep the + // standard channel type that requires no feature. + val defaultChannelTypes = ConfigFactory.parseString( + """ + | channel-types { + | commitment-format = ["standard"] + | } + """.stripMargin + ) + // Because of https://github.com/ACINQ/eclair/issues/1434, we need to remove the default features when falling back // to the default configuration. - def finalizeConf(testCfg: Config): Config = testCfg.withFallback(defaultConf.withoutPath("features")) + def finalizeConf(testCfg: Config): Config = testCfg.withFallback(defaultChannelTypes).withFallback(defaultConf.withoutPath("features")) val legalFeaturesConf = ConfigFactory.parseMap(Map( s"features.${OptionDataLossProtect.rfcName}" -> "optional", @@ -135,6 +146,7 @@ class StartupSpec extends AnyFunSuite { s"features.${ChannelRangeQueriesExtended.rfcName}" -> "optional" ).asJava) + makeNodeParamsWithDefaults(finalizeConf(legalFeaturesConf)) assert(Try(makeNodeParamsWithDefaults(finalizeConf(legalFeaturesConf))).isSuccess) assert(Try(makeNodeParamsWithDefaults(finalizeConf(noVariableLengthOnionConf))).isFailure) assert(Try(makeNodeParamsWithDefaults(finalizeConf(optionalVarOnionOptinConf))).isFailure) @@ -143,25 +155,178 @@ class StartupSpec extends AnyFunSuite { assert(Try(makeNodeParamsWithDefaults(finalizeConf(illegalFeaturesConf))).isFailure) } - test("parse human readable override features") { + test("NodeParams should fail if channel types are inconsistent with features") { + val validConf1 = ConfigFactory.parseString( + """ + | features { + | var_onion_optin = mandatory + | payment_secret = mandatory + | basic_mpp = mandatory + | option_static_remotekey = optional + | } + | channel-types { + | commitment-format = ["standard", "static_remotekey"] + | } + """.stripMargin + ) + + val validConf2 = ConfigFactory.parseString( + """ + | features { + | var_onion_optin = mandatory + | payment_secret = mandatory + | basic_mpp = mandatory + | option_static_remotekey = optional + | option_anchor_outputs = optional + | } + | channel-types { + | commitment-format = ["standard", "static_remotekey", "anchor_outputs"] + | } + """.stripMargin + ) + + val unknownChannelType = ConfigFactory.parseString( + """ + | features { + | var_onion_optin = mandatory + | payment_secret = mandatory + | basic_mpp = mandatory + | option_static_remotekey = optional + | } + | channel-types { + | commitment-format = ["standard", "much_commitment_very_format"] + | } + """.stripMargin + ) + + val missingFeature = ConfigFactory.parseString( + """ + | features { + | var_onion_optin = mandatory + | payment_secret = mandatory + | basic_mpp = mandatory + | option_static_remotekey = optional + | } + | channel-types { + | commitment-format = ["standard", "anchor_outputs"] + | } + """.stripMargin + ) + + assert(Try(makeNodeParamsWithDefaults(validConf1.withFallback(defaultConf))).isSuccess) + assert(Try(makeNodeParamsWithDefaults(validConf2.withFallback(defaultConf))).isSuccess) + assert(Try(makeNodeParamsWithDefaults(unknownChannelType.withFallback(defaultConf))).isFailure) + assert(Try(makeNodeParamsWithDefaults(missingFeature.withFallback(defaultConf))).isFailure) + } + + test("override features and channel types") { + val defaultFeatures = ConfigFactory.parseString( + """ + | features { + | var_onion_optin = mandatory + | payment_secret = mandatory + | option_static_remotekey = optional + | } + """.stripMargin + ).withFallback(defaultConf.withoutPath("features")) + val defaultChannelTypes = ConfigFactory.parseString( + """ + | channel-types { + | commitment-format = ["standard", "static_remotekey"] + | } + """.stripMargin + ) val perNodeConf = ConfigFactory.parseString( """ | override-features = [ // optional per-node features | { | nodeid = "02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", - | features { - | var_onion_optin = mandatory - | payment_secret = mandatory - | basic_mpp = mandatory - | } + | features { + | var_onion_optin = mandatory + | payment_secret = mandatory + | basic_mpp = mandatory + | } + | channel-types { + | commitment-format = ["standard"] + | } + | }, + | { + | nodeid = "02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", + | features { + | var_onion_optin = mandatory + | payment_secret = mandatory + | option_static_remotekey = optional + | option_anchor_outputs = optional + | } + | channel-types { + | commitment-format = ["static_remotekey", "anchor_outputs"] + | } + | }, + | { + | nodeid = "02cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc", + | features { + | var_onion_optin = mandatory + | payment_secret = mandatory + | option_static_remotekey = optional + | option_anchor_outputs = optional + | } + | }, + | { + | nodeid = "02dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd", + | channel-types { + | commitment-format = ["static_remotekey"] + | } + | } + | ] + """.stripMargin + ) + val invalidPerNodeConf = ConfigFactory.parseString( + """ + | override-features = [ // optional per-node features + | { + | nodeid = "02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + | features { + | var_onion_optin = mandatory + | payment_secret = mandatory + | basic_mpp = mandatory + | } + | channel-types = ["basic", "anchor_outputs"] | } | ] """.stripMargin ) - val nodeParams = makeNodeParamsWithDefaults(perNodeConf.withFallback(defaultConf)) - val perNodeFeatures = nodeParams.featuresFor(PublicKey(ByteVector.fromValidHex("02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"))) - assert(perNodeFeatures === Features(VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Mandatory)) + assert(Try(makeNodeParamsWithDefaults(invalidPerNodeConf.withFallback(defaultChannelTypes).withFallback(defaultFeatures))).isFailure) + val nodeParams = makeNodeParamsWithDefaults(perNodeConf.withFallback(defaultChannelTypes).withFallback(defaultFeatures)) + + { + val alice = PublicKey(ByteVector.fromValidHex("02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")) + val perNodeFeatures = nodeParams.featuresFor(alice) + val perNodeChannelTypes = nodeParams.channelTypesFor(alice) + assert(perNodeFeatures === Features(VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, BasicMultiPartPayment -> Mandatory)) + assert(perNodeChannelTypes === List(ChannelType(Features.empty))) + } + { + val bob = PublicKey(ByteVector.fromValidHex("02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")) + val perNodeFeatures = nodeParams.featuresFor(bob) + val perNodeChannelTypes = nodeParams.channelTypesFor(bob) + assert(perNodeFeatures === Features(VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, StaticRemoteKey -> Optional, AnchorOutputs -> Optional)) + assert(perNodeChannelTypes === List(ChannelType(Features(StaticRemoteKey -> Optional)), ChannelType(Features(StaticRemoteKey -> Optional, AnchorOutputs -> Optional)))) + } + { + val carol = PublicKey(ByteVector.fromValidHex("02cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc")) + val perNodeFeatures = nodeParams.featuresFor(carol) + val perNodeChannelTypes = nodeParams.channelTypesFor(carol) + assert(perNodeFeatures === Features(VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, StaticRemoteKey -> Optional, AnchorOutputs -> Optional)) + assert(perNodeChannelTypes === List(ChannelType(Features.empty), ChannelType(Features(StaticRemoteKey -> Optional)))) + } + { + val dave = PublicKey(ByteVector.fromValidHex("02dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd")) + val perNodeFeatures = nodeParams.featuresFor(dave) + val perNodeChannelTypes = nodeParams.channelTypesFor(dave) + assert(perNodeFeatures === Features(VariableLengthOnion -> Mandatory, PaymentSecret -> Mandatory, StaticRemoteKey -> Optional)) + assert(perNodeChannelTypes === List(ChannelType(Features(StaticRemoteKey -> Optional)))) + } } test("override feerate mismatch tolerance") { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala index 6e5497fe1d..9e16f103e6 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala @@ -20,7 +20,7 @@ import fr.acinq.bitcoin.{Block, ByteVector32, Satoshi, SatoshiLong, Script} import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.Features._ import fr.acinq.eclair.blockchain.fee.{FeeEstimator, FeeTargets, FeeratesPerKw, OnChainFeeConf, _} -import fr.acinq.eclair.channel.LocalParams +import fr.acinq.eclair.channel.{ChannelType, LocalParams} import fr.acinq.eclair.crypto.keymanager.{LocalChannelKeyManager, LocalNodeKeyManager} import fr.acinq.eclair.io.{Peer, PeerConnection} import fr.acinq.eclair.router.Router.RouterConf @@ -93,6 +93,7 @@ object TestConstants { ), Set(UnknownFeature(TestFeature.optional)) ), + channelTypes = List(ChannelType(Features.empty)), pluginParams = List(pluginParams), overrideFeatures = Map.empty, syncWhitelist = Set.empty, @@ -199,6 +200,7 @@ object TestConstants { PaymentSecret -> Mandatory, BasicMultiPartPayment -> Optional ), + channelTypes = List(ChannelType(Features.empty)), pluginParams = Nil, overrideFeatures = Map.empty, syncWhitelist = Set.empty, diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/ChannelTypesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/ChannelTypesSpec.scala index b8a34492be..2c12f5dcc0 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/ChannelTypesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/ChannelTypesSpec.scala @@ -2,13 +2,15 @@ package fr.acinq.eclair.channel import akka.testkit.{TestFSMRef, TestProbe} import fr.acinq.bitcoin.{ByteVector32, OutPoint, SatoshiLong, Transaction, TxIn, TxOut} +import fr.acinq.eclair.FeatureSupport._ +import fr.acinq.eclair.Features._ import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher.WatchFundingSpentTriggered import fr.acinq.eclair.channel.Helpers.Closing import fr.acinq.eclair.channel.states.StateTestsHelperMethods import fr.acinq.eclair.transactions.Transactions import fr.acinq.eclair.transactions.Transactions._ import fr.acinq.eclair.wire.protocol.{CommitSig, RevokeAndAck, UpdateAddHtlc} -import fr.acinq.eclair.{MilliSatoshiLong, TestKitBaseClass} +import fr.acinq.eclair.{Features, MilliSatoshiLong, TestKitBaseClass} import org.scalatest.funsuite.AnyFunSuiteLike import scodec.bits.ByteVector @@ -36,10 +38,6 @@ class ChannelTypesSpec extends TestKitBaseClass with AnyFunSuiteLike with StateT } test("pick channel version based on local and remote features") { - import fr.acinq.eclair.FeatureSupport._ - import fr.acinq.eclair.Features - import fr.acinq.eclair.Features._ - case class TestCase(localFeatures: Features, remoteFeatures: Features, expectedChannelVersion: ChannelVersion) val testCases = Seq( TestCase(Features.empty, Features.empty, ChannelVersion.STANDARD), @@ -56,6 +54,39 @@ class ChannelTypesSpec extends TestKitBaseClass with AnyFunSuiteLike with StateT } } + test("pick channel version based on channel type") { + val testCases = Seq( + ChannelType(Features.empty) -> ChannelVersion.STANDARD, + ChannelType(Features(StaticRemoteKey -> Optional)) -> ChannelVersion.STATIC_REMOTEKEY, + ChannelType(Features(StaticRemoteKey -> Optional, AnchorOutputs -> Optional)) -> ChannelVersion.ANCHOR_OUTPUTS, + // These channel types are invalid and should be filtered out, but we're able to handle them just in case + ChannelType(Features(StaticRemoteKey -> Mandatory)) -> ChannelVersion.STATIC_REMOTEKEY, + ChannelType(Features(StaticRemoteKey -> Mandatory, AnchorOutputs -> Optional)) -> ChannelVersion.ANCHOR_OUTPUTS, + ) + + for ((channelType, expectedChannelVersion) <- testCases) { + assert(ChannelVersion.pickChannelVersion(channelType) === expectedChannelVersion) + } + } + + test("filter compatible channel types") { + val standard = ChannelType(Features.empty) + val staticRemoteKey = ChannelType(Features(StaticRemoteKey -> Optional)) + val anchorOutputs = ChannelType(Features(StaticRemoteKey -> Optional, AnchorOutputs -> Optional)) + assert(ChannelVersion.STANDARD.channelType === standard) + assert(ChannelVersion.STANDARD.filterChannelTypes(Seq(standard)) === Seq(standard)) + assert(ChannelVersion.STANDARD.filterChannelTypes(Seq(staticRemoteKey)) === Nil) + assert(ChannelVersion.STANDARD.filterChannelTypes(Seq(standard, staticRemoteKey, anchorOutputs)) === Seq(standard, anchorOutputs)) + assert(ChannelVersion.STATIC_REMOTEKEY.channelType === staticRemoteKey) + assert(ChannelVersion.STATIC_REMOTEKEY.filterChannelTypes(Seq(standard)) === Nil) + assert(ChannelVersion.STATIC_REMOTEKEY.filterChannelTypes(Seq(staticRemoteKey)) === Seq(staticRemoteKey)) + assert(ChannelVersion.STATIC_REMOTEKEY.filterChannelTypes(Seq(standard, staticRemoteKey, anchorOutputs)) === Seq(staticRemoteKey)) + assert(ChannelVersion.ANCHOR_OUTPUTS.channelType === anchorOutputs) + assert(ChannelVersion.ANCHOR_OUTPUTS.filterChannelTypes(Seq(standard)) === Seq(standard)) + assert(ChannelVersion.ANCHOR_OUTPUTS.filterChannelTypes(Seq(anchorOutputs)) === Seq(anchorOutputs)) + assert(ChannelVersion.ANCHOR_OUTPUTS.filterChannelTypes(Seq(standard, staticRemoteKey, anchorOutputs)) === Seq(standard, anchorOutputs)) + } + case class HtlcWithPreimage(preimage: ByteVector32, htlc: UpdateAddHtlc) case class Fixture(alice: TestFSMRef[State, Data, Channel], alicePendingHtlc: HtlcWithPreimage, bob: TestFSMRef[State, Data, Channel], bobPendingHtlc: HtlcWithPreimage, probe: TestProbe) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/a/WaitForAcceptChannelStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/a/WaitForAcceptChannelStateSpec.scala index 03adaece48..b416fac9ae 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/a/WaitForAcceptChannelStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/a/WaitForAcceptChannelStateSpec.scala @@ -24,8 +24,8 @@ import fr.acinq.eclair.blockchain.{MakeFundingTxResponse, TestWallet} import fr.acinq.eclair.channel.Channel.TickChannelOpenTimeout import fr.acinq.eclair.channel._ import fr.acinq.eclair.channel.states.{StateTestsBase, StateTestsTags} -import fr.acinq.eclair.wire.protocol.{AcceptChannel, ChannelTlv, Error, Init, OpenChannel, TlvStream} -import fr.acinq.eclair.{CltvExpiryDelta, TestConstants, TestKitBaseClass} +import fr.acinq.eclair.wire.protocol.{AcceptChannel, AcceptChannelTlv, ChannelTlv, Error, Init, OpenChannel, TlvStream} +import fr.acinq.eclair.{CltvExpiryDelta, FeatureSupport, Features, TestConstants, TestKitBaseClass} import org.scalatest.funsuite.FixtureAnyFunSuiteLike import org.scalatest.{Outcome, Tag} import scodec.bits.ByteVector @@ -79,7 +79,8 @@ class WaitForAcceptChannelStateSpec extends TestKitBaseClass with FixtureAnyFunS import f._ val accept = bob2alice.expectMsgType[AcceptChannel] // Since https://github.com/lightningnetwork/lightning-rfc/pull/714 we must include an empty upfront_shutdown_script. - assert(accept.tlvStream === TlvStream(ChannelTlv.UpfrontShutdownScript(ByteVector.empty))) + assert(accept.tlvStream.get[ChannelTlv.UpfrontShutdownScript] === Some(ChannelTlv.UpfrontShutdownScript(ByteVector.empty))) + assert(accept.channelType_opt === Some(Features.empty)) bob2alice.forward(alice) awaitCond(alice.stateName == WAIT_FOR_FUNDING_INTERNAL) } @@ -177,6 +178,16 @@ class WaitForAcceptChannelStateSpec extends TestKitBaseClass with FixtureAnyFunS awaitCond(alice.stateName == WAIT_FOR_FUNDING_INTERNAL) } + test("recv AcceptChannel (incompatible channel type)") { f => + import f._ + val accept = bob2alice.expectMsgType[AcceptChannel].copy( + tlvStream = TlvStream(AcceptChannelTlv.ChannelType(Features(Features.StaticRemoteKey -> FeatureSupport.Optional))) + ) + bob2alice.forward(alice, accept) + alice2bob.expectMsgType[Error] + awaitCond(alice.stateName == CLOSED) + } + test("recv Error") { f => import f._ alice ! Error(ByteVector32.Zeroes, "oops") diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/a/WaitForOpenChannelStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/a/WaitForOpenChannelStateSpec.scala index 6f865d27b2..020954a7ff 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/a/WaitForOpenChannelStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/a/WaitForOpenChannelStateSpec.scala @@ -22,8 +22,8 @@ import fr.acinq.eclair.TestConstants.{Alice, Bob} import fr.acinq.eclair.blockchain.fee.FeeratePerKw import fr.acinq.eclair.channel._ import fr.acinq.eclair.channel.states.{StateTestsBase, StateTestsTags} -import fr.acinq.eclair.wire.protocol.{AcceptChannel, ChannelTlv, Error, Init, OpenChannel, TlvStream} -import fr.acinq.eclair.{CltvExpiryDelta, MilliSatoshiLong, TestConstants, TestKitBaseClass, ToMilliSatoshiConversion} +import fr.acinq.eclair.wire.protocol.{AcceptChannel, ChannelTlv, Error, Init, OpenChannel} +import fr.acinq.eclair.{CltvExpiryDelta, Features, MilliSatoshiLong, TestConstants, TestKitBaseClass, ToMilliSatoshiConversion} import org.scalatest.funsuite.FixtureAnyFunSuiteLike import org.scalatest.{Outcome, Tag} import scodec.bits.ByteVector @@ -63,7 +63,8 @@ class WaitForOpenChannelStateSpec extends TestKitBaseClass with FixtureAnyFunSui import f._ val open = alice2bob.expectMsgType[OpenChannel] // Since https://github.com/lightningnetwork/lightning-rfc/pull/714 we must include an empty upfront_shutdown_script. - assert(open.tlvStream === TlvStream(ChannelTlv.UpfrontShutdownScript(ByteVector.empty))) + assert(open.tlvStream.get[ChannelTlv.UpfrontShutdownScript] === Some(ChannelTlv.UpfrontShutdownScript(ByteVector.empty))) + assert(open.channelTypes === List(Features.empty)) alice2bob.forward(bob) awaitCond(bob.stateName == WAIT_FOR_FUNDING_CREATED) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala index 1040f115a0..4cb2689af6 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala @@ -65,6 +65,7 @@ class PeerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Paralle import com.softwaremill.quicklens._ val aliceParams = TestConstants.Alice.nodeParams .modify(_.features).setToIf(test.tags.contains("static_remotekey"))(Features(StaticRemoteKey -> Optional)) + .modify(_.channelTypes).setToIf(test.tags.contains("static_remotekey"))(List(ChannelType(Features.empty), ChannelType(Features(StaticRemoteKey -> Optional)))) .modify(_.features).setToIf(test.tags.contains("wumbo"))(Features(Wumbo -> Optional)) .modify(_.features).setToIf(test.tags.contains("anchor_outputs"))(Features(StaticRemoteKey -> Optional, AnchorOutputs -> Optional)) .modify(_.maxFundingSatoshis).setToIf(test.tags.contains("high-max-funding-satoshis"))(Btc(0.9)) @@ -303,6 +304,76 @@ class PeerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Paralle assert(probe.expectMsgType[Failure].cause.getMessage == s"fundingSatoshis=$fundingAmountBig is too big for the current settings, increase 'eclair.max-funding-satoshis' (see eclair.conf)") } + test("don't spawn a channel if channel types and features are incompatible") { f => + import f._ + + // Alice only wants to use standard channels, but both peers have turned on the static_remotekey feature. + // Alice can't know beforehand if Bob supports explicit channel type negotiation. + // If Bob does support it, they would open a standard channel. + // If Bob doesn't support it, they would open a static_remotekey channel (based on node feature bits). + // These two outcomes are incompatible, so we report a configuration issue. + val nodeParams = TestConstants.Alice.nodeParams.copy( + channelTypes = ChannelType(Features.empty) :: Nil, + features = Features(StaticRemoteKey -> Optional) + ) + val peer = TestFSMRef(new Peer(nodeParams, remoteNodeId, new TestWallet(), FakeChannelFactory(channel))) + connect(remoteNodeId, peer, peerConnection, remoteInit = protocol.Init(Features(StaticRemoteKey -> Optional))) + assert(peer.stateData.channels.isEmpty) + + val probe = TestProbe() + probe.send(peer, Peer.OpenChannel(remoteNodeId, 25000 sat, 0 msat, None, None, None, None)) + assert(probe.expectMsgType[Failure].cause.getMessage.contains("cannot find a suitable channel type")) + } + + test("don't spawn a channel if channel types and features are incompatible (with peer override)") { f => + import f._ + + // Alice only wants to use standard channels with Bob, but both peers have turned on the static_remotekey feature. + // Alice can't know beforehand if Bob supports explicit channel type negotiation. + // If Bob does support it, they would open a standard channel. + // If Bob doesn't support it, they would open a static_remotekey channel (based on node feature bits). + // These two outcomes are incompatible, so we report a configuration issue. + val nodeParams = TestConstants.Alice.nodeParams.copy( + channelTypes = ChannelType(Features.empty) :: ChannelType(Features(StaticRemoteKey -> Optional)) :: Nil, + features = Features(StaticRemoteKey -> Optional), + overrideFeatures = Map(remoteNodeId -> (Features(StaticRemoteKey -> Optional), ChannelType(Features.empty) :: Nil)) + ) + val peer = TestFSMRef(new Peer(nodeParams, remoteNodeId, new TestWallet(), FakeChannelFactory(channel))) + connect(remoteNodeId, peer, peerConnection, remoteInit = protocol.Init(Features(StaticRemoteKey -> Optional))) + assert(peer.stateData.channels.isEmpty) + + val probe = TestProbe() + probe.send(peer, Peer.OpenChannel(remoteNodeId, 25000 sat, 0 msat, None, None, None, None)) + assert(probe.expectMsgType[Failure].cause.getMessage.contains("cannot find a suitable channel type")) + } + + test("don't spawn a channel if we don't support their channel types") { f => + import f._ + + connect(remoteNodeId, peer, peerConnection) + assert(peer.stateData.channels.isEmpty) + + // They only support anchor outputs and we don't. + val openTlv = TlvStream[OpenChannelTlv](OpenChannelTlv.ChannelTypes(Features(StaticRemoteKey -> Optional, AnchorOutputs -> Optional) :: Nil)) + val open = protocol.OpenChannel(Block.RegtestGenesisBlock.hash, randomBytes32(), 25000 sat, 0 msat, 483 sat, UInt64(100), 1000 sat, 1 msat, TestConstants.feeratePerKw, CltvExpiryDelta(144), 10, randomKey().publicKey, randomKey().publicKey, randomKey().publicKey, randomKey().publicKey, randomKey().publicKey, randomKey().publicKey, 0, openTlv) + peerConnection.send(peer, open) + peerConnection.expectMsg(Error(open.temporaryChannelId, "incompatible channel types")) + } + + test("choose from their channel types when spawning a channel", Tag("static_remotekey")) { f => + import f._ + + // We both support option_static_remotekey but they don't propose it in their channel types. + connect(remoteNodeId, peer, peerConnection, remoteInit = protocol.Init(Features(StaticRemoteKey -> Optional))) + assert(peer.stateData.channels.isEmpty) + val openTlv = TlvStream[OpenChannelTlv](OpenChannelTlv.ChannelTypes(Features(StaticRemoteKey -> Optional, AnchorOutputs -> Optional) :: Features.empty :: Nil)) + val open = protocol.OpenChannel(Block.RegtestGenesisBlock.hash, randomBytes32(), 25000 sat, 0 msat, 483 sat, UInt64(100), 1000 sat, 1 msat, TestConstants.feeratePerKw, CltvExpiryDelta(144), 10, randomKey().publicKey, randomKey().publicKey, randomKey().publicKey, randomKey().publicKey, randomKey().publicKey, randomKey().publicKey, 0, openTlv) + peerConnection.send(peer, open) + awaitCond(peer.stateData.channels.nonEmpty) + assert(channel.expectMsgType[INPUT_INIT_FUNDEE].channelVersion === ChannelVersion.STANDARD) + channel.expectMsg(open) + } + test("use correct fee rates when spawning a channel") { f => import f._ @@ -331,7 +402,7 @@ class PeerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Paralle feeEstimator.setFeerate(FeeratesPerKw.single(TestConstants.anchorOutputsFeeratePerKw * 2)) probe.send(peer, Peer.OpenChannel(remoteNodeId, 15000 sat, 0 msat, None, None, None, None)) val init = channel.expectMsgType[INPUT_INIT_FUNDER] - assert(init.channelVersion.hasAnchorOutputs) + assert(init.channelVersion === ChannelVersion.ANCHOR_OUTPUTS) assert(init.fundingAmount === 15000.sat) assert(init.initialRelayFees_opt === None) assert(init.initialFeeratePerKw === TestConstants.anchorOutputsFeeratePerKw) @@ -345,7 +416,7 @@ class PeerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Paralle connect(remoteNodeId, peer, peerConnection, remoteInit = protocol.Init(Features(StaticRemoteKey -> Optional))) probe.send(peer, Peer.OpenChannel(remoteNodeId, 24000 sat, 0 msat, None, None, None, None)) val init = channel.expectMsgType[INPUT_INIT_FUNDER] - assert(init.channelVersion.hasStaticRemotekey) + assert(init.channelVersion === ChannelVersion.STATIC_REMOTEKEY) assert(init.localParams.walletStaticPaymentBasepoint.isDefined) assert(init.localParams.defaultFinalScriptPubKey === Script.write(Script.pay2wpkh(init.localParams.walletStaticPaymentBasepoint.get))) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecsSpec.scala index da3089154d..e52a028134 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecsSpec.scala @@ -18,6 +18,8 @@ package fr.acinq.eclair.wire.protocol import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64, SatoshiLong} +import fr.acinq.eclair.FeatureSupport.Optional +import fr.acinq.eclair.Features.{AnchorOutputs, StaticRemoteKey} import fr.acinq.eclair._ import fr.acinq.eclair.blockchain.fee.FeeratePerKw import fr.acinq.eclair.router.Announcements @@ -107,7 +109,11 @@ class LightningMessageCodecsSpec extends AnyFunSuite { // empty upfront_shutdown_script + unknown odd tlv records defaultEncoded ++ hex"0000 0302002a 050102" -> defaultOpen.copy(tlvStream = TlvStream(Seq(ChannelTlv.UpfrontShutdownScript(ByteVector.empty)), Seq(GenericTlv(UInt64(3), hex"002a"), GenericTlv(UInt64(5), hex"02")))), // non-empty upfront_shutdown_script + unknown odd tlv records - defaultEncoded ++ hex"0002 1234 0303010203" -> defaultOpen.copy(tlvStream = TlvStream(Seq(ChannelTlv.UpfrontShutdownScript(hex"1234")), Seq(GenericTlv(UInt64(3), hex"010203")))) + defaultEncoded ++ hex"0002 1234 0303010203" -> defaultOpen.copy(tlvStream = TlvStream(Seq(ChannelTlv.UpfrontShutdownScript(hex"1234")), Seq(GenericTlv(UInt64(3), hex"010203")))), + // empty upfront_shutdown_script + channel types + defaultEncoded ++ hex"0000" ++ hex"0106000000022000" -> defaultOpen.copy(tlvStream = TlvStream(ChannelTlv.UpfrontShutdownScript(ByteVector.empty), OpenChannelTlv.ChannelTypes(List(Features(), Features(StaticRemoteKey -> Optional))))), + // non-empty upfront_shutdown_script + channel types + defaultEncoded ++ hex"0004 01abcdef" ++ hex"0109000320200000022000" -> defaultOpen.copy(tlvStream = TlvStream(ChannelTlv.UpfrontShutdownScript(hex"01abcdef"), OpenChannelTlv.ChannelTypes(List(Features(StaticRemoteKey -> Optional, AnchorOutputs -> Optional), Features(StaticRemoteKey -> Optional))))) ) for ((encoded, expected) <- testCases) { @@ -124,6 +130,7 @@ class LightningMessageCodecsSpec extends AnyFunSuite { defaultEncoded ++ hex"00", // truncated length defaultEncoded ++ hex"01", // truncated length defaultEncoded ++ hex"0004 123456", // truncated upfront_shutdown_script + defaultEncoded ++ hex"0000 010400040123", // truncated channel types defaultEncoded ++ hex"0000 02012a", // invalid tlv stream (unknown even record) defaultEncoded ++ hex"0000 01012a 030201", // invalid tlv stream (truncated) defaultEncoded ++ hex"02012a", // invalid tlv stream (unknown even record) @@ -145,8 +152,10 @@ class LightningMessageCodecsSpec extends AnyFunSuite { val testCases = Map( defaultEncoded -> defaultAccept, // legacy encoding without upfront_shutdown_script defaultEncoded ++ hex"0000" -> defaultAccept.copy(tlvStream = TlvStream(ChannelTlv.UpfrontShutdownScript(ByteVector.empty))), // empty upfront_shutdown_script + defaultEncoded ++ hex"0000" ++ hex"01020000" -> defaultAccept.copy(tlvStream = TlvStream(ChannelTlv.UpfrontShutdownScript(ByteVector.empty), AcceptChannelTlv.ChannelType(Features()))), // empty upfront_shutdown_script with channel type defaultEncoded ++ hex"0004 01abcdef" -> defaultAccept.copy(tlvStream = TlvStream(ChannelTlv.UpfrontShutdownScript(hex"01abcdef"))), // non-empty upfront_shutdown_script - defaultEncoded ++ hex"0000 0102002a 030102" -> defaultAccept.copy(tlvStream = TlvStream(ChannelTlv.UpfrontShutdownScript(ByteVector.empty) :: Nil, GenericTlv(UInt64(1), hex"002a") :: GenericTlv(UInt64(3), hex"02") :: Nil)), // empty upfront_shutdown_script + unknown odd tlv records + defaultEncoded ++ hex"0004 01abcdef" ++ hex"010400022000" -> defaultAccept.copy(tlvStream = TlvStream(ChannelTlv.UpfrontShutdownScript(hex"01abcdef"), AcceptChannelTlv.ChannelType(Features(StaticRemoteKey -> Optional)))), // non-empty upfront_shutdown_script with channel type + defaultEncoded ++ hex"0000 0302002a 050102" -> defaultAccept.copy(tlvStream = TlvStream(ChannelTlv.UpfrontShutdownScript(ByteVector.empty) :: Nil, GenericTlv(UInt64(3), hex"002a") :: GenericTlv(UInt64(5), hex"02") :: Nil)), // empty upfront_shutdown_script + unknown odd tlv records defaultEncoded ++ hex"0002 1234 0303010203" -> defaultAccept.copy(tlvStream = TlvStream(ChannelTlv.UpfrontShutdownScript(hex"1234") :: Nil, GenericTlv(UInt64(3), hex"010203") :: Nil)), // non-empty upfront_shutdown_script + unknown odd tlv records defaultEncoded ++ hex"0303010203 05020123" -> defaultAccept.copy(tlvStream = TlvStream(Nil, GenericTlv(UInt64(3), hex"010203") :: GenericTlv(UInt64(5), hex"0123") :: Nil)) // no upfront_shutdown_script + unknown odd tlv records )