From 3afcb1f22463e38497667707a4d54da83a84a122 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Wed, 7 Aug 2019 15:03:05 -0700 Subject: [PATCH 1/9] tlv/varint: add modified bitcoin varint This varint has the same serialization as the varint in btcd and bitcoind, but has different behavior wrt returned errors. In order to ensure the inner loop properly detects cleanly written records, ReadVarInt will not only return EOF if it can't read the first byte, as that means the reader has zero bytes left. It also modifies the API to allow the caller to provided a static byte array, which can be reused across all encoding and decoding and increases performance. --- tlv/varint.go | 109 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 tlv/varint.go diff --git a/tlv/varint.go b/tlv/varint.go new file mode 100644 index 00000000000..3888bfcb47e --- /dev/null +++ b/tlv/varint.go @@ -0,0 +1,109 @@ +package tlv + +import ( + "encoding/binary" + "errors" + "io" +) + +// ErrVarIntNotCanonical signals that the decoded varint was not minimally encoded. +var ErrVarIntNotCanonical = errors.New("decoded varint is not canonical") + +// ReadVarInt reads a variable length integer from r and returns it as a uint64. +func ReadVarInt(r io.Reader, buf *[8]byte) (uint64, error) { + _, err := io.ReadFull(r, buf[:1]) + if err != nil { + return 0, err + } + discriminant := buf[0] + + var rv uint64 + switch { + case discriminant < 0xfd: + rv = uint64(discriminant) + + case discriminant == 0xfd: + _, err := io.ReadFull(r, buf[:2]) + switch { + case err == io.EOF: + return 0, io.ErrUnexpectedEOF + case err != nil: + return 0, err + } + rv = uint64(binary.BigEndian.Uint16(buf[:2])) + + // The encoding is not canonical if the value could have been + // encoded using fewer bytes. + if rv < 0xfd { + return 0, ErrVarIntNotCanonical + } + + case discriminant == 0xfe: + _, err := io.ReadFull(r, buf[:4]) + switch { + case err == io.EOF: + return 0, io.ErrUnexpectedEOF + case err != nil: + return 0, err + } + rv = uint64(binary.BigEndian.Uint32(buf[:4])) + + // The encoding is not canonical if the value could have been + // encoded using fewer bytes. + if rv <= 0xffff { + return 0, ErrVarIntNotCanonical + } + + default: + _, err := io.ReadFull(r, buf[:]) + switch { + case err == io.EOF: + return 0, io.ErrUnexpectedEOF + case err != nil: + return 0, err + } + rv = binary.BigEndian.Uint64(buf[:]) + + // The encoding is not canonical if the value could have been + // encoded using fewer bytes. + if rv <= 0xffffffff { + return 0, ErrVarIntNotCanonical + } + } + + return rv, nil +} + +// WriteVarInt serializes val to w using a variable number of bytes depending +// on its value. +func WriteVarInt(w io.Writer, val uint64, buf *[8]byte) error { + var length int + switch { + case val < 0xfd: + buf[0] = uint8(val) + length = 1 + + case val <= 0xffff: + buf[0] = uint8(0xfd) + binary.BigEndian.PutUint16(buf[1:3], uint16(val)) + length = 3 + + case val <= 0xffffffff: + buf[0] = uint8(0xfe) + binary.BigEndian.PutUint32(buf[1:5], uint32(val)) + length = 5 + + default: + buf[0] = uint8(0xff) + _, err := w.Write(buf[:1]) + if err != nil { + return err + } + + binary.BigEndian.PutUint64(buf[:], uint64(val)) + length = 8 + } + + _, err := w.Write(buf[:length]) + return err +} From 75fcf1cee1d9e01c9cd4732aad59552049355ffa Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Wed, 7 Aug 2019 15:03:18 -0700 Subject: [PATCH 2/9] tlv/varint_test: add tests vectors for custom Read/WriteVarInt --- tlv/varint_test.go | 217 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 tlv/varint_test.go diff --git a/tlv/varint_test.go b/tlv/varint_test.go new file mode 100644 index 00000000000..75cc1085f74 --- /dev/null +++ b/tlv/varint_test.go @@ -0,0 +1,217 @@ +package tlv_test + +import ( + "bytes" + "io" + "math" + "testing" + + "github.com/lightningnetwork/lnd/tlv" +) + +type varIntTest struct { + Name string + Value uint64 + Bytes []byte + ExpErr error +} + +var writeVarIntTests = []varIntTest{ + { + Name: "zero", + Value: 0x00, + Bytes: []byte{0x00}, + }, + { + Name: "one byte high", + Value: 0xfc, + Bytes: []byte{0xfc}, + }, + { + Name: "two byte low", + Value: 0xfd, + Bytes: []byte{0xfd, 0x00, 0xfd}, + }, + { + Name: "two byte high", + Value: 0xffff, + Bytes: []byte{0xfd, 0xff, 0xff}, + }, + { + Name: "four byte low", + Value: 0x10000, + Bytes: []byte{0xfe, 0x00, 0x01, 0x00, 0x00}, + }, + { + Name: "four byte high", + Value: 0xffffffff, + Bytes: []byte{0xfe, 0xff, 0xff, 0xff, 0xff}, + }, + { + Name: "eight byte low", + Value: 0x100000000, + Bytes: []byte{0xff, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}, + }, + { + Name: "eight byte high", + Value: math.MaxUint64, + Bytes: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + }, +} + +// TestWriteVarInt asserts the behavior of tlv.WriteVarInt under various +// positive and negative test cases. +func TestWriteVarInt(t *testing.T) { + for _, test := range writeVarIntTests { + t.Run(test.Name, func(t *testing.T) { + testWriteVarInt(t, test) + }) + } +} + +func testWriteVarInt(t *testing.T, test varIntTest) { + var ( + w bytes.Buffer + buf [8]byte + ) + err := tlv.WriteVarInt(&w, test.Value, &buf) + if err != nil { + t.Fatalf("unable to encode %d as varint: %v", + test.Value, err) + } + + if bytes.Compare(w.Bytes(), test.Bytes) != 0 { + t.Fatalf("expected bytes: %v, got %v", + test.Bytes, w.Bytes()) + } +} + +var readVarIntTests = []varIntTest{ + { + Name: "zero", + Value: 0x00, + Bytes: []byte{0x00}, + }, + { + Name: "one byte high", + Value: 0xfc, + Bytes: []byte{0xfc}, + }, + { + Name: "two byte low", + Value: 0xfd, + Bytes: []byte{0xfd, 0x00, 0xfd}, + }, + { + Name: "two byte high", + Value: 0xffff, + Bytes: []byte{0xfd, 0xff, 0xff}, + }, + { + Name: "four byte low", + Value: 0x10000, + Bytes: []byte{0xfe, 0x00, 0x01, 0x00, 0x00}, + }, + { + Name: "four byte high", + Value: 0xffffffff, + Bytes: []byte{0xfe, 0xff, 0xff, 0xff, 0xff}, + }, + { + Name: "eight byte low", + Value: 0x100000000, + Bytes: []byte{0xff, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}, + }, + { + Name: "eight byte high", + Value: math.MaxUint64, + Bytes: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + }, + { + Name: "two byte not canonical", + Bytes: []byte{0xfd, 0x00, 0xfc}, + ExpErr: tlv.ErrVarIntNotCanonical, + }, + { + Name: "four byte not canonical", + Bytes: []byte{0xfe, 0x00, 0x00, 0xff, 0xff}, + ExpErr: tlv.ErrVarIntNotCanonical, + }, + { + Name: "eight byte not canonical", + Bytes: []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff}, + ExpErr: tlv.ErrVarIntNotCanonical, + }, + { + Name: "two byte short read", + Bytes: []byte{0xfd, 0x00}, + ExpErr: io.ErrUnexpectedEOF, + }, + { + Name: "four byte short read", + Bytes: []byte{0xfe, 0xff, 0xff}, + ExpErr: io.ErrUnexpectedEOF, + }, + { + Name: "eight byte short read", + Bytes: []byte{0xff, 0xff, 0xff, 0xff, 0xff}, + ExpErr: io.ErrUnexpectedEOF, + }, + { + Name: "one byte no read", + Bytes: []byte{}, + ExpErr: io.EOF, + }, + // The following cases are the reason for needing to make a custom + // version of the varint for the tlv package. For the varint encodings + // in btcd's wire package these would return io.EOF, since it is + // actually a composite of two calls to io.ReadFull. In TLV, we need to + // be able to distinguish whether no bytes were read at all from no + // Bytes being read on the second read as the latter is not a proper TLV + // stream. We handle this by returning io.ErrUnexpectedEOF if we + // encounter io.EOF on any of these secondary reads for larger values. + { + Name: "two byte no read", + Bytes: []byte{0xfd}, + ExpErr: io.ErrUnexpectedEOF, + }, + { + Name: "four byte no read", + Bytes: []byte{0xfe}, + ExpErr: io.ErrUnexpectedEOF, + }, + { + Name: "eight byte no read", + Bytes: []byte{0xff}, + ExpErr: io.ErrUnexpectedEOF, + }, +} + +// TestReadVarInt asserts the behavior of tlv.ReadVarInt under various positive +// and negative test cases. +func TestReadVarInt(t *testing.T) { + for _, test := range readVarIntTests { + t.Run(test.Name, func(t *testing.T) { + testReadVarInt(t, test) + }) + } +} + +func testReadVarInt(t *testing.T, test varIntTest) { + var buf [8]byte + r := bytes.NewReader(test.Bytes) + val, err := tlv.ReadVarInt(r, &buf) + if err != nil && err != test.ExpErr { + t.Fatalf("expected decoding error: %v, got: %v", + test.ExpErr, err) + } + + // If we expected a decoding error, there's no point checking the value. + if test.ExpErr != nil { + return + } + + if val != test.Value { + t.Fatalf("expected value: %d, got %d", test.Value, val) + } +} From 6773d4770afce880173433c4b310fa52e3e8b615 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Wed, 7 Aug 2019 15:03:30 -0700 Subject: [PATCH 3/9] tlv/primitive: add primitive encodings --- tlv/primitive.go | 273 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 273 insertions(+) create mode 100644 tlv/primitive.go diff --git a/tlv/primitive.go b/tlv/primitive.go new file mode 100644 index 00000000000..46ccfde9e68 --- /dev/null +++ b/tlv/primitive.go @@ -0,0 +1,273 @@ +package tlv + +import ( + "encoding/binary" + "fmt" + "io" + + "github.com/btcsuite/btcd/btcec" +) + +// ErrTypeForEncoding signals that an incorrect type was passed to an Encoder. +type ErrTypeForEncoding struct { + val interface{} + expType string +} + +// NewTypeForEncodingErr creates a new ErrTypeForEncoding given the incorrect +// val and the expected type. +func NewTypeForEncodingErr(val interface{}, expType string) ErrTypeForEncoding { + return ErrTypeForEncoding{ + val: val, + expType: expType, + } +} + +// Error returns a human-readable description of the type mismatch. +func (e ErrTypeForEncoding) Error() string { + return fmt.Sprintf("ErrTypeForEncoding want (type: *%s), "+ + "got (type: %T)", e.expType, e.val) +} + +// ErrTypeForDecoding signals that an incorrect type was passed to a Decoder or +// that the expected length of the encoding is different from that required by +// the expected type. +type ErrTypeForDecoding struct { + val interface{} + expType string + valLength uint64 + expLength uint64 +} + +// NewTypeForDecodingErr creates a new ErrTypeForDecoding given the incorrect +// val and expected type, or the mismatch in their expected lengths. +func NewTypeForDecodingErr(val interface{}, expType string, + valLength, expLength uint64) ErrTypeForDecoding { + + return ErrTypeForDecoding{ + val: val, + expType: expType, + valLength: valLength, + expLength: expLength, + } +} + +// Error returns a human-readable description of the type mismatch. +func (e ErrTypeForDecoding) Error() string { + return fmt.Sprintf("ErrTypeForDecoding want (type: *%s, length: %v), "+ + "got (type: %T, length: %v)", e.expType, e.expLength, e.val, + e.valLength) +} + +var ( + byteOrder = binary.BigEndian +) + +// EUint8 is an Encoder for uint8 values. An error is returned if val is not a +// *uint8. +func EUint8(w io.Writer, val interface{}, buf *[8]byte) error { + if i, ok := val.(*uint8); ok { + buf[0] = *i + _, err := w.Write(buf[:1]) + return err + } + return ErrTypeForEncoding{val, "uint8"} +} + +// EUint16 is an Encoder for uint16 values. An error is returned if val is not a +// *uint16. +func EUint16(w io.Writer, val interface{}, buf *[8]byte) error { + if i, ok := val.(*uint16); ok { + byteOrder.PutUint16(buf[:2], *i) + _, err := w.Write(buf[:2]) + return err + } + return ErrTypeForEncoding{val, "uint16"} +} + +// EUint32 is an Encoder for uint32 values. An error is returned if val is not a +// *uint32. +func EUint32(w io.Writer, val interface{}, buf *[8]byte) error { + if i, ok := val.(*uint32); ok { + byteOrder.PutUint32(buf[:4], *i) + _, err := w.Write(buf[:4]) + return err + } + return ErrTypeForEncoding{val, "uint32"} +} + +// EUint64 is an Encoder for uint64 values. An error is returned if val is not a +// *uint64. +func EUint64(w io.Writer, val interface{}, buf *[8]byte) error { + if i, ok := val.(*uint64); ok { + byteOrder.PutUint64(buf[:], *i) + _, err := w.Write(buf[:]) + return err + } + return ErrTypeForEncoding{val, "uint64"} +} + +// DUint8 is a Decoder for uint8 values. An error is returned if val is not a +// *uint8. +func DUint8(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if i, ok := val.(*uint8); ok && l == 1 { + if _, err := io.ReadFull(r, buf[:1]); err != nil { + return err + } + *i = buf[0] + return nil + } + return ErrTypeForDecoding{val, "uint8", l, 1} +} + +// DUint16 is a Decoder for uint16 values. An error is returned if val is not a +// *uint16. +func DUint16(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if i, ok := val.(*uint16); ok && l == 2 { + if _, err := io.ReadFull(r, buf[:2]); err != nil { + return err + } + *i = byteOrder.Uint16(buf[:2]) + return nil + } + return ErrTypeForDecoding{val, "uint16", l, 2} +} + +// DUint32 is a Decoder for uint32 values. An error is returned if val is not a +// *uint32. +func DUint32(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if i, ok := val.(*uint32); ok && l == 4 { + if _, err := io.ReadFull(r, buf[:4]); err != nil { + return err + } + *i = byteOrder.Uint32(buf[:4]) + return nil + } + return ErrTypeForDecoding{val, "uint32", l, 4} +} + +// DUint64 is a Decoder for uint64 values. An error is returned if val is not a +// *uint64. +func DUint64(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if i, ok := val.(*uint64); ok && l == 8 { + if _, err := io.ReadFull(r, buf[:]); err != nil { + return err + } + *i = byteOrder.Uint64(buf[:]) + return nil + } + return ErrTypeForDecoding{val, "uint64", l, 8} +} + +// EBytes32 is an Encoder for 32-byte arrays. An error is returned if val is not +// a *[32]byte. +func EBytes32(w io.Writer, val interface{}, _ *[8]byte) error { + if b, ok := val.(*[32]byte); ok { + _, err := w.Write(b[:]) + return err + } + return ErrTypeForEncoding{val, "[32]byte"} +} + +// DBytes32 is a Decoder for 32-byte arrays. An error is returned if val is not +// a *[32]byte. +func DBytes32(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if b, ok := val.(*[32]byte); ok && l == 32 { + _, err := io.ReadFull(r, b[:]) + return err + } + return ErrTypeForDecoding{val, "[32]byte", l, 32} +} + +// EBytes33 is an Encoder for 33-byte arrays. An error is returned if val is not +// a *[33]byte. +func EBytes33(w io.Writer, val interface{}, _ *[8]byte) error { + if b, ok := val.(*[33]byte); ok { + _, err := w.Write(b[:]) + return err + } + return ErrTypeForEncoding{val, "[33]byte"} +} + +// DBytes33 is a Decoder for 33-byte arrays. An error is returned if val is not +// a *[33]byte. +func DBytes33(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if b, ok := val.(*[33]byte); ok { + _, err := io.ReadFull(r, b[:]) + return err + } + return ErrTypeForDecoding{val, "[33]byte", l, 33} +} + +// EBytes64 is an Encoder for 64-byte arrays. An error is returned if val is not +// a *[64]byte. +func EBytes64(w io.Writer, val interface{}, _ *[8]byte) error { + if b, ok := val.(*[64]byte); ok { + _, err := w.Write(b[:]) + return err + } + return ErrTypeForEncoding{val, "[64]byte"} +} + +// DBytes64 is an Decoder for 64-byte arrays. An error is returned if val is not +// a *[64]byte. +func DBytes64(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if b, ok := val.(*[64]byte); ok && l == 64 { + _, err := io.ReadFull(r, b[:]) + return err + } + return ErrTypeForDecoding{val, "[64]byte", l, 64} +} + +// EPubKey is an Encoder for *btcec.PublicKey values. An error is returned if +// val is not a **btcec.PublicKey. +func EPubKey(w io.Writer, val interface{}, _ *[8]byte) error { + if pk, ok := val.(**btcec.PublicKey); ok { + _, err := w.Write((*pk).SerializeCompressed()) + return err + } + return ErrTypeForEncoding{val, "*btcec.PublicKey"} +} + +// DPubKey is a Decoder for *btcec.PublicKey values. An error is returned if val +// is not a **btcec.PublicKey. +func DPubKey(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if pk, ok := val.(**btcec.PublicKey); ok && l == 33 { + var b [33]byte + _, err := io.ReadFull(r, b[:]) + if err != nil { + return err + } + + p, err := btcec.ParsePubKey(b[:], btcec.S256()) + if err != nil { + return err + } + + *pk = p + + return nil + } + return ErrTypeForDecoding{val, "*btcec.PublicKey", l, 33} +} + +// EVarBytes is an Encoder for variable byte slices. An error is returned if val +// is not *[]byte. +func EVarBytes(w io.Writer, val interface{}, _ *[8]byte) error { + if b, ok := val.(*[]byte); ok { + _, err := w.Write(*b) + return err + } + return ErrTypeForEncoding{val, "[]byte"} +} + +// DVarBytes is a Decoder for variable byte slices. An error is returned if val +// is not *[]byte. +func DVarBytes(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if b, ok := val.(*[]byte); ok { + *b = make([]byte, l) + _, err := io.ReadFull(r, *b) + return err + } + return ErrTypeForDecoding{val, "[]byte", l, l} +} From 96e0bb1411207bafda51ba777141fb2fc75b119d Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Wed, 7 Aug 2019 15:03:43 -0700 Subject: [PATCH 4/9] tlv/record: adds various tlv record constructors --- tlv/record.go | 153 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 tlv/record.go diff --git a/tlv/record.go b/tlv/record.go new file mode 100644 index 00000000000..ae21c050a47 --- /dev/null +++ b/tlv/record.go @@ -0,0 +1,153 @@ +package tlv + +import ( + "io" + + "github.com/btcsuite/btcd/btcec" +) + +// Type is an 64-bit identifier for a TLV Record. +type Type uint64 + +// Encoder is a signature for methods that can encode TLV values. An error +// should be returned if the Encoder cannot support the underlying type of val. +// The provided scratch buffer must be non-nil. +type Encoder func(w io.Writer, val interface{}, buf *[8]byte) error + +// Decoder is a signature for methods that can decode TLV values. An error +// should be returned if the Decoder cannot support the underlying type of val. +// The provided scratch buffer must be non-nil. +type Decoder func(r io.Reader, val interface{}, buf *[8]byte, l uint64) error + +// ENOP is an encoder that doesn't modify the io.Writer and never fails. +func ENOP(io.Writer, interface{}, *[8]byte) error { return nil } + +// DNOP is an encoder that doesn't modify the io.Reader and never fails. +func DNOP(io.Reader, interface{}, *[8]byte, uint64) error { return nil } + +// SizeFunc is a function that can compute the length of a given field. Since +// the size of the underlying field can change, this allows the size of the +// field to be evaluated at the time of encoding. +type SizeFunc func() uint64 + +// SizeVarBytes returns a SizeFunc that can compute the length of a byte slice. +func SizeVarBytes(e *[]byte) SizeFunc { + return func() uint64 { + return uint64(len(*e)) + } +} + +// Record holds the required information to encode or decode a TLV record. +type Record struct { + value interface{} + typ Type + staticSize uint64 + sizeFunc SizeFunc + encoder Encoder + decoder Decoder +} + +// Size returns the size of the Record's value. If no static size is known, the +// dynamic size will be evaluated. +func (f *Record) Size() uint64 { + if f.sizeFunc == nil { + return f.staticSize + } + + return f.sizeFunc() +} + +// MakePrimitiveRecord creates a record for common types. +func MakePrimitiveRecord(typ Type, val interface{}) Record { + var ( + staticSize uint64 + sizeFunc SizeFunc + encoder Encoder + decoder Decoder + ) + switch e := val.(type) { + case *uint8: + staticSize = 1 + encoder = EUint8 + decoder = DUint8 + + case *uint16: + staticSize = 2 + encoder = EUint16 + decoder = DUint16 + + case *uint32: + staticSize = 4 + encoder = EUint32 + decoder = DUint32 + + case *uint64: + staticSize = 8 + encoder = EUint64 + decoder = DUint64 + + case *[32]byte: + staticSize = 32 + encoder = EBytes32 + decoder = DBytes32 + + case *[33]byte: + staticSize = 33 + encoder = EBytes33 + decoder = DBytes33 + + case **btcec.PublicKey: + staticSize = 33 + encoder = EPubKey + decoder = DPubKey + + case *[64]byte: + staticSize = 64 + encoder = EBytes64 + decoder = DBytes64 + + case *[]byte: + sizeFunc = SizeVarBytes(e) + encoder = EVarBytes + decoder = DVarBytes + + default: + panic("unknown primitive type") + } + + return Record{ + value: val, + typ: typ, + staticSize: staticSize, + sizeFunc: sizeFunc, + encoder: encoder, + decoder: decoder, + } +} + +// MakeStaticRecord creates a record for a field of fixed-size +func MakeStaticRecord(typ Type, val interface{}, size uint64, encoder Encoder, + decoder Decoder) Record { + + return Record{ + value: val, + typ: typ, + staticSize: size, + encoder: encoder, + decoder: decoder, + } +} + +// MakeDynamicRecord creates a record whose size may vary, and will be +// determined at the time of encoding via sizeFunc. +func MakeDynamicRecord(typ Type, val interface{}, sizeFunc SizeFunc, + encoder Encoder, decoder Decoder) Record { + + return Record{ + value: val, + typ: typ, + sizeFunc: sizeFunc, + encoder: encoder, + decoder: decoder, + } +} From bc1f23d98a38d08f7ff409738b984bd3fdf5eae2 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Wed, 7 Aug 2019 15:03:56 -0700 Subject: [PATCH 5/9] tlv/stream: adds tlv stream encoding/decoding --- tlv/stream.go | 280 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 280 insertions(+) create mode 100644 tlv/stream.go diff --git a/tlv/stream.go b/tlv/stream.go new file mode 100644 index 00000000000..49bb70edc10 --- /dev/null +++ b/tlv/stream.go @@ -0,0 +1,280 @@ +package tlv + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "math" +) + +// ErrStreamNotCanonical signals that a decoded stream does not contain records +// sorting by monotonically-increasing type. +var ErrStreamNotCanonical = errors.New("tlv stream is not canonical") + +// ErrUnknownRequiredType is an error returned when decoding an unknown and even +// type from a Stream. +type ErrUnknownRequiredType Type + +// Error returns a human-readable description of unknown required type. +func (t ErrUnknownRequiredType) Error() string { + return fmt.Sprintf("unknown required type: %d", t) +} + +// Stream defines a TLV stream that can be used for encoding or decoding a set +// of TLV Records. +type Stream struct { + records []Record + buf [8]byte +} + +// NewStream creates a new TLV Stream given an encoding codec, a decoding codec, +// and a set of known records. +func NewStream(records ...Record) (*Stream, error) { + // Assert that the ordering of the Records is canonical and appear in + // ascending order of type. + var ( + min Type + overflow bool + ) + for _, record := range records { + if overflow || record.typ < min { + return nil, ErrStreamNotCanonical + } + if record.encoder == nil { + record.encoder = ENOP + } + if record.decoder == nil { + record.decoder = DNOP + } + if record.typ == math.MaxUint64 { + overflow = true + } + min = record.typ + 1 + } + + return &Stream{ + records: records, + }, nil +} + +// MustNewStream creates a new TLV Stream given an encoding codec, a decoding +// codec, and a set of known records. If an error is encountered in creating the +// stream, this method will panic instead of returning the error. +func MustNewStream(records ...Record) *Stream { + stream, err := NewStream(records...) + if err != nil { + panic(err.Error()) + } + return stream +} + +// Encode writes a Stream to the passed io.Writer. Each of the Records known to +// the Stream is written in ascending order of their type so as to be canonical. +// +// The stream is constructed by concatenating the individual, serialized Records +// where each record has the following format: +// [varint: type] +// [varint: length] +// [length: value] +// +// An error is returned if the io.Writer fails to accept bytes from the +// encoding, and nothing else. The ordering of the Records is asserted upon the +// creation of a Stream, and thus the output will be by definition canonical. +func (s *Stream) Encode(w io.Writer) error { + // Iterate through all known records, if any, serializing each record's + // type, length and value. + for i := range s.records { + rec := &s.records[i] + + // Write the record's type as a varint. + err := WriteVarInt(w, uint64(rec.typ), &s.buf) + if err != nil { + return err + } + + // Write the record's length as a varint. + err = WriteVarInt(w, rec.Size(), &s.buf) + if err != nil { + return err + } + + // Encode the current record's value using the stream's codec. + err = rec.encoder(w, rec.value, &s.buf) + if err != nil { + return err + } + } + + return nil +} + +// Decode deserializes TLV Stream from the passed io.Reader. The Stream will +// inspect each record that is parsed and check to see if it has a corresponding +// Record to facilitate deserialization of that field. If the record is unknown, +// the Stream will discard the record's bytes and proceed to the subsequent +// record. +// +// Each record has the following format: +// [varint: type] +// [varint: length] +// [length: value] +// +// A series of (possibly zero) records are concatenated into a stream, this +// example contains two records: +// +// (t: 0x01, l: 0x04, v: 0xff, 0xff, 0xff, 0xff) +// (t: 0x02, l: 0x01, v: 0x01) +// +// This method asserts that the byte stream is canonical, namely that each +// record is unique and that all records are sorted in ascending order. An +// ErrNotCanonicalStream error is returned if the encoded TLV stream is not. +// +// We permit an io.EOF error only when reading the type byte which signals that +// the last record was read cleanly and we should stop parsing. All other io.EOF +// or io.ErrUnexpectedEOF errors are returned. +func (s *Stream) Decode(r io.Reader) error { + var ( + typ Type + min Type + recordIdx int + overflow bool + ) + + // Iterate through all possible type identifiers. As types are read from + // the io.Reader, min will skip forward to the last read type. + for { + // Read the next varint type. + t, err := ReadVarInt(r, &s.buf) + switch { + + // We'll silence an EOF when zero bytes remain, meaning the + // stream was cleanly encoded. + case err == io.EOF: + return nil + + // Other unexpected errors. + case err != nil: + return err + } + + typ = Type(t) + + // Assert that this type is greater than any previously read. + // If we've already overflowed and we parsed another type, the + // stream is not canonical. This check prevents us from accepts + // encodings that have duplicate records or from accepting an + // unsorted series. + if overflow || typ < min { + return ErrStreamNotCanonical + } + + // Read the varint length. + length, err := ReadVarInt(r, &s.buf) + switch { + + // We'll convert any EOFs to ErrUnexpectedEOF, since this + // results in an invalid record. + case err == io.EOF: + return io.ErrUnexpectedEOF + + // Other unexpected errors. + case err != nil: + return err + } + + // Search the records known to the stream for this type. We'll + // begin the search and recordIdx and walk forward until we find + // it or the next record's type is larger. + rec, newIdx, ok := s.getRecord(typ, recordIdx) + switch { + + // We know of this record type, proceed to decode the value. + // This method asserts that length bytes are read in the + // process, and returns an error if the number of bytes is not + // exactly length. + case ok: + err := rec.decoder(r, rec.value, &s.buf, length) + switch { + + // We'll convert any EOFs to ErrUnexpectedEOF, since this + // results in an invalid record. + case err == io.EOF: + return io.ErrUnexpectedEOF + + // Other unexpected errors. + case err != nil: + return err + } + + // This record type is unknown to the stream, fail if the type + // is even meaning that we are required to understand it. + case typ%2 == 0: + return ErrUnknownRequiredType(typ) + + // Otherwise, the record type is unknown and is odd, discard the + // number of bytes specified by length. + default: + _, err := io.CopyN(ioutil.Discard, r, int64(length)) + switch { + + // We'll convert any EOFs to ErrUnexpectedEOF, since this + // results in an invalid record. + case err == io.EOF: + return io.ErrUnexpectedEOF + + // Other unexpected errors. + case err != nil: + return err + } + } + + // Update our record index so that we can begin our next search + // from where we left off. + recordIdx = newIdx + + // If we've parsed the largest possible type, the next loop will + // overflow back to zero. However, we need to attempt parsing + // the next type to ensure that the stream is empty. + if typ == math.MaxUint64 { + overflow = true + } + + // Finally, set our lower bound on the next accepted type. + min = typ + 1 + } +} + +// getRecord searches for a record matching typ known to the stream. The boolean +// return value indicates whether the record is known to the stream. The integer +// return value carries the index from where getRecord should be invoked on the +// subsequent call. The first call to getRecord should always use an idx of 0. +func (s *Stream) getRecord(typ Type, idx int) (Record, int, bool) { + for idx < len(s.records) { + record := s.records[idx] + switch { + + // Found target record, return it to the caller. The next index + // returned points to the immediately following record. + case record.typ == typ: + return record, idx + 1, true + + // This record's type is lower than the target. Advance our + // index and continue to the next record which will have a + // strictly higher type. + case record.typ < typ: + idx++ + continue + + // This record's type is larger than the target, hence we have + // no record matching the current type. Return the current index + // so that we can start our search from here when processing the + // next tlv record. + default: + return Record{}, idx, false + } + } + + // All known records are exhausted. + return Record{}, idx, false +} From a0ebaeaa6c2544f98da01eda51339c1754cf6155 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Wed, 7 Aug 2019 15:04:08 -0700 Subject: [PATCH 6/9] tlv: zero alloc encoding for extended types This commit adds concrete encoding methods for primitive integral types. When external libs need to create custom encoders, this allows them to do so without incurring an extra allocation on the heap. Previously, the need to pass a pointer to the integer using an interface{} would cause the argument to escape, which we avoid by having them copied directly. --- tlv/primitive.go | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tlv/primitive.go b/tlv/primitive.go index 46ccfde9e68..3d1f0a067a2 100644 --- a/tlv/primitive.go +++ b/tlv/primitive.go @@ -74,6 +74,15 @@ func EUint8(w io.Writer, val interface{}, buf *[8]byte) error { return ErrTypeForEncoding{val, "uint8"} } +// EUint8T encodes a uint8 val to the provided io.Writer. This method is exposed +// so that encodings for custom uint8-like types can be created without +// incurring an extra heap allocation. +func EUint8T(w io.Writer, val uint8, buf *[8]byte) error { + buf[0] = val + _, err := w.Write(buf[:1]) + return err +} + // EUint16 is an Encoder for uint16 values. An error is returned if val is not a // *uint16. func EUint16(w io.Writer, val interface{}, buf *[8]byte) error { @@ -85,6 +94,15 @@ func EUint16(w io.Writer, val interface{}, buf *[8]byte) error { return ErrTypeForEncoding{val, "uint16"} } +// EUint16T encodes a uint16 val to the provided io.Writer. This method is +// exposed so that encodings for custom uint16-like types can be created without +// incurring an extra heap allocation. +func EUint16T(w io.Writer, val uint16, buf *[8]byte) error { + byteOrder.PutUint16(buf[:2], val) + _, err := w.Write(buf[:2]) + return err +} + // EUint32 is an Encoder for uint32 values. An error is returned if val is not a // *uint32. func EUint32(w io.Writer, val interface{}, buf *[8]byte) error { @@ -96,6 +114,15 @@ func EUint32(w io.Writer, val interface{}, buf *[8]byte) error { return ErrTypeForEncoding{val, "uint32"} } +// EUint32T encodes a uint32 val to the provided io.Writer. This method is +// exposed so that encodings for custom uint32-like types can be created without +// incurring an extra heap allocation. +func EUint32T(w io.Writer, val uint32, buf *[8]byte) error { + byteOrder.PutUint32(buf[:4], val) + _, err := w.Write(buf[:4]) + return err +} + // EUint64 is an Encoder for uint64 values. An error is returned if val is not a // *uint64. func EUint64(w io.Writer, val interface{}, buf *[8]byte) error { @@ -107,6 +134,15 @@ func EUint64(w io.Writer, val interface{}, buf *[8]byte) error { return ErrTypeForEncoding{val, "uint64"} } +// EUint64T encodes a uint64 val to the provided io.Writer. This method is +// exposed so that encodings for custom uint64-like types can be created without +// incurring an extra heap allocation. +func EUint64T(w io.Writer, val uint64, buf *[8]byte) error { + byteOrder.PutUint64(buf[:], val) + _, err := w.Write(buf[:]) + return err +} + // DUint8 is a Decoder for uint8 values. An error is returned if val is not a // *uint8. func DUint8(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { From abdcd47dcc240d9a52af67204750c450b74134bf Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Wed, 7 Aug 2019 15:04:20 -0700 Subject: [PATCH 7/9] tlv/bench_test: add basic benchmark --- tlv/bench_test.go | 161 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 tlv/bench_test.go diff --git a/tlv/bench_test.go b/tlv/bench_test.go new file mode 100644 index 00000000000..f71a7eb61ef --- /dev/null +++ b/tlv/bench_test.go @@ -0,0 +1,161 @@ +package tlv_test + +import ( + "bytes" + "io" + "io/ioutil" + "testing" + + "github.com/lightningnetwork/lnd/lnwallet" + "github.com/lightningnetwork/lnd/tlv" + "github.com/lightningnetwork/lnd/watchtower/blob" + "github.com/lightningnetwork/lnd/watchtower/wtwire" +) + +// CreateSessionTLV mirrors the wtwire.CreateSession message, but uses TLV for +// encoding/decoding. +type CreateSessionTLV struct { + BlobType blob.Type + MaxUpdates uint16 + RewardBase uint32 + RewardRate uint32 + SweepFeeRate lnwallet.SatPerKWeight + + tlvStream *tlv.Stream +} + +// EBlobType is an encoder for blob.Type. +func EBlobType(w io.Writer, val interface{}, buf *[8]byte) error { + if t, ok := val.(*blob.Type); ok { + return tlv.EUint16T(w, uint16(*t), buf) + } + return tlv.NewTypeForEncodingErr(val, "blob.Type") +} + +// EBlobType is an decoder for blob.Type. +func DBlobType(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if typ, ok := val.(*blob.Type); ok { + var t uint16 + err := tlv.DUint16(r, &t, buf, l) + if err != nil { + return err + } + *typ = blob.Type(t) + return nil + } + return tlv.NewTypeForDecodingErr(val, "blob.Type", l, 2) +} + +// ESatPerKW is an encoder for lnwallet.SatPerKWeight. +func ESatPerKW(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*lnwallet.SatPerKWeight); ok { + return tlv.EUint64(w, uint64(*v), buf) + } + return tlv.NewTypeForEncodingErr(val, "lnwallet.SatPerKWeight") +} + +// DSatPerKW is an decoder for lnwallet.SatPerKWeight. +func DSatPerKW(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if v, ok := val.(*lnwallet.SatPerKWeight); ok { + var sat uint64 + err := tlv.DUint64(r, &sat, buf, l) + if err != nil { + return err + } + *v = lnwallet.SatPerKWeight(sat) + return nil + } + return tlv.NewTypeForDecodingErr(val, "lnwallet.SatPerKWeight", l, 8) +} + +// NewCreateSessionTLV initializes a new CreateSessionTLV message. +func NewCreateSessionTLV() *CreateSessionTLV { + m := &CreateSessionTLV{} + m.tlvStream = tlv.MustNewStream( + tlv.MakeStaticRecord(0, &m.BlobType, 2, EBlobType, DBlobType), + tlv.MakePrimitiveRecord(1, &m.MaxUpdates), + tlv.MakePrimitiveRecord(2, &m.RewardBase), + tlv.MakePrimitiveRecord(3, &m.RewardRate), + tlv.MakeStaticRecord(4, &m.SweepFeeRate, 8, ESatPerKW, DSatPerKW), + ) + + return m +} + +// Encode writes the CreateSessionTLV to the passed io.Writer. +func (c *CreateSessionTLV) Encode(w io.Writer) error { + return c.tlvStream.Encode(w) +} + +// Decode reads the CreateSessionTLV from the passed io.Reader. +func (c *CreateSessionTLV) Decode(r io.Reader) error { + return c.tlvStream.Decode(r) +} + +// BenchmarkEncodeCreateSession benchmarks encoding of the non-TLV +// CreateSession. +func BenchmarkEncodeCreateSession(t *testing.B) { + m := &wtwire.CreateSession{} + + t.ReportAllocs() + t.ResetTimer() + + var err error + for i := 0; i < t.N; i++ { + err = m.Encode(ioutil.Discard, 0) + } + _ = err +} + +// BenchmarkEncodeCreateSessionTLV benchmarks encoding of the TLV CreateSession. +func BenchmarkEncodeCreateSessionTLV(t *testing.B) { + m := NewCreateSessionTLV() + + t.ReportAllocs() + t.ResetTimer() + + var err error + for i := 0; i < t.N; i++ { + err = m.Encode(ioutil.Discard) + } + _ = err +} + +// BenchmarkDecodeCreateSession benchmarks encoding of the non-TLV +// CreateSession. +func BenchmarkDecodeCreateSession(t *testing.B) { + m := &wtwire.CreateSession{} + + var b bytes.Buffer + m.Encode(&b, 0) + r := bytes.NewReader(b.Bytes()) + + t.ReportAllocs() + t.ResetTimer() + + var err error + for i := 0; i < t.N; i++ { + r.Seek(0, 0) + err = m.Decode(r, 0) + } + _ = err +} + +// BenchmarkDecodeCreateSessionTLV benchmarks decoding of the TLV CreateSession. +func BenchmarkDecodeCreateSessionTLV(t *testing.B) { + m := NewCreateSessionTLV() + + var b bytes.Buffer + var err error + m.Encode(&b) + r := bytes.NewReader(b.Bytes()) + + t.ReportAllocs() + t.ResetTimer() + + for i := 0; i < t.N; i++ { + r.Seek(0, 0) + err = m.Decode(r) + } + _ = err +} From 7c94bbb4a2f66c1585b30a6055360bf0a9f93cc1 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Wed, 7 Aug 2019 15:04:33 -0700 Subject: [PATCH 8/9] tlv/truncated: add truncated integer encodings This commit adds the truncated integer encodings used in the variable-size onion payloads. The amount and cltv delta both use the truncated encoding to shave bytes in the overall size, and will likely be used in the future for additional extensions where size is a constraint. --- tlv/truncated.go | 180 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 tlv/truncated.go diff --git a/tlv/truncated.go b/tlv/truncated.go new file mode 100644 index 00000000000..b35fc166a83 --- /dev/null +++ b/tlv/truncated.go @@ -0,0 +1,180 @@ +package tlv + +import ( + "encoding/binary" + "errors" + "io" +) + +// ErrTUintNotMinimal signals that decoding a truncated uint failed because the +// value was not minimally encoded. +var ErrTUintNotMinimal = errors.New("truncated uint not minimally encoded") + +// numLeadingZeroBytes16 computes the number of leading zeros for a uint16. +func numLeadingZeroBytes16(v uint16) uint64 { + switch { + case v == 0: + return 2 + case v&0xff00 == 0: + return 1 + default: + return 0 + } +} + +// SizeTUint16 returns the number of bytes remaining in a uint16 after +// truncating the leading zeros. +func SizeTUint16(v uint16) uint64 { + return 2 - numLeadingZeroBytes16(v) +} + +// ETUint16 is an Encoder for truncated uint16 values, where leading zeros will +// be omitted. An error is returned if val is not a *uint16. +func ETUint16(w io.Writer, val interface{}, buf *[8]byte) error { + if t, ok := val.(*uint16); ok { + binary.BigEndian.PutUint16(buf[:2], *t) + numZeros := numLeadingZeroBytes16(*t) + _, err := w.Write(buf[numZeros:2]) + return err + } + return NewTypeForEncodingErr(val, "uint16") +} + +// DTUint16 is an Decoder for truncated uint16 values, where leading zeros will +// be resurrected. An error is returned if val is not a *uint16. +func DTUint16(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if t, ok := val.(*uint16); ok && l <= 2 { + _, err := io.ReadFull(r, buf[2-l:]) + if err != nil { + return err + } + zero(buf[:2-l]) + *t = binary.BigEndian.Uint16(buf[:2]) + if 2-numLeadingZeroBytes16(*t) != l { + return ErrTUintNotMinimal + } + return nil + } + return NewTypeForDecodingErr(val, "uint16", l, 2) +} + +// numLeadingZeroBytes16 computes the number of leading zeros for a uint32. +func numLeadingZeroBytes32(v uint32) uint64 { + switch { + case v == 0: + return 4 + case v&0xffffff00 == 0: + return 3 + case v&0xffff0000 == 0: + return 2 + case v&0xff000000 == 0: + return 1 + default: + return 0 + } +} + +// SizeTUint32 returns the number of bytes remaining in a uint32 after +// truncating the leading zeros. +func SizeTUint32(v uint32) uint64 { + return 4 - numLeadingZeroBytes32(v) +} + +// ETUint32 is an Encoder for truncated uint32 values, where leading zeros will +// be omitted. An error is returned if val is not a *uint32. +func ETUint32(w io.Writer, val interface{}, buf *[8]byte) error { + if t, ok := val.(*uint32); ok { + binary.BigEndian.PutUint32(buf[:4], *t) + numZeros := numLeadingZeroBytes32(*t) + _, err := w.Write(buf[numZeros:4]) + return err + } + return NewTypeForEncodingErr(val, "uint32") +} + +// DTUint32 is an Decoder for truncated uint32 values, where leading zeros will +// be resurrected. An error is returned if val is not a *uint32. +func DTUint32(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if t, ok := val.(*uint32); ok && l <= 4 { + _, err := io.ReadFull(r, buf[4-l:]) + if err != nil { + return err + } + zero(buf[:4-l]) + *t = binary.BigEndian.Uint32(buf[:4]) + if 4-numLeadingZeroBytes32(*t) != l { + return ErrTUintNotMinimal + } + return nil + } + return NewTypeForDecodingErr(val, "uint32", l, 4) +} + +// numLeadingZeroBytes64 computes the number of leading zeros for a uint32. +// +// TODO(conner): optimize using unrolled binary search +func numLeadingZeroBytes64(v uint64) uint64 { + switch { + case v == 0: + return 8 + case v&0xffffffffffffff00 == 0: + return 7 + case v&0xffffffffffff0000 == 0: + return 6 + case v&0xffffffffff000000 == 0: + return 5 + case v&0xffffffff00000000 == 0: + return 4 + case v&0xffffff0000000000 == 0: + return 3 + case v&0xffff000000000000 == 0: + return 2 + case v&0xff00000000000000 == 0: + return 1 + default: + return 0 + } +} + +// SizeTUint64 returns the number of bytes remaining in a uint64 after +// truncating the leading zeros. +func SizeTUint64(v uint64) uint64 { + return 8 - numLeadingZeroBytes64(v) +} + +// ETUint64 is an Encoder for truncated uint64 values, where leading zeros will +// be omitted. An error is returned if val is not a *uint64. +func ETUint64(w io.Writer, val interface{}, buf *[8]byte) error { + if t, ok := val.(*uint64); ok { + binary.BigEndian.PutUint64(buf[:], *t) + numZeros := numLeadingZeroBytes64(*t) + _, err := w.Write(buf[numZeros:]) + return err + } + return NewTypeForEncodingErr(val, "uint64") +} + +// DTUint64 is an Decoder for truncated uint64 values, where leading zeros will +// be resurrected. An error is returned if val is not a *uint64. +func DTUint64(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if t, ok := val.(*uint64); ok && l <= 8 { + _, err := io.ReadFull(r, buf[8-l:]) + if err != nil { + return err + } + zero(buf[:8-l]) + *t = binary.BigEndian.Uint64(buf[:]) + if 8-numLeadingZeroBytes64(*t) != l { + return ErrTUintNotMinimal + } + return nil + } + return NewTypeForDecodingErr(val, "uint64", l, 8) +} + +// zero clears the passed byte slice. +func zero(b []byte) { + for i := range b { + b[i] = 0x00 + } +} From 36909995117795209b5c9a31556b1d310d01b706 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Wed, 7 Aug 2019 15:04:45 -0700 Subject: [PATCH 9/9] tlv/tlv_test: add BOLT1 test vectors --- tlv/tlv_test.go | 559 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 559 insertions(+) create mode 100644 tlv/tlv_test.go diff --git a/tlv/tlv_test.go b/tlv/tlv_test.go new file mode 100644 index 00000000000..c3b996fef63 --- /dev/null +++ b/tlv/tlv_test.go @@ -0,0 +1,559 @@ +package tlv_test + +import ( + "bytes" + "errors" + "io" + "reflect" + "testing" + + "github.com/btcsuite/btcd/btcec" + "github.com/lightningnetwork/lnd/tlv" +) + +type nodeAmts struct { + nodeID *btcec.PublicKey + amt1 uint64 + amt2 uint64 +} + +func ENodeAmts(w io.Writer, val interface{}, buf *[8]byte) error { + if t, ok := val.(*nodeAmts); ok { + if err := tlv.EPubKey(w, &t.nodeID, buf); err != nil { + return err + } + if err := tlv.EUint64T(w, t.amt1, buf); err != nil { + return err + } + return tlv.EUint64T(w, t.amt2, buf) + } + return tlv.NewTypeForEncodingErr(val, "nodeAmts") +} + +func DNodeAmts(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if t, ok := val.(*nodeAmts); ok && l == 49 { + if err := tlv.DPubKey(r, &t.nodeID, buf, 33); err != nil { + return err + } + if err := tlv.DUint64(r, &t.amt1, buf, 8); err != nil { + return err + } + return tlv.DUint64(r, &t.amt2, buf, 8) + } + return tlv.NewTypeForDecodingErr(val, "nodeAmts", l, 49) +} + +type N1 struct { + amt uint64 + scid uint64 + nodeAmts nodeAmts + cltvDelta uint16 + + stream *tlv.Stream +} + +func (n *N1) sizeAmt() uint64 { + return tlv.SizeTUint64(n.amt) +} + +func NewN1() *N1 { + n := new(N1) + + n.stream = tlv.MustNewStream( + tlv.MakeDynamicRecord( + 1, &n.amt, n.sizeAmt, tlv.ETUint64, tlv.DTUint64, + ), + tlv.MakePrimitiveRecord(2, &n.scid), + tlv.MakeStaticRecord(3, &n.nodeAmts, 49, ENodeAmts, DNodeAmts), + tlv.MakePrimitiveRecord(254, &n.cltvDelta), + ) + + return n +} + +func (n *N1) Encode(w io.Writer) error { + return n.stream.Encode(w) +} + +func (n *N1) Decode(r io.Reader) error { + return n.stream.Decode(r) +} + +type N2 struct { + amt uint64 + cltvExpiry uint32 + + stream *tlv.Stream +} + +func (n *N2) sizeAmt() uint64 { + return tlv.SizeTUint64(n.amt) +} + +func (n *N2) sizeCltv() uint64 { + return tlv.SizeTUint32(n.cltvExpiry) +} + +func NewN2() *N2 { + n := new(N2) + + n.stream = tlv.MustNewStream( + tlv.MakeDynamicRecord( + 0, &n.amt, n.sizeAmt, tlv.ETUint64, tlv.DTUint64, + ), + tlv.MakeDynamicRecord( + 11, &n.cltvExpiry, n.sizeCltv, tlv.ETUint32, tlv.DTUint32, + ), + ) + + return n +} + +func (n *N2) Encode(w io.Writer) error { + return n.stream.Encode(w) +} + +func (n *N2) Decode(r io.Reader) error { + return n.stream.Decode(r) +} + +var tlvDecodingFailureTests = []struct { + name string + bytes []byte + expErr error + + // skipN2 if true, will cause the test to only be executed on N1. + skipN2 bool +}{ + { + name: "type truncated", + bytes: []byte{0xfd}, + expErr: io.ErrUnexpectedEOF, + }, + { + name: "type truncated", + bytes: []byte{0xfd, 0x01}, + expErr: io.ErrUnexpectedEOF, + }, + { + name: "not minimally encoded type", + bytes: []byte{0xfd, 0x00, 0x01}, // spec has trailing 0x00 + expErr: tlv.ErrVarIntNotCanonical, + }, + { + name: "missing length", + bytes: []byte{0xfd, 0x01, 0x01}, + expErr: io.ErrUnexpectedEOF, + }, + { + name: "length truncated", + bytes: []byte{0x0f, 0xfd}, + expErr: io.ErrUnexpectedEOF, + }, + { + name: "length truncated", + bytes: []byte{0x0f, 0xfd, 0x26}, + expErr: io.ErrUnexpectedEOF, + }, + { + name: "missing value", + bytes: []byte{0x0f, 0xfd, 0x26, 0x02}, + expErr: io.ErrUnexpectedEOF, + }, + { + name: "not minimally encoded length", + bytes: []byte{0x0f, 0xfd, 0x00, 0x01}, // spec has trailing 0x00 + expErr: tlv.ErrVarIntNotCanonical, + }, + { + name: "value truncated", + bytes: []byte{0x0f, 0xfd, 0x02, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + expErr: io.ErrUnexpectedEOF, + }, + { + name: "unknown even type", + bytes: []byte{0x12, 0x00}, + expErr: tlv.ErrUnknownRequiredType(0x12), + }, + { + name: "unknown even type", + bytes: []byte{0xfd, 0x01, 0x02, 0x00}, + expErr: tlv.ErrUnknownRequiredType(0x102), + }, + { + name: "unknown even type", + bytes: []byte{0xfe, 0x01, 0x00, 0x00, 0x02, 0x00}, + expErr: tlv.ErrUnknownRequiredType(0x01000002), + }, + { + name: "unknown even type", + bytes: []byte{0xff, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00}, + expErr: tlv.ErrUnknownRequiredType(0x0100000000000002), + }, + { + name: "greater than encoding length for n1's amt", + bytes: []byte{0x01, 0x09, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + expErr: tlv.NewTypeForDecodingErr(new(uint64), "uint64", 9, 8), + skipN2: true, + }, + { + name: "encoding for n1's amt is not minimal", + bytes: []byte{0x01, 0x01, 0x00}, + expErr: tlv.ErrTUintNotMinimal, + skipN2: true, + }, + { + name: "encoding for n1's amt is not minimal", + bytes: []byte{0x01, 0x02, 0x00, 0x01}, + expErr: tlv.ErrTUintNotMinimal, + skipN2: true, + }, + { + name: "encoding for n1's amt is not minimal", + bytes: []byte{0x01, 0x03, 0x00, 0x01, 0x00}, + expErr: tlv.ErrTUintNotMinimal, + skipN2: true, + }, + { + name: "encoding for n1's amt is not minimal", + bytes: []byte{0x01, 0x04, 0x00, 0x01, 0x00, 0x00}, + expErr: tlv.ErrTUintNotMinimal, + skipN2: true, + }, + { + name: "encoding for n1's amt is not minimal", + bytes: []byte{0x01, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00}, + expErr: tlv.ErrTUintNotMinimal, + skipN2: true, + }, + { + name: "encoding for n1's amt is not minimal", + bytes: []byte{0x01, 0x06, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}, + expErr: tlv.ErrTUintNotMinimal, + skipN2: true, + }, + { + name: "encoding for n1's amt is not minimal", + bytes: []byte{0x01, 0x07, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00}, + expErr: tlv.ErrTUintNotMinimal, + skipN2: true, + }, + { + name: "encoding for n1's amt is not minimal", + bytes: []byte{0x01, 0x08, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + expErr: tlv.ErrTUintNotMinimal, + skipN2: true, + }, + { + name: "less than encoding length for n1's scid", + bytes: []byte{0x02, 0x07, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01}, + expErr: tlv.NewTypeForDecodingErr(new(uint64), "uint64", 7, 8), + skipN2: true, + }, + { + name: "less than encoding length for n1's scid", + bytes: []byte{0x02, 0x09, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01}, + expErr: tlv.NewTypeForDecodingErr(new(uint64), "uint64", 9, 8), + skipN2: true, + }, + { + name: "less than encoding length for n1's nodeAmts", + bytes: []byte{0x03, 0x29, + 0x02, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2, + 0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47, + 0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e, + 0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, + }, + expErr: tlv.NewTypeForDecodingErr(new(nodeAmts), "nodeAmts", 41, 49), + skipN2: true, + }, + { + name: "less than encoding length for n1's nodeAmts", + bytes: []byte{0x03, 0x30, + 0x02, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2, + 0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47, + 0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e, + 0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x01, + }, + expErr: tlv.NewTypeForDecodingErr(new(nodeAmts), "nodeAmts", 48, 49), + skipN2: true, + }, + { + name: "n1's node_id is not a valid point", + bytes: []byte{0x03, 0x31, + 0x04, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2, + 0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47, + 0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e, + 0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x02, + }, + expErr: errors.New("invalid magic in compressed pubkey string: 4"), + skipN2: true, + }, + { + name: "greater than encoding length for n1's nodeAmts", + bytes: []byte{0x03, 0x32, + 0x02, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2, + 0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47, + 0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e, + 0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + }, + expErr: tlv.NewTypeForDecodingErr(new(nodeAmts), "nodeAmts", 50, 49), + skipN2: true, + }, + { + name: "unknown required type or n1", + bytes: []byte{0x00, 0x00}, + expErr: tlv.ErrUnknownRequiredType(0x00), + skipN2: true, + }, + { + name: "less than encoding length for n1's cltvDelta", + bytes: []byte{0xfd, 0x00, 0x0fe, 0x00}, + expErr: tlv.NewTypeForDecodingErr(new(uint16), "uint16", 0, 2), + skipN2: true, + }, + { + name: "less than encoding length for n1's cltvDelta", + bytes: []byte{0xfd, 0x00, 0xfe, 0x01, 0x01}, + expErr: tlv.NewTypeForDecodingErr(new(uint16), "uint16", 1, 2), + skipN2: true, + }, + { + name: "greater than encoding length for n1's cltvDelta", + bytes: []byte{0xfd, 0x00, 0xfe, 0x03, 0x01, 0x01, 0x01}, + expErr: tlv.NewTypeForDecodingErr(new(uint16), "uint16", 3, 2), + skipN2: true, + }, + { + name: "unknown even field for n1's namespace", + bytes: []byte{0x0a, 0x00}, + expErr: tlv.ErrUnknownRequiredType(0x0a), + skipN2: true, + }, + { + name: "valid records but invalid ordering", + bytes: []byte{0x02, 0x08, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x26, 0x01, + 0x01, 0x2a, + }, + expErr: tlv.ErrStreamNotCanonical, + skipN2: true, + }, + { + name: "duplicate tlv type", + bytes: []byte{0x02, 0x08, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x31, 0x02, + 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x51, + }, + expErr: tlv.ErrStreamNotCanonical, + skipN2: true, + }, + { + name: "duplicate ignored tlv type", + bytes: []byte{0x1f, 0x00, 0x1f, 0x01, 0x2a}, + expErr: tlv.ErrStreamNotCanonical, + skipN2: true, + }, + { + name: "type wraparound", + bytes: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00}, + expErr: tlv.ErrStreamNotCanonical, + }, +} + +// TestTLVDecodingSuccess asserts that the TLV parser fails to decode invalid +// TLV streams. +func TestTLVDecodingFailures(t *testing.T) { + for _, test := range tlvDecodingFailureTests { + t.Run(test.name, func(t *testing.T) { + n1 := NewN1() + r := bytes.NewReader(test.bytes) + + err := n1.Decode(r) + if !reflect.DeepEqual(err, test.expErr) { + t.Fatalf("expected N1 decoding failure: %v, "+ + "got: %v", test.expErr, err) + } + + if test.skipN2 { + return + } + + n2 := NewN2() + r = bytes.NewReader(test.bytes) + + err = n2.Decode(r) + if !reflect.DeepEqual(err, test.expErr) { + t.Fatalf("expected N2 decoding failure: %v, "+ + "got: %v", test.expErr, err) + } + }) + } +} + +var tlvDecodingSuccessTests = []struct { + name string + bytes []byte + skipN2 bool +}{ + { + name: "empty", + }, + { + name: "unknown odd type", + bytes: []byte{0x21, 0x00}, + }, + { + name: "unknown odd type", + bytes: []byte{0xfd, 0x02, 0x01, 0x00}, + }, + { + name: "unknown odd type", + bytes: []byte{0xfd, 0x00, 0xfd, 0x00}, + }, + { + name: "unknown odd type", + bytes: []byte{0xfd, 0x00, 0xff, 0x00}, + }, + { + name: "unknown odd type", + bytes: []byte{0xfe, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + { + name: "unknown odd type", + bytes: []byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00}, + }, + { + name: "N1 amt=0", + bytes: []byte{0x01, 0x00}, + skipN2: true, + }, + { + name: "N1 amt=1", + bytes: []byte{0x01, 0x01, 0x01}, + skipN2: true, + }, + { + name: "N1 amt=256", + bytes: []byte{0x01, 0x02, 0x01, 0x00}, + skipN2: true, + }, + { + name: "N1 amt=65536", + bytes: []byte{0x01, 0x03, 0x01, 0x00, 0x00}, + skipN2: true, + }, + { + name: "N1 amt=16777216", + bytes: []byte{0x01, 0x04, 0x01, 0x00, 0x00, 0x00}, + skipN2: true, + }, + { + name: "N1 amt=4294967296", + bytes: []byte{0x01, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00}, + skipN2: true, + }, + { + name: "N1 amt=1099511627776", + bytes: []byte{0x01, 0x06, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00}, + skipN2: true, + }, + { + name: "N1 amt=281474976710656", + bytes: []byte{0x01, 0x07, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + skipN2: true, + }, + { + name: "N1 amt=72057594037927936", + bytes: []byte{0x01, 0x08, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + skipN2: true, + }, + { + name: "N1 scid=0x0x550", + bytes: []byte{0x02, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x26}, + skipN2: true, + }, + { + name: "N1 node_id=023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb amount_msat_1=1 amount_msat_2=2", + bytes: []byte{0x03, 0x31, + 0x02, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2, + 0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47, + 0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e, + 0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x02}, + skipN2: true, + }, + { + name: "N1 cltv_delta=550", + bytes: []byte{0xfd, 0x00, 0xfe, 0x02, 0x02, 0x26}, + skipN2: true, + }, +} + +// TestTLVDecodingSuccess asserts that the TLV parser is able to successfully +// decode valid TLV streams. +func TestTLVDecodingSuccess(t *testing.T) { + for _, test := range tlvDecodingSuccessTests { + t.Run(test.name, func(t *testing.T) { + n1 := NewN1() + r := bytes.NewReader(test.bytes) + + err := n1.Decode(r) + if err != nil { + t.Fatalf("expected N1 decoding success, got: %v", + err) + } + + if test.skipN2 { + return + } + + n2 := NewN2() + r = bytes.NewReader(test.bytes) + + err = n2.Decode(r) + if err != nil { + t.Fatalf("expected N2 decoding succes, got: %v", + err) + } + }) + } +}