diff --git a/tlv/bench_test.go b/tlv/bench_test.go new file mode 100644 index 00000000000..f71a7eb61ef --- /dev/null +++ b/tlv/bench_test.go @@ -0,0 +1,161 @@ +package tlv_test + +import ( + "bytes" + "io" + "io/ioutil" + "testing" + + "github.com/lightningnetwork/lnd/lnwallet" + "github.com/lightningnetwork/lnd/tlv" + "github.com/lightningnetwork/lnd/watchtower/blob" + "github.com/lightningnetwork/lnd/watchtower/wtwire" +) + +// CreateSessionTLV mirrors the wtwire.CreateSession message, but uses TLV for +// encoding/decoding. +type CreateSessionTLV struct { + BlobType blob.Type + MaxUpdates uint16 + RewardBase uint32 + RewardRate uint32 + SweepFeeRate lnwallet.SatPerKWeight + + tlvStream *tlv.Stream +} + +// EBlobType is an encoder for blob.Type. +func EBlobType(w io.Writer, val interface{}, buf *[8]byte) error { + if t, ok := val.(*blob.Type); ok { + return tlv.EUint16T(w, uint16(*t), buf) + } + return tlv.NewTypeForEncodingErr(val, "blob.Type") +} + +// EBlobType is an decoder for blob.Type. +func DBlobType(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if typ, ok := val.(*blob.Type); ok { + var t uint16 + err := tlv.DUint16(r, &t, buf, l) + if err != nil { + return err + } + *typ = blob.Type(t) + return nil + } + return tlv.NewTypeForDecodingErr(val, "blob.Type", l, 2) +} + +// ESatPerKW is an encoder for lnwallet.SatPerKWeight. +func ESatPerKW(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*lnwallet.SatPerKWeight); ok { + return tlv.EUint64(w, uint64(*v), buf) + } + return tlv.NewTypeForEncodingErr(val, "lnwallet.SatPerKWeight") +} + +// DSatPerKW is an decoder for lnwallet.SatPerKWeight. +func DSatPerKW(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if v, ok := val.(*lnwallet.SatPerKWeight); ok { + var sat uint64 + err := tlv.DUint64(r, &sat, buf, l) + if err != nil { + return err + } + *v = lnwallet.SatPerKWeight(sat) + return nil + } + return tlv.NewTypeForDecodingErr(val, "lnwallet.SatPerKWeight", l, 8) +} + +// NewCreateSessionTLV initializes a new CreateSessionTLV message. +func NewCreateSessionTLV() *CreateSessionTLV { + m := &CreateSessionTLV{} + m.tlvStream = tlv.MustNewStream( + tlv.MakeStaticRecord(0, &m.BlobType, 2, EBlobType, DBlobType), + tlv.MakePrimitiveRecord(1, &m.MaxUpdates), + tlv.MakePrimitiveRecord(2, &m.RewardBase), + tlv.MakePrimitiveRecord(3, &m.RewardRate), + tlv.MakeStaticRecord(4, &m.SweepFeeRate, 8, ESatPerKW, DSatPerKW), + ) + + return m +} + +// Encode writes the CreateSessionTLV to the passed io.Writer. +func (c *CreateSessionTLV) Encode(w io.Writer) error { + return c.tlvStream.Encode(w) +} + +// Decode reads the CreateSessionTLV from the passed io.Reader. +func (c *CreateSessionTLV) Decode(r io.Reader) error { + return c.tlvStream.Decode(r) +} + +// BenchmarkEncodeCreateSession benchmarks encoding of the non-TLV +// CreateSession. +func BenchmarkEncodeCreateSession(t *testing.B) { + m := &wtwire.CreateSession{} + + t.ReportAllocs() + t.ResetTimer() + + var err error + for i := 0; i < t.N; i++ { + err = m.Encode(ioutil.Discard, 0) + } + _ = err +} + +// BenchmarkEncodeCreateSessionTLV benchmarks encoding of the TLV CreateSession. +func BenchmarkEncodeCreateSessionTLV(t *testing.B) { + m := NewCreateSessionTLV() + + t.ReportAllocs() + t.ResetTimer() + + var err error + for i := 0; i < t.N; i++ { + err = m.Encode(ioutil.Discard) + } + _ = err +} + +// BenchmarkDecodeCreateSession benchmarks encoding of the non-TLV +// CreateSession. +func BenchmarkDecodeCreateSession(t *testing.B) { + m := &wtwire.CreateSession{} + + var b bytes.Buffer + m.Encode(&b, 0) + r := bytes.NewReader(b.Bytes()) + + t.ReportAllocs() + t.ResetTimer() + + var err error + for i := 0; i < t.N; i++ { + r.Seek(0, 0) + err = m.Decode(r, 0) + } + _ = err +} + +// BenchmarkDecodeCreateSessionTLV benchmarks decoding of the TLV CreateSession. +func BenchmarkDecodeCreateSessionTLV(t *testing.B) { + m := NewCreateSessionTLV() + + var b bytes.Buffer + var err error + m.Encode(&b) + r := bytes.NewReader(b.Bytes()) + + t.ReportAllocs() + t.ResetTimer() + + for i := 0; i < t.N; i++ { + r.Seek(0, 0) + err = m.Decode(r) + } + _ = err +} diff --git a/tlv/primitive.go b/tlv/primitive.go new file mode 100644 index 00000000000..3d1f0a067a2 --- /dev/null +++ b/tlv/primitive.go @@ -0,0 +1,309 @@ +package tlv + +import ( + "encoding/binary" + "fmt" + "io" + + "github.com/btcsuite/btcd/btcec" +) + +// ErrTypeForEncoding signals that an incorrect type was passed to an Encoder. +type ErrTypeForEncoding struct { + val interface{} + expType string +} + +// NewTypeForEncodingErr creates a new ErrTypeForEncoding given the incorrect +// val and the expected type. +func NewTypeForEncodingErr(val interface{}, expType string) ErrTypeForEncoding { + return ErrTypeForEncoding{ + val: val, + expType: expType, + } +} + +// Error returns a human-readable description of the type mismatch. +func (e ErrTypeForEncoding) Error() string { + return fmt.Sprintf("ErrTypeForEncoding want (type: *%s), "+ + "got (type: %T)", e.expType, e.val) +} + +// ErrTypeForDecoding signals that an incorrect type was passed to a Decoder or +// that the expected length of the encoding is different from that required by +// the expected type. +type ErrTypeForDecoding struct { + val interface{} + expType string + valLength uint64 + expLength uint64 +} + +// NewTypeForDecodingErr creates a new ErrTypeForDecoding given the incorrect +// val and expected type, or the mismatch in their expected lengths. +func NewTypeForDecodingErr(val interface{}, expType string, + valLength, expLength uint64) ErrTypeForDecoding { + + return ErrTypeForDecoding{ + val: val, + expType: expType, + valLength: valLength, + expLength: expLength, + } +} + +// Error returns a human-readable description of the type mismatch. +func (e ErrTypeForDecoding) Error() string { + return fmt.Sprintf("ErrTypeForDecoding want (type: *%s, length: %v), "+ + "got (type: %T, length: %v)", e.expType, e.expLength, e.val, + e.valLength) +} + +var ( + byteOrder = binary.BigEndian +) + +// EUint8 is an Encoder for uint8 values. An error is returned if val is not a +// *uint8. +func EUint8(w io.Writer, val interface{}, buf *[8]byte) error { + if i, ok := val.(*uint8); ok { + buf[0] = *i + _, err := w.Write(buf[:1]) + return err + } + return ErrTypeForEncoding{val, "uint8"} +} + +// EUint8T encodes a uint8 val to the provided io.Writer. This method is exposed +// so that encodings for custom uint8-like types can be created without +// incurring an extra heap allocation. +func EUint8T(w io.Writer, val uint8, buf *[8]byte) error { + buf[0] = val + _, err := w.Write(buf[:1]) + return err +} + +// EUint16 is an Encoder for uint16 values. An error is returned if val is not a +// *uint16. +func EUint16(w io.Writer, val interface{}, buf *[8]byte) error { + if i, ok := val.(*uint16); ok { + byteOrder.PutUint16(buf[:2], *i) + _, err := w.Write(buf[:2]) + return err + } + return ErrTypeForEncoding{val, "uint16"} +} + +// EUint16T encodes a uint16 val to the provided io.Writer. This method is +// exposed so that encodings for custom uint16-like types can be created without +// incurring an extra heap allocation. +func EUint16T(w io.Writer, val uint16, buf *[8]byte) error { + byteOrder.PutUint16(buf[:2], val) + _, err := w.Write(buf[:2]) + return err +} + +// EUint32 is an Encoder for uint32 values. An error is returned if val is not a +// *uint32. +func EUint32(w io.Writer, val interface{}, buf *[8]byte) error { + if i, ok := val.(*uint32); ok { + byteOrder.PutUint32(buf[:4], *i) + _, err := w.Write(buf[:4]) + return err + } + return ErrTypeForEncoding{val, "uint32"} +} + +// EUint32T encodes a uint32 val to the provided io.Writer. This method is +// exposed so that encodings for custom uint32-like types can be created without +// incurring an extra heap allocation. +func EUint32T(w io.Writer, val uint32, buf *[8]byte) error { + byteOrder.PutUint32(buf[:4], val) + _, err := w.Write(buf[:4]) + return err +} + +// EUint64 is an Encoder for uint64 values. An error is returned if val is not a +// *uint64. +func EUint64(w io.Writer, val interface{}, buf *[8]byte) error { + if i, ok := val.(*uint64); ok { + byteOrder.PutUint64(buf[:], *i) + _, err := w.Write(buf[:]) + return err + } + return ErrTypeForEncoding{val, "uint64"} +} + +// EUint64T encodes a uint64 val to the provided io.Writer. This method is +// exposed so that encodings for custom uint64-like types can be created without +// incurring an extra heap allocation. +func EUint64T(w io.Writer, val uint64, buf *[8]byte) error { + byteOrder.PutUint64(buf[:], val) + _, err := w.Write(buf[:]) + return err +} + +// DUint8 is a Decoder for uint8 values. An error is returned if val is not a +// *uint8. +func DUint8(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if i, ok := val.(*uint8); ok && l == 1 { + if _, err := io.ReadFull(r, buf[:1]); err != nil { + return err + } + *i = buf[0] + return nil + } + return ErrTypeForDecoding{val, "uint8", l, 1} +} + +// DUint16 is a Decoder for uint16 values. An error is returned if val is not a +// *uint16. +func DUint16(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if i, ok := val.(*uint16); ok && l == 2 { + if _, err := io.ReadFull(r, buf[:2]); err != nil { + return err + } + *i = byteOrder.Uint16(buf[:2]) + return nil + } + return ErrTypeForDecoding{val, "uint16", l, 2} +} + +// DUint32 is a Decoder for uint32 values. An error is returned if val is not a +// *uint32. +func DUint32(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if i, ok := val.(*uint32); ok && l == 4 { + if _, err := io.ReadFull(r, buf[:4]); err != nil { + return err + } + *i = byteOrder.Uint32(buf[:4]) + return nil + } + return ErrTypeForDecoding{val, "uint32", l, 4} +} + +// DUint64 is a Decoder for uint64 values. An error is returned if val is not a +// *uint64. +func DUint64(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if i, ok := val.(*uint64); ok && l == 8 { + if _, err := io.ReadFull(r, buf[:]); err != nil { + return err + } + *i = byteOrder.Uint64(buf[:]) + return nil + } + return ErrTypeForDecoding{val, "uint64", l, 8} +} + +// EBytes32 is an Encoder for 32-byte arrays. An error is returned if val is not +// a *[32]byte. +func EBytes32(w io.Writer, val interface{}, _ *[8]byte) error { + if b, ok := val.(*[32]byte); ok { + _, err := w.Write(b[:]) + return err + } + return ErrTypeForEncoding{val, "[32]byte"} +} + +// DBytes32 is a Decoder for 32-byte arrays. An error is returned if val is not +// a *[32]byte. +func DBytes32(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if b, ok := val.(*[32]byte); ok && l == 32 { + _, err := io.ReadFull(r, b[:]) + return err + } + return ErrTypeForDecoding{val, "[32]byte", l, 32} +} + +// EBytes33 is an Encoder for 33-byte arrays. An error is returned if val is not +// a *[33]byte. +func EBytes33(w io.Writer, val interface{}, _ *[8]byte) error { + if b, ok := val.(*[33]byte); ok { + _, err := w.Write(b[:]) + return err + } + return ErrTypeForEncoding{val, "[33]byte"} +} + +// DBytes33 is a Decoder for 33-byte arrays. An error is returned if val is not +// a *[33]byte. +func DBytes33(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if b, ok := val.(*[33]byte); ok { + _, err := io.ReadFull(r, b[:]) + return err + } + return ErrTypeForDecoding{val, "[33]byte", l, 33} +} + +// EBytes64 is an Encoder for 64-byte arrays. An error is returned if val is not +// a *[64]byte. +func EBytes64(w io.Writer, val interface{}, _ *[8]byte) error { + if b, ok := val.(*[64]byte); ok { + _, err := w.Write(b[:]) + return err + } + return ErrTypeForEncoding{val, "[64]byte"} +} + +// DBytes64 is an Decoder for 64-byte arrays. An error is returned if val is not +// a *[64]byte. +func DBytes64(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if b, ok := val.(*[64]byte); ok && l == 64 { + _, err := io.ReadFull(r, b[:]) + return err + } + return ErrTypeForDecoding{val, "[64]byte", l, 64} +} + +// EPubKey is an Encoder for *btcec.PublicKey values. An error is returned if +// val is not a **btcec.PublicKey. +func EPubKey(w io.Writer, val interface{}, _ *[8]byte) error { + if pk, ok := val.(**btcec.PublicKey); ok { + _, err := w.Write((*pk).SerializeCompressed()) + return err + } + return ErrTypeForEncoding{val, "*btcec.PublicKey"} +} + +// DPubKey is a Decoder for *btcec.PublicKey values. An error is returned if val +// is not a **btcec.PublicKey. +func DPubKey(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if pk, ok := val.(**btcec.PublicKey); ok && l == 33 { + var b [33]byte + _, err := io.ReadFull(r, b[:]) + if err != nil { + return err + } + + p, err := btcec.ParsePubKey(b[:], btcec.S256()) + if err != nil { + return err + } + + *pk = p + + return nil + } + return ErrTypeForDecoding{val, "*btcec.PublicKey", l, 33} +} + +// EVarBytes is an Encoder for variable byte slices. An error is returned if val +// is not *[]byte. +func EVarBytes(w io.Writer, val interface{}, _ *[8]byte) error { + if b, ok := val.(*[]byte); ok { + _, err := w.Write(*b) + return err + } + return ErrTypeForEncoding{val, "[]byte"} +} + +// DVarBytes is a Decoder for variable byte slices. An error is returned if val +// is not *[]byte. +func DVarBytes(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if b, ok := val.(*[]byte); ok { + *b = make([]byte, l) + _, err := io.ReadFull(r, *b) + return err + } + return ErrTypeForDecoding{val, "[]byte", l, l} +} diff --git a/tlv/record.go b/tlv/record.go new file mode 100644 index 00000000000..ae21c050a47 --- /dev/null +++ b/tlv/record.go @@ -0,0 +1,153 @@ +package tlv + +import ( + "io" + + "github.com/btcsuite/btcd/btcec" +) + +// Type is an 64-bit identifier for a TLV Record. +type Type uint64 + +// Encoder is a signature for methods that can encode TLV values. An error +// should be returned if the Encoder cannot support the underlying type of val. +// The provided scratch buffer must be non-nil. +type Encoder func(w io.Writer, val interface{}, buf *[8]byte) error + +// Decoder is a signature for methods that can decode TLV values. An error +// should be returned if the Decoder cannot support the underlying type of val. +// The provided scratch buffer must be non-nil. +type Decoder func(r io.Reader, val interface{}, buf *[8]byte, l uint64) error + +// ENOP is an encoder that doesn't modify the io.Writer and never fails. +func ENOP(io.Writer, interface{}, *[8]byte) error { return nil } + +// DNOP is an encoder that doesn't modify the io.Reader and never fails. +func DNOP(io.Reader, interface{}, *[8]byte, uint64) error { return nil } + +// SizeFunc is a function that can compute the length of a given field. Since +// the size of the underlying field can change, this allows the size of the +// field to be evaluated at the time of encoding. +type SizeFunc func() uint64 + +// SizeVarBytes returns a SizeFunc that can compute the length of a byte slice. +func SizeVarBytes(e *[]byte) SizeFunc { + return func() uint64 { + return uint64(len(*e)) + } +} + +// Record holds the required information to encode or decode a TLV record. +type Record struct { + value interface{} + typ Type + staticSize uint64 + sizeFunc SizeFunc + encoder Encoder + decoder Decoder +} + +// Size returns the size of the Record's value. If no static size is known, the +// dynamic size will be evaluated. +func (f *Record) Size() uint64 { + if f.sizeFunc == nil { + return f.staticSize + } + + return f.sizeFunc() +} + +// MakePrimitiveRecord creates a record for common types. +func MakePrimitiveRecord(typ Type, val interface{}) Record { + var ( + staticSize uint64 + sizeFunc SizeFunc + encoder Encoder + decoder Decoder + ) + switch e := val.(type) { + case *uint8: + staticSize = 1 + encoder = EUint8 + decoder = DUint8 + + case *uint16: + staticSize = 2 + encoder = EUint16 + decoder = DUint16 + + case *uint32: + staticSize = 4 + encoder = EUint32 + decoder = DUint32 + + case *uint64: + staticSize = 8 + encoder = EUint64 + decoder = DUint64 + + case *[32]byte: + staticSize = 32 + encoder = EBytes32 + decoder = DBytes32 + + case *[33]byte: + staticSize = 33 + encoder = EBytes33 + decoder = DBytes33 + + case **btcec.PublicKey: + staticSize = 33 + encoder = EPubKey + decoder = DPubKey + + case *[64]byte: + staticSize = 64 + encoder = EBytes64 + decoder = DBytes64 + + case *[]byte: + sizeFunc = SizeVarBytes(e) + encoder = EVarBytes + decoder = DVarBytes + + default: + panic("unknown primitive type") + } + + return Record{ + value: val, + typ: typ, + staticSize: staticSize, + sizeFunc: sizeFunc, + encoder: encoder, + decoder: decoder, + } +} + +// MakeStaticRecord creates a record for a field of fixed-size +func MakeStaticRecord(typ Type, val interface{}, size uint64, encoder Encoder, + decoder Decoder) Record { + + return Record{ + value: val, + typ: typ, + staticSize: size, + encoder: encoder, + decoder: decoder, + } +} + +// MakeDynamicRecord creates a record whose size may vary, and will be +// determined at the time of encoding via sizeFunc. +func MakeDynamicRecord(typ Type, val interface{}, sizeFunc SizeFunc, + encoder Encoder, decoder Decoder) Record { + + return Record{ + value: val, + typ: typ, + sizeFunc: sizeFunc, + encoder: encoder, + decoder: decoder, + } +} diff --git a/tlv/stream.go b/tlv/stream.go new file mode 100644 index 00000000000..49bb70edc10 --- /dev/null +++ b/tlv/stream.go @@ -0,0 +1,280 @@ +package tlv + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "math" +) + +// ErrStreamNotCanonical signals that a decoded stream does not contain records +// sorting by monotonically-increasing type. +var ErrStreamNotCanonical = errors.New("tlv stream is not canonical") + +// ErrUnknownRequiredType is an error returned when decoding an unknown and even +// type from a Stream. +type ErrUnknownRequiredType Type + +// Error returns a human-readable description of unknown required type. +func (t ErrUnknownRequiredType) Error() string { + return fmt.Sprintf("unknown required type: %d", t) +} + +// Stream defines a TLV stream that can be used for encoding or decoding a set +// of TLV Records. +type Stream struct { + records []Record + buf [8]byte +} + +// NewStream creates a new TLV Stream given an encoding codec, a decoding codec, +// and a set of known records. +func NewStream(records ...Record) (*Stream, error) { + // Assert that the ordering of the Records is canonical and appear in + // ascending order of type. + var ( + min Type + overflow bool + ) + for _, record := range records { + if overflow || record.typ < min { + return nil, ErrStreamNotCanonical + } + if record.encoder == nil { + record.encoder = ENOP + } + if record.decoder == nil { + record.decoder = DNOP + } + if record.typ == math.MaxUint64 { + overflow = true + } + min = record.typ + 1 + } + + return &Stream{ + records: records, + }, nil +} + +// MustNewStream creates a new TLV Stream given an encoding codec, a decoding +// codec, and a set of known records. If an error is encountered in creating the +// stream, this method will panic instead of returning the error. +func MustNewStream(records ...Record) *Stream { + stream, err := NewStream(records...) + if err != nil { + panic(err.Error()) + } + return stream +} + +// Encode writes a Stream to the passed io.Writer. Each of the Records known to +// the Stream is written in ascending order of their type so as to be canonical. +// +// The stream is constructed by concatenating the individual, serialized Records +// where each record has the following format: +// [varint: type] +// [varint: length] +// [length: value] +// +// An error is returned if the io.Writer fails to accept bytes from the +// encoding, and nothing else. The ordering of the Records is asserted upon the +// creation of a Stream, and thus the output will be by definition canonical. +func (s *Stream) Encode(w io.Writer) error { + // Iterate through all known records, if any, serializing each record's + // type, length and value. + for i := range s.records { + rec := &s.records[i] + + // Write the record's type as a varint. + err := WriteVarInt(w, uint64(rec.typ), &s.buf) + if err != nil { + return err + } + + // Write the record's length as a varint. + err = WriteVarInt(w, rec.Size(), &s.buf) + if err != nil { + return err + } + + // Encode the current record's value using the stream's codec. + err = rec.encoder(w, rec.value, &s.buf) + if err != nil { + return err + } + } + + return nil +} + +// Decode deserializes TLV Stream from the passed io.Reader. The Stream will +// inspect each record that is parsed and check to see if it has a corresponding +// Record to facilitate deserialization of that field. If the record is unknown, +// the Stream will discard the record's bytes and proceed to the subsequent +// record. +// +// Each record has the following format: +// [varint: type] +// [varint: length] +// [length: value] +// +// A series of (possibly zero) records are concatenated into a stream, this +// example contains two records: +// +// (t: 0x01, l: 0x04, v: 0xff, 0xff, 0xff, 0xff) +// (t: 0x02, l: 0x01, v: 0x01) +// +// This method asserts that the byte stream is canonical, namely that each +// record is unique and that all records are sorted in ascending order. An +// ErrNotCanonicalStream error is returned if the encoded TLV stream is not. +// +// We permit an io.EOF error only when reading the type byte which signals that +// the last record was read cleanly and we should stop parsing. All other io.EOF +// or io.ErrUnexpectedEOF errors are returned. +func (s *Stream) Decode(r io.Reader) error { + var ( + typ Type + min Type + recordIdx int + overflow bool + ) + + // Iterate through all possible type identifiers. As types are read from + // the io.Reader, min will skip forward to the last read type. + for { + // Read the next varint type. + t, err := ReadVarInt(r, &s.buf) + switch { + + // We'll silence an EOF when zero bytes remain, meaning the + // stream was cleanly encoded. + case err == io.EOF: + return nil + + // Other unexpected errors. + case err != nil: + return err + } + + typ = Type(t) + + // Assert that this type is greater than any previously read. + // If we've already overflowed and we parsed another type, the + // stream is not canonical. This check prevents us from accepts + // encodings that have duplicate records or from accepting an + // unsorted series. + if overflow || typ < min { + return ErrStreamNotCanonical + } + + // Read the varint length. + length, err := ReadVarInt(r, &s.buf) + switch { + + // We'll convert any EOFs to ErrUnexpectedEOF, since this + // results in an invalid record. + case err == io.EOF: + return io.ErrUnexpectedEOF + + // Other unexpected errors. + case err != nil: + return err + } + + // Search the records known to the stream for this type. We'll + // begin the search and recordIdx and walk forward until we find + // it or the next record's type is larger. + rec, newIdx, ok := s.getRecord(typ, recordIdx) + switch { + + // We know of this record type, proceed to decode the value. + // This method asserts that length bytes are read in the + // process, and returns an error if the number of bytes is not + // exactly length. + case ok: + err := rec.decoder(r, rec.value, &s.buf, length) + switch { + + // We'll convert any EOFs to ErrUnexpectedEOF, since this + // results in an invalid record. + case err == io.EOF: + return io.ErrUnexpectedEOF + + // Other unexpected errors. + case err != nil: + return err + } + + // This record type is unknown to the stream, fail if the type + // is even meaning that we are required to understand it. + case typ%2 == 0: + return ErrUnknownRequiredType(typ) + + // Otherwise, the record type is unknown and is odd, discard the + // number of bytes specified by length. + default: + _, err := io.CopyN(ioutil.Discard, r, int64(length)) + switch { + + // We'll convert any EOFs to ErrUnexpectedEOF, since this + // results in an invalid record. + case err == io.EOF: + return io.ErrUnexpectedEOF + + // Other unexpected errors. + case err != nil: + return err + } + } + + // Update our record index so that we can begin our next search + // from where we left off. + recordIdx = newIdx + + // If we've parsed the largest possible type, the next loop will + // overflow back to zero. However, we need to attempt parsing + // the next type to ensure that the stream is empty. + if typ == math.MaxUint64 { + overflow = true + } + + // Finally, set our lower bound on the next accepted type. + min = typ + 1 + } +} + +// getRecord searches for a record matching typ known to the stream. The boolean +// return value indicates whether the record is known to the stream. The integer +// return value carries the index from where getRecord should be invoked on the +// subsequent call. The first call to getRecord should always use an idx of 0. +func (s *Stream) getRecord(typ Type, idx int) (Record, int, bool) { + for idx < len(s.records) { + record := s.records[idx] + switch { + + // Found target record, return it to the caller. The next index + // returned points to the immediately following record. + case record.typ == typ: + return record, idx + 1, true + + // This record's type is lower than the target. Advance our + // index and continue to the next record which will have a + // strictly higher type. + case record.typ < typ: + idx++ + continue + + // This record's type is larger than the target, hence we have + // no record matching the current type. Return the current index + // so that we can start our search from here when processing the + // next tlv record. + default: + return Record{}, idx, false + } + } + + // All known records are exhausted. + return Record{}, idx, false +} diff --git a/tlv/tlv_test.go b/tlv/tlv_test.go new file mode 100644 index 00000000000..c3b996fef63 --- /dev/null +++ b/tlv/tlv_test.go @@ -0,0 +1,559 @@ +package tlv_test + +import ( + "bytes" + "errors" + "io" + "reflect" + "testing" + + "github.com/btcsuite/btcd/btcec" + "github.com/lightningnetwork/lnd/tlv" +) + +type nodeAmts struct { + nodeID *btcec.PublicKey + amt1 uint64 + amt2 uint64 +} + +func ENodeAmts(w io.Writer, val interface{}, buf *[8]byte) error { + if t, ok := val.(*nodeAmts); ok { + if err := tlv.EPubKey(w, &t.nodeID, buf); err != nil { + return err + } + if err := tlv.EUint64T(w, t.amt1, buf); err != nil { + return err + } + return tlv.EUint64T(w, t.amt2, buf) + } + return tlv.NewTypeForEncodingErr(val, "nodeAmts") +} + +func DNodeAmts(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if t, ok := val.(*nodeAmts); ok && l == 49 { + if err := tlv.DPubKey(r, &t.nodeID, buf, 33); err != nil { + return err + } + if err := tlv.DUint64(r, &t.amt1, buf, 8); err != nil { + return err + } + return tlv.DUint64(r, &t.amt2, buf, 8) + } + return tlv.NewTypeForDecodingErr(val, "nodeAmts", l, 49) +} + +type N1 struct { + amt uint64 + scid uint64 + nodeAmts nodeAmts + cltvDelta uint16 + + stream *tlv.Stream +} + +func (n *N1) sizeAmt() uint64 { + return tlv.SizeTUint64(n.amt) +} + +func NewN1() *N1 { + n := new(N1) + + n.stream = tlv.MustNewStream( + tlv.MakeDynamicRecord( + 1, &n.amt, n.sizeAmt, tlv.ETUint64, tlv.DTUint64, + ), + tlv.MakePrimitiveRecord(2, &n.scid), + tlv.MakeStaticRecord(3, &n.nodeAmts, 49, ENodeAmts, DNodeAmts), + tlv.MakePrimitiveRecord(254, &n.cltvDelta), + ) + + return n +} + +func (n *N1) Encode(w io.Writer) error { + return n.stream.Encode(w) +} + +func (n *N1) Decode(r io.Reader) error { + return n.stream.Decode(r) +} + +type N2 struct { + amt uint64 + cltvExpiry uint32 + + stream *tlv.Stream +} + +func (n *N2) sizeAmt() uint64 { + return tlv.SizeTUint64(n.amt) +} + +func (n *N2) sizeCltv() uint64 { + return tlv.SizeTUint32(n.cltvExpiry) +} + +func NewN2() *N2 { + n := new(N2) + + n.stream = tlv.MustNewStream( + tlv.MakeDynamicRecord( + 0, &n.amt, n.sizeAmt, tlv.ETUint64, tlv.DTUint64, + ), + tlv.MakeDynamicRecord( + 11, &n.cltvExpiry, n.sizeCltv, tlv.ETUint32, tlv.DTUint32, + ), + ) + + return n +} + +func (n *N2) Encode(w io.Writer) error { + return n.stream.Encode(w) +} + +func (n *N2) Decode(r io.Reader) error { + return n.stream.Decode(r) +} + +var tlvDecodingFailureTests = []struct { + name string + bytes []byte + expErr error + + // skipN2 if true, will cause the test to only be executed on N1. + skipN2 bool +}{ + { + name: "type truncated", + bytes: []byte{0xfd}, + expErr: io.ErrUnexpectedEOF, + }, + { + name: "type truncated", + bytes: []byte{0xfd, 0x01}, + expErr: io.ErrUnexpectedEOF, + }, + { + name: "not minimally encoded type", + bytes: []byte{0xfd, 0x00, 0x01}, // spec has trailing 0x00 + expErr: tlv.ErrVarIntNotCanonical, + }, + { + name: "missing length", + bytes: []byte{0xfd, 0x01, 0x01}, + expErr: io.ErrUnexpectedEOF, + }, + { + name: "length truncated", + bytes: []byte{0x0f, 0xfd}, + expErr: io.ErrUnexpectedEOF, + }, + { + name: "length truncated", + bytes: []byte{0x0f, 0xfd, 0x26}, + expErr: io.ErrUnexpectedEOF, + }, + { + name: "missing value", + bytes: []byte{0x0f, 0xfd, 0x26, 0x02}, + expErr: io.ErrUnexpectedEOF, + }, + { + name: "not minimally encoded length", + bytes: []byte{0x0f, 0xfd, 0x00, 0x01}, // spec has trailing 0x00 + expErr: tlv.ErrVarIntNotCanonical, + }, + { + name: "value truncated", + bytes: []byte{0x0f, 0xfd, 0x02, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + expErr: io.ErrUnexpectedEOF, + }, + { + name: "unknown even type", + bytes: []byte{0x12, 0x00}, + expErr: tlv.ErrUnknownRequiredType(0x12), + }, + { + name: "unknown even type", + bytes: []byte{0xfd, 0x01, 0x02, 0x00}, + expErr: tlv.ErrUnknownRequiredType(0x102), + }, + { + name: "unknown even type", + bytes: []byte{0xfe, 0x01, 0x00, 0x00, 0x02, 0x00}, + expErr: tlv.ErrUnknownRequiredType(0x01000002), + }, + { + name: "unknown even type", + bytes: []byte{0xff, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00}, + expErr: tlv.ErrUnknownRequiredType(0x0100000000000002), + }, + { + name: "greater than encoding length for n1's amt", + bytes: []byte{0x01, 0x09, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + expErr: tlv.NewTypeForDecodingErr(new(uint64), "uint64", 9, 8), + skipN2: true, + }, + { + name: "encoding for n1's amt is not minimal", + bytes: []byte{0x01, 0x01, 0x00}, + expErr: tlv.ErrTUintNotMinimal, + skipN2: true, + }, + { + name: "encoding for n1's amt is not minimal", + bytes: []byte{0x01, 0x02, 0x00, 0x01}, + expErr: tlv.ErrTUintNotMinimal, + skipN2: true, + }, + { + name: "encoding for n1's amt is not minimal", + bytes: []byte{0x01, 0x03, 0x00, 0x01, 0x00}, + expErr: tlv.ErrTUintNotMinimal, + skipN2: true, + }, + { + name: "encoding for n1's amt is not minimal", + bytes: []byte{0x01, 0x04, 0x00, 0x01, 0x00, 0x00}, + expErr: tlv.ErrTUintNotMinimal, + skipN2: true, + }, + { + name: "encoding for n1's amt is not minimal", + bytes: []byte{0x01, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00}, + expErr: tlv.ErrTUintNotMinimal, + skipN2: true, + }, + { + name: "encoding for n1's amt is not minimal", + bytes: []byte{0x01, 0x06, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}, + expErr: tlv.ErrTUintNotMinimal, + skipN2: true, + }, + { + name: "encoding for n1's amt is not minimal", + bytes: []byte{0x01, 0x07, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00}, + expErr: tlv.ErrTUintNotMinimal, + skipN2: true, + }, + { + name: "encoding for n1's amt is not minimal", + bytes: []byte{0x01, 0x08, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + expErr: tlv.ErrTUintNotMinimal, + skipN2: true, + }, + { + name: "less than encoding length for n1's scid", + bytes: []byte{0x02, 0x07, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01}, + expErr: tlv.NewTypeForDecodingErr(new(uint64), "uint64", 7, 8), + skipN2: true, + }, + { + name: "less than encoding length for n1's scid", + bytes: []byte{0x02, 0x09, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01}, + expErr: tlv.NewTypeForDecodingErr(new(uint64), "uint64", 9, 8), + skipN2: true, + }, + { + name: "less than encoding length for n1's nodeAmts", + bytes: []byte{0x03, 0x29, + 0x02, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2, + 0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47, + 0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e, + 0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, + }, + expErr: tlv.NewTypeForDecodingErr(new(nodeAmts), "nodeAmts", 41, 49), + skipN2: true, + }, + { + name: "less than encoding length for n1's nodeAmts", + bytes: []byte{0x03, 0x30, + 0x02, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2, + 0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47, + 0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e, + 0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x01, + }, + expErr: tlv.NewTypeForDecodingErr(new(nodeAmts), "nodeAmts", 48, 49), + skipN2: true, + }, + { + name: "n1's node_id is not a valid point", + bytes: []byte{0x03, 0x31, + 0x04, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2, + 0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47, + 0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e, + 0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x02, + }, + expErr: errors.New("invalid magic in compressed pubkey string: 4"), + skipN2: true, + }, + { + name: "greater than encoding length for n1's nodeAmts", + bytes: []byte{0x03, 0x32, + 0x02, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2, + 0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47, + 0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e, + 0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + }, + expErr: tlv.NewTypeForDecodingErr(new(nodeAmts), "nodeAmts", 50, 49), + skipN2: true, + }, + { + name: "unknown required type or n1", + bytes: []byte{0x00, 0x00}, + expErr: tlv.ErrUnknownRequiredType(0x00), + skipN2: true, + }, + { + name: "less than encoding length for n1's cltvDelta", + bytes: []byte{0xfd, 0x00, 0x0fe, 0x00}, + expErr: tlv.NewTypeForDecodingErr(new(uint16), "uint16", 0, 2), + skipN2: true, + }, + { + name: "less than encoding length for n1's cltvDelta", + bytes: []byte{0xfd, 0x00, 0xfe, 0x01, 0x01}, + expErr: tlv.NewTypeForDecodingErr(new(uint16), "uint16", 1, 2), + skipN2: true, + }, + { + name: "greater than encoding length for n1's cltvDelta", + bytes: []byte{0xfd, 0x00, 0xfe, 0x03, 0x01, 0x01, 0x01}, + expErr: tlv.NewTypeForDecodingErr(new(uint16), "uint16", 3, 2), + skipN2: true, + }, + { + name: "unknown even field for n1's namespace", + bytes: []byte{0x0a, 0x00}, + expErr: tlv.ErrUnknownRequiredType(0x0a), + skipN2: true, + }, + { + name: "valid records but invalid ordering", + bytes: []byte{0x02, 0x08, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x26, 0x01, + 0x01, 0x2a, + }, + expErr: tlv.ErrStreamNotCanonical, + skipN2: true, + }, + { + name: "duplicate tlv type", + bytes: []byte{0x02, 0x08, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x31, 0x02, + 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x51, + }, + expErr: tlv.ErrStreamNotCanonical, + skipN2: true, + }, + { + name: "duplicate ignored tlv type", + bytes: []byte{0x1f, 0x00, 0x1f, 0x01, 0x2a}, + expErr: tlv.ErrStreamNotCanonical, + skipN2: true, + }, + { + name: "type wraparound", + bytes: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00}, + expErr: tlv.ErrStreamNotCanonical, + }, +} + +// TestTLVDecodingSuccess asserts that the TLV parser fails to decode invalid +// TLV streams. +func TestTLVDecodingFailures(t *testing.T) { + for _, test := range tlvDecodingFailureTests { + t.Run(test.name, func(t *testing.T) { + n1 := NewN1() + r := bytes.NewReader(test.bytes) + + err := n1.Decode(r) + if !reflect.DeepEqual(err, test.expErr) { + t.Fatalf("expected N1 decoding failure: %v, "+ + "got: %v", test.expErr, err) + } + + if test.skipN2 { + return + } + + n2 := NewN2() + r = bytes.NewReader(test.bytes) + + err = n2.Decode(r) + if !reflect.DeepEqual(err, test.expErr) { + t.Fatalf("expected N2 decoding failure: %v, "+ + "got: %v", test.expErr, err) + } + }) + } +} + +var tlvDecodingSuccessTests = []struct { + name string + bytes []byte + skipN2 bool +}{ + { + name: "empty", + }, + { + name: "unknown odd type", + bytes: []byte{0x21, 0x00}, + }, + { + name: "unknown odd type", + bytes: []byte{0xfd, 0x02, 0x01, 0x00}, + }, + { + name: "unknown odd type", + bytes: []byte{0xfd, 0x00, 0xfd, 0x00}, + }, + { + name: "unknown odd type", + bytes: []byte{0xfd, 0x00, 0xff, 0x00}, + }, + { + name: "unknown odd type", + bytes: []byte{0xfe, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + { + name: "unknown odd type", + bytes: []byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00}, + }, + { + name: "N1 amt=0", + bytes: []byte{0x01, 0x00}, + skipN2: true, + }, + { + name: "N1 amt=1", + bytes: []byte{0x01, 0x01, 0x01}, + skipN2: true, + }, + { + name: "N1 amt=256", + bytes: []byte{0x01, 0x02, 0x01, 0x00}, + skipN2: true, + }, + { + name: "N1 amt=65536", + bytes: []byte{0x01, 0x03, 0x01, 0x00, 0x00}, + skipN2: true, + }, + { + name: "N1 amt=16777216", + bytes: []byte{0x01, 0x04, 0x01, 0x00, 0x00, 0x00}, + skipN2: true, + }, + { + name: "N1 amt=4294967296", + bytes: []byte{0x01, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00}, + skipN2: true, + }, + { + name: "N1 amt=1099511627776", + bytes: []byte{0x01, 0x06, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00}, + skipN2: true, + }, + { + name: "N1 amt=281474976710656", + bytes: []byte{0x01, 0x07, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + skipN2: true, + }, + { + name: "N1 amt=72057594037927936", + bytes: []byte{0x01, 0x08, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + skipN2: true, + }, + { + name: "N1 scid=0x0x550", + bytes: []byte{0x02, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x26}, + skipN2: true, + }, + { + name: "N1 node_id=023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb amount_msat_1=1 amount_msat_2=2", + bytes: []byte{0x03, 0x31, + 0x02, 0x3d, 0xa0, 0x92, 0xf6, 0x98, 0x0e, 0x58, 0xd2, + 0xc0, 0x37, 0x17, 0x31, 0x80, 0xe9, 0xa4, 0x65, 0x47, + 0x60, 0x26, 0xee, 0x50, 0xf9, 0x66, 0x95, 0x96, 0x3e, + 0x8e, 0xfe, 0x43, 0x6f, 0x54, 0xeb, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x02}, + skipN2: true, + }, + { + name: "N1 cltv_delta=550", + bytes: []byte{0xfd, 0x00, 0xfe, 0x02, 0x02, 0x26}, + skipN2: true, + }, +} + +// TestTLVDecodingSuccess asserts that the TLV parser is able to successfully +// decode valid TLV streams. +func TestTLVDecodingSuccess(t *testing.T) { + for _, test := range tlvDecodingSuccessTests { + t.Run(test.name, func(t *testing.T) { + n1 := NewN1() + r := bytes.NewReader(test.bytes) + + err := n1.Decode(r) + if err != nil { + t.Fatalf("expected N1 decoding success, got: %v", + err) + } + + if test.skipN2 { + return + } + + n2 := NewN2() + r = bytes.NewReader(test.bytes) + + err = n2.Decode(r) + if err != nil { + t.Fatalf("expected N2 decoding succes, got: %v", + err) + } + }) + } +} diff --git a/tlv/truncated.go b/tlv/truncated.go new file mode 100644 index 00000000000..b35fc166a83 --- /dev/null +++ b/tlv/truncated.go @@ -0,0 +1,180 @@ +package tlv + +import ( + "encoding/binary" + "errors" + "io" +) + +// ErrTUintNotMinimal signals that decoding a truncated uint failed because the +// value was not minimally encoded. +var ErrTUintNotMinimal = errors.New("truncated uint not minimally encoded") + +// numLeadingZeroBytes16 computes the number of leading zeros for a uint16. +func numLeadingZeroBytes16(v uint16) uint64 { + switch { + case v == 0: + return 2 + case v&0xff00 == 0: + return 1 + default: + return 0 + } +} + +// SizeTUint16 returns the number of bytes remaining in a uint16 after +// truncating the leading zeros. +func SizeTUint16(v uint16) uint64 { + return 2 - numLeadingZeroBytes16(v) +} + +// ETUint16 is an Encoder for truncated uint16 values, where leading zeros will +// be omitted. An error is returned if val is not a *uint16. +func ETUint16(w io.Writer, val interface{}, buf *[8]byte) error { + if t, ok := val.(*uint16); ok { + binary.BigEndian.PutUint16(buf[:2], *t) + numZeros := numLeadingZeroBytes16(*t) + _, err := w.Write(buf[numZeros:2]) + return err + } + return NewTypeForEncodingErr(val, "uint16") +} + +// DTUint16 is an Decoder for truncated uint16 values, where leading zeros will +// be resurrected. An error is returned if val is not a *uint16. +func DTUint16(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if t, ok := val.(*uint16); ok && l <= 2 { + _, err := io.ReadFull(r, buf[2-l:]) + if err != nil { + return err + } + zero(buf[:2-l]) + *t = binary.BigEndian.Uint16(buf[:2]) + if 2-numLeadingZeroBytes16(*t) != l { + return ErrTUintNotMinimal + } + return nil + } + return NewTypeForDecodingErr(val, "uint16", l, 2) +} + +// numLeadingZeroBytes16 computes the number of leading zeros for a uint32. +func numLeadingZeroBytes32(v uint32) uint64 { + switch { + case v == 0: + return 4 + case v&0xffffff00 == 0: + return 3 + case v&0xffff0000 == 0: + return 2 + case v&0xff000000 == 0: + return 1 + default: + return 0 + } +} + +// SizeTUint32 returns the number of bytes remaining in a uint32 after +// truncating the leading zeros. +func SizeTUint32(v uint32) uint64 { + return 4 - numLeadingZeroBytes32(v) +} + +// ETUint32 is an Encoder for truncated uint32 values, where leading zeros will +// be omitted. An error is returned if val is not a *uint32. +func ETUint32(w io.Writer, val interface{}, buf *[8]byte) error { + if t, ok := val.(*uint32); ok { + binary.BigEndian.PutUint32(buf[:4], *t) + numZeros := numLeadingZeroBytes32(*t) + _, err := w.Write(buf[numZeros:4]) + return err + } + return NewTypeForEncodingErr(val, "uint32") +} + +// DTUint32 is an Decoder for truncated uint32 values, where leading zeros will +// be resurrected. An error is returned if val is not a *uint32. +func DTUint32(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if t, ok := val.(*uint32); ok && l <= 4 { + _, err := io.ReadFull(r, buf[4-l:]) + if err != nil { + return err + } + zero(buf[:4-l]) + *t = binary.BigEndian.Uint32(buf[:4]) + if 4-numLeadingZeroBytes32(*t) != l { + return ErrTUintNotMinimal + } + return nil + } + return NewTypeForDecodingErr(val, "uint32", l, 4) +} + +// numLeadingZeroBytes64 computes the number of leading zeros for a uint32. +// +// TODO(conner): optimize using unrolled binary search +func numLeadingZeroBytes64(v uint64) uint64 { + switch { + case v == 0: + return 8 + case v&0xffffffffffffff00 == 0: + return 7 + case v&0xffffffffffff0000 == 0: + return 6 + case v&0xffffffffff000000 == 0: + return 5 + case v&0xffffffff00000000 == 0: + return 4 + case v&0xffffff0000000000 == 0: + return 3 + case v&0xffff000000000000 == 0: + return 2 + case v&0xff00000000000000 == 0: + return 1 + default: + return 0 + } +} + +// SizeTUint64 returns the number of bytes remaining in a uint64 after +// truncating the leading zeros. +func SizeTUint64(v uint64) uint64 { + return 8 - numLeadingZeroBytes64(v) +} + +// ETUint64 is an Encoder for truncated uint64 values, where leading zeros will +// be omitted. An error is returned if val is not a *uint64. +func ETUint64(w io.Writer, val interface{}, buf *[8]byte) error { + if t, ok := val.(*uint64); ok { + binary.BigEndian.PutUint64(buf[:], *t) + numZeros := numLeadingZeroBytes64(*t) + _, err := w.Write(buf[numZeros:]) + return err + } + return NewTypeForEncodingErr(val, "uint64") +} + +// DTUint64 is an Decoder for truncated uint64 values, where leading zeros will +// be resurrected. An error is returned if val is not a *uint64. +func DTUint64(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if t, ok := val.(*uint64); ok && l <= 8 { + _, err := io.ReadFull(r, buf[8-l:]) + if err != nil { + return err + } + zero(buf[:8-l]) + *t = binary.BigEndian.Uint64(buf[:]) + if 8-numLeadingZeroBytes64(*t) != l { + return ErrTUintNotMinimal + } + return nil + } + return NewTypeForDecodingErr(val, "uint64", l, 8) +} + +// zero clears the passed byte slice. +func zero(b []byte) { + for i := range b { + b[i] = 0x00 + } +} diff --git a/tlv/varint.go b/tlv/varint.go new file mode 100644 index 00000000000..3888bfcb47e --- /dev/null +++ b/tlv/varint.go @@ -0,0 +1,109 @@ +package tlv + +import ( + "encoding/binary" + "errors" + "io" +) + +// ErrVarIntNotCanonical signals that the decoded varint was not minimally encoded. +var ErrVarIntNotCanonical = errors.New("decoded varint is not canonical") + +// ReadVarInt reads a variable length integer from r and returns it as a uint64. +func ReadVarInt(r io.Reader, buf *[8]byte) (uint64, error) { + _, err := io.ReadFull(r, buf[:1]) + if err != nil { + return 0, err + } + discriminant := buf[0] + + var rv uint64 + switch { + case discriminant < 0xfd: + rv = uint64(discriminant) + + case discriminant == 0xfd: + _, err := io.ReadFull(r, buf[:2]) + switch { + case err == io.EOF: + return 0, io.ErrUnexpectedEOF + case err != nil: + return 0, err + } + rv = uint64(binary.BigEndian.Uint16(buf[:2])) + + // The encoding is not canonical if the value could have been + // encoded using fewer bytes. + if rv < 0xfd { + return 0, ErrVarIntNotCanonical + } + + case discriminant == 0xfe: + _, err := io.ReadFull(r, buf[:4]) + switch { + case err == io.EOF: + return 0, io.ErrUnexpectedEOF + case err != nil: + return 0, err + } + rv = uint64(binary.BigEndian.Uint32(buf[:4])) + + // The encoding is not canonical if the value could have been + // encoded using fewer bytes. + if rv <= 0xffff { + return 0, ErrVarIntNotCanonical + } + + default: + _, err := io.ReadFull(r, buf[:]) + switch { + case err == io.EOF: + return 0, io.ErrUnexpectedEOF + case err != nil: + return 0, err + } + rv = binary.BigEndian.Uint64(buf[:]) + + // The encoding is not canonical if the value could have been + // encoded using fewer bytes. + if rv <= 0xffffffff { + return 0, ErrVarIntNotCanonical + } + } + + return rv, nil +} + +// WriteVarInt serializes val to w using a variable number of bytes depending +// on its value. +func WriteVarInt(w io.Writer, val uint64, buf *[8]byte) error { + var length int + switch { + case val < 0xfd: + buf[0] = uint8(val) + length = 1 + + case val <= 0xffff: + buf[0] = uint8(0xfd) + binary.BigEndian.PutUint16(buf[1:3], uint16(val)) + length = 3 + + case val <= 0xffffffff: + buf[0] = uint8(0xfe) + binary.BigEndian.PutUint32(buf[1:5], uint32(val)) + length = 5 + + default: + buf[0] = uint8(0xff) + _, err := w.Write(buf[:1]) + if err != nil { + return err + } + + binary.BigEndian.PutUint64(buf[:], uint64(val)) + length = 8 + } + + _, err := w.Write(buf[:length]) + return err +} diff --git a/tlv/varint_test.go b/tlv/varint_test.go new file mode 100644 index 00000000000..75cc1085f74 --- /dev/null +++ b/tlv/varint_test.go @@ -0,0 +1,217 @@ +package tlv_test + +import ( + "bytes" + "io" + "math" + "testing" + + "github.com/lightningnetwork/lnd/tlv" +) + +type varIntTest struct { + Name string + Value uint64 + Bytes []byte + ExpErr error +} + +var writeVarIntTests = []varIntTest{ + { + Name: "zero", + Value: 0x00, + Bytes: []byte{0x00}, + }, + { + Name: "one byte high", + Value: 0xfc, + Bytes: []byte{0xfc}, + }, + { + Name: "two byte low", + Value: 0xfd, + Bytes: []byte{0xfd, 0x00, 0xfd}, + }, + { + Name: "two byte high", + Value: 0xffff, + Bytes: []byte{0xfd, 0xff, 0xff}, + }, + { + Name: "four byte low", + Value: 0x10000, + Bytes: []byte{0xfe, 0x00, 0x01, 0x00, 0x00}, + }, + { + Name: "four byte high", + Value: 0xffffffff, + Bytes: []byte{0xfe, 0xff, 0xff, 0xff, 0xff}, + }, + { + Name: "eight byte low", + Value: 0x100000000, + Bytes: []byte{0xff, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}, + }, + { + Name: "eight byte high", + Value: math.MaxUint64, + Bytes: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + }, +} + +// TestWriteVarInt asserts the behavior of tlv.WriteVarInt under various +// positive and negative test cases. +func TestWriteVarInt(t *testing.T) { + for _, test := range writeVarIntTests { + t.Run(test.Name, func(t *testing.T) { + testWriteVarInt(t, test) + }) + } +} + +func testWriteVarInt(t *testing.T, test varIntTest) { + var ( + w bytes.Buffer + buf [8]byte + ) + err := tlv.WriteVarInt(&w, test.Value, &buf) + if err != nil { + t.Fatalf("unable to encode %d as varint: %v", + test.Value, err) + } + + if bytes.Compare(w.Bytes(), test.Bytes) != 0 { + t.Fatalf("expected bytes: %v, got %v", + test.Bytes, w.Bytes()) + } +} + +var readVarIntTests = []varIntTest{ + { + Name: "zero", + Value: 0x00, + Bytes: []byte{0x00}, + }, + { + Name: "one byte high", + Value: 0xfc, + Bytes: []byte{0xfc}, + }, + { + Name: "two byte low", + Value: 0xfd, + Bytes: []byte{0xfd, 0x00, 0xfd}, + }, + { + Name: "two byte high", + Value: 0xffff, + Bytes: []byte{0xfd, 0xff, 0xff}, + }, + { + Name: "four byte low", + Value: 0x10000, + Bytes: []byte{0xfe, 0x00, 0x01, 0x00, 0x00}, + }, + { + Name: "four byte high", + Value: 0xffffffff, + Bytes: []byte{0xfe, 0xff, 0xff, 0xff, 0xff}, + }, + { + Name: "eight byte low", + Value: 0x100000000, + Bytes: []byte{0xff, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}, + }, + { + Name: "eight byte high", + Value: math.MaxUint64, + Bytes: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + }, + { + Name: "two byte not canonical", + Bytes: []byte{0xfd, 0x00, 0xfc}, + ExpErr: tlv.ErrVarIntNotCanonical, + }, + { + Name: "four byte not canonical", + Bytes: []byte{0xfe, 0x00, 0x00, 0xff, 0xff}, + ExpErr: tlv.ErrVarIntNotCanonical, + }, + { + Name: "eight byte not canonical", + Bytes: []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff}, + ExpErr: tlv.ErrVarIntNotCanonical, + }, + { + Name: "two byte short read", + Bytes: []byte{0xfd, 0x00}, + ExpErr: io.ErrUnexpectedEOF, + }, + { + Name: "four byte short read", + Bytes: []byte{0xfe, 0xff, 0xff}, + ExpErr: io.ErrUnexpectedEOF, + }, + { + Name: "eight byte short read", + Bytes: []byte{0xff, 0xff, 0xff, 0xff, 0xff}, + ExpErr: io.ErrUnexpectedEOF, + }, + { + Name: "one byte no read", + Bytes: []byte{}, + ExpErr: io.EOF, + }, + // The following cases are the reason for needing to make a custom + // version of the varint for the tlv package. For the varint encodings + // in btcd's wire package these would return io.EOF, since it is + // actually a composite of two calls to io.ReadFull. In TLV, we need to + // be able to distinguish whether no bytes were read at all from no + // Bytes being read on the second read as the latter is not a proper TLV + // stream. We handle this by returning io.ErrUnexpectedEOF if we + // encounter io.EOF on any of these secondary reads for larger values. + { + Name: "two byte no read", + Bytes: []byte{0xfd}, + ExpErr: io.ErrUnexpectedEOF, + }, + { + Name: "four byte no read", + Bytes: []byte{0xfe}, + ExpErr: io.ErrUnexpectedEOF, + }, + { + Name: "eight byte no read", + Bytes: []byte{0xff}, + ExpErr: io.ErrUnexpectedEOF, + }, +} + +// TestReadVarInt asserts the behavior of tlv.ReadVarInt under various positive +// and negative test cases. +func TestReadVarInt(t *testing.T) { + for _, test := range readVarIntTests { + t.Run(test.Name, func(t *testing.T) { + testReadVarInt(t, test) + }) + } +} + +func testReadVarInt(t *testing.T, test varIntTest) { + var buf [8]byte + r := bytes.NewReader(test.Bytes) + val, err := tlv.ReadVarInt(r, &buf) + if err != nil && err != test.ExpErr { + t.Fatalf("expected decoding error: %v, got: %v", + test.ExpErr, err) + } + + // If we expected a decoding error, there's no point checking the value. + if test.ExpErr != nil { + return + } + + if val != test.Value { + t.Fatalf("expected value: %d, got %d", test.Value, val) + } +}