From f8b5fa3dd41d2349c480fb7da2b0f2d4bc0c25d7 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Thu, 20 Jun 2019 09:49:46 +0200 Subject: [PATCH 01/12] LightningMessageTypes/Codecs: clean-up warnings --- .../eclair/wire/LightningMessageCodecs.scala | 31 +++++++++---------- .../eclair/wire/LightningMessageTypes.scala | 4 --- .../wire/LightningMessageCodecsSpec.scala | 6 ++-- 3 files changed, 17 insertions(+), 24 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala index e984eaf5e4..bc68f24daa 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala @@ -18,7 +18,6 @@ package fr.acinq.eclair.wire import java.net.{Inet4Address, Inet6Address, InetAddress} -import com.google.common.cache.{CacheBuilder, CacheLoader} import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.{ByteVector32, ByteVector64} import fr.acinq.eclair.crypto.Sphinx @@ -27,7 +26,7 @@ import fr.acinq.eclair.{ShortChannelId, UInt64, wire} import org.apache.commons.codec.binary.Base32 import scodec.bits.{BitVector, ByteVector} import scodec.codecs._ -import scodec.{Attempt, Codec, DecodeResult, Err, SizeBound} +import scodec.{Attempt, Codec, Err} import scala.util.{Failure, Success, Try} @@ -216,14 +215,13 @@ object LightningMessageCodecs { ("nodeSignature" | bytes64) :: ("bitcoinSignature" | bytes64)).as[AnnouncementSignatures] - val channelAnnouncementWitnessCodec = ( - ("features" | varsizebinarydata) :: - ("chainHash" | bytes32) :: - ("shortChannelId" | shortchannelid) :: - ("nodeId1" | publicKey) :: - ("nodeId2" | publicKey) :: - ("bitcoinKey1" | publicKey) :: - ("bitcoinKey2" | publicKey)) + val channelAnnouncementWitnessCodec = ("features" | varsizebinarydata) :: + ("chainHash" | bytes32) :: + ("shortChannelId" | shortchannelid) :: + ("nodeId1" | publicKey) :: + ("nodeId2" | publicKey) :: + ("bitcoinKey1" | publicKey) :: + ("bitcoinKey2" | publicKey) val channelAnnouncementCodec: Codec[ChannelAnnouncement] = ( ("nodeSignature1" | bytes64) :: @@ -232,13 +230,12 @@ object LightningMessageCodecs { ("bitcoinSignature2" | bytes64) :: channelAnnouncementWitnessCodec).as[ChannelAnnouncement] - val nodeAnnouncementWitnessCodec = ( - ("features" | varsizebinarydata) :: - ("timestamp" | uint32) :: - ("nodeId" | publicKey) :: - ("rgbColor" | rgb) :: - ("alias" | zeropaddedstring(32)) :: - ("addresses" | listofnodeaddresses)) + val nodeAnnouncementWitnessCodec = ("features" | varsizebinarydata) :: + ("timestamp" | uint32) :: + ("nodeId" | publicKey) :: + ("rgbColor" | rgb) :: + ("alias" | zeropaddedstring(32)) :: + ("addresses" | listofnodeaddresses) val nodeAnnouncementCodec: Codec[NodeAnnouncement] = ( ("signature" | bytes64) :: diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala index e6019d7d94..9cf827eb54 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala @@ -181,10 +181,6 @@ object NodeAddress { * * We don't attempt to resolve onion addresses (it will be done by the tor proxy), so we just recognize them based on * the .onion TLD and rely on their length to separate v2/v3. - * - * @param host - * @param port - * @return */ def fromParts(host: String, port: Int): Try[NodeAddress] = Try { host match { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala index 93917a0425..2e6f9a8bc5 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala @@ -139,7 +139,7 @@ class LightningMessageCodecsSpec extends FunSuite { { val alias = "IRATEMONK" val bin = c.encode(alias).require - assert(bin === BitVector(alias.getBytes("UTF-8") ++ Array.fill[Byte](32 - alias.size)(0))) + assert(bin === BitVector(alias.getBytes("UTF-8") ++ Array.fill[Byte](32 - alias.length)(0))) val alias2 = c.decode(bin).require.value assert(alias === alias2) } @@ -147,7 +147,7 @@ class LightningMessageCodecsSpec extends FunSuite { { val alias = "this-alias-is-exactly-32-B-long." val bin = c.encode(alias).require - assert(bin === BitVector(alias.getBytes("UTF-8") ++ Array.fill[Byte](32 - alias.size)(0))) + assert(bin === BitVector(alias.getBytes("UTF-8") ++ Array.fill[Byte](32 - alias.length)(0))) val alias2 = c.decode(bin).require.value assert(alias === alias2) } @@ -222,7 +222,7 @@ class LightningMessageCodecsSpec extends FunSuite { channel_announcement :: node_announcement :: channel_update :: gossip_timestamp_filter :: query_short_channel_id :: query_channel_range :: reply_channel_range :: announcement_signatures :: ping :: pong :: channel_reestablish :: Nil msgs.foreach { - case msg => { + msg => { val encoded = lightningMessageCodec.encode(msg).require val decoded = lightningMessageCodec.decode(encoded).require assert(msg === decoded.value) From 9960c3dabf39f3d4a2a2c74e444958192f9af716 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Thu, 20 Jun 2019 10:19:19 +0200 Subject: [PATCH 02/12] Add Bitcoin varInt codec --- .../eclair/wire/LightningMessageCodecs.scala | 43 ++++++++++++++++ .../wire/LightningMessageCodecsSpec.scala | 49 ++++++++++++++++++- 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala index bc68f24daa..1ab37d8751 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala @@ -45,8 +45,51 @@ object LightningMessageCodecs { // (for something smarter see https://github.com/yzernik/bitcoin-scodec/blob/master/src/main/scala/io/github/yzernik/bitcoinscodec/structures/UInt64.scala) val uint64: Codec[Long] = int64.narrow(l => if (l >= 0) Attempt.Successful(l) else Attempt.failure(Err(s"overflow for value $l")), l => l) + val uint64L: Codec[Long] = int64L.narrow(l => if (l >= 0) Attempt.Successful(l) else Attempt.failure(Err(s"overflow for value $l")), l => l) + val uint64ex: Codec[UInt64] = bytes(8).xmap(b => UInt64(b), a => a.toByteVector.padLeft(8)) + // Bitcoin-style varint codec (CompactSize) + val varInt = Codec[Long]( + (n: Long) => + n match { + case i if i < 0xfd => + uint8L.encode(i.toInt) + case i if i < 0xffff => + for { + a <- uint8L.encode(0xfd) + b <- uint16L.encode(i.toInt) + } yield a ++ b + case i if i < 0xffffffffL => + for { + a <- uint8L.encode(0xfe) + b <- uint32L.encode(i) + } yield a ++ b + case i => + for { + a <- uint8L.encode(0xff) + b <- uint64L.encode(i) + } yield a ++ b + }, + (buf: BitVector) => { + uint8L.decode(buf) match { + case scodec.Attempt.Successful(b) => + b.value match { + case 0xff => + uint64L.decode(b.remainder) + case 0xfe => + uint32L.decode(b.remainder) + case 0xfd => + uint16L.decode(b.remainder) + .map(b => b.map(_.toLong)) + case _ => + scodec.Attempt.Successful(scodec.DecodeResult(b.value.toLong, b.remainder)) + } + case scodec.Attempt.Failure(err) => + scodec.Attempt.Failure(err) + } + }) + def bytes32: Codec[ByteVector32] = limitedSizeBytes(32, bytesStrict(32).xmap(d => ByteVector32(d), d => d.bytes)) def bytes64: Codec[ByteVector64] = limitedSizeBytes(64, bytesStrict(64).xmap(d => ByteVector64(d), d => d.bytes)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala index 2e6f9a8bc5..8332b65e5a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala @@ -50,6 +50,7 @@ class LightningMessageCodecsSpec extends FunSuite { UInt64(42) -> hex"00 00 00 00 00 00 00 2a", UInt64(hex"ffffffffffffffff") -> hex"ff ff ff ff ff ff ff ff" ).mapValues(_.toBitVector) + for ((uint, ref) <- expected) { val encoded = uint64ex.encode(uint).require assert(ref === encoded) @@ -58,6 +59,53 @@ class LightningMessageCodecsSpec extends FunSuite { } } + test("encode/decode with uint64L codec") { + val expected = Map( + 0L -> hex"00 00 00 00 00 00 00 00", + 42L -> hex"2a 00 00 00 00 00 00 00", + 6211610197754262546L -> hex"12 34 56 78 90 12 34 56" + ).mapValues(_.toBitVector) + + for ((long, ref) <- expected) { + val encoded = uint64L.encode(long).require + assert(ref === encoded) + val decoded = uint64L.decode(encoded).require.value + assert(long === decoded) + } + } + + test("encode/decode with varint codec") { + val expected = Map( + 0L -> hex"00", + 42L -> hex"2a", + 550L -> hex"fd 26 02", + 998000L -> hex"fe 70 3a 0f 00", + 6211610197754262546L -> hex"ff 12 34 56 78 90 12 34 56" + ).mapValues(_.toBitVector) + + for ((long, ref) <- expected) { + val encoded = varInt.encode(long).require + assert(ref === encoded) + val decoded = varInt.decode(encoded).require.value + assert(long === decoded) + } + } + + test("decode invalid varint") { + val testCases = Seq( + hex"fd", + hex"fe 01", + hex"fe", + hex"fe 12 34", + hex"ff", + hex"ff 12 34 56 78" + ).map(_.toBitVector) + + for (testCase <- testCases) { + assert(varInt.decode(testCase).isFailure) + } + } + test("encode/decode with rgb codec") { val color = Color(47.toByte, 255.toByte, 142.toByte) val bin = rgb.encode(color).require @@ -189,7 +237,6 @@ class LightningMessageCodecsSpec extends FunSuite { } test("encode/decode all channel messages") { - val open = OpenChannel(randomBytes32, randomBytes32, 3, 4, 5, UInt64(6), 7, 8, 9, 10, 11, publicKey(1), point(2), point(3), point(4), point(5), point(6), 0.toByte) val accept = AcceptChannel(randomBytes32, 3, UInt64(4), 5, 6, 7, 8, 9, publicKey(1), point(2), point(3), point(4), point(5), point(6)) val funding_created = FundingCreated(randomBytes32, bin32(0), 3, randomBytes64) From 8a2a2498544289455174832d6a59ee644cdb7e33 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Thu, 20 Jun 2019 10:26:06 +0200 Subject: [PATCH 03/12] Remove unused FixedSizeStrictCodec. Scodec has integrated that feature already. --- .../eclair/wire/FixedSizeStrictCodec.scala | 75 ------------------- .../eclair/wire/LightningMessageCodecs.scala | 2 - 2 files changed, 77 deletions(-) delete mode 100644 eclair-core/src/main/scala/fr/acinq/eclair/wire/FixedSizeStrictCodec.scala diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/FixedSizeStrictCodec.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/FixedSizeStrictCodec.scala deleted file mode 100644 index d5afa984cc..0000000000 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/FixedSizeStrictCodec.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright 2019 ACINQ SAS - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package fr.acinq.eclair.wire - -import scodec.bits.{BitVector, ByteVector} -import scodec.{Attempt, Codec, DecodeResult, Err, SizeBound, codecs} - -/** - * - * REMOVE THIS A NEW VERSION OF SCODEC IS RELEASED THAT INCLUDES CHANGES MADE IN - * https://github.com/scodec/scodec/pull/99/files - * - * Created by PM on 02/06/2017. - */ -final class FixedSizeStrictCodec[A](size: Long, codec: Codec[A]) extends Codec[A] { - - override def sizeBound = SizeBound.exact(size) - - override def encode(a: A) = for { - encoded <- codec.encode(a) - result <- { - if (encoded.size != size) - Attempt.failure(Err(s"[$a] requires ${encoded.size} bits but field is fixed size of exactly $size bits")) - else - Attempt.successful(encoded.padTo(size)) - } - } yield result - - override def decode(buffer: BitVector) = { - if (buffer.size == size) { - codec.decode(buffer.take(size)) map { res => - DecodeResult(res.value, buffer.drop(size)) - } - } else { - Attempt.failure(Err(s"expected exactly $size bits but got ${buffer.size} bits")) - } - } - - override def toString = s"fixedSizeBitsStrict($size, $codec)" -} - -object FixedSizeStrictCodec { - /** - * Encodes by returning the supplied byte vector if its length is `size` bytes, otherwise returning error; - * decodes by taking `size * 8` bits from the supplied bit vector and converting to a byte vector. - * - * @param size number of bits to encode/decode - * @group bits - */ - def bytesStrict(size: Int): Codec[ByteVector] = new Codec[ByteVector] { - private val codec = new FixedSizeStrictCodec(size * 8L, codecs.bits).xmap[ByteVector](_.toByteVector, _.toBitVector) - - def sizeBound = codec.sizeBound - - def encode(b: ByteVector) = codec.encode(b) - - def decode(b: BitVector) = codec.decode(b) - - override def toString = s"bytesStrict($size)" - } -} \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala index 1ab37d8751..a3920268b0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala @@ -21,7 +21,6 @@ import java.net.{Inet4Address, Inet6Address, InetAddress} import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.{ByteVector32, ByteVector64} import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.wire.FixedSizeStrictCodec.bytesStrict import fr.acinq.eclair.{ShortChannelId, UInt64, wire} import org.apache.commons.codec.binary.Base32 import scodec.bits.{BitVector, ByteVector} @@ -30,7 +29,6 @@ import scodec.{Attempt, Codec, Err} import scala.util.{Failure, Success, Try} - /** * Created by PM on 15/11/2016. */ From 7dadd2d6f1afc9a58affc8fcd334738dd97c3c56 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Thu, 20 Jun 2019 11:30:36 +0200 Subject: [PATCH 04/12] Add generic TLV codec. Codecs will be namespaced, hence the use of child traits of the Tlv trait. --- .../fr/acinq/eclair/wire/TlvCodecs.scala | 39 +++++++++ .../scala/fr/acinq/eclair/wire/TlvTypes.scala | 36 ++++++++ .../fr/acinq/eclair/wire/TlvCodecsSpec.scala | 86 +++++++++++++++++++ 3 files changed, 161 insertions(+) create mode 100644 eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala create mode 100644 eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala create mode 100644 eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala new file mode 100644 index 0000000000..4cff651d1c --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala @@ -0,0 +1,39 @@ +/* + * Copyright 2019 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.wire + +import fr.acinq.eclair.wire.LightningMessageCodecs._ +import scodec.codecs._ +import scodec.Codec + +/** + * Created by t-bast on 20/06/2019. + */ + +object TlvCodecs { + + val genericTlv: Codec[GenericTlv] = (("type" | varInt) :: variableSizeBytesLong(varInt, bytes)).as[GenericTlv] + + def tlvFallback(codec: Codec[Tlv]): Codec[Tlv] = discriminatorFallback(genericTlv, codec).xmap(_ match { + case Left(l) => l + case Right(r) => r + }, _ match { + case g: GenericTlv => Left(g) + case o => Right(o) + }) + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala new file mode 100644 index 0000000000..fa084fcd97 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala @@ -0,0 +1,36 @@ +/* + * Copyright 2019 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.wire + +import scodec.bits.ByteVector + +/** + * Created by t-bast on 20/06/2019. + */ + +// @formatter:off +trait Tlv +sealed trait OnionTlv extends Tlv +// @formatter:on + +/** + * Generic tlv type we fallback to if we don't understand the incoming type. + * + * @param `type` tlv type. + * @param value tlv value (length is implicit, and encoded as a varint). + */ +case class GenericTlv(`type`: Long, value: ByteVector) extends Tlv diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala new file mode 100644 index 0000000000..9611d8456a --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala @@ -0,0 +1,86 @@ +/* + * Copyright 2019 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.wire + +import fr.acinq.bitcoin.Crypto.PublicKey +import fr.acinq.eclair.ShortChannelId +import fr.acinq.eclair.wire.LightningMessageCodecs.{publicKey, shortchannelid, uint64, varInt} +import fr.acinq.eclair.wire.TlvCodecs._ +import org.scalatest.FunSuite +import scodec.bits.HexStringSyntax +import scodec.codecs._ +import scodec.Codec + +/** + * Created by t-bast on 20/06/2019. + */ + +class TlvCodecsSpec extends FunSuite { + + import TlvCodecsSpec._ + + test("encode/decode tlv") { + val testCases = Seq( + (hex"0x01 08 000000000000002a", TestType1(42)), + (hex"0x02 08 0000000000000226", TestType2(ShortChannelId(550))), + (hex"0x03 31 02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619 0000000000000231 0000000000000451", TestType3(PublicKey(hex"02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 561, 1105)), + (hex"0xff1234567890123456 fd0001 10101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010010101010101", GenericTlv(6211610197754262546L, hex"10101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010010101010101")) + ) + + for ((bin, expected) <- testCases) { + val decoded = testTlvCodec.decode(bin.toBitVector).require.value.asInstanceOf[Tlv] + assert(decoded === expected) + val encoded = testTlvCodec.encode(expected).require.toByteVector + assert(encoded === bin) + } + } + + test("decode invalid tlv") { + val testCases = Seq( + hex"0xfd022a", // type truncated + hex"0x2a fd022a", // length truncated + hex"0x2a fd2602 0231", // value truncated + hex"0x02 01 2a", // short channel id too short + hex"0x02 09 010101010101010101" // short channel id length too big + ) + + for (testCase <- testCases) { + assert(testTlvCodec.decode(testCase.toBitVector).isFailure) + } + } + +} + +object TlvCodecsSpec { + + // @formatter:off + sealed trait TestTlv extends Tlv + case class TestType1(longValue: Long) extends TestTlv + case class TestType2(shortChannelId: ShortChannelId) extends TestTlv + case class TestType3(nodeId: PublicKey, value1: Long, value2: Long) extends TestTlv + + val testCodec1: Codec[TestType1] = (("length" | constant(hex"0x08")) :: ("value" | uint64)).as[TestType1] + val testCodec2: Codec[TestType2] = (("length" | constant(hex"0x08")) :: ("short_channel_id" | shortchannelid)).as[TestType2] + val testCodec3: Codec[TestType3] = (("length" | constant(hex"0x31")) :: ("node_id" | publicKey) :: ("value_1" | uint64) :: ("value_2" | uint64)).as[TestType3] + val testTlvCodec = tlvFallback(discriminated[Tlv].by(varInt) + .typecase(1, testCodec1) + .typecase(2, testCodec2) + .typecase(3, testCodec3) + ) + // @formatter:on + +} From fa8867a6acae38385f29e57c287c319edd845ac6 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Thu, 20 Jun 2019 13:55:53 +0200 Subject: [PATCH 05/12] Codecs refactoring. Move common codecs to their own file. Harmonize use of val vs def for fields. --- .../electrum/db/sqlite/SqliteWalletDb.scala | 2 +- .../fr/acinq/eclair/crypto/ShaChain.scala | 4 +- .../eclair/db/sqlite/SqlitePeersDb.scala | 4 +- .../acinq/eclair/payment/PaymentRequest.scala | 4 +- .../fr/acinq/eclair/wire/ChannelCodecs.scala | 8 +- .../fr/acinq/eclair/wire/CommandCodecs.scala | 2 +- .../fr/acinq/eclair/wire/CommonCodecs.scala | 135 ++++++++++++ .../fr/acinq/eclair/wire/FailureMessage.scala | 23 +- .../eclair/wire/LightningMessageCodecs.scala | 110 +-------- .../fr/acinq/eclair/wire/TlvCodecs.scala | 2 +- .../eclair/crypto/TransportHandlerSpec.scala | 14 +- .../acinq/eclair/wire/CommonCodecsSpec.scala | 208 ++++++++++++++++++ .../wire/LightningMessageCodecsSpec.scala | 177 +-------------- .../fr/acinq/eclair/wire/TlvCodecsSpec.scala | 2 +- 14 files changed, 379 insertions(+), 316 deletions(-) create mode 100644 eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala create mode 100644 eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/electrum/db/sqlite/SqliteWalletDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/electrum/db/sqlite/SqliteWalletDb.scala index ca7b4fbbb9..907cd45f7d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/electrum/db/sqlite/SqliteWalletDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/blockchain/electrum/db/sqlite/SqliteWalletDb.scala @@ -136,7 +136,7 @@ class SqliteWalletDb(sqlite: Connection) extends WalletDb { object SqliteWalletDb { import fr.acinq.eclair.wire.ChannelCodecs._ - import fr.acinq.eclair.wire.LightningMessageCodecs._ + import fr.acinq.eclair.wire.CommonCodecs._ import scodec.Codec import scodec.bits.BitVector import scodec.codecs._ diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/ShaChain.scala b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/ShaChain.scala index 985136ae82..94438b65ef 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/ShaChain.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/ShaChain.scala @@ -17,7 +17,7 @@ package fr.acinq.eclair.crypto import fr.acinq.bitcoin._ -import fr.acinq.eclair.wire.LightningMessageCodecs +import fr.acinq.eclair.wire.CommonCodecs import scodec.Codec import scala.annotation.tailrec @@ -117,7 +117,7 @@ object ShaChain { import scodec.codecs._ // codec for a single map entry (i.e. Vector[Boolean] -> ByteVector - val entryCodec = vectorOfN(uint16, bool) ~ variableSizeBytes(uint16, LightningMessageCodecs.bytes32) + val entryCodec = vectorOfN(uint16, bool) ~ variableSizeBytes(uint16, CommonCodecs.bytes32) // codec for a Map[Vector[Boolean], ByteVector]: write all k -> v pairs using the codec defined above val mapCodec: Codec[Map[Vector[Boolean], ByteVector32]] = Codec[Map[Vector[Boolean], ByteVector32]]( diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala index f34d98d500..8d9e828ba3 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala @@ -38,7 +38,7 @@ import SqliteUtils.ExtendedResultSet._ } override def addOrUpdatePeer(nodeId: Crypto.PublicKey, nodeaddress: NodeAddress): Unit = { - val data = LightningMessageCodecs.nodeaddress.encode(nodeaddress).require.toByteArray + val data = CommonCodecs.nodeaddress.encode(nodeaddress).require.toByteArray using(sqlite.prepareStatement("UPDATE peers SET data=? WHERE node_id=?")) { update => update.setBytes(1, data) update.setBytes(2, nodeId.value.toArray) @@ -65,7 +65,7 @@ import SqliteUtils.ExtendedResultSet._ var m: Map[PublicKey, NodeAddress] = Map() while (rs.next()) { val nodeid = PublicKey(rs.getByteVector("node_id")) - val nodeaddress = LightningMessageCodecs.nodeaddress.decode(BitVector(rs.getBytes("data"))).require.value + val nodeaddress = CommonCodecs.nodeaddress.decode(BitVector(rs.getBytes("data"))).require.value m += (nodeid -> nodeaddress) } m diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentRequest.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentRequest.scala index 9e8a9323cf..b35b0d7b19 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentRequest.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentRequest.scala @@ -16,8 +16,6 @@ package fr.acinq.eclair.payment -import java.math.BigInteger - import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.{MilliSatoshi, _} import fr.acinq.eclair.ShortChannelId @@ -303,7 +301,7 @@ object PaymentRequest { object Codecs { - import fr.acinq.eclair.wire.LightningMessageCodecs._ + import fr.acinq.eclair.wire.CommonCodecs._ import scodec.bits.BitVector import scodec.codecs._ import scodec.{Attempt, Codec, DecodeResult} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/ChannelCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/ChannelCodecs.scala index 66569bc304..11ff20bf25 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/ChannelCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/ChannelCodecs.scala @@ -26,6 +26,7 @@ import fr.acinq.eclair.crypto.ShaChain import fr.acinq.eclair.payment.{Local, Origin, Relayed} import fr.acinq.eclair.transactions.Transactions._ import fr.acinq.eclair.transactions._ +import fr.acinq.eclair.wire.CommonCodecs._ import fr.acinq.eclair.wire.LightningMessageCodecs._ import grizzled.slf4j.Logging import scodec.bits.BitVector @@ -35,7 +36,6 @@ import scodec.{Attempt, Codec} import scala.compat.Platform import scala.concurrent.duration._ - /** * Created by PM on 02/06/2017. */ @@ -100,11 +100,11 @@ object ChannelCodecs extends Logging { ("toLocalMsat" | uint64) :: ("toRemoteMsat" | uint64)).as[CommitmentSpec] - def outPointCodec: Codec[OutPoint] = variableSizeBytes(uint16, bytes.xmap(d => OutPoint.read(d.toArray), d => OutPoint.write(d))) + val outPointCodec: Codec[OutPoint] = variableSizeBytes(uint16, bytes.xmap(d => OutPoint.read(d.toArray), d => OutPoint.write(d))) - def txOutCodec: Codec[TxOut] = variableSizeBytes(uint16, bytes.xmap(d => TxOut.read(d.toArray), d => TxOut.write(d))) + val txOutCodec: Codec[TxOut] = variableSizeBytes(uint16, bytes.xmap(d => TxOut.read(d.toArray), d => TxOut.write(d))) - def txCodec: Codec[Transaction] = variableSizeBytes(uint16, bytes.xmap(d => Transaction.read(d.toArray), d => Transaction.write(d))) + val txCodec: Codec[Transaction] = variableSizeBytes(uint16, bytes.xmap(d => Transaction.read(d.toArray), d => Transaction.write(d))) val inputInfoCodec: Codec[InputInfo] = ( ("outPoint" | outPointCodec) :: diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommandCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommandCodecs.scala index e70677aa5a..dac5191ae2 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommandCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommandCodecs.scala @@ -17,8 +17,8 @@ package fr.acinq.eclair.wire import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FAIL_MALFORMED_HTLC, CMD_FULFILL_HTLC, Command} +import fr.acinq.eclair.wire.CommonCodecs._ import fr.acinq.eclair.wire.FailureMessageCodecs.failureMessageCodec -import fr.acinq.eclair.wire.LightningMessageCodecs._ import scodec.Codec import scodec.codecs._ diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala new file mode 100644 index 0000000000..de0aff19d8 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala @@ -0,0 +1,135 @@ +/* + * Copyright 2019 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.wire + +import java.net.{Inet4Address, Inet6Address, InetAddress} + +import fr.acinq.bitcoin.{ByteVector32, ByteVector64} +import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} +import fr.acinq.eclair.{ShortChannelId, UInt64} +import org.apache.commons.codec.binary.Base32 +import scodec.{Attempt, Codec, Err} +import scodec.bits.{BitVector, ByteVector} +import scodec.codecs._ + +import scala.util.{Failure, Success, Try} + +/** + * Created by t-bast on 20/06/2019. + */ + +object CommonCodecs { + + def attemptFromTry[T](f: => T): Attempt[T] = Try(f) match { + case Success(t) => Attempt.successful(t) + case Failure(t) => Attempt.failure(Err(s"deserialization error: ${t.getMessage}")) + } + + // this codec can be safely used for values < 2^63 and will fail otherwise + // (for something smarter see https://github.com/yzernik/bitcoin-scodec/blob/master/src/main/scala/io/github/yzernik/bitcoinscodec/structures/UInt64.scala) + val uint64: Codec[Long] = int64.narrow(l => if (l >= 0) Attempt.Successful(l) else Attempt.failure(Err(s"overflow for value $l")), l => l) + + val uint64L: Codec[Long] = int64L.narrow(l => if (l >= 0) Attempt.Successful(l) else Attempt.failure(Err(s"overflow for value $l")), l => l) + + val uint64ex: Codec[UInt64] = bytes(8).xmap(b => UInt64(b), a => a.toByteVector.padLeft(8)) + + // Bitcoin-style varint codec (CompactSize) + val varInt = Codec[Long]( + (n: Long) => + n match { + case i if i < 0xfd => + uint8L.encode(i.toInt) + case i if i < 0xffff => + for { + a <- uint8L.encode(0xfd) + b <- uint16L.encode(i.toInt) + } yield a ++ b + case i if i < 0xffffffffL => + for { + a <- uint8L.encode(0xfe) + b <- uint32L.encode(i) + } yield a ++ b + case i => + for { + a <- uint8L.encode(0xff) + b <- uint64L.encode(i) + } yield a ++ b + }, + (buf: BitVector) => { + uint8L.decode(buf) match { + case scodec.Attempt.Successful(b) => + b.value match { + case 0xff => + uint64L.decode(b.remainder) + case 0xfe => + uint32L.decode(b.remainder) + case 0xfd => + uint16L.decode(b.remainder) + .map(b => b.map(_.toLong)) + case _ => + scodec.Attempt.Successful(scodec.DecodeResult(b.value.toLong, b.remainder)) + } + case scodec.Attempt.Failure(err) => + scodec.Attempt.Failure(err) + } + }) + + val bytes32: Codec[ByteVector32] = limitedSizeBytes(32, bytesStrict(32).xmap(d => ByteVector32(d), d => d.bytes)) + + val bytes64: Codec[ByteVector64] = limitedSizeBytes(64, bytesStrict(64).xmap(d => ByteVector64(d), d => d.bytes)) + + val sha256: Codec[ByteVector32] = bytes32 + + val varsizebinarydata: Codec[ByteVector] = variableSizeBytes(uint16, bytes) + + val listofsignatures: Codec[List[ByteVector64]] = listOfN(uint16, bytes64) + + val ipv4address: Codec[Inet4Address] = bytes(4).xmap(b => InetAddress.getByAddress(b.toArray).asInstanceOf[Inet4Address], a => ByteVector(a.getAddress)) + + val ipv6address: Codec[Inet6Address] = bytes(16).exmap(b => attemptFromTry(Inet6Address.getByAddress(null, b.toArray, null)), a => attemptFromTry(ByteVector(a.getAddress))) + + def base32(size: Int): Codec[String] = bytes(size).xmap(b => new Base32().encodeAsString(b.toArray).toLowerCase, a => ByteVector(new Base32().decode(a.toUpperCase()))) + + val nodeaddress: Codec[NodeAddress] = + discriminated[NodeAddress].by(uint8) + .typecase(1, (ipv4address :: uint16).as[IPv4]) + .typecase(2, (ipv6address :: uint16).as[IPv6]) + .typecase(3, (base32(10) :: uint16).as[Tor2]) + .typecase(4, (base32(35) :: uint16).as[Tor3]) + + // this one is a bit different from most other codecs: the first 'len' element is *not* the number of items + // in the list but rather the number of bytes of the encoded list. The rationale is once we've read this + // number of bytes we can just skip to the next field + val listofnodeaddresses: Codec[List[NodeAddress]] = variableSizeBytes(uint16, list(nodeaddress)) + + val shortchannelid: Codec[ShortChannelId] = int64.xmap(l => ShortChannelId(l), s => s.toLong) + + val privateKey: Codec[PrivateKey] = Codec[PrivateKey]( + (priv: PrivateKey) => bytes(32).encode(priv.value), + (wire: BitVector) => bytes(32).decode(wire).map(_.map(b => PrivateKey(b))) + ) + + val publicKey: Codec[PublicKey] = Codec[PublicKey]( + (pub: PublicKey) => bytes(33).encode(pub.value), + (wire: BitVector) => bytes(33).decode(wire).map(_.map(b => PublicKey(b))) + ) + + val rgb: Codec[Color] = bytes(3).xmap(buf => Color(buf(0), buf(1), buf(2)), t => ByteVector(t.r, t.g, t.b)) + + def zeropaddedstring(size: Int): Codec[String] = fixedSizeBytes(32, utf8).xmap(s => s.takeWhile(_ != '\u0000'), s => s) + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala index 7039f00361..5ff9b67188 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala @@ -17,9 +17,10 @@ package fr.acinq.eclair.wire import fr.acinq.bitcoin.ByteVector32 -import fr.acinq.eclair.wire.LightningMessageCodecs.{bytes32, channelUpdateCodec, uint64} +import fr.acinq.eclair.wire.CommonCodecs.{sha256, uint64} +import fr.acinq.eclair.wire.LightningMessageCodecs.channelUpdateCodec import scodec.codecs._ -import scodec.{Attempt, Codec} +import scodec.Attempt /** * see https://github.com/lightningnetwork/lightning-rfc/blob/master/04-onion-routing.md @@ -63,8 +64,6 @@ object FailureMessageCodecs { val NODE = 0x2000 val UPDATE = 0x1000 - val sha256Codec: Codec[ByteVector32] = ("sha256Codec" | bytes32) - val channelUpdateCodecWithType = LightningMessageCodecs.lightningMessageCodec.narrow[ChannelUpdate](f => Attempt.successful(f.asInstanceOf[ChannelUpdate]), g => g) // NB: for historical reasons some implementations were including/ommitting the message type (258 for ChannelUpdate) @@ -76,22 +75,22 @@ object FailureMessageCodecs { .typecase(NODE | 2, provide(TemporaryNodeFailure)) .typecase(PERM | 2, provide(PermanentNodeFailure)) .typecase(PERM | NODE | 3, provide(RequiredNodeFeatureMissing)) - .typecase(BADONION | PERM | 4, sha256Codec.as[InvalidOnionVersion]) - .typecase(BADONION | PERM | 5, sha256Codec.as[InvalidOnionHmac]) - .typecase(BADONION | PERM | 6, sha256Codec.as[InvalidOnionKey]) - .typecase(UPDATE | 7, (("channelUpdate" | channelUpdateWithLengthCodec)).as[TemporaryChannelFailure]) + .typecase(BADONION | PERM | 4, sha256.as[InvalidOnionVersion]) + .typecase(BADONION | PERM | 5, sha256.as[InvalidOnionHmac]) + .typecase(BADONION | PERM | 6, sha256.as[InvalidOnionKey]) + .typecase(UPDATE | 7, ("channelUpdate" | channelUpdateWithLengthCodec).as[TemporaryChannelFailure]) .typecase(PERM | 8, provide(PermanentChannelFailure)) .typecase(PERM | 9, provide(RequiredChannelFeatureMissing)) .typecase(PERM | 10, provide(UnknownNextPeer)) .typecase(UPDATE | 11, (("amountMsat" | uint64) :: ("channelUpdate" | channelUpdateWithLengthCodec)).as[AmountBelowMinimum]) .typecase(UPDATE | 12, (("amountMsat" | uint64) :: ("channelUpdate" | channelUpdateWithLengthCodec)).as[FeeInsufficient]) .typecase(UPDATE | 13, (("expiry" | uint32) :: ("channelUpdate" | channelUpdateWithLengthCodec)).as[IncorrectCltvExpiry]) - .typecase(UPDATE | 14, (("channelUpdate" | channelUpdateWithLengthCodec)).as[ExpiryTooSoon]) + .typecase(UPDATE | 14, ("channelUpdate" | channelUpdateWithLengthCodec).as[ExpiryTooSoon]) .typecase(UPDATE | 20, (("messageFlags" | byte) :: ("channelFlags" | byte) :: ("channelUpdate" | channelUpdateWithLengthCodec)).as[ChannelDisabled]) - .typecase(PERM | 15, (("amountMsat" | withDefaultValue(optional(bitsRemaining, uint64), 0L))).as[IncorrectOrUnknownPaymentDetails]) + .typecase(PERM | 15, ("amountMsat" | withDefaultValue(optional(bitsRemaining, uint64), 0L)).as[IncorrectOrUnknownPaymentDetails]) .typecase(PERM | 16, provide(IncorrectPaymentAmount)) .typecase(17, provide(FinalExpiryTooSoon)) - .typecase(18, (("expiry" | uint32)).as[FinalIncorrectCltvExpiry]) - .typecase(19, (("amountMsat" | uint64)).as[FinalIncorrectHtlcAmount]) + .typecase(18, ("expiry" | uint32).as[FinalIncorrectCltvExpiry]) + .typecase(19, ("amountMsat" | uint64).as[FinalIncorrectHtlcAmount]) .typecase(21, provide(ExpiryTooFar)) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala index a3920268b0..9aa43b1e75 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala @@ -16,120 +16,18 @@ package fr.acinq.eclair.wire -import java.net.{Inet4Address, Inet6Address, InetAddress} - -import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} -import fr.acinq.bitcoin.{ByteVector32, ByteVector64} import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.{ShortChannelId, UInt64, wire} -import org.apache.commons.codec.binary.Base32 -import scodec.bits.{BitVector, ByteVector} +import fr.acinq.eclair.wire +import fr.acinq.eclair.wire.CommonCodecs._ +import scodec.bits.ByteVector import scodec.codecs._ -import scodec.{Attempt, Codec, Err} - -import scala.util.{Failure, Success, Try} +import scodec.Codec /** * Created by PM on 15/11/2016. */ object LightningMessageCodecs { - def attemptFromTry[T](f: => T): Attempt[T] = Try(f) match { - case Success(t) => Attempt.successful(t) - case Failure(t) => Attempt.failure(Err(s"deserialization error: ${t.getMessage}")) - } - - // this codec can be safely used for values < 2^63 and will fail otherwise - // (for something smarter see https://github.com/yzernik/bitcoin-scodec/blob/master/src/main/scala/io/github/yzernik/bitcoinscodec/structures/UInt64.scala) - val uint64: Codec[Long] = int64.narrow(l => if (l >= 0) Attempt.Successful(l) else Attempt.failure(Err(s"overflow for value $l")), l => l) - - val uint64L: Codec[Long] = int64L.narrow(l => if (l >= 0) Attempt.Successful(l) else Attempt.failure(Err(s"overflow for value $l")), l => l) - - val uint64ex: Codec[UInt64] = bytes(8).xmap(b => UInt64(b), a => a.toByteVector.padLeft(8)) - - // Bitcoin-style varint codec (CompactSize) - val varInt = Codec[Long]( - (n: Long) => - n match { - case i if i < 0xfd => - uint8L.encode(i.toInt) - case i if i < 0xffff => - for { - a <- uint8L.encode(0xfd) - b <- uint16L.encode(i.toInt) - } yield a ++ b - case i if i < 0xffffffffL => - for { - a <- uint8L.encode(0xfe) - b <- uint32L.encode(i) - } yield a ++ b - case i => - for { - a <- uint8L.encode(0xff) - b <- uint64L.encode(i) - } yield a ++ b - }, - (buf: BitVector) => { - uint8L.decode(buf) match { - case scodec.Attempt.Successful(b) => - b.value match { - case 0xff => - uint64L.decode(b.remainder) - case 0xfe => - uint32L.decode(b.remainder) - case 0xfd => - uint16L.decode(b.remainder) - .map(b => b.map(_.toLong)) - case _ => - scodec.Attempt.Successful(scodec.DecodeResult(b.value.toLong, b.remainder)) - } - case scodec.Attempt.Failure(err) => - scodec.Attempt.Failure(err) - } - }) - - def bytes32: Codec[ByteVector32] = limitedSizeBytes(32, bytesStrict(32).xmap(d => ByteVector32(d), d => d.bytes)) - - def bytes64: Codec[ByteVector64] = limitedSizeBytes(64, bytesStrict(64).xmap(d => ByteVector64(d), d => d.bytes)) - - def varsizebinarydata: Codec[ByteVector] = variableSizeBytes(uint16, bytes) - - def listofsignatures: Codec[List[ByteVector64]] = listOfN(uint16, bytes64) - - def ipv4address: Codec[Inet4Address] = bytes(4).xmap(b => InetAddress.getByAddress(b.toArray).asInstanceOf[Inet4Address], a => ByteVector(a.getAddress)) - - def ipv6address: Codec[Inet6Address] = bytes(16).exmap(b => attemptFromTry(Inet6Address.getByAddress(null, b.toArray, null)), a => attemptFromTry(ByteVector(a.getAddress))) - - def base32(size: Int): Codec[String] = bytes(size).xmap(b => new Base32().encodeAsString(b.toArray).toLowerCase, a => ByteVector(new Base32().decode(a.toUpperCase()))) - - def nodeaddress: Codec[NodeAddress] = - discriminated[NodeAddress].by(uint8) - .typecase(1, (ipv4address :: uint16).as[IPv4]) - .typecase(2, (ipv6address :: uint16).as[IPv6]) - .typecase(3, (base32(10) :: uint16).as[Tor2]) - .typecase(4, (base32(35) :: uint16).as[Tor3]) - - // this one is a bit different from most other codecs: the first 'len' element is *not* the number of items - // in the list but rather the number of bytes of the encoded list. The rationale is once we've read this - // number of bytes we can just skip to the next field - def listofnodeaddresses: Codec[List[NodeAddress]] = variableSizeBytes(uint16, list(nodeaddress)) - - def shortchannelid: Codec[ShortChannelId] = int64.xmap(l => ShortChannelId(l), s => s.toLong) - - def privateKey: Codec[PrivateKey] = Codec[PrivateKey]( - (priv: PrivateKey) => bytes(32).encode(priv.value), - (wire: BitVector) => bytes(32).decode(wire).map(_.map(b => PrivateKey(b))) - ) - - def publicKey: Codec[PublicKey] = Codec[PublicKey]( - (pub: PublicKey) => bytes(33).encode(pub.value), - (wire: BitVector) => bytes(33).decode(wire).map(_.map(b => PublicKey(b))) - ) - - def rgb: Codec[Color] = bytes(3).xmap(buf => Color(buf(0), buf(1), buf(2)), t => ByteVector(t.r, t.g, t.b)) - - def zeropaddedstring(size: Int): Codec[String] = fixedSizeBytes(32, utf8).xmap(s => s.takeWhile(_ != '\u0000'), s => s) - val initCodec: Codec[Init] = ( ("globalFeatures" | varsizebinarydata) :: ("localFeatures" | varsizebinarydata)).as[Init] diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala index 4cff651d1c..e557485726 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala @@ -16,7 +16,7 @@ package fr.acinq.eclair.wire -import fr.acinq.eclair.wire.LightningMessageCodecs._ +import fr.acinq.eclair.wire.CommonCodecs._ import scodec.codecs._ import scodec.Codec diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala index 4f137245b2..37f2e58fed 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala @@ -23,7 +23,7 @@ import akka.io.Tcp import akka.testkit.{TestActorRef, TestFSMRef, TestKit, TestProbe} import fr.acinq.eclair.crypto.Noise.{Chacha20Poly1305CipherFunctions, CipherState} import fr.acinq.eclair.crypto.TransportHandler.{Encryptor, ExtendedCipherState, Listener} -import fr.acinq.eclair.wire.LightningMessageCodecs +import fr.acinq.eclair.wire.CommonCodecs import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} import scodec.Codec import scodec.bits._ @@ -49,8 +49,8 @@ class TransportHandlerSpec extends TestKit(ActorSystem("test")) with FunSuiteLik val pipe = system.actorOf(Props[MyPipe]) val probe1 = TestProbe() val probe2 = TestProbe() - val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, LightningMessageCodecs.varsizebinarydata)) - val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, LightningMessageCodecs.varsizebinarydata)) + val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, CommonCodecs.varsizebinarydata)) + val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, CommonCodecs.varsizebinarydata)) pipe ! (initiator, responder) awaitCond(initiator.stateName == TransportHandler.WaitingForListener) @@ -111,8 +111,8 @@ class TransportHandlerSpec extends TestKit(ActorSystem("test")) with FunSuiteLik val pipe = system.actorOf(Props[MyPipeSplitter]) val probe1 = TestProbe() val probe2 = TestProbe() - val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, LightningMessageCodecs.varsizebinarydata)) - val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, LightningMessageCodecs.varsizebinarydata)) + val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, CommonCodecs.varsizebinarydata)) + val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, CommonCodecs.varsizebinarydata)) pipe ! (initiator, responder) awaitCond(initiator.stateName == TransportHandler.WaitingForListener) @@ -141,8 +141,8 @@ class TransportHandlerSpec extends TestKit(ActorSystem("test")) with FunSuiteLik val pipe = system.actorOf(Props[MyPipe]) val probe1 = TestProbe() val supervisor = TestActorRef(Props(new MySupervisor())) - val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Initiator.s.pub), pipe, LightningMessageCodecs.varsizebinarydata), supervisor, "ini") - val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, LightningMessageCodecs.varsizebinarydata), supervisor, "res") + val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Initiator.s.pub), pipe, CommonCodecs.varsizebinarydata), supervisor, "ini") + val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, CommonCodecs.varsizebinarydata), supervisor, "res") probe1.watch(responder) pipe ! (initiator, responder) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala new file mode 100644 index 0000000000..59db780424 --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala @@ -0,0 +1,208 @@ +/* + * Copyright 2019 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.wire + +import java.net.{Inet4Address, Inet6Address, InetAddress} + +import com.google.common.net.InetAddresses +import fr.acinq.bitcoin.Crypto.PrivateKey +import fr.acinq.eclair.{UInt64, randomBytes32} +import fr.acinq.eclair.wire.CommonCodecs._ +import org.scalatest.FunSuite +import scodec.bits.{BitVector, HexStringSyntax} + +/** + * Created by t-bast on 20/06/2019. + */ + +class CommonCodecsSpec extends FunSuite { + + test("encode/decode with uint64 codec") { + val expected = Map( + UInt64(0) -> hex"00 00 00 00 00 00 00 00", + UInt64(42) -> hex"00 00 00 00 00 00 00 2a", + UInt64(hex"ffffffffffffffff") -> hex"ff ff ff ff ff ff ff ff" + ).mapValues(_.toBitVector) + + for ((uint, ref) <- expected) { + val encoded = uint64ex.encode(uint).require + assert(ref === encoded) + val decoded = uint64ex.decode(encoded).require.value + assert(uint === decoded) + } + } + + test("encode/decode with uint64L codec") { + val expected = Map( + 0L -> hex"00 00 00 00 00 00 00 00", + 42L -> hex"2a 00 00 00 00 00 00 00", + 6211610197754262546L -> hex"12 34 56 78 90 12 34 56" + ).mapValues(_.toBitVector) + + for ((long, ref) <- expected) { + val encoded = uint64L.encode(long).require + assert(ref === encoded) + val decoded = uint64L.decode(encoded).require.value + assert(long === decoded) + } + } + + test("encode/decode with varint codec") { + val expected = Map( + 0L -> hex"00", + 42L -> hex"2a", + 550L -> hex"fd 26 02", + 998000L -> hex"fe 70 3a 0f 00", + 6211610197754262546L -> hex"ff 12 34 56 78 90 12 34 56" + ).mapValues(_.toBitVector) + + for ((long, ref) <- expected) { + val encoded = varInt.encode(long).require + assert(ref === encoded) + val decoded = varInt.decode(encoded).require.value + assert(long === decoded) + } + } + + test("decode invalid varint") { + val testCases = Seq( + hex"fd", + hex"fe 01", + hex"fe", + hex"fe 12 34", + hex"ff", + hex"ff 12 34 56 78" + ).map(_.toBitVector) + + for (testCase <- testCases) { + assert(varInt.decode(testCase).isFailure) + } + } + + test("encode/decode with rgb codec") { + val color = Color(47.toByte, 255.toByte, 142.toByte) + val bin = rgb.encode(color).require + assert(bin === hex"2f ff 8e".toBitVector) + val color2 = rgb.decode(bin).require.value + assert(color === color2) + } + + test("encode/decode all kind of IPv6 addresses with ipv6address codec") { + { + // IPv4 mapped + val bin = hex"00000000000000000000ffffae8a0b08".toBitVector + val ipv6 = Inet6Address.getByAddress(null, bin.toByteArray, null) + val bin2 = ipv6address.encode(ipv6).require + assert(bin === bin2) + } + + { + // regular IPv6 address + val ipv6 = InetAddresses.forString("1080:0:0:0:8:800:200C:417A").asInstanceOf[Inet6Address] + val bin = ipv6address.encode(ipv6).require + val ipv62 = ipv6address.decode(bin).require.value + assert(ipv6 === ipv62) + } + } + + test("encode/decode with nodeaddress codec") { + { + val ipv4addr = InetAddress.getByAddress(Array[Byte](192.toByte, 168.toByte, 1.toByte, 42.toByte)).asInstanceOf[Inet4Address] + val nodeaddr = IPv4(ipv4addr, 4231) + val bin = nodeaddress.encode(nodeaddr).require + assert(bin === hex"01 C0 A8 01 2A 10 87".toBitVector) + val nodeaddr2 = nodeaddress.decode(bin).require.value + assert(nodeaddr === nodeaddr2) + } + { + val ipv6addr = InetAddress.getByAddress(hex"2001 0db8 0000 85a3 0000 0000 ac1f 8001".toArray).asInstanceOf[Inet6Address] + val nodeaddr = IPv6(ipv6addr, 4231) + val bin = nodeaddress.encode(nodeaddr).require + assert(bin === hex"02 2001 0db8 0000 85a3 0000 0000 ac1f 8001 1087".toBitVector) + val nodeaddr2 = nodeaddress.decode(bin).require.value + assert(nodeaddr === nodeaddr2) + } + { + val nodeaddr = Tor2("z4zif3fy7fe7bpg3", 4231) + val bin = nodeaddress.encode(nodeaddr).require + assert(bin === hex"03 cf3282ecb8f949f0bcdb 1087".toBitVector) + val nodeaddr2 = nodeaddress.decode(bin).require.value + assert(nodeaddr === nodeaddr2) + } + { + val nodeaddr = Tor3("mrl2d3ilhctt2vw4qzvmz3etzjvpnc6dczliq5chrxetthgbuczuggyd", 4231) + val bin = nodeaddress.encode(nodeaddr).require + assert(bin === hex"04 6457a1ed0b38a73d56dc866accec93ca6af68bc316568874478dc9399cc1a0b3431b03 1087".toBitVector) + val nodeaddr2 = nodeaddress.decode(bin).require.value + assert(nodeaddr === nodeaddr2) + } + } + + test("encode/decode with private key codec") { + val value = PrivateKey(randomBytes32) + val wire = privateKey.encode(value).require + assert(wire.length == 256) + val value1 = privateKey.decode(wire).require.value + assert(value1 == value) + } + + test("encode/decode with public key codec") { + val value = PrivateKey(randomBytes32).publicKey + val wire = CommonCodecs.publicKey.encode(value).require + assert(wire.length == 33 * 8) + val value1 = CommonCodecs.publicKey.decode(wire).require.value + assert(value1 == value) + } + + test("encode/decode with zeropaddedstring codec") { + val c = zeropaddedstring(32) + + { + val alias = "IRATEMONK" + val bin = c.encode(alias).require + assert(bin === BitVector(alias.getBytes("UTF-8") ++ Array.fill[Byte](32 - alias.length)(0))) + val alias2 = c.decode(bin).require.value + assert(alias === alias2) + } + + { + val alias = "this-alias-is-exactly-32-B-long." + val bin = c.encode(alias).require + assert(bin === BitVector(alias.getBytes("UTF-8") ++ Array.fill[Byte](32 - alias.length)(0))) + val alias2 = c.decode(bin).require.value + assert(alias === alias2) + } + + { + val alias = "this-alias-is-far-too-long-because-we-are-limited-to-32-bytes" + assert(c.encode(alias).isFailure) + } + } + + test("encode/decode UInt64") { + val codec = uint64ex + Seq( + UInt64(hex"ffffffffffffffff"), + UInt64(hex"fffffffffffffffe"), + UInt64(hex"efffffffffffffff"), + UInt64(hex"effffffffffffffe") + ).map(value => { + assert(codec.decode(codec.encode(value).require).require.value === value) + }) + } + +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala index 8332b65e5a..700b552971 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala @@ -16,9 +16,8 @@ package fr.acinq.eclair.wire -import java.net.{Inet4Address, Inet6Address, InetAddress} +import java.net.{Inet4Address, InetAddress} -import com.google.common.net.InetAddresses import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64} import fr.acinq.eclair._ @@ -44,180 +43,6 @@ class LightningMessageCodecsSpec extends FunSuite { def publicKey(fill: Byte) = PrivateKey(ByteVector.fill(32)(fill)).publicKey - test("encode/decode with uint64 codec") { - val expected = Map( - UInt64(0) -> hex"00 00 00 00 00 00 00 00", - UInt64(42) -> hex"00 00 00 00 00 00 00 2a", - UInt64(hex"ffffffffffffffff") -> hex"ff ff ff ff ff ff ff ff" - ).mapValues(_.toBitVector) - - for ((uint, ref) <- expected) { - val encoded = uint64ex.encode(uint).require - assert(ref === encoded) - val decoded = uint64ex.decode(encoded).require.value - assert(uint === decoded) - } - } - - test("encode/decode with uint64L codec") { - val expected = Map( - 0L -> hex"00 00 00 00 00 00 00 00", - 42L -> hex"2a 00 00 00 00 00 00 00", - 6211610197754262546L -> hex"12 34 56 78 90 12 34 56" - ).mapValues(_.toBitVector) - - for ((long, ref) <- expected) { - val encoded = uint64L.encode(long).require - assert(ref === encoded) - val decoded = uint64L.decode(encoded).require.value - assert(long === decoded) - } - } - - test("encode/decode with varint codec") { - val expected = Map( - 0L -> hex"00", - 42L -> hex"2a", - 550L -> hex"fd 26 02", - 998000L -> hex"fe 70 3a 0f 00", - 6211610197754262546L -> hex"ff 12 34 56 78 90 12 34 56" - ).mapValues(_.toBitVector) - - for ((long, ref) <- expected) { - val encoded = varInt.encode(long).require - assert(ref === encoded) - val decoded = varInt.decode(encoded).require.value - assert(long === decoded) - } - } - - test("decode invalid varint") { - val testCases = Seq( - hex"fd", - hex"fe 01", - hex"fe", - hex"fe 12 34", - hex"ff", - hex"ff 12 34 56 78" - ).map(_.toBitVector) - - for (testCase <- testCases) { - assert(varInt.decode(testCase).isFailure) - } - } - - test("encode/decode with rgb codec") { - val color = Color(47.toByte, 255.toByte, 142.toByte) - val bin = rgb.encode(color).require - assert(bin === hex"2f ff 8e".toBitVector) - val color2 = rgb.decode(bin).require.value - assert(color === color2) - } - - test("encode/decode all kind of IPv6 addresses with ipv6address codec") { - { - // IPv4 mapped - val bin = hex"00000000000000000000ffffae8a0b08".toBitVector - val ipv6 = Inet6Address.getByAddress(null, bin.toByteArray, null) - val bin2 = ipv6address.encode(ipv6).require - assert(bin === bin2) - } - - { - // regular IPv6 address - val ipv6 = InetAddresses.forString("1080:0:0:0:8:800:200C:417A").asInstanceOf[Inet6Address] - val bin = ipv6address.encode(ipv6).require - val ipv62 = ipv6address.decode(bin).require.value - assert(ipv6 === ipv62) - } - } - - test("encode/decode with nodeaddress codec") { - { - val ipv4addr = InetAddress.getByAddress(Array[Byte](192.toByte, 168.toByte, 1.toByte, 42.toByte)).asInstanceOf[Inet4Address] - val nodeaddr = IPv4(ipv4addr, 4231) - val bin = nodeaddress.encode(nodeaddr).require - assert(bin === hex"01 C0 A8 01 2A 10 87".toBitVector) - val nodeaddr2 = nodeaddress.decode(bin).require.value - assert(nodeaddr === nodeaddr2) - } - { - val ipv6addr = InetAddress.getByAddress(hex"2001 0db8 0000 85a3 0000 0000 ac1f 8001".toArray).asInstanceOf[Inet6Address] - val nodeaddr = IPv6(ipv6addr, 4231) - val bin = nodeaddress.encode(nodeaddr).require - assert(bin === hex"02 2001 0db8 0000 85a3 0000 0000 ac1f 8001 1087".toBitVector) - val nodeaddr2 = nodeaddress.decode(bin).require.value - assert(nodeaddr === nodeaddr2) - } - { - val nodeaddr = Tor2("z4zif3fy7fe7bpg3", 4231) - val bin = nodeaddress.encode(nodeaddr).require - assert(bin === hex"03 cf3282ecb8f949f0bcdb 1087".toBitVector) - val nodeaddr2 = nodeaddress.decode(bin).require.value - assert(nodeaddr === nodeaddr2) - } - { - val nodeaddr = Tor3("mrl2d3ilhctt2vw4qzvmz3etzjvpnc6dczliq5chrxetthgbuczuggyd", 4231) - val bin = nodeaddress.encode(nodeaddr).require - assert(bin === hex"04 6457a1ed0b38a73d56dc866accec93ca6af68bc316568874478dc9399cc1a0b3431b03 1087".toBitVector) - val nodeaddr2 = nodeaddress.decode(bin).require.value - assert(nodeaddr === nodeaddr2) - } - } - - test("encode/decode with private key codec") { - val value = PrivateKey(randomBytes32) - val wire = LightningMessageCodecs.privateKey.encode(value).require - assert(wire.length == 256) - val value1 = LightningMessageCodecs.privateKey.decode(wire).require.value - assert(value1 == value) - } - - test("encode/decode with public key codec") { - val value = PrivateKey(randomBytes32).publicKey - val wire = LightningMessageCodecs.publicKey.encode(value).require - assert(wire.length == 33 * 8) - val value1 = LightningMessageCodecs.publicKey.decode(wire).require.value - assert(value1 == value) - } - - test("encode/decode with zeropaddedstring codec") { - val c = zeropaddedstring(32) - - { - val alias = "IRATEMONK" - val bin = c.encode(alias).require - assert(bin === BitVector(alias.getBytes("UTF-8") ++ Array.fill[Byte](32 - alias.length)(0))) - val alias2 = c.decode(bin).require.value - assert(alias === alias2) - } - - { - val alias = "this-alias-is-exactly-32-B-long." - val bin = c.encode(alias).require - assert(bin === BitVector(alias.getBytes("UTF-8") ++ Array.fill[Byte](32 - alias.length)(0))) - val alias2 = c.decode(bin).require.value - assert(alias === alias2) - } - - { - val alias = "this-alias-is-far-too-long-because-we-are-limited-to-32-bytes" - assert(c.encode(alias).isFailure) - } - } - - test("encode/decode UInt64") { - val codec = uint64ex - Seq( - UInt64(hex"ffffffffffffffff"), - UInt64(hex"fffffffffffffffe"), - UInt64(hex"efffffffffffffff"), - UInt64(hex"effffffffffffffe") - ).map(value => { - assert(codec.decode(codec.encode(value).require).require.value === value) - }) - } - test("encode/decode live node_announcements") { val anns = List( hex"a58338c9660d135fd7d087eb62afd24a33562c54507a9334e79f0dc4f17d407e6d7c61f0e2f3d0d38599502f61704cf1ae93608df027014ade7ff592f27ce26900005acdf50702d2eabbbacc7c25bbd73b39e65d28237705f7bde76f557e94fb41cb18a9ec00841122116c6e302e646563656e7465722e776f726c64000000000000000000000000000000130200000000000000000000ffffae8a0b082607" diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala index 9611d8456a..646c8d5a61 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala @@ -18,7 +18,7 @@ package fr.acinq.eclair.wire import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.ShortChannelId -import fr.acinq.eclair.wire.LightningMessageCodecs.{publicKey, shortchannelid, uint64, varInt} +import fr.acinq.eclair.wire.CommonCodecs.{publicKey, shortchannelid, uint64, varInt} import fr.acinq.eclair.wire.TlvCodecs._ import org.scalatest.FunSuite import scodec.bits.HexStringSyntax From 74942ec808d26c2066e84aef83ace5787cbe2226 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Fri, 21 Jun 2019 18:00:08 +0200 Subject: [PATCH 06/12] Add tlv stream codec. Add minimal encoding enforcement in varInt codec. --- .../fr/acinq/eclair/wire/CommonCodecs.scala | 35 +++--- .../fr/acinq/eclair/wire/TlvCodecs.scala | 33 +++++- .../scala/fr/acinq/eclair/wire/TlvTypes.scala | 39 ++++++- .../acinq/eclair/wire/CommonCodecsSpec.scala | 24 ++-- .../fr/acinq/eclair/wire/TlvCodecsSpec.scala | 103 +++++++++++++++--- 5 files changed, 195 insertions(+), 39 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala index de0aff19d8..73b25fe519 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala @@ -22,7 +22,7 @@ import fr.acinq.bitcoin.{ByteVector32, ByteVector64} import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.eclair.{ShortChannelId, UInt64} import org.apache.commons.codec.binary.Base32 -import scodec.{Attempt, Codec, Err} +import scodec.{Attempt, Codec, DecodeResult, Err} import scodec.bits.{BitVector, ByteVector} import scodec.codecs._ @@ -47,6 +47,21 @@ object CommonCodecs { val uint64ex: Codec[UInt64] = bytes(8).xmap(b => UInt64(b), a => a.toByteVector.padLeft(8)) + /** + * We impose a minimal encoding on varint values to ensure that signed hashes can be reproduced easily. + * If a value could be encoded with less bytes, it's considered invalid and results in a failed decoding attempt. + * + * @param min the minimal value that should be encoded. + * @param attempt the decoding attempt. + */ + def verifyMinimalEncoding(min: Long, attempt: Attempt[DecodeResult[Long]]): Attempt[DecodeResult[Long]] = { + attempt match { + case Attempt.Successful(res) if res.value < min => Attempt.Failure(scodec.Err("varint was not minimally encoded")) + case Attempt.Successful(res) => Attempt.Successful(res) + case Attempt.Failure(err) => Attempt.Failure(err) + } + } + // Bitcoin-style varint codec (CompactSize) val varInt = Codec[Long]( (n: Long) => @@ -71,20 +86,14 @@ object CommonCodecs { }, (buf: BitVector) => { uint8L.decode(buf) match { - case scodec.Attempt.Successful(b) => + case Attempt.Successful(b) => b.value match { - case 0xff => - uint64L.decode(b.remainder) - case 0xfe => - uint32L.decode(b.remainder) - case 0xfd => - uint16L.decode(b.remainder) - .map(b => b.map(_.toLong)) - case _ => - scodec.Attempt.Successful(scodec.DecodeResult(b.value.toLong, b.remainder)) + case 0xff => verifyMinimalEncoding(0x100000000L, uint64L.decode(b.remainder)) + case 0xfe => verifyMinimalEncoding(0x10000L, uint32L.decode(b.remainder)) + case 0xfd => verifyMinimalEncoding(0xfdL, uint16L.decode(b.remainder).map(b => b.map(_.toLong))) + case _ => Attempt.Successful(DecodeResult(b.value.toLong, b.remainder)) } - case scodec.Attempt.Failure(err) => - scodec.Attempt.Failure(err) + case Attempt.Failure(err) => Attempt.Failure(err) } }) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala index e557485726..4bfc80e33d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala @@ -17,8 +17,11 @@ package fr.acinq.eclair.wire import fr.acinq.eclair.wire.CommonCodecs._ +import scodec.{Attempt, Codec, DecodeResult, Decoder, Encoder} +import scodec.bits.BitVector import scodec.codecs._ -import scodec.Codec + +import scala.collection.compat._ /** * Created by t-bast on 20/06/2019. @@ -36,4 +39,32 @@ object TlvCodecs { case o => Right(o) }) + /** + * A tlv stream codec relies on an underlying tlv codec. + * This allows tlv streams to have different namespaces, increasing the total number of tlv types available. + * + * @param codec codec used for the tlv records contained in the stream. + */ + def tlvStream(codec: Codec[Tlv]): Codec[TlvStream] = { + Codec[TlvStream]( + (s: TlvStream) => { + val recordTypes = s.records.map(_.`type`) + if (recordTypes.length != recordTypes.distinct.length) { + Attempt.Failure(scodec.Err("duplicate tlv records aren't allowed")) + } else { + Encoder.encodeSeq(codec)(s.records.sortBy(_.`type`).toList) + } + }, + (buf: BitVector) => { + Decoder.decodeCollect[List, Tlv](codec, None)(buf).map(_.map(TlvStream(_))) match { + case Attempt.Failure(err) => Attempt.Failure(err) + case Attempt.Successful(res@DecodeResult(stream, _)) => stream.validate match { + case None => Attempt.Successful(res) + case Some(err) => Attempt.Failure(scodec.Err(err.message)) + } + } + } + ) + } + } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala index fa084fcd97..f3540a480d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala @@ -18,12 +18,16 @@ package fr.acinq.eclair.wire import scodec.bits.ByteVector +import scala.annotation.tailrec + /** * Created by t-bast on 20/06/2019. */ // @formatter:off -trait Tlv +trait Tlv { + val `type`: Long +} sealed trait OnionTlv extends Tlv // @formatter:on @@ -34,3 +38,36 @@ sealed trait OnionTlv extends Tlv * @param value tlv value (length is implicit, and encoded as a varint). */ case class GenericTlv(`type`: Long, value: ByteVector) extends Tlv + +/** + * A tlv stream is a collection of tlv records. + * A tlv stream is part of a given namespace that dictates how to parse the tlv records. + * That namespace is indicated by a trait extending the top-level tlv trait. + * + * @param records tlv records. + */ +case class TlvStream(records: Seq[Tlv]) { + + // @formatter:off + sealed trait Error { val message: String } + case object RecordsNotOrdered extends Error { override val message = "tlv records must be ordered by monotonically-increasing types" } + case object DuplicateRecords extends Error { override val message = "tlv streams must not contain duplicate records" } + case object UnknownEvenTlv extends Error { override val message = "tlv streams must not contain unknown even tlv types" } + // @formatter:on + + def validate: Option[Error] = { + @tailrec + def loop(previous: Long, next: Seq[Tlv]): Option[Error] = { + next.headOption match { + case Some(record) if record.`type` == previous => Some(DuplicateRecords) + case Some(record) if record.`type` < previous => Some(RecordsNotOrdered) + case Some(record) if (record.`type` % 2) == 0 && record.isInstanceOf[GenericTlv] => Some(UnknownEvenTlv) + case Some(record) => loop(record.`type`, next.tail) + case None => None + } + } + + loop(-1L, records) + } + +} \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala index 59db780424..bd280f352f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala @@ -65,6 +65,9 @@ class CommonCodecsSpec extends FunSuite { val expected = Map( 0L -> hex"00", 42L -> hex"2a", + 253L -> hex"fd fd 00", + 254L -> hex"fd fe 00", + 255L -> hex"fd ff 00", 550L -> hex"fd 26 02", 998000L -> hex"fe 70 3a 0f 00", 6211610197754262546L -> hex"ff 12 34 56 78 90 12 34 56" @@ -72,24 +75,27 @@ class CommonCodecsSpec extends FunSuite { for ((long, ref) <- expected) { val encoded = varInt.encode(long).require - assert(ref === encoded) + assert(ref === encoded, ref) val decoded = varInt.decode(encoded).require.value - assert(long === decoded) + assert(long === decoded, long) } } test("decode invalid varint") { val testCases = Seq( - hex"fd", - hex"fe 01", - hex"fe", - hex"fe 12 34", - hex"ff", - hex"ff 12 34 56 78" + hex"fd", // truncated + hex"fe 01", // truncated + hex"fe", // truncated + hex"fe 12 34", // truncated + hex"ff", // truncated + hex"ff 12 34 56 78", // truncated + hex"fd fc 00", // not minimally-encoded + hex"fe ff ff 00 00", // not minimally-encoded + hex"ff ff ff ff ff 00 00 00 00" // not minimally-encoded ).map(_.toBitVector) for (testCase <- testCases) { - assert(varInt.decode(testCase).isFailure) + assert(varInt.decode(testCase).isFailure, testCase.toByteVector) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala index 646c8d5a61..0209e7c70a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala @@ -35,10 +35,10 @@ class TlvCodecsSpec extends FunSuite { test("encode/decode tlv") { val testCases = Seq( - (hex"0x01 08 000000000000002a", TestType1(42)), - (hex"0x02 08 0000000000000226", TestType2(ShortChannelId(550))), - (hex"0x03 31 02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619 0000000000000231 0000000000000451", TestType3(PublicKey(hex"02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 561, 1105)), - (hex"0xff1234567890123456 fd0001 10101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010010101010101", GenericTlv(6211610197754262546L, hex"10101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010010101010101")) + (hex"01 08 000000000000002a", TestType1(42)), + (hex"02 08 0000000000000226", TestType2(ShortChannelId(550))), + (hex"03 31 02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619 0000000000000231 0000000000000451", TestType3(PublicKey(hex"02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 561, 1105)), + (hex"ff1234567890123456 fd0001 10101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010010101010101", GenericTlv(6211610197754262546L, hex"10101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010010101010101")) ) for ((bin, expected) <- testCases) { @@ -51,11 +51,15 @@ class TlvCodecsSpec extends FunSuite { test("decode invalid tlv") { val testCases = Seq( - hex"0xfd022a", // type truncated - hex"0x2a fd022a", // length truncated - hex"0x2a fd2602 0231", // value truncated - hex"0x02 01 2a", // short channel id too short - hex"0x02 09 010101010101010101" // short channel id length too big + hex"fd02", // type truncated + hex"fd022a", // truncated after type + hex"fd0100", // not minimally encoded type + hex"2a fd02", // length truncated + hex"2a fd0226", // truncated after length + hex"2a fe01010000", // not minimally encoded length + hex"2a fd2602 0231", // value truncated + hex"02 01 2a", // short channel id too short + hex"02 09 010101010101010101" // short channel id length too big ) for (testCase <- testCases) { @@ -63,23 +67,92 @@ class TlvCodecsSpec extends FunSuite { } } + test("decode invalid tlv stream") { + val testCases = Seq( + hex"0108000000000000002a 01", // valid tlv record followed by invalid tlv record (only type, length and value are missing) + hex"02080000000000000226 0108000000000000002a", // valid tlv records but invalid ordering + hex"02080000000000000231 02080000000000000451", // duplicate tlv type + hex"0108000000000000002a 2a0101", // unknown even type + hex"0a080000000000000231 0b0400000451" // valid tlv records but from different namespace + ) + + for (testCase <- testCases) { + assert(tlvStream(testTlvCodec).decode(testCase.toBitVector).isFailure, testCase) + } + } + + test("encode invalid tlv stream") { + val testCases = Seq( + TlvStream(Seq(TestType1(561), TestType2(ShortChannelId(1105)), OtherType1(42))), + TlvStream(Seq(TestType1(561), TestType1(1105))) + ) + + for (testCase <- testCases) { + assert(tlvStream(testTlvCodec).encode(testCase).isFailure, testCase) + } + } + + test("encode/decode tlv stream") { + val bin = hex"01080000000000000231 02080000000000000451 033102eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f28368661900000000000002310000000000000451" + val expected = Seq( + TestType1(561), + TestType2(ShortChannelId(1105)), + TestType3(PublicKey(hex"02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 561, 1105) + ) + + val decoded = tlvStream(testTlvCodec).decode(bin.toBitVector).require.value + assert(decoded === TlvStream(expected)) + + val encoded = tlvStream(testTlvCodec).encode(TlvStream(expected.reverse)).require.toByteVector + assert(encoded === bin) + } + + test("encode/decode tlv stream with unknown odd type") { + val bin = hex"01080000000000000231 0b0400000451 0d02002a" + val expected = Seq( + TestType1(561), + GenericTlv(11, hex"00000451"), + TestType13(42) + ) + + val decoded = tlvStream(testTlvCodec).decode(bin.toBitVector).require.value + assert(decoded === TlvStream(expected)) + + val encoded = tlvStream(testTlvCodec).encode(TlvStream(expected.reverse)).require.toByteVector + assert(encoded === bin) + } + } object TlvCodecsSpec { // @formatter:off sealed trait TestTlv extends Tlv - case class TestType1(longValue: Long) extends TestTlv - case class TestType2(shortChannelId: ShortChannelId) extends TestTlv - case class TestType3(nodeId: PublicKey, value1: Long, value2: Long) extends TestTlv + case class TestType1(longValue: Long) extends TestTlv { override val `type` = 1L } + case class TestType2(shortChannelId: ShortChannelId) extends TestTlv { override val `type` = 2L } + case class TestType3(nodeId: PublicKey, value1: Long, value2: Long) extends TestTlv { override val `type` = 3L } + case class TestType13(intValue: Int) extends TestTlv { override val `type` = 13L } - val testCodec1: Codec[TestType1] = (("length" | constant(hex"0x08")) :: ("value" | uint64)).as[TestType1] - val testCodec2: Codec[TestType2] = (("length" | constant(hex"0x08")) :: ("short_channel_id" | shortchannelid)).as[TestType2] - val testCodec3: Codec[TestType3] = (("length" | constant(hex"0x31")) :: ("node_id" | publicKey) :: ("value_1" | uint64) :: ("value_2" | uint64)).as[TestType3] + val testCodec1: Codec[TestType1] = (("length" | constant(hex"08")) :: ("value" | uint64)).as[TestType1] + val testCodec2: Codec[TestType2] = (("length" | constant(hex"08")) :: ("short_channel_id" | shortchannelid)).as[TestType2] + val testCodec3: Codec[TestType3] = (("length" | constant(hex"31")) :: ("node_id" | publicKey) :: ("value_1" | uint64) :: ("value_2" | uint64)).as[TestType3] + val testCodec13: Codec[TestType13] = (("length" | constant(hex"02")) :: ("value" | uint16)).as[TestType13] val testTlvCodec = tlvFallback(discriminated[Tlv].by(varInt) .typecase(1, testCodec1) .typecase(2, testCodec2) .typecase(3, testCodec3) + .typecase(13, testCodec13) + ) + + sealed trait OtherTlv extends Tlv + case class OtherType1(longValue: Long) extends OtherTlv { override val `type` = 10L } + case class OtherType2(lessLongValue: Long) extends OtherTlv { override val `type` = 11L } + + val otherCodec1: Codec[OtherType1] = (("length" | constant(hex"08")) :: ("value" | uint64)).as[OtherType1] + val otherCodec2: Codec[OtherType2] = (("length" | constant(hex"04")) :: ("value" | uint32)).as[OtherType2] + val otherTlvCodec = tlvFallback(discriminated[Tlv].by(varInt) + .typecase(10, otherCodec1) + .typecase(11, otherCodec2) ) // @formatter:on From beaf0b4ffdec24cc30604cb20fd0bb7f2b7e1249 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Fri, 21 Jun 2019 18:05:14 +0200 Subject: [PATCH 07/12] Replace custom attemptFromTry by scodec's built-in Attempt.fromTry() --- .../main/scala/fr/acinq/eclair/wire/CommonCodecs.scala | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala index 73b25fe519..b4449af303 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala @@ -26,7 +26,7 @@ import scodec.{Attempt, Codec, DecodeResult, Err} import scodec.bits.{BitVector, ByteVector} import scodec.codecs._ -import scala.util.{Failure, Success, Try} +import scala.util.Try /** * Created by t-bast on 20/06/2019. @@ -34,11 +34,6 @@ import scala.util.{Failure, Success, Try} object CommonCodecs { - def attemptFromTry[T](f: => T): Attempt[T] = Try(f) match { - case Success(t) => Attempt.successful(t) - case Failure(t) => Attempt.failure(Err(s"deserialization error: ${t.getMessage}")) - } - // this codec can be safely used for values < 2^63 and will fail otherwise // (for something smarter see https://github.com/yzernik/bitcoin-scodec/blob/master/src/main/scala/io/github/yzernik/bitcoinscodec/structures/UInt64.scala) val uint64: Codec[Long] = int64.narrow(l => if (l >= 0) Attempt.Successful(l) else Attempt.failure(Err(s"overflow for value $l")), l => l) @@ -109,7 +104,7 @@ object CommonCodecs { val ipv4address: Codec[Inet4Address] = bytes(4).xmap(b => InetAddress.getByAddress(b.toArray).asInstanceOf[Inet4Address], a => ByteVector(a.getAddress)) - val ipv6address: Codec[Inet6Address] = bytes(16).exmap(b => attemptFromTry(Inet6Address.getByAddress(null, b.toArray, null)), a => attemptFromTry(ByteVector(a.getAddress))) + val ipv6address: Codec[Inet6Address] = bytes(16).exmap(b => Attempt.fromTry(Try(Inet6Address.getByAddress(null, b.toArray, null))), a => Attempt.fromTry(Try(ByteVector(a.getAddress)))) def base32(size: Int): Codec[String] = bytes(size).xmap(b => new Base32().encodeAsString(b.toArray).toLowerCase, a => ByteVector(new Base32().decode(a.toUpperCase()))) From 35b09538f5a9163dfc67dd598ad61ff3831cf8f4 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Mon, 24 Jun 2019 13:52:42 +0200 Subject: [PATCH 08/12] Make varint (CompactSize) a Codec[UInt64] instead of Codec[Long] --- .../fr/acinq/eclair/wire/CommonCodecs.scala | 37 +++++---- .../fr/acinq/eclair/wire/TlvCodecs.scala | 2 +- .../scala/fr/acinq/eclair/wire/TlvTypes.scala | 17 ++-- .../acinq/eclair/wire/CommonCodecsSpec.scala | 79 ++++++++++++++----- .../fr/acinq/eclair/wire/TlvCodecsSpec.scala | 24 +++--- 5 files changed, 103 insertions(+), 56 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala index b4449af303..744f8b00ba 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala @@ -38,10 +38,10 @@ object CommonCodecs { // (for something smarter see https://github.com/yzernik/bitcoin-scodec/blob/master/src/main/scala/io/github/yzernik/bitcoinscodec/structures/UInt64.scala) val uint64: Codec[Long] = int64.narrow(l => if (l >= 0) Attempt.Successful(l) else Attempt.failure(Err(s"overflow for value $l")), l => l) - val uint64L: Codec[Long] = int64L.narrow(l => if (l >= 0) Attempt.Successful(l) else Attempt.failure(Err(s"overflow for value $l")), l => l) - val uint64ex: Codec[UInt64] = bytes(8).xmap(b => UInt64(b), a => a.toByteVector.padLeft(8)) + val uint64L: Codec[UInt64] = bytes(8).xmap(b => UInt64(b.reverse), a => a.toByteVector.padLeft(8).reverse) + /** * We impose a minimal encoding on varint values to ensure that signed hashes can be reproduced easily. * If a value could be encoded with less bytes, it's considered invalid and results in a failed decoding attempt. @@ -49,29 +49,30 @@ object CommonCodecs { * @param min the minimal value that should be encoded. * @param attempt the decoding attempt. */ - def verifyMinimalEncoding(min: Long, attempt: Attempt[DecodeResult[Long]]): Attempt[DecodeResult[Long]] = { + def verifyMinimalEncoding(min: Long, attempt: Attempt[DecodeResult[UInt64]]): Attempt[DecodeResult[UInt64]] = { attempt match { - case Attempt.Successful(res) if res.value < min => Attempt.Failure(scodec.Err("varint was not minimally encoded")) + case Attempt.Successful(res) if res.value < UInt64(min) => Attempt.Failure(scodec.Err("varint was not minimally encoded")) case Attempt.Successful(res) => Attempt.Successful(res) case Attempt.Failure(err) => Attempt.Failure(err) } } - // Bitcoin-style varint codec (CompactSize) - val varInt = Codec[Long]( - (n: Long) => + // Bitcoin-style varint codec (CompactSize). + // See https://bitcoin.org/en/developer-reference#compactsize-unsigned-integers for reference. + val varint = Codec[UInt64]( + (n: UInt64) => n match { - case i if i < 0xfd => - uint8L.encode(i.toInt) - case i if i < 0xffff => + case i if i < UInt64(0xfd) => + uint8L.encode(i.toBigInt.toInt) + case i if i < UInt64(0xffff) => for { a <- uint8L.encode(0xfd) - b <- uint16L.encode(i.toInt) + b <- uint16L.encode(i.toBigInt.toInt) } yield a ++ b - case i if i < 0xffffffffL => + case i if i < UInt64(0xffffffffL) => for { a <- uint8L.encode(0xfe) - b <- uint32L.encode(i) + b <- uint32L.encode(i.toBigInt.toLong) } yield a ++ b case i => for { @@ -84,14 +85,18 @@ object CommonCodecs { case Attempt.Successful(b) => b.value match { case 0xff => verifyMinimalEncoding(0x100000000L, uint64L.decode(b.remainder)) - case 0xfe => verifyMinimalEncoding(0x10000L, uint32L.decode(b.remainder)) - case 0xfd => verifyMinimalEncoding(0xfdL, uint16L.decode(b.remainder).map(b => b.map(_.toLong))) - case _ => Attempt.Successful(DecodeResult(b.value.toLong, b.remainder)) + case 0xfe => verifyMinimalEncoding(0x10000L, uint32L.decode(b.remainder).map(b => b.map(UInt64(_)))) + case 0xfd => verifyMinimalEncoding(0xfdL, uint16L.decode(b.remainder).map(b => b.map(UInt64(_)))) + case _ => Attempt.Successful(DecodeResult(UInt64(b.value), b.remainder)) } case Attempt.Failure(err) => Attempt.Failure(err) } }) + // This codec can be safely used for values < 2^63 and will fail otherwise. + // It is useful in combination with variableSizeBytesLong to encode/decode TLV lengths because those will always be < 2^63. + val varlong: Codec[Long] = varint.narrow(l => if (l <= UInt64(Long.MaxValue)) Attempt.successful(l.toBigInt.toLong) else Attempt.failure(Err(s"overflow for value $l")), l => UInt64(l)) + val bytes32: Codec[ByteVector32] = limitedSizeBytes(32, bytesStrict(32).xmap(d => ByteVector32(d), d => d.bytes)) val bytes64: Codec[ByteVector64] = limitedSizeBytes(64, bytesStrict(64).xmap(d => ByteVector64(d), d => d.bytes)) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala index 4bfc80e33d..9cfaff9743 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala @@ -29,7 +29,7 @@ import scala.collection.compat._ object TlvCodecs { - val genericTlv: Codec[GenericTlv] = (("type" | varInt) :: variableSizeBytesLong(varInt, bytes)).as[GenericTlv] + val genericTlv: Codec[GenericTlv] = (("type" | varint) :: variableSizeBytesLong(varlong, bytes)).as[GenericTlv] def tlvFallback(codec: Codec[Tlv]): Codec[Tlv] = discriminatorFallback(genericTlv, codec).xmap(_ match { case Left(l) => l diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala index f3540a480d..a3524acb06 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala @@ -16,6 +16,7 @@ package fr.acinq.eclair.wire +import fr.acinq.eclair.UInt64 import scodec.bits.ByteVector import scala.annotation.tailrec @@ -26,7 +27,7 @@ import scala.annotation.tailrec // @formatter:off trait Tlv { - val `type`: Long + val `type`: UInt64 } sealed trait OnionTlv extends Tlv // @formatter:on @@ -37,7 +38,7 @@ sealed trait OnionTlv extends Tlv * @param `type` tlv type. * @param value tlv value (length is implicit, and encoded as a varint). */ -case class GenericTlv(`type`: Long, value: ByteVector) extends Tlv +case class GenericTlv(`type`: UInt64, value: ByteVector) extends Tlv /** * A tlv stream is a collection of tlv records. @@ -57,17 +58,17 @@ case class TlvStream(records: Seq[Tlv]) { def validate: Option[Error] = { @tailrec - def loop(previous: Long, next: Seq[Tlv]): Option[Error] = { + def loop(previous: Option[UInt64], next: Seq[Tlv]): Option[Error] = { next.headOption match { - case Some(record) if record.`type` == previous => Some(DuplicateRecords) - case Some(record) if record.`type` < previous => Some(RecordsNotOrdered) - case Some(record) if (record.`type` % 2) == 0 && record.isInstanceOf[GenericTlv] => Some(UnknownEvenTlv) - case Some(record) => loop(record.`type`, next.tail) + case Some(record) if previous.contains(record.`type`) => Some(DuplicateRecords) + case Some(record) if previous.isDefined && record.`type` < previous.get => Some(RecordsNotOrdered) + case Some(record) if (record.`type`.toBigInt % 2) == 0 && record.isInstanceOf[GenericTlv] => Some(UnknownEvenTlv) + case Some(record) => loop(Some(record.`type`), next.tail) case None => None } } - loop(-1L, records) + loop(None, records) } } \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala index bd280f352f..e4c5afc703 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala @@ -35,7 +35,8 @@ class CommonCodecsSpec extends FunSuite { val expected = Map( UInt64(0) -> hex"00 00 00 00 00 00 00 00", UInt64(42) -> hex"00 00 00 00 00 00 00 2a", - UInt64(hex"ffffffffffffffff") -> hex"ff ff ff ff ff ff ff ff" + UInt64(6211610197754262546L) -> hex"56 34 12 90 78 56 34 12", + UInt64(hex"ff ff ff ff ff ff ff ff") -> hex"ff ff ff ff ff ff ff ff" ).mapValues(_.toBitVector) for ((uint, ref) <- expected) { @@ -48,36 +49,38 @@ class CommonCodecsSpec extends FunSuite { test("encode/decode with uint64L codec") { val expected = Map( - 0L -> hex"00 00 00 00 00 00 00 00", - 42L -> hex"2a 00 00 00 00 00 00 00", - 6211610197754262546L -> hex"12 34 56 78 90 12 34 56" + UInt64(0) -> hex"00 00 00 00 00 00 00 00", + UInt64(42) -> hex"2a 00 00 00 00 00 00 00", + UInt64(6211610197754262546L) -> hex"12 34 56 78 90 12 34 56", + UInt64(hex"ff ff ff ff ff ff ff ff") -> hex"ff ff ff ff ff ff ff ff" ).mapValues(_.toBitVector) - for ((long, ref) <- expected) { - val encoded = uint64L.encode(long).require + for ((uint, ref) <- expected) { + val encoded = uint64L.encode(uint).require assert(ref === encoded) val decoded = uint64L.decode(encoded).require.value - assert(long === decoded) + assert(uint === decoded) } } test("encode/decode with varint codec") { val expected = Map( - 0L -> hex"00", - 42L -> hex"2a", - 253L -> hex"fd fd 00", - 254L -> hex"fd fe 00", - 255L -> hex"fd ff 00", - 550L -> hex"fd 26 02", - 998000L -> hex"fe 70 3a 0f 00", - 6211610197754262546L -> hex"ff 12 34 56 78 90 12 34 56" + UInt64(0L) -> hex"00", + UInt64(42L) -> hex"2a", + UInt64(253L) -> hex"fd fd 00", + UInt64(254L) -> hex"fd fe 00", + UInt64(255L) -> hex"fd ff 00", + UInt64(550L) -> hex"fd 26 02", + UInt64(998000L) -> hex"fe 70 3a 0f 00", + UInt64(6211610197754262546L) -> hex"ff 12 34 56 78 90 12 34 56", + UInt64.MaxValue -> hex"ff ff ff ff ff ff ff ff ff" ).mapValues(_.toBitVector) - for ((long, ref) <- expected) { - val encoded = varInt.encode(long).require + for ((uint, ref) <- expected) { + val encoded = varint.encode(uint).require assert(ref === encoded, ref) - val decoded = varInt.decode(encoded).require.value - assert(long === decoded, long) + val decoded = varint.decode(encoded).require.value + assert(uint === decoded, uint) } } @@ -89,13 +92,49 @@ class CommonCodecsSpec extends FunSuite { hex"fe 12 34", // truncated hex"ff", // truncated hex"ff 12 34 56 78", // truncated + hex"fd 00 00", // not minimally-encoded hex"fd fc 00", // not minimally-encoded + hex"fe 00 00 00 00", // not minimally-encoded hex"fe ff ff 00 00", // not minimally-encoded + hex"ff 00 00 00 00 00 00 00 00", // not minimally-encoded + hex"ff ff ff ff 01 00 00 00 00", // not minimally-encoded hex"ff ff ff ff ff 00 00 00 00" // not minimally-encoded ).map(_.toBitVector) for (testCase <- testCases) { - assert(varInt.decode(testCase).isFailure, testCase.toByteVector) + assert(varint.decode(testCase).isFailure, testCase.toByteVector) + } + } + + test("encode/decode with varlong codec") { + val expected = Map( + 0L -> hex"00", + 42L -> hex"2a", + 253L -> hex"fd fd 00", + 254L -> hex"fd fe 00", + 255L -> hex"fd ff 00", + 550L -> hex"fd 26 02", + 998000L -> hex"fe 70 3a 0f 00", + 6211610197754262546L -> hex"ff 12 34 56 78 90 12 34 56", + Long.MaxValue -> hex"ff ff ff ff ff ff ff ff 7f" + ).mapValues(_.toBitVector) + + for ((long, ref) <- expected) { + val encoded = varlong.encode(long).require + assert(ref === encoded, ref) + val decoded = varlong.decode(encoded).require.value + assert(long === decoded, long) + } + } + + test("decode invalid varlong") { + val testCases = Seq( + hex"ff 00 00 00 00 00 00 00 80", + hex"ff ff ff ff ff ff ff ff ff" + ).map(_.toBitVector) + + for (testCase <- testCases) { + assert(varlong.decode(testCase).isFailure, testCase.toByteVector) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala index 0209e7c70a..557a172b12 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala @@ -17,8 +17,9 @@ package fr.acinq.eclair.wire import fr.acinq.bitcoin.Crypto.PublicKey -import fr.acinq.eclair.ShortChannelId -import fr.acinq.eclair.wire.CommonCodecs.{publicKey, shortchannelid, uint64, varInt} +import fr.acinq.eclair.{ShortChannelId, UInt64} +import fr.acinq.eclair.UInt64.Conversions._ +import fr.acinq.eclair.wire.CommonCodecs.{publicKey, shortchannelid, uint64, varint} import fr.acinq.eclair.wire.TlvCodecs._ import org.scalatest.FunSuite import scodec.bits.HexStringSyntax @@ -59,7 +60,8 @@ class TlvCodecsSpec extends FunSuite { hex"2a fe01010000", // not minimally encoded length hex"2a fd2602 0231", // value truncated hex"02 01 2a", // short channel id too short - hex"02 09 010101010101010101" // short channel id length too big + hex"02 09 010101010101010101", // short channel id length too big + hex"2a ff0000000000000080" // invalid length (too big to fit inside a long) ) for (testCase <- testCases) { @@ -128,16 +130,16 @@ object TlvCodecsSpec { // @formatter:off sealed trait TestTlv extends Tlv - case class TestType1(longValue: Long) extends TestTlv { override val `type` = 1L } - case class TestType2(shortChannelId: ShortChannelId) extends TestTlv { override val `type` = 2L } - case class TestType3(nodeId: PublicKey, value1: Long, value2: Long) extends TestTlv { override val `type` = 3L } - case class TestType13(intValue: Int) extends TestTlv { override val `type` = 13L } + case class TestType1(longValue: Long) extends TestTlv { override val `type` = UInt64(1) } + case class TestType2(shortChannelId: ShortChannelId) extends TestTlv { override val `type` = UInt64(2) } + case class TestType3(nodeId: PublicKey, value1: Long, value2: Long) extends TestTlv { override val `type` = UInt64(3) } + case class TestType13(intValue: Int) extends TestTlv { override val `type` = UInt64(13) } val testCodec1: Codec[TestType1] = (("length" | constant(hex"08")) :: ("value" | uint64)).as[TestType1] val testCodec2: Codec[TestType2] = (("length" | constant(hex"08")) :: ("short_channel_id" | shortchannelid)).as[TestType2] val testCodec3: Codec[TestType3] = (("length" | constant(hex"31")) :: ("node_id" | publicKey) :: ("value_1" | uint64) :: ("value_2" | uint64)).as[TestType3] val testCodec13: Codec[TestType13] = (("length" | constant(hex"02")) :: ("value" | uint16)).as[TestType13] - val testTlvCodec = tlvFallback(discriminated[Tlv].by(varInt) + val testTlvCodec = tlvFallback(discriminated[Tlv].by(varint) .typecase(1, testCodec1) .typecase(2, testCodec2) .typecase(3, testCodec3) @@ -145,12 +147,12 @@ object TlvCodecsSpec { ) sealed trait OtherTlv extends Tlv - case class OtherType1(longValue: Long) extends OtherTlv { override val `type` = 10L } - case class OtherType2(lessLongValue: Long) extends OtherTlv { override val `type` = 11L } + case class OtherType1(longValue: Long) extends OtherTlv { override val `type` = UInt64(10) } + case class OtherType2(lessLongValue: Long) extends OtherTlv { override val `type` = UInt64(11) } val otherCodec1: Codec[OtherType1] = (("length" | constant(hex"08")) :: ("value" | uint64)).as[OtherType1] val otherCodec2: Codec[OtherType2] = (("length" | constant(hex"04")) :: ("value" | uint32)).as[OtherType2] - val otherTlvCodec = tlvFallback(discriminated[Tlv].by(varInt) + val otherTlvCodec = tlvFallback(discriminated[Tlv].by(varint) .typecase(10, otherCodec1) .typecase(11, otherCodec2) ) From 451697c7b144584f3b4882397480140ef768abd0 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Fri, 28 Jun 2019 10:12:05 +0200 Subject: [PATCH 09/12] Move TlvStream errors to companion object. Rename uint64 codecs to harmonize. --- .../fr/acinq/eclair/wire/ChannelCodecs.scala | 34 +++++++-------- .../fr/acinq/eclair/wire/CommonCodecs.scala | 4 +- .../fr/acinq/eclair/wire/FailureMessage.scala | 10 ++--- .../eclair/wire/LightningMessageCodecs.scala | 42 +++++++++---------- .../scala/fr/acinq/eclair/wire/TlvTypes.scala | 22 ++++++---- .../acinq/eclair/wire/CommonCodecsSpec.scala | 6 +-- .../fr/acinq/eclair/wire/TlvCodecsSpec.scala | 8 ++-- 7 files changed, 66 insertions(+), 60 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/ChannelCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/ChannelCodecs.scala index 11ff20bf25..b5164b485e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/ChannelCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/ChannelCodecs.scala @@ -53,10 +53,10 @@ object ChannelCodecs extends Logging { val localParamsCodec: Codec[LocalParams] = ( ("nodeId" | publicKey) :: ("channelPath" | keyPathCodec) :: - ("dustLimitSatoshis" | uint64) :: - ("maxHtlcValueInFlightMsat" | uint64ex) :: - ("channelReserveSatoshis" | uint64) :: - ("htlcMinimumMsat" | uint64) :: + ("dustLimitSatoshis" | uint64overflow) :: + ("maxHtlcValueInFlightMsat" | uint64) :: + ("channelReserveSatoshis" | uint64overflow) :: + ("htlcMinimumMsat" | uint64overflow) :: ("toSelfDelay" | uint16) :: ("maxAcceptedHtlcs" | uint16) :: ("isFunder" | bool) :: @@ -66,10 +66,10 @@ object ChannelCodecs extends Logging { val remoteParamsCodec: Codec[RemoteParams] = ( ("nodeId" | publicKey) :: - ("dustLimitSatoshis" | uint64) :: - ("maxHtlcValueInFlightMsat" | uint64ex) :: - ("channelReserveSatoshis" | uint64) :: - ("htlcMinimumMsat" | uint64) :: + ("dustLimitSatoshis" | uint64overflow) :: + ("maxHtlcValueInFlightMsat" | uint64) :: + ("channelReserveSatoshis" | uint64overflow) :: + ("htlcMinimumMsat" | uint64overflow) :: ("toSelfDelay" | uint16) :: ("maxAcceptedHtlcs" | uint16) :: ("fundingPubKey" | publicKey) :: @@ -97,8 +97,8 @@ object ChannelCodecs extends Logging { val commitmentSpecCodec: Codec[CommitmentSpec] = ( ("htlcs" | setCodec(htlcCodec)) :: ("feeratePerKw" | uint32) :: - ("toLocalMsat" | uint64) :: - ("toRemoteMsat" | uint64)).as[CommitmentSpec] + ("toLocalMsat" | uint64overflow) :: + ("toRemoteMsat" | uint64overflow)).as[CommitmentSpec] val outPointCodec: Codec[OutPoint] = variableSizeBytes(uint16, bytes.xmap(d => OutPoint.read(d.toArray), d => OutPoint.write(d))) @@ -142,12 +142,12 @@ object ChannelCodecs extends Logging { ("htlcTxsAndSigs" | listOfN(uint16, htlcTxAndSigsCodec))).as[PublishableTxs] val localCommitCodec: Codec[LocalCommit] = ( - ("index" | uint64) :: + ("index" | uint64overflow) :: ("spec" | commitmentSpecCodec) :: ("publishableTxs" | publishableTxsCodec)).as[LocalCommit] val remoteCommitCodec: Codec[RemoteCommit] = ( - ("index" | uint64) :: + ("index" | uint64overflow) :: ("spec" | commitmentSpecCodec) :: ("txid" | bytes32) :: ("remotePerCommitmentPoint" | publicKey)).as[RemoteCommit] @@ -167,7 +167,7 @@ object ChannelCodecs extends Logging { val waitingForRevocationCodec: Codec[WaitingForRevocation] = ( ("nextRemoteCommit" | remoteCommitCodec) :: ("sent" | commitSigCodec) :: - ("sentAfterLocalCommitIndex" | uint64) :: + ("sentAfterLocalCommitIndex" | uint64overflow) :: ("reSignAsap" | bool)).as[WaitingForRevocation] val localCodec: Codec[Local] = ( @@ -178,8 +178,8 @@ object ChannelCodecs extends Logging { val relayedCodec: Codec[Relayed] = ( ("originChannelId" | bytes32) :: ("originHtlcId" | int64) :: - ("amountMsatIn" | uint64) :: - ("amountMsatOut" | uint64)).as[Relayed] + ("amountMsatIn" | uint64overflow) :: + ("amountMsatOut" | uint64overflow)).as[Relayed] // this is for backward compatibility to handle legacy payments that didn't have identifiers val UNKNOWN_UUID = UUID.fromString("00000000-0000-0000-0000-000000000000") @@ -211,8 +211,8 @@ object ChannelCodecs extends Logging { ("remoteCommit" | remoteCommitCodec) :: ("localChanges" | localChangesCodec) :: ("remoteChanges" | remoteChangesCodec) :: - ("localNextHtlcId" | uint64) :: - ("remoteNextHtlcId" | uint64) :: + ("localNextHtlcId" | uint64overflow) :: + ("remoteNextHtlcId" | uint64overflow) :: ("originChannels" | originsMapCodec) :: ("remoteNextCommitInfo" | either(bool, waitingForRevocationCodec, publicKey)) :: ("commitInput" | inputInfoCodec) :: diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala index 744f8b00ba..97cd3fa67a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala @@ -36,9 +36,9 @@ object CommonCodecs { // this codec can be safely used for values < 2^63 and will fail otherwise // (for something smarter see https://github.com/yzernik/bitcoin-scodec/blob/master/src/main/scala/io/github/yzernik/bitcoinscodec/structures/UInt64.scala) - val uint64: Codec[Long] = int64.narrow(l => if (l >= 0) Attempt.Successful(l) else Attempt.failure(Err(s"overflow for value $l")), l => l) + val uint64overflow: Codec[Long] = int64.narrow(l => if (l >= 0) Attempt.Successful(l) else Attempt.failure(Err(s"overflow for value $l")), l => l) - val uint64ex: Codec[UInt64] = bytes(8).xmap(b => UInt64(b), a => a.toByteVector.padLeft(8)) + val uint64: Codec[UInt64] = bytes(8).xmap(b => UInt64(b), a => a.toByteVector.padLeft(8)) val uint64L: Codec[UInt64] = bytes(8).xmap(b => UInt64(b.reverse), a => a.toByteVector.padLeft(8).reverse) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala index 5ff9b67188..fc1233283b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/FailureMessage.scala @@ -17,7 +17,7 @@ package fr.acinq.eclair.wire import fr.acinq.bitcoin.ByteVector32 -import fr.acinq.eclair.wire.CommonCodecs.{sha256, uint64} +import fr.acinq.eclair.wire.CommonCodecs.{sha256, uint64overflow} import fr.acinq.eclair.wire.LightningMessageCodecs.channelUpdateCodec import scodec.codecs._ import scodec.Attempt @@ -82,15 +82,15 @@ object FailureMessageCodecs { .typecase(PERM | 8, provide(PermanentChannelFailure)) .typecase(PERM | 9, provide(RequiredChannelFeatureMissing)) .typecase(PERM | 10, provide(UnknownNextPeer)) - .typecase(UPDATE | 11, (("amountMsat" | uint64) :: ("channelUpdate" | channelUpdateWithLengthCodec)).as[AmountBelowMinimum]) - .typecase(UPDATE | 12, (("amountMsat" | uint64) :: ("channelUpdate" | channelUpdateWithLengthCodec)).as[FeeInsufficient]) + .typecase(UPDATE | 11, (("amountMsat" | uint64overflow) :: ("channelUpdate" | channelUpdateWithLengthCodec)).as[AmountBelowMinimum]) + .typecase(UPDATE | 12, (("amountMsat" | uint64overflow) :: ("channelUpdate" | channelUpdateWithLengthCodec)).as[FeeInsufficient]) .typecase(UPDATE | 13, (("expiry" | uint32) :: ("channelUpdate" | channelUpdateWithLengthCodec)).as[IncorrectCltvExpiry]) .typecase(UPDATE | 14, ("channelUpdate" | channelUpdateWithLengthCodec).as[ExpiryTooSoon]) .typecase(UPDATE | 20, (("messageFlags" | byte) :: ("channelFlags" | byte) :: ("channelUpdate" | channelUpdateWithLengthCodec)).as[ChannelDisabled]) - .typecase(PERM | 15, ("amountMsat" | withDefaultValue(optional(bitsRemaining, uint64), 0L)).as[IncorrectOrUnknownPaymentDetails]) + .typecase(PERM | 15, ("amountMsat" | withDefaultValue(optional(bitsRemaining, uint64overflow), 0L)).as[IncorrectOrUnknownPaymentDetails]) .typecase(PERM | 16, provide(IncorrectPaymentAmount)) .typecase(17, provide(FinalExpiryTooSoon)) .typecase(18, ("expiry" | uint32).as[FinalIncorrectCltvExpiry]) - .typecase(19, ("amountMsat" | uint64).as[FinalIncorrectHtlcAmount]) + .typecase(19, ("amountMsat" | uint64overflow).as[FinalIncorrectHtlcAmount]) .typecase(21, provide(ExpiryTooFar)) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala index 9aa43b1e75..468573788a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala @@ -45,20 +45,20 @@ object LightningMessageCodecs { val channelReestablishCodec: Codec[ChannelReestablish] = ( ("channelId" | bytes32) :: - ("nextLocalCommitmentNumber" | uint64) :: - ("nextRemoteRevocationNumber" | uint64) :: + ("nextLocalCommitmentNumber" | uint64overflow) :: + ("nextRemoteRevocationNumber" | uint64overflow) :: ("yourLastPerCommitmentSecret" | optional(bitsRemaining, privateKey)) :: ("myCurrentPerCommitmentPoint" | optional(bitsRemaining, publicKey))).as[ChannelReestablish] val openChannelCodec: Codec[OpenChannel] = ( ("chainHash" | bytes32) :: ("temporaryChannelId" | bytes32) :: - ("fundingSatoshis" | uint64) :: - ("pushMsat" | uint64) :: - ("dustLimitSatoshis" | uint64) :: - ("maxHtlcValueInFlightMsat" | uint64ex) :: - ("channelReserveSatoshis" | uint64) :: - ("htlcMinimumMsat" | uint64) :: + ("fundingSatoshis" | uint64overflow) :: + ("pushMsat" | uint64overflow) :: + ("dustLimitSatoshis" | uint64overflow) :: + ("maxHtlcValueInFlightMsat" | uint64) :: + ("channelReserveSatoshis" | uint64overflow) :: + ("htlcMinimumMsat" | uint64overflow) :: ("feeratePerKw" | uint32) :: ("toSelfDelay" | uint16) :: ("maxAcceptedHtlcs" | uint16) :: @@ -72,10 +72,10 @@ object LightningMessageCodecs { val acceptChannelCodec: Codec[AcceptChannel] = ( ("temporaryChannelId" | bytes32) :: - ("dustLimitSatoshis" | uint64) :: - ("maxHtlcValueInFlightMsat" | uint64ex) :: - ("channelReserveSatoshis" | uint64) :: - ("htlcMinimumMsat" | uint64) :: + ("dustLimitSatoshis" | uint64overflow) :: + ("maxHtlcValueInFlightMsat" | uint64) :: + ("channelReserveSatoshis" | uint64overflow) :: + ("htlcMinimumMsat" | uint64overflow) :: ("minimumDepth" | uint32) :: ("toSelfDelay" | uint16) :: ("maxAcceptedHtlcs" | uint16) :: @@ -106,30 +106,30 @@ object LightningMessageCodecs { val closingSignedCodec: Codec[ClosingSigned] = ( ("channelId" | bytes32) :: - ("feeSatoshis" | uint64) :: + ("feeSatoshis" | uint64overflow) :: ("signature" | bytes64)).as[ClosingSigned] val updateAddHtlcCodec: Codec[UpdateAddHtlc] = ( ("channelId" | bytes32) :: - ("id" | uint64) :: - ("amountMsat" | uint64) :: + ("id" | uint64overflow) :: + ("amountMsat" | uint64overflow) :: ("paymentHash" | bytes32) :: ("expiry" | uint32) :: ("onionRoutingPacket" | bytes(Sphinx.PacketLength))).as[UpdateAddHtlc] val updateFulfillHtlcCodec: Codec[UpdateFulfillHtlc] = ( ("channelId" | bytes32) :: - ("id" | uint64) :: + ("id" | uint64overflow) :: ("paymentPreimage" | bytes32)).as[UpdateFulfillHtlc] val updateFailHtlcCodec: Codec[UpdateFailHtlc] = ( ("channelId" | bytes32) :: - ("id" | uint64) :: + ("id" | uint64overflow) :: ("reason" | varsizebinarydata)).as[UpdateFailHtlc] val updateFailMalformedHtlcCodec: Codec[UpdateFailMalformedHtlc] = ( ("channelId" | bytes32) :: - ("id" | uint64) :: + ("id" | uint64overflow) :: ("onionHash" | bytes32) :: ("failureCode" | uint16)).as[UpdateFailMalformedHtlc] @@ -187,10 +187,10 @@ object LightningMessageCodecs { (("messageFlags" | byte) >>:~ { messageFlags => ("channelFlags" | byte) :: ("cltvExpiryDelta" | uint16) :: - ("htlcMinimumMsat" | uint64) :: + ("htlcMinimumMsat" | uint64overflow) :: ("feeBaseMsat" | uint32) :: ("feeProportionalMillionths" | uint32) :: - ("htlcMaximumMsat" | conditional((messageFlags & 1) != 0, uint64)) + ("htlcMaximumMsat" | conditional((messageFlags & 1) != 0, uint64overflow)) }) val channelUpdateCodec: Codec[ChannelUpdate] = ( @@ -260,7 +260,7 @@ object LightningMessageCodecs { val perHopPayloadCodec: Codec[PerHopPayload] = ( ("realm" | constant(ByteVector.fromByte(0))) :: ("short_channel_id" | shortchannelid) :: - ("amt_to_forward" | uint64) :: + ("amt_to_forward" | uint64overflow) :: ("outgoing_cltv_value" | uint32) :: ("unused_with_v0_version_on_header" | ignore(8 * 12))).as[PerHopPayload] diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala index a3524acb06..939e9a4cfe 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala @@ -49,16 +49,11 @@ case class GenericTlv(`type`: UInt64, value: ByteVector) extends Tlv */ case class TlvStream(records: Seq[Tlv]) { - // @formatter:off - sealed trait Error { val message: String } - case object RecordsNotOrdered extends Error { override val message = "tlv records must be ordered by monotonically-increasing types" } - case object DuplicateRecords extends Error { override val message = "tlv streams must not contain duplicate records" } - case object UnknownEvenTlv extends Error { override val message = "tlv streams must not contain unknown even tlv types" } - // @formatter:on + import TlvStream._ - def validate: Option[Error] = { + def validate: Option[TlvStream.Error] = { @tailrec - def loop(previous: Option[UInt64], next: Seq[Tlv]): Option[Error] = { + def loop(previous: Option[UInt64], next: Seq[Tlv]): Option[TlvStream.Error] = { next.headOption match { case Some(record) if previous.contains(record.`type`) => Some(DuplicateRecords) case Some(record) if previous.isDefined && record.`type` < previous.get => Some(RecordsNotOrdered) @@ -71,4 +66,15 @@ case class TlvStream(records: Seq[Tlv]) { loop(None, records) } +} + +object TlvStream { + + // @formatter:off + sealed trait Error { val message: String } + case object RecordsNotOrdered extends Error { override val message = "tlv records must be ordered by monotonically-increasing types" } + case object DuplicateRecords extends Error { override val message = "tlv streams must not contain duplicate records" } + case object UnknownEvenTlv extends Error { override val message = "tlv streams must not contain unknown even tlv types" } + // @formatter:on + } \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala index e4c5afc703..c2a618d3fc 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala @@ -40,9 +40,9 @@ class CommonCodecsSpec extends FunSuite { ).mapValues(_.toBitVector) for ((uint, ref) <- expected) { - val encoded = uint64ex.encode(uint).require + val encoded = uint64.encode(uint).require assert(ref === encoded) - val decoded = uint64ex.decode(encoded).require.value + val decoded = uint64.decode(encoded).require.value assert(uint === decoded) } } @@ -239,7 +239,7 @@ class CommonCodecsSpec extends FunSuite { } test("encode/decode UInt64") { - val codec = uint64ex + val codec = uint64 Seq( UInt64(hex"ffffffffffffffff"), UInt64(hex"fffffffffffffffe"), diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala index 557a172b12..906811a90c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala @@ -130,9 +130,9 @@ object TlvCodecsSpec { // @formatter:off sealed trait TestTlv extends Tlv - case class TestType1(longValue: Long) extends TestTlv { override val `type` = UInt64(1) } + case class TestType1(uintValue: UInt64) extends TestTlv { override val `type` = UInt64(1) } case class TestType2(shortChannelId: ShortChannelId) extends TestTlv { override val `type` = UInt64(2) } - case class TestType3(nodeId: PublicKey, value1: Long, value2: Long) extends TestTlv { override val `type` = UInt64(3) } + case class TestType3(nodeId: PublicKey, value1: UInt64, value2: UInt64) extends TestTlv { override val `type` = UInt64(3) } case class TestType13(intValue: Int) extends TestTlv { override val `type` = UInt64(13) } val testCodec1: Codec[TestType1] = (("length" | constant(hex"08")) :: ("value" | uint64)).as[TestType1] @@ -147,8 +147,8 @@ object TlvCodecsSpec { ) sealed trait OtherTlv extends Tlv - case class OtherType1(longValue: Long) extends OtherTlv { override val `type` = UInt64(10) } - case class OtherType2(lessLongValue: Long) extends OtherTlv { override val `type` = UInt64(11) } + case class OtherType1(uintValue: UInt64) extends OtherTlv { override val `type` = UInt64(10) } + case class OtherType2(smallValue: Long) extends OtherTlv { override val `type` = UInt64(11) } val otherCodec1: Codec[OtherType1] = (("length" | constant(hex"08")) :: ("value" | uint64)).as[OtherType1] val otherCodec2: Codec[OtherType2] = (("length" | constant(hex"04")) :: ("value" | uint32)).as[OtherType2] From f87f92449d80842d6ec9c26a2de7a1ec3988fc7c Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Fri, 28 Jun 2019 17:00:14 +0200 Subject: [PATCH 10/12] Rename varlong to varintoverflow for consistency --- .../src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala | 2 +- .../src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala | 2 +- .../test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala index 97cd3fa67a..9239df7b94 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala @@ -95,7 +95,7 @@ object CommonCodecs { // This codec can be safely used for values < 2^63 and will fail otherwise. // It is useful in combination with variableSizeBytesLong to encode/decode TLV lengths because those will always be < 2^63. - val varlong: Codec[Long] = varint.narrow(l => if (l <= UInt64(Long.MaxValue)) Attempt.successful(l.toBigInt.toLong) else Attempt.failure(Err(s"overflow for value $l")), l => UInt64(l)) + val varintoverflow: Codec[Long] = varint.narrow(l => if (l <= UInt64(Long.MaxValue)) Attempt.successful(l.toBigInt.toLong) else Attempt.failure(Err(s"overflow for value $l")), l => UInt64(l)) val bytes32: Codec[ByteVector32] = limitedSizeBytes(32, bytesStrict(32).xmap(d => ByteVector32(d), d => d.bytes)) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala index 9cfaff9743..ab5fca042d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala @@ -29,7 +29,7 @@ import scala.collection.compat._ object TlvCodecs { - val genericTlv: Codec[GenericTlv] = (("type" | varint) :: variableSizeBytesLong(varlong, bytes)).as[GenericTlv] + val genericTlv: Codec[GenericTlv] = (("type" | varint) :: variableSizeBytesLong(varintoverflow, bytes)).as[GenericTlv] def tlvFallback(codec: Codec[Tlv]): Codec[Tlv] = discriminatorFallback(genericTlv, codec).xmap(_ match { case Left(l) => l diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala index c2a618d3fc..f17bc45786 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/CommonCodecsSpec.scala @@ -120,9 +120,9 @@ class CommonCodecsSpec extends FunSuite { ).mapValues(_.toBitVector) for ((long, ref) <- expected) { - val encoded = varlong.encode(long).require + val encoded = varintoverflow.encode(long).require assert(ref === encoded, ref) - val decoded = varlong.decode(encoded).require.value + val decoded = varintoverflow.decode(encoded).require.value assert(long === decoded, long) } } @@ -134,7 +134,7 @@ class CommonCodecsSpec extends FunSuite { ).map(_.toBitVector) for (testCase <- testCases) { - assert(varlong.decode(testCase).isFailure, testCase.toByteVector) + assert(varintoverflow.decode(testCase).isFailure, testCase.toByteVector) } } From 8229a0bb526e00af09d5e8cee844a3e7de6eef95 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Mon, 1 Jul 2019 10:03:00 +0200 Subject: [PATCH 11/12] Clean up Tlv stream. Validation during case class `apply`. Simplify codec using the built-in `list`. --- .../fr/acinq/eclair/wire/TlvCodecs.scala | 30 ++++------------ .../scala/fr/acinq/eclair/wire/TlvTypes.scala | 35 +++++-------------- .../fr/acinq/eclair/wire/TlvCodecsSpec.scala | 20 +++++------ 3 files changed, 23 insertions(+), 62 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala index ab5fca042d..b1651251d9 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala @@ -17,11 +17,10 @@ package fr.acinq.eclair.wire import fr.acinq.eclair.wire.CommonCodecs._ -import scodec.{Attempt, Codec, DecodeResult, Decoder, Encoder} -import scodec.bits.BitVector +import scodec.{Attempt, Codec} import scodec.codecs._ -import scala.collection.compat._ +import scala.util.Try /** * Created by t-bast on 20/06/2019. @@ -45,26 +44,9 @@ object TlvCodecs { * * @param codec codec used for the tlv records contained in the stream. */ - def tlvStream(codec: Codec[Tlv]): Codec[TlvStream] = { - Codec[TlvStream]( - (s: TlvStream) => { - val recordTypes = s.records.map(_.`type`) - if (recordTypes.length != recordTypes.distinct.length) { - Attempt.Failure(scodec.Err("duplicate tlv records aren't allowed")) - } else { - Encoder.encodeSeq(codec)(s.records.sortBy(_.`type`).toList) - } - }, - (buf: BitVector) => { - Decoder.decodeCollect[List, Tlv](codec, None)(buf).map(_.map(TlvStream(_))) match { - case Attempt.Failure(err) => Attempt.Failure(err) - case Attempt.Successful(res@DecodeResult(stream, _)) => stream.validate match { - case None => Attempt.Successful(res) - case Some(err) => Attempt.Failure(scodec.Err(err.message)) - } - } - } - ) - } + def tlvStream(codec: Codec[Tlv]): Codec[TlvStream] = list(codec).exmap( + records => Attempt.fromTry(Try(TlvStream(records))), + stream => Attempt.successful(stream.records.toList) + ) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala index 939e9a4cfe..84d38a090f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala @@ -49,32 +49,15 @@ case class GenericTlv(`type`: UInt64, value: ByteVector) extends Tlv */ case class TlvStream(records: Seq[Tlv]) { - import TlvStream._ - - def validate: Option[TlvStream.Error] = { - @tailrec - def loop(previous: Option[UInt64], next: Seq[Tlv]): Option[TlvStream.Error] = { - next.headOption match { - case Some(record) if previous.contains(record.`type`) => Some(DuplicateRecords) - case Some(record) if previous.isDefined && record.`type` < previous.get => Some(RecordsNotOrdered) - case Some(record) if (record.`type`.toBigInt % 2) == 0 && record.isInstanceOf[GenericTlv] => Some(UnknownEvenTlv) - case Some(record) => loop(Some(record.`type`), next.tail) - case None => None - } - } - - loop(None, records) + records.foldLeft(Option.empty[Tlv]) { + case (None, record) => + require(!record.isInstanceOf[GenericTlv] || record.`type`.toBigInt % 2 != 0, "tlv streams must not contain unknown even tlv types") + Some(record) + case (Some(previousRecord), record) => + require(record.`type` != previousRecord.`type`, "tlv streams must not contain duplicate records") + require(record.`type` > previousRecord.`type`, "tlv records must be ordered by monotonically-increasing types") + require(!record.isInstanceOf[GenericTlv] || record.`type`.toBigInt % 2 != 0, "tlv streams must not contain unknown even tlv types") + Some(record) } -} - -object TlvStream { - - // @formatter:off - sealed trait Error { val message: String } - case object RecordsNotOrdered extends Error { override val message = "tlv records must be ordered by monotonically-increasing types" } - case object DuplicateRecords extends Error { override val message = "tlv streams must not contain duplicate records" } - case object UnknownEvenTlv extends Error { override val message = "tlv streams must not contain unknown even tlv types" } - // @formatter:on - } \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala index 906811a90c..a4c4541b8e 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala @@ -71,7 +71,7 @@ class TlvCodecsSpec extends FunSuite { test("decode invalid tlv stream") { val testCases = Seq( - hex"0108000000000000002a 01", // valid tlv record followed by invalid tlv record (only type, length and value are missing) + hex"0108000000000000002a 02", // valid tlv record followed by invalid tlv record (only type, length and value are missing) hex"02080000000000000226 0108000000000000002a", // valid tlv records but invalid ordering hex"02080000000000000231 02080000000000000451", // duplicate tlv type hex"0108000000000000002a 2a0101", // unknown even type @@ -83,15 +83,11 @@ class TlvCodecsSpec extends FunSuite { } } - test("encode invalid tlv stream") { - val testCases = Seq( - TlvStream(Seq(TestType1(561), TestType2(ShortChannelId(1105)), OtherType1(42))), - TlvStream(Seq(TestType1(561), TestType1(1105))) - ) - - for (testCase <- testCases) { - assert(tlvStream(testTlvCodec).encode(testCase).isFailure, testCase) - } + test("create invalid tlv stream") { + assertThrows[IllegalArgumentException](TlvStream(Seq(GenericTlv(42, hex"2a")))) // unknown even type + assertThrows[IllegalArgumentException](TlvStream(Seq(TestType1(561), TestType2(ShortChannelId(1105)), GenericTlv(42, hex"2a")))) // unknown even type + assertThrows[IllegalArgumentException](TlvStream(Seq(TestType1(561), TestType1(1105)))) // duplicate type + assertThrows[IllegalArgumentException](TlvStream(Seq(TestType2(ShortChannelId(1105)), TestType1(561)))) // invalid ordering } test("encode/decode tlv stream") { @@ -105,7 +101,7 @@ class TlvCodecsSpec extends FunSuite { val decoded = tlvStream(testTlvCodec).decode(bin.toBitVector).require.value assert(decoded === TlvStream(expected)) - val encoded = tlvStream(testTlvCodec).encode(TlvStream(expected.reverse)).require.toByteVector + val encoded = tlvStream(testTlvCodec).encode(TlvStream(expected)).require.toByteVector assert(encoded === bin) } @@ -120,7 +116,7 @@ class TlvCodecsSpec extends FunSuite { val decoded = tlvStream(testTlvCodec).decode(bin.toBitVector).require.value assert(decoded === TlvStream(expected)) - val encoded = tlvStream(testTlvCodec).encode(TlvStream(expected.reverse)).require.toByteVector + val encoded = tlvStream(testTlvCodec).encode(TlvStream(expected)).require.toByteVector assert(encoded === bin) } From d63388b157034f0c1843722f3aa86ef47d88be02 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier Date: Mon, 1 Jul 2019 14:15:58 +0200 Subject: [PATCH 12/12] Simplify varint codec (credits to pm47) --- .../fr/acinq/eclair/wire/CommonCodecs.scala | 70 +++++++------------ .../fr/acinq/eclair/wire/TlvCodecs.scala | 4 +- 2 files changed, 29 insertions(+), 45 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala index 9239df7b94..143a516667 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala @@ -22,7 +22,7 @@ import fr.acinq.bitcoin.{ByteVector32, ByteVector64} import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.eclair.{ShortChannelId, UInt64} import org.apache.commons.codec.binary.Base32 -import scodec.{Attempt, Codec, DecodeResult, Err} +import scodec.{Attempt, Codec, DecodeResult, Err, SizeBound} import scodec.bits.{BitVector, ByteVector} import scodec.codecs._ @@ -34,6 +34,19 @@ import scala.util.Try object CommonCodecs { + /** + * Discriminator codec with a default fallback codec (of the same type). + */ + def discriminatorWithDefault[A](discriminator: Codec[A], fallback: Codec[A]): Codec[A] = new Codec[A] { + def sizeBound: SizeBound = discriminator.sizeBound | fallback.sizeBound + + def encode(e: A): Attempt[BitVector] = discriminator.encode(e).recoverWith { case _ => fallback.encode(e) } + + def decode(b: BitVector): Attempt[DecodeResult[A]] = discriminator.decode(b).recoverWith { + case _: KnownDiscriminatorType[_]#UnknownDiscriminator => fallback.decode(b) + } + } + // this codec can be safely used for values < 2^63 and will fail otherwise // (for something smarter see https://github.com/yzernik/bitcoin-scodec/blob/master/src/main/scala/io/github/yzernik/bitcoinscodec/structures/UInt64.scala) val uint64overflow: Codec[Long] = int64.narrow(l => if (l >= 0) Attempt.Successful(l) else Attempt.failure(Err(s"overflow for value $l")), l => l) @@ -46,52 +59,23 @@ object CommonCodecs { * We impose a minimal encoding on varint values to ensure that signed hashes can be reproduced easily. * If a value could be encoded with less bytes, it's considered invalid and results in a failed decoding attempt. * - * @param min the minimal value that should be encoded. - * @param attempt the decoding attempt. + * @param codec the integer codec (depends on the value). + * @param min the minimal value that should be encoded. */ - def verifyMinimalEncoding(min: Long, attempt: Attempt[DecodeResult[UInt64]]): Attempt[DecodeResult[UInt64]] = { - attempt match { - case Attempt.Successful(res) if res.value < UInt64(min) => Attempt.Failure(scodec.Err("varint was not minimally encoded")) - case Attempt.Successful(res) => Attempt.Successful(res) - case Attempt.Failure(err) => Attempt.Failure(err) - } - } + def uint64min(codec: Codec[UInt64], min: UInt64): Codec[UInt64] = codec.exmap({ + case i if i < min => Attempt.failure(Err("varint was not minimally encoded")) + case i => Attempt.successful(i) + }, Attempt.successful) // Bitcoin-style varint codec (CompactSize). // See https://bitcoin.org/en/developer-reference#compactsize-unsigned-integers for reference. - val varint = Codec[UInt64]( - (n: UInt64) => - n match { - case i if i < UInt64(0xfd) => - uint8L.encode(i.toBigInt.toInt) - case i if i < UInt64(0xffff) => - for { - a <- uint8L.encode(0xfd) - b <- uint16L.encode(i.toBigInt.toInt) - } yield a ++ b - case i if i < UInt64(0xffffffffL) => - for { - a <- uint8L.encode(0xfe) - b <- uint32L.encode(i.toBigInt.toLong) - } yield a ++ b - case i => - for { - a <- uint8L.encode(0xff) - b <- uint64L.encode(i) - } yield a ++ b - }, - (buf: BitVector) => { - uint8L.decode(buf) match { - case Attempt.Successful(b) => - b.value match { - case 0xff => verifyMinimalEncoding(0x100000000L, uint64L.decode(b.remainder)) - case 0xfe => verifyMinimalEncoding(0x10000L, uint32L.decode(b.remainder).map(b => b.map(UInt64(_)))) - case 0xfd => verifyMinimalEncoding(0xfdL, uint16L.decode(b.remainder).map(b => b.map(UInt64(_)))) - case _ => Attempt.Successful(DecodeResult(UInt64(b.value), b.remainder)) - } - case Attempt.Failure(err) => Attempt.Failure(err) - } - }) + val varint: Codec[UInt64] = discriminatorWithDefault( + discriminated[UInt64].by(uint8L) + .\(0xff) { case i if i >= UInt64(0x100000000L) => i }(uint64min(uint64L, UInt64(0x100000000L))) + .\(0xfe) { case i if i >= UInt64(0x10000) => i }(uint64min(uint32L.xmap(UInt64(_), _.toBigInt.toLong), UInt64(0x10000))) + .\(0xfd) { case i if i >= UInt64(0xfd) => i }(uint64min(uint16L.xmap(UInt64(_), _.toBigInt.toInt), UInt64(0xfd))), + uint8L.xmap(UInt64(_), _.toBigInt.toInt) + ) // This codec can be safely used for values < 2^63 and will fail otherwise. // It is useful in combination with variableSizeBytesLong to encode/decode TLV lengths because those will always be < 2^63. diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala index b1651251d9..986f7c1d3a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvCodecs.scala @@ -30,10 +30,10 @@ object TlvCodecs { val genericTlv: Codec[GenericTlv] = (("type" | varint) :: variableSizeBytesLong(varintoverflow, bytes)).as[GenericTlv] - def tlvFallback(codec: Codec[Tlv]): Codec[Tlv] = discriminatorFallback(genericTlv, codec).xmap(_ match { + def tlvFallback(codec: Codec[Tlv]): Codec[Tlv] = discriminatorFallback(genericTlv, codec).xmap({ case Left(l) => l case Right(r) => r - }, _ match { + }, { case g: GenericTlv => Left(g) case o => Right(o) })