diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala index 88dda7225a..24c11b41e4 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala @@ -100,8 +100,8 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A log.debug(s"got authenticated connection to $remoteNodeId@${address.getHostString}:${address.getPort}") transport ! TransportHandler.Listener(self) context watch transport - val localInit = nodeParams.overrideFeatures.get(remoteNodeId) match { - case Some(f) => wire.Init(f) + val localFeatures = nodeParams.overrideFeatures.get(remoteNodeId) match { + case Some(f) => f case None => // Eclair-mobile thinks feature bit 15 (payment_secret) is gossip_queries_ex which creates issues, so we mask // off basic_mpp and payment_secret. As long as they're provided in the invoice it's not an issue. @@ -116,9 +116,10 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A // ... and leave the others untouched case (value, _) => value }).reverse.bytes.dropWhile(_ == 0) - wire.Init(tweakedFeatures) + tweakedFeatures } - log.info(s"using features=${localInit.features.toBin}") + log.info(s"using features=${localFeatures.toBin}") + val localInit = wire.Init(localFeatures, TlvStream(InitTlv.Networks(nodeParams.chainHash :: Nil))) transport ! localInit val address_opt = if (outgoing) { @@ -148,9 +149,19 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A case Event(remoteInit: wire.Init, d: InitializingData) => d.transport ! TransportHandler.ReadAck(remoteInit) - log.info(s"peer is using features=${remoteInit.features.toBin}") + log.info(s"peer is using features=${remoteInit.features.toBin}, networks=${remoteInit.networks.mkString(",")}") - if (Features.areSupported(remoteInit.features)) { + if (remoteInit.networks.nonEmpty && !remoteInit.networks.contains(nodeParams.chainHash)) { + log.warning(s"incompatible networks (${remoteInit.networks}), disconnecting") + d.origin_opt.foreach(origin => origin ! Status.Failure(new RuntimeException("incompatible networks"))) + d.transport ! PoisonPill + stay + } else if (!Features.areSupported(remoteInit.features)) { + log.warning("incompatible features, disconnecting") + d.origin_opt.foreach(origin => origin ! Status.Failure(new RuntimeException("incompatible features"))) + d.transport ! PoisonPill + stay + } else { d.origin_opt.foreach(origin => origin ! "connected") def localHasFeature(f: Feature): Boolean = Features.hasFeature(d.localInit.features, f) @@ -181,11 +192,6 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A val rebroadcastDelay = Random.nextInt(nodeParams.routerConf.routerBroadcastInterval.toSeconds.toInt).seconds log.info(s"rebroadcast will be delayed by $rebroadcastDelay") goto(CONNECTED) using ConnectedData(d.address_opt, d.transport, d.localInit, remoteInit, d.channels.map { case (k: ChannelId, v) => (k, v) }, rebroadcastDelay) forMax (30 seconds) // forMax will trigger a StateTimeout - } else { - log.warning(s"incompatible features, disconnecting") - d.origin_opt.foreach(origin => origin ! Status.Failure(new RuntimeException("incompatible features"))) - d.transport ! PoisonPill - stay } case Event(Authenticator.Authenticated(connection, _, _, _, _, origin_opt), _) => diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/InitTlv.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/InitTlv.scala new file mode 100644 index 0000000000..275243860d --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/InitTlv.scala @@ -0,0 +1,49 @@ +/* + * Copyright 2019 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.wire + +import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.eclair.UInt64 +import fr.acinq.eclair.wire.CommonCodecs._ +import scodec.Codec +import scodec.codecs.{discriminated, list, variableSizeBytesLong} + +/** + * Created by t-bast on 13/12/2019. + */ + +/** Tlv types used inside Init messages. */ +sealed trait InitTlv extends Tlv + +object InitTlv { + + /** The chains the node is interested in. */ + case class Networks(chainHashes: List[ByteVector32]) extends InitTlv + +} + +object InitTlvCodecs { + + import InitTlv._ + + private val networks: Codec[Networks] = variableSizeBytesLong(varintoverflow, list(bytes32)).as[Networks] + + val initTlvCodec = TlvCodecs.tlvStream(discriminated[InitTlv].by(varint) + .typecase(UInt64(1), networks) + ) + +} \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala index 2e030cc110..aec272e10b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala @@ -39,7 +39,7 @@ object LightningMessageCodecs { }, { features => (ByteVector.empty, features) }) - val initCodec: Codec[Init] = combinedFeaturesCodec.as[Init] + val initCodec: Codec[Init] = (("features" | combinedFeaturesCodec) :: ("tlvStream" | InitTlvCodecs.initTlvCodec)).as[Init] val errorCodec: Codec[Error] = ( ("channelId" | bytes32) :: diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala index 45ae3c69b8..be88cfda64 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala @@ -45,7 +45,9 @@ sealed trait HasChainHash extends LightningMessage { def chainHash: ByteVector32 sealed trait UpdateMessage extends HtlcMessage // <- not in the spec // @formatter:on -case class Init(features: ByteVector) extends SetupMessage +case class Init(features: ByteVector, tlvs: TlvStream[InitTlv] = TlvStream.empty) extends SetupMessage { + val networks = tlvs.get[InitTlv.Networks].map(_.chainHashes).getOrElse(Nil) +} case class Error(channelId: ByteVector32, data: ByteVector) extends SetupMessage with HasChannelId { def toAscii: String = if (fr.acinq.eclair.isAsciiPrintable(data)) new String(data.toArray, StandardCharsets.US_ASCII) else "n/a" diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala index 3ea3d85c21..f108d24f32 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala @@ -22,41 +22,40 @@ import scodec.bits.ByteVector import scala.reflect.ClassTag /** - * Created by t-bast on 20/06/2019. - */ + * Created by t-bast on 20/06/2019. + */ trait Tlv /** - * Generic tlv type we fallback to if we don't understand the incoming tlv. - * - * @param tag tlv tag. - * @param value tlv value (length is implicit, and encoded as a varint). - */ + * Generic tlv type we fallback to if we don't understand the incoming tlv. + * + * @param tag tlv tag. + * @param value tlv value (length is implicit, and encoded as a varint). + */ case class GenericTlv(tag: UInt64, value: ByteVector) extends Tlv /** - * A tlv stream is a collection of tlv records. - * A tlv stream is constrained to a specific tlv namespace that dictates how to parse the tlv records. - * That namespace is provided by a trait extending the top-level tlv trait. - * - * @param records known tlv records. - * @param unknown unknown tlv records. - * @tparam T the stream namespace is a trait extending the top-level tlv trait. - */ + * A tlv stream is a collection of tlv records. + * A tlv stream is constrained to a specific tlv namespace that dictates how to parse the tlv records. + * That namespace is provided by a trait extending the top-level tlv trait. + * + * @param records known tlv records. + * @param unknown unknown tlv records. + * @tparam T the stream namespace is a trait extending the top-level tlv trait. + */ case class TlvStream[T <: Tlv](records: Traversable[T], unknown: Traversable[GenericTlv] = Nil) { /** - * - * @tparam R input type parameter, must be a subtype of the main TLV type - * @return the TLV record of type that matches the input type parameter if any (there can be at most one, since BOLTs specify - * that TLV records are supposed to be unique) - */ + * + * @tparam R input type parameter, must be a subtype of the main TLV type + * @return the TLV record of type that matches the input type parameter if any (there can be at most one, since BOLTs specify + * that TLV records are supposed to be unique) + */ def get[R <: T : ClassTag]: Option[R] = records.collectFirst { case r: R => r } } object TlvStream { - def empty[T <: Tlv] = TlvStream[T](Nil, Nil) + def empty[T <: Tlv]: TlvStream[T] = TlvStream[T](Nil, Nil) def apply[T <: Tlv](records: T*): TlvStream[T] = TlvStream(records, Nil) - } \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala index 8994595a56..83dbf71b60 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala @@ -21,6 +21,7 @@ import java.net.{Inet4Address, InetAddress, InetSocketAddress, ServerSocket} import akka.actor.FSM.{CurrentState, SubscribeTransitionCallBack, Transition} import akka.actor.{ActorRef, PoisonPill} import akka.testkit.{TestFSMRef, TestProbe} +import fr.acinq.bitcoin.Block import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.TestConstants._ import fr.acinq.eclair._ @@ -31,7 +32,7 @@ import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.io.Peer._ import fr.acinq.eclair.router.RoutingSyncSpec.makeFakeRoutingInfo import fr.acinq.eclair.router.{Rebroadcast, RoutingSyncSpec, SendChannelQuery} -import fr.acinq.eclair.wire.{ChannelCodecsSpec, Color, EncodedShortChannelIds, EncodingType, Error, IPv4, LightningMessageCodecs, NodeAddress, NodeAnnouncement, Ping, Pong, QueryShortChannelIds, TlvStream} +import fr.acinq.eclair.wire.{ChannelCodecsSpec, Color, EncodedShortChannelIds, EncodingType, Error, IPv4, InitTlv, LightningMessageCodecs, NodeAddress, NodeAnnouncement, Ping, Pong, QueryShortChannelIds, TlvStream} import org.scalatest.{Outcome, Tag} import scodec.bits.{ByteVector, _} @@ -81,7 +82,8 @@ class PeerSpec extends TestkitBaseClass with StateTestsHelperMethods { probe.send(peer, Peer.Init(None, channels)) authenticator.send(peer, Authenticator.Authenticated(connection.ref, transport.ref, remoteNodeId, fakeIPAddress.socketAddress, outgoing = true, None)) transport.expectMsgType[TransportHandler.Listener] - transport.expectMsgType[wire.Init] + val localInit = transport.expectMsgType[wire.Init] + assert(localInit.networks === List(Block.RegtestGenesisBlock.hash)) transport.send(peer, remoteInit) transport.expectMsgType[TransportHandler.ReadAck] if (expectSync) { @@ -255,6 +257,19 @@ class PeerSpec extends TestkitBaseClass with StateTestsHelperMethods { assert(init.features === sentFeatures.bytes) } } + + test("disconnect if incompatible networks") { f => + import f._ + val probe = TestProbe() + probe.watch(transport.ref) + probe.send(peer, Peer.Init(None, Set.empty)) + authenticator.send(peer, Authenticator.Authenticated(connection.ref, transport.ref, remoteNodeId, new InetSocketAddress("1.2.3.4", 42000), outgoing = true, None)) + transport.expectMsgType[TransportHandler.Listener] + transport.expectMsgType[wire.Init] + transport.send(peer, wire.Init(Bob.nodeParams.features, TlvStream(InitTlv.Networks(Block.LivenetGenesisBlock.hash :: Block.SegnetGenesisBlock.hash :: Nil)))) + transport.expectMsgType[TransportHandler.ReadAck] + probe.expectTerminated(transport.ref) + } test("handle disconnect in status INITIALIZING") { f => import f._ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala index 221b33965f..7b037fdb99 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala @@ -44,21 +44,37 @@ class LightningMessageCodecsSpec extends FunSuite { def publicKey(fill: Byte) = PrivateKey(ByteVector.fill(32)(fill)).publicKey test("encode/decode init message") { + case class TestCase(encoded: ByteVector, features: ByteVector, networks: List[ByteVector32], valid: Boolean, reEncoded: Option[ByteVector] = None) + val chainHash1 = ByteVector32(hex"0101010101010101010101010101010101010101010101010101010101010101") + val chainHash2 = ByteVector32(hex"0202020202020202020202020202020202020202020202020202020202020202") val testCases = Seq( - (hex"0000 0000", hex"", hex"0000 0000"), // no features - (hex"0000 0002088a", hex"088a", hex"0000 0002088a"), // no global features - (hex"00020200 0000", hex"0200", hex"0000 00020200"), // no local features - (hex"00020200 0002088a", hex"0a8a", hex"0000 00020a8a"), // local and global - no conflict - same size - (hex"00020200 0003020002", hex"020202", hex"0000 0003020202"), // local and global - no conflict - different sizes - (hex"00020a02 0002088a", hex"0a8a", hex"0000 00020a8a"), // local and global - conflict - same size - (hex"00022200 000302aaa2", hex"02aaa2", hex"0000 000302aaa2") // local and global - conflict - different sizes + TestCase(hex"0000 0000", hex"", Nil, valid = true), // no features + TestCase(hex"0000 0002088a", hex"088a", Nil, valid = true), // no global features + TestCase(hex"00020200 0000", hex"0200", Nil, valid = true, Some(hex"0000 00020200")), // no local features + TestCase(hex"00020200 0002088a", hex"0a8a", Nil, valid = true, Some(hex"0000 00020a8a")), // local and global - no conflict - same size + TestCase(hex"00020200 0003020002", hex"020202", Nil, valid = true, Some(hex"0000 0003020202")), // local and global - no conflict - different sizes + TestCase(hex"00020a02 0002088a", hex"0a8a", Nil, valid = true, Some(hex"0000 00020a8a")), // local and global - conflict - same size + TestCase(hex"00022200 000302aaa2", hex"02aaa2", Nil, valid = true, Some(hex"0000 000302aaa2")), // local and global - conflict - different sizes + TestCase(hex"0000 0002088a 03012a05022aa2", hex"088a", Nil, valid = true), // unknown odd records + TestCase(hex"0000 0002088a 03012a04022aa2", hex"088a", Nil, valid = false), // unknown even records + TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101", hex"088a", Nil, valid = false), // invalid tlv stream + TestCase(hex"0000 0002088a 01200101010101010101010101010101010101010101010101010101010101010101", hex"088a", List(chainHash1), valid = true), // single network + TestCase(hex"0000 0002088a 014001010101010101010101010101010101010101010101010101010101010101010202020202020202020202020202020202020202020202020202020202020202", hex"088a", List(chainHash1, chainHash2), valid = true), // multiple networks + TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101010103012a", hex"088a", List(chainHash1), valid = true), // network and unknown odd records + TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101010102012a", hex"088a", Nil, valid = false) // network and unknown even records ) - for ((bin, features, encoded) <- testCases) { - val init = initCodec.decode(bin.bits).require.value - assert(init.features === features) - assert(initCodec.encode(init).require.bytes === encoded) - assert(initCodec.decode(encoded.bits).require.value === init) + for (testCase <- testCases) { + if (testCase.valid) { + val init = initCodec.decode(testCase.encoded.bits).require.value + assert(init.features === testCase.features) + assert(init.networks === testCase.networks) + val encoded = initCodec.encode(init).require + assert(encoded.bytes === testCase.reEncoded.getOrElse(testCase.encoded)) + assert(initCodec.decode(encoded).require.value === init) + } else { + assert(initCodec.decode(testCase.encoded.bits).isFailure) + } } }