diff --git a/.gitignore b/.gitignore index f23749bd118..371b57b6815 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,8 @@ lntest/itest/output*.log lntest/itest/pprof*.log lntest/itest/.backendlogs lntest/itest/.minerlogs +lntest/itest/lnd-itest +lntest/itest/.logs-* cmd/cmd *.key diff --git a/.travis.yml b/.travis.yml index feed17cc11e..51cf9402c17 100644 --- a/.travis.yml +++ b/.travis.yml @@ -50,32 +50,30 @@ jobs: - stage: Integration Test name: Btcd Integration script: - - make itest + - make itest-parallel - name: Bitcoind Integration (txindex enabled) script: - bash ./scripts/install_bitcoind.sh - - make itest backend=bitcoind + - make itest-parallel backend=bitcoind - name: Bitcoind Integration (txindex disabled) script: - bash ./scripts/install_bitcoind.sh - - make itest backend="bitcoind notxindex" + - make itest-parallel backend="bitcoind notxindex" - name: Neutrino Integration script: - - make itest backend=neutrino + - make itest-parallel backend=neutrino - name: Btcd Integration ARM script: - - GOARM=7 GOARCH=arm GOOS=linux CGO_ENABLED=0 make btcd build-itest - - file lnd-itest - - GOARM=7 GOARCH=arm GOOS=linux CGO_ENABLED=0 make itest-only + - GOARM=7 GOARCH=arm GOOS=linux make itest-parallel arch: arm64 - name: Btcd Integration Windows script: - - make itest-windows + - make itest-parallel-windows os: windows before_install: - choco upgrade --no-progress -y make netcat curl findutils @@ -85,7 +83,8 @@ jobs: case $TRAVIS_OS_NAME in windows) echo "Uploading to termbin.com..." - for f in ./lntest/itest/*.log; do cat $f | nc termbin.com 9999 | xargs -r0 printf "$f"' uploaded to %s'; done + LOG_FILES=$(find ./lntest/itest -name '*.log') + for f in $LOG_FILES; do echo -n $f; cat $f | nc termbin.com 9999 | xargs -r0 printf ' uploaded to %s'; done ;; esac @@ -97,8 +96,8 @@ after_failure: ;; *) - LOG_FILES=./lntest/itest/*.log - echo "Uploading to termbin.com..." && find $LOG_FILES | xargs -I{} sh -c "cat {} | nc termbin.com 9999 | xargs -r0 printf '{} uploaded to %s'" + LOG_FILES=$(find ./lntest/itest -name '*.log') + echo "Uploading to termbin.com..." && for f in $LOG_FILES; do echo -n $f; cat $f | nc termbin.com 9999 | xargs -r0 printf ' uploaded to %s'; done echo "Uploading to file.io..." && tar -zcvO $LOG_FILES | curl -s -F 'file=@-;filename=logs.tar.gz' https://file.io | xargs -r0 printf 'logs.tar.gz uploaded to %s\n' ;; esac diff --git a/Makefile b/Makefile index 5f55f0bc75d..ec0e322751a 100644 --- a/Makefile +++ b/Makefile @@ -175,6 +175,27 @@ itest-only: itest: btcd build-itest itest-only +itest-parallel: btcd + @$(call print, "Building lnd binary") + CGO_ENABLED=0 $(GOBUILD) -tags="$(ITEST_TAGS)" -o lntest/itest/lnd-itest $(ITEST_LDFLAGS) $(PKG)/cmd/lnd + + @$(call print, "Building itest binary for $(backend) backend") + CGO_ENABLED=0 $(GOTEST) -v ./lntest/itest -tags="$(DEV_TAGS) $(RPC_TAGS) rpctest $(backend)" -logoutput -goroutinedump -c -o lntest/itest/itest.test + + @$(call print, "Running tests") + rm -rf lntest/itest/*.log lntest/itest/.logs-* + echo -n "$$(seq 0 $$(expr $(NUM_ITEST_TRANCHES) - 1))" | xargs -P $(NUM_ITEST_TRANCHES) -n 1 -I {} scripts/itest_part.sh {} $(NUM_ITEST_TRANCHES) $(TEST_FLAGS) + +itest-parallel-windows: btcd + @$(call print, "Building lnd binary") + CGO_ENABLED=0 $(GOBUILD) -tags="$(ITEST_TAGS)" -o lntest/itest/lnd-itest.exe $(ITEST_LDFLAGS) $(PKG)/cmd/lnd + + @$(call print, "Building itest binary for $(backend) backend") + CGO_ENABLED=0 $(GOTEST) -v ./lntest/itest -tags="$(DEV_TAGS) $(RPC_TAGS) rpctest $(backend)" -logoutput -goroutinedump -c -o lntest/itest/itest.test.exe + + @$(call print, "Running tests") + EXEC_SUFFIX=".exe" echo -n "$$(seq 0 $$(expr $(NUM_ITEST_TRANCHES) - 1))" | xargs -P $(NUM_ITEST_TRANCHES) -n 1 -I {} scripts/itest_part.sh {} $(NUM_ITEST_TRANCHES) $(TEST_FLAGS) + itest-windows: btcd build-itest-windows itest-only unit: btcd diff --git a/chanbackup/multi.go b/chanbackup/multi.go index e90bd613e41..78e6197f03d 100644 --- a/chanbackup/multi.go +++ b/chanbackup/multi.go @@ -63,7 +63,9 @@ func (m Multi) PackToWriter(w io.Writer, keyRing keychain.KeyRing) error { var multiBackupBuffer bytes.Buffer // First, we'll write out the version of this multi channel baackup. - err := lnwire.WriteElements(&multiBackupBuffer, byte(m.Version)) + err := lnwire.WriteElements( + &multiBackupBuffer, lnwire.ProtocolVersionTLV, byte(m.Version), + ) if err != nil { return err } @@ -111,7 +113,9 @@ func (m *Multi) UnpackFromReader(r io.Reader, keyRing keychain.KeyRing) error { // First, we'll need to read the version of this multi-back up so we // can know how to unpack each of the individual SCB's. var multiVersion byte - err = lnwire.ReadElements(backupReader, &multiVersion) + err = lnwire.ReadElements( + backupReader, lnwire.ProtocolVersionTLV, &multiVersion, + ) if err != nil { return err } @@ -127,7 +131,9 @@ func (m *Multi) UnpackFromReader(r io.Reader, keyRing keychain.KeyRing) error { // backup is the same size, so we can continue until we've // parsed out everything. var numBackups uint32 - err = lnwire.ReadElements(backupReader, &numBackups) + err = lnwire.ReadElements( + backupReader, lnwire.ProtocolVersionTLV, &numBackups, + ) if err != nil { return err } diff --git a/chanbackup/single.go b/chanbackup/single.go index 490657b90dd..5d4c9a20eef 100644 --- a/chanbackup/single.go +++ b/chanbackup/single.go @@ -207,6 +207,7 @@ func (s *Single) Serialize(w io.Writer) error { var singleBytes bytes.Buffer if err := lnwire.WriteElements( &singleBytes, + lnwire.ProtocolVersionTLV, s.IsInitiator, s.ChainHash[:], s.FundingOutpoint, @@ -249,6 +250,7 @@ func (s *Single) Serialize(w io.Writer) error { return lnwire.WriteElements( w, + lnwire.ProtocolVersionTLV, byte(s.Version), uint16(len(singleBytes.Bytes())), singleBytes.Bytes(), @@ -290,12 +292,14 @@ func readLocalKeyDesc(r io.Reader) (keychain.KeyDescriptor, error) { var keyDesc keychain.KeyDescriptor var keyFam uint32 - if err := lnwire.ReadElements(r, &keyFam); err != nil { + err := lnwire.ReadElements(r, lnwire.ProtocolVersionTLV, &keyFam) + if err != nil { return keyDesc, err } keyDesc.Family = keychain.KeyFamily(keyFam) - if err := lnwire.ReadElements(r, &keyDesc.Index); err != nil { + err = lnwire.ReadElements(r, lnwire.ProtocolVersionTLV, &keyDesc.Index) + if err != nil { return keyDesc, err } @@ -333,7 +337,7 @@ func (s *Single) Deserialize(r io.Reader) error { // First, we'll need to read the version of this single-back up so we // can know how to unpack each of the SCB. var version byte - err := lnwire.ReadElements(r, &version) + err := lnwire.ReadElements(r, lnwire.ProtocolVersionTLV, &version) if err != nil { return err } @@ -350,19 +354,23 @@ func (s *Single) Deserialize(r io.Reader) error { } var length uint16 - if err := lnwire.ReadElements(r, &length); err != nil { + err = lnwire.ReadElements(r, lnwire.ProtocolVersionTLV, &length) + if err != nil { return err } err = lnwire.ReadElements( - r, &s.IsInitiator, s.ChainHash[:], &s.FundingOutpoint, - &s.ShortChannelID, &s.RemoteNodePub, &s.Addresses, &s.Capacity, + r, lnwire.ProtocolVersionTLV, &s.IsInitiator, s.ChainHash[:], + &s.FundingOutpoint, &s.ShortChannelID, &s.RemoteNodePub, + &s.Addresses, &s.Capacity, ) if err != nil { return err } - err = lnwire.ReadElements(r, &s.LocalChanCfg.CsvDelay) + err = lnwire.ReadElements( + r, lnwire.ProtocolVersionTLV, &s.LocalChanCfg.CsvDelay, + ) if err != nil { return err } @@ -387,7 +395,9 @@ func (s *Single) Deserialize(r io.Reader) error { return err } - err = lnwire.ReadElements(r, &s.RemoteChanCfg.CsvDelay) + err = lnwire.ReadElements( + r, lnwire.ProtocolVersionTLV, &s.RemoteChanCfg.CsvDelay, + ) if err != nil { return err } @@ -417,7 +427,8 @@ func (s *Single) Deserialize(r io.Reader) error { shaChainPub [33]byte zeroPub [33]byte ) - if err := lnwire.ReadElements(r, shaChainPub[:]); err != nil { + err = lnwire.ReadElements(r, lnwire.ProtocolVersionTLV, shaChainPub[:]) + if err != nil { return err } @@ -433,12 +444,17 @@ func (s *Single) Deserialize(r io.Reader) error { } var shaKeyFam uint32 - if err := lnwire.ReadElements(r, &shaKeyFam); err != nil { + err = lnwire.ReadElements( + r, lnwire.ProtocolVersionTLV, &shaKeyFam, + ) + if err != nil { return err } s.ShaChainRootDesc.KeyLocator.Family = keychain.KeyFamily(shaKeyFam) - return lnwire.ReadElements(r, &s.ShaChainRootDesc.KeyLocator.Index) + return lnwire.ReadElements( + r, lnwire.ProtocolVersionTLV, &s.ShaChainRootDesc.KeyLocator.Index, + ) } // UnpackFromReader is similar to Deserialize method, but it expects the passed diff --git a/channeldb/channel.go b/channeldb/channel.go index 35a0700d472..f34cd43e847 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -1871,7 +1871,8 @@ func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { return err } - if err := diff.CommitSig.Encode(w, 0); err != nil { + err := diff.CommitSig.Encode(w, lnwire.ProtocolVersionTLV) + if err != nil { return err } @@ -1918,7 +1919,7 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) { } d.CommitSig = &lnwire.CommitSig{} - if err := d.CommitSig.Decode(r, 0); err != nil { + if err := d.CommitSig.Decode(r, lnwire.ProtocolVersionTLV); err != nil { return nil, err } diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 656a885bf55..6f319003dc7 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -639,7 +639,8 @@ func TestChannelStateTransition(t *testing.T) { { LogIndex: 2, UpdateMsg: &lnwire.UpdateAddHTLC{ - ChanID: lnwire.ChannelID{1, 2, 3}, + ChanID: lnwire.ChannelID{1, 2, 3}, + ExtraData: make([]byte, 0), }, }, } @@ -660,7 +661,9 @@ func TestChannelStateTransition(t *testing.T) { if !reflect.DeepEqual( dbUnsignedAckedUpdates[0], unsignedAckedUpdates[0], ) { - t.Fatalf("unexpected update") + t.Fatalf("unexpected update: expected %v, got %v", + spew.Sdump(unsignedAckedUpdates[0]), + spew.Sdump(dbUnsignedAckedUpdates)) } // The balances, new update, the HTLCs and the changes to the fake @@ -702,22 +705,25 @@ func TestChannelStateTransition(t *testing.T) { wireSig, wireSig, }, + ExtraData: make([]byte, 0), }, LogUpdates: []LogUpdate{ { LogIndex: 1, UpdateMsg: &lnwire.UpdateAddHTLC{ - ID: 1, - Amount: lnwire.NewMSatFromSatoshis(100), - Expiry: 25, + ID: 1, + Amount: lnwire.NewMSatFromSatoshis(100), + Expiry: 25, + ExtraData: make([]byte, 0), }, }, { LogIndex: 2, UpdateMsg: &lnwire.UpdateAddHTLC{ - ID: 2, - Amount: lnwire.NewMSatFromSatoshis(200), - Expiry: 50, + ID: 2, + Amount: lnwire.NewMSatFromSatoshis(200), + Expiry: 50, + ExtraData: make([]byte, 0), }, }, }, diff --git a/channeldb/codec.go b/channeldb/codec.go index f6903175f8d..deb47017886 100644 --- a/channeldb/codec.go +++ b/channeldb/codec.go @@ -178,7 +178,8 @@ func WriteElement(w io.Writer, element interface{}) error { } case lnwire.Message: - if _, err := lnwire.WriteMessage(w, e, 0); err != nil { + _, err := lnwire.WriteMessage(w, e, lnwire.ProtocolVersionTLV) + if err != nil { return err } @@ -394,7 +395,7 @@ func ReadElement(r io.Reader, element interface{}) error { *e = bytes case *lnwire.Message: - msg, err := lnwire.ReadMessage(r, 0) + msg, err := lnwire.ReadMessage(r, lnwire.ProtocolVersionTLV) if err != nil { return err } diff --git a/channeldb/migration_01_to_11/codec.go b/channeldb/migration_01_to_11/codec.go index 1727c8c997d..cf590cb2890 100644 --- a/channeldb/migration_01_to_11/codec.go +++ b/channeldb/migration_01_to_11/codec.go @@ -172,7 +172,8 @@ func WriteElement(w io.Writer, element interface{}) error { } case lnwire.Message: - if _, err := lnwire.WriteMessage(w, e, 0); err != nil { + _, err := lnwire.WriteMessage(w, e, lnwire.ProtocolVersionLegacy) + if err != nil { return err } @@ -383,7 +384,7 @@ func ReadElement(r io.Reader, element interface{}) error { *e = bytes case *lnwire.Message: - msg, err := lnwire.ReadMessage(r, 0) + msg, err := lnwire.ReadMessage(r, lnwire.ProtocolVersionLegacy) if err != nil { return err } diff --git a/channeldb/migration_01_to_11/migrations.go b/channeldb/migration_01_to_11/migrations.go index 35be510e996..d46711b0791 100644 --- a/channeldb/migration_01_to_11/migrations.go +++ b/channeldb/migration_01_to_11/migrations.go @@ -724,7 +724,8 @@ func MigrateGossipMessageStoreKeys(tx kvdb.RwTx) error { // Serialize the message with its wire encoding. var b bytes.Buffer - if _, err := lnwire.WriteMessage(&b, msg, 0); err != nil { + _, err := lnwire.WriteMessage(&b, msg, lnwire.ProtocolVersionTLV) + if err != nil { return err } diff --git a/channeldb/migration_01_to_11/migrations_test.go b/channeldb/migration_01_to_11/migrations_test.go index 6cd855e85dd..3692af93232 100644 --- a/channeldb/migration_01_to_11/migrations_test.go +++ b/channeldb/migration_01_to_11/migrations_test.go @@ -464,7 +464,10 @@ func TestMigrateGossipMessageStoreKeys(t *testing.T) { // Construct the message which we'll use to test the migration, along // with its old and new key formats. shortChanID := lnwire.ShortChannelID{BlockHeight: 10} - msg := &lnwire.AnnounceSignatures{ShortChannelID: shortChanID} + msg := &lnwire.AnnounceSignatures{ + ShortChannelID: shortChanID, + ExtraOpaqueData: make([]byte, 0), + } var oldMsgKey [33 + 8]byte copy(oldMsgKey[:33], pubKey.SerializeCompressed()) @@ -526,7 +529,9 @@ func TestMigrateGossipMessageStoreKeys(t *testing.T) { t.Fatal(err) } - gotMsg, err := lnwire.ReadMessage(bytes.NewReader(rawMsg), 0) + gotMsg, err := lnwire.ReadMessage( + bytes.NewReader(rawMsg), lnwire.ProtocolVersionLegacy, + ) if err != nil { t.Fatalf("unable to deserialize raw message: %v", err) } diff --git a/channeldb/waitingproof.go b/channeldb/waitingproof.go index 2ea706c8405..e69afc05679 100644 --- a/channeldb/waitingproof.go +++ b/channeldb/waitingproof.go @@ -227,7 +227,8 @@ func (p *WaitingProof) Encode(w io.Writer) error { return err } - if err := p.AnnounceSignatures.Encode(w, 0); err != nil { + err := p.AnnounceSignatures.Encode(w, lnwire.ProtocolVersionTLV) + if err != nil { return err } @@ -242,7 +243,7 @@ func (p *WaitingProof) Decode(r io.Reader) error { } msg := &lnwire.AnnounceSignatures{} - if err := msg.Decode(r, 0); err != nil { + if err := msg.Decode(r, lnwire.ProtocolVersionTLV); err != nil { return err } diff --git a/channeldb/waitingproof_test.go b/channeldb/waitingproof_test.go index 12679b69f4a..24207da813e 100644 --- a/channeldb/waitingproof_test.go +++ b/channeldb/waitingproof_test.go @@ -5,6 +5,7 @@ import ( "reflect" + "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/lnwire" ) @@ -23,6 +24,7 @@ func TestWaitingProofStore(t *testing.T) { proof1 := NewWaitingProof(true, &lnwire.AnnounceSignatures{ NodeSignature: wireSig, BitcoinSignature: wireSig, + ExtraOpaqueData: make([]byte, 0), }) store, err := NewWaitingProofStore(db) @@ -40,7 +42,8 @@ func TestWaitingProofStore(t *testing.T) { t.Fatalf("unable retrieve proof from storage: %v", err) } if !reflect.DeepEqual(proof1, proof2) { - t.Fatal("wrong proof retrieved") + t.Fatalf("wrong proof retrieved: expected %v, got %v", + spew.Sdump(proof1), spew.Sdump(proof2)) } if _, err := store.Get(proof1.OppositeKey()); err != ErrWaitingProofNotFound { diff --git a/discovery/message_store.go b/discovery/message_store.go index f86ede20860..8399dd49f61 100644 --- a/discovery/message_store.go +++ b/discovery/message_store.go @@ -120,7 +120,8 @@ func (s *MessageStore) AddMessage(msg lnwire.Message, peerPubKey [33]byte) error // Serialize the message with its wire encoding. var b bytes.Buffer - if _, err := lnwire.WriteMessage(&b, msg, 0); err != nil { + _, err = lnwire.WriteMessage(&b, msg, lnwire.ProtocolVersionTLV) + if err != nil { return err } @@ -163,7 +164,9 @@ func (s *MessageStore) DeleteMessage(msg lnwire.Message, return nil } - dbMsg, err := lnwire.ReadMessage(bytes.NewReader(v), 0) + dbMsg, err := lnwire.ReadMessage( + bytes.NewReader(v), lnwire.ProtocolVersionTLV, + ) if err != nil { return err } @@ -182,7 +185,9 @@ func (s *MessageStore) DeleteMessage(msg lnwire.Message, // readMessage reads a message from its serialized form and ensures its // supported by the current version of the message store. func readMessage(msgBytes []byte) (lnwire.Message, error) { - msg, err := lnwire.ReadMessage(bytes.NewReader(msgBytes), 0) + msg, err := lnwire.ReadMessage( + bytes.NewReader(msgBytes), lnwire.ProtocolVersionTLV, + ) if err != nil { return nil, err } diff --git a/discovery/message_store_test.go b/discovery/message_store_test.go index fc7ba3360e6..d62dd16050c 100644 --- a/discovery/message_store_test.go +++ b/discovery/message_store_test.go @@ -64,13 +64,15 @@ func randCompressedPubKey(t *testing.T) [33]byte { func randAnnounceSignatures() *lnwire.AnnounceSignatures { return &lnwire.AnnounceSignatures{ - ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()), + ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()), + ExtraOpaqueData: make([]byte, 0), } } func randChannelUpdate() *lnwire.ChannelUpdate { return &lnwire.ChannelUpdate{ - ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()), + ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()), + ExtraOpaqueData: make([]byte, 0), } } diff --git a/discovery/sync_manager_test.go b/discovery/sync_manager_test.go index c7a228f8cf6..b8b309249b9 100644 --- a/discovery/sync_manager_test.go +++ b/discovery/sync_manager_test.go @@ -536,8 +536,9 @@ func assertTransitionToChansSynced(t *testing.T, s *GossipSyncer, peer *mockPeer assertMsgSent(t, peer, query) s.ProcessQueryMsg(&lnwire.ReplyChannelRange{ - QueryChannelRange: *query, - Complete: 1, + FirstBlockHeight: 0, + NumBlocks: math.MaxUint32, + Complete: 1, }, nil) chanSeries := s.cfg.channelSeries.(*mockChannelGraphTimeSeries) diff --git a/discovery/syncer.go b/discovery/syncer.go index 10b6d4205da..8417fda5049 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -685,7 +685,9 @@ func (g *GossipSyncer) synchronizeChanIDs() (bool, error) { func isLegacyReplyChannelRange(query *lnwire.QueryChannelRange, reply *lnwire.ReplyChannelRange) bool { - return reply.QueryChannelRange == *query + return (reply.ChainHash == query.ChainHash && + reply.FirstBlockHeight == query.FirstBlockHeight && + reply.NumBlocks == query.NumBlocks) } // processChanRangeReply is called each time the GossipSyncer receives a new @@ -705,7 +707,7 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro // The last block should also be. We don't need to check the // intermediate ones because they should already be in sorted // order. - replyLastHeight := msg.QueryChannelRange.LastBlockHeight() + replyLastHeight := msg.LastBlockHeight() queryLastHeight := g.curQueryRangeMsg.LastBlockHeight() if replyLastHeight > queryLastHeight { return fmt.Errorf("reply includes channels for height "+ @@ -754,7 +756,7 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro // Otherwise, we'll look at the reply's height range. default: - replyLastHeight := msg.QueryChannelRange.LastBlockHeight() + replyLastHeight := msg.LastBlockHeight() queryLastHeight := g.curQueryRangeMsg.LastBlockHeight() // TODO(wilmer): This might require some padding if the remote @@ -904,10 +906,12 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro g.cfg.chainHash) return g.cfg.sendToPeerSync(&lnwire.ReplyChannelRange{ - QueryChannelRange: *query, - Complete: 0, - EncodingType: g.cfg.encodingType, - ShortChanIDs: nil, + ChainHash: query.ChainHash, + FirstBlockHeight: query.FirstBlockHeight, + NumBlocks: query.NumBlocks, + Complete: 0, + EncodingType: g.cfg.encodingType, + ShortChanIDs: nil, }) } @@ -1001,14 +1005,12 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro // With our chunk assembled, we'll now send to the remote peer // the current chunk. replyChunk := lnwire.ReplyChannelRange{ - QueryChannelRange: lnwire.QueryChannelRange{ - ChainHash: query.ChainHash, - NumBlocks: numBlocksInResp, - FirstBlockHeight: firstBlockHeight, - }, - Complete: 0, - EncodingType: g.cfg.encodingType, - ShortChanIDs: channelChunk, + ChainHash: query.ChainHash, + NumBlocks: numBlocksInResp, + FirstBlockHeight: firstBlockHeight, + Complete: 0, + EncodingType: g.cfg.encodingType, + ShortChanIDs: channelChunk, } if isFinalChunk { replyChunk.Complete = 1 diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index 8e99fa49efa..b0d649de896 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -576,10 +576,9 @@ func TestGossipSyncerQueryChannelRangeWrongChainHash(t *testing.T) { t.Fatalf("expected lnwire.ReplyChannelRange, got %T", msg) } - if msg.QueryChannelRange != *query { - t.Fatalf("wrong query channel range in reply: "+ - "expected: %v\ngot: %v", spew.Sdump(*query), - spew.Sdump(msg.QueryChannelRange)) + if msg.ChainHash != query.ChainHash { + t.Fatalf("wrong chain hash: expected %v got %v", + query.ChainHash, msg.ChainHash) } if msg.Complete != 0 { t.Fatalf("expected complete set to 0, got %v", @@ -1192,34 +1191,13 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { t.Fatalf("unable to generate channel range query: %v", err) } - var replyQueries []*lnwire.QueryChannelRange - if legacy { - // Each reply query is the same as the original query in the - // legacy mode. - replyQueries = []*lnwire.QueryChannelRange{query, query, query} - } else { - // When interpreting block ranges, the first reply should start - // from our requested first block, and the last should end at - // our requested last block. - replyQueries = []*lnwire.QueryChannelRange{ - { - FirstBlockHeight: 0, - NumBlocks: 11, - }, - { - FirstBlockHeight: 11, - NumBlocks: 1, - }, - { - FirstBlockHeight: 12, - NumBlocks: query.NumBlocks - 12, - }, - } - } - + // When interpreting block ranges, the first reply should start from + // our requested first block, and the last should end at our requested + // last block. replies := []*lnwire.ReplyChannelRange{ { - QueryChannelRange: *replyQueries[0], + FirstBlockHeight: 0, + NumBlocks: 11, ShortChanIDs: []lnwire.ShortChannelID{ { BlockHeight: 10, @@ -1227,7 +1205,8 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { }, }, { - QueryChannelRange: *replyQueries[1], + FirstBlockHeight: 11, + NumBlocks: 1, ShortChanIDs: []lnwire.ShortChannelID{ { BlockHeight: 11, @@ -1235,8 +1214,9 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { }, }, { - QueryChannelRange: *replyQueries[2], - Complete: 1, + FirstBlockHeight: 12, + NumBlocks: query.NumBlocks - 12, + Complete: 1, ShortChanIDs: []lnwire.ShortChannelID{ { BlockHeight: 12, @@ -1245,6 +1225,19 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { }, } + // Each reply query is the same as the original query in the legacy + // mode. + if legacy { + replies[0].FirstBlockHeight = query.FirstBlockHeight + replies[0].NumBlocks = query.NumBlocks + + replies[1].FirstBlockHeight = query.FirstBlockHeight + replies[1].NumBlocks = query.NumBlocks + + replies[2].FirstBlockHeight = query.FirstBlockHeight + replies[2].NumBlocks = query.NumBlocks + } + // We'll begin by sending the syncer a set of non-complete channel // range replies. if err := syncer.processChanRangeReply(replies[0]); err != nil { diff --git a/fundingmanager.go b/fundingmanager.go index 43e3c29d3db..49e6edc5c4a 100644 --- a/fundingmanager.go +++ b/fundingmanager.go @@ -1437,7 +1437,9 @@ func (f *fundingManager) handleFundingOpen(peer lnpeer.Peer, PubKey: copyPubKey(msg.HtlcPoint), }, }, - UpfrontShutdown: msg.UpfrontShutdownScript, + UpfrontShutdown: lnwire.DeliveryAddress( + msg.UpfrontShutdownScript, + ), } err = reservation.ProcessSingleContribution(remoteContribution) if err != nil { @@ -1455,21 +1457,23 @@ func (f *fundingManager) handleFundingOpen(peer lnpeer.Peer, // contribution in the next message of the workflow. ourContribution := reservation.OurContribution() fundingAccept := lnwire.AcceptChannel{ - PendingChannelID: msg.PendingChannelID, - DustLimit: ourContribution.DustLimit, - MaxValueInFlight: remoteMaxValue, - ChannelReserve: chanReserve, - MinAcceptDepth: uint32(numConfsReq), - HtlcMinimum: minHtlc, - CsvDelay: remoteCsvDelay, - MaxAcceptedHTLCs: maxHtlcs, - FundingKey: ourContribution.MultiSigKey.PubKey, - RevocationPoint: ourContribution.RevocationBasePoint.PubKey, - PaymentPoint: ourContribution.PaymentBasePoint.PubKey, - DelayedPaymentPoint: ourContribution.DelayBasePoint.PubKey, - HtlcPoint: ourContribution.HtlcBasePoint.PubKey, - FirstCommitmentPoint: ourContribution.FirstCommitmentPoint, - UpfrontShutdownScript: ourContribution.UpfrontShutdown, + PendingChannelID: msg.PendingChannelID, + DustLimit: ourContribution.DustLimit, + MaxValueInFlight: remoteMaxValue, + ChannelReserve: chanReserve, + MinAcceptDepth: uint32(numConfsReq), + HtlcMinimum: minHtlc, + CsvDelay: remoteCsvDelay, + MaxAcceptedHTLCs: maxHtlcs, + FundingKey: ourContribution.MultiSigKey.PubKey, + RevocationPoint: ourContribution.RevocationBasePoint.PubKey, + PaymentPoint: ourContribution.PaymentBasePoint.PubKey, + DelayedPaymentPoint: ourContribution.DelayBasePoint.PubKey, + HtlcPoint: ourContribution.HtlcBasePoint.PubKey, + FirstCommitmentPoint: ourContribution.FirstCommitmentPoint, + UpfrontShutdownScript: lnwire.TypedDeliveryAddress( + ourContribution.UpfrontShutdown, + ), } if err := peer.SendMessage(true, &fundingAccept); err != nil { @@ -1568,7 +1572,9 @@ func (f *fundingManager) handleFundingAccept(peer lnpeer.Peer, PubKey: copyPubKey(msg.HtlcPoint), }, }, - UpfrontShutdown: msg.UpfrontShutdownScript, + UpfrontShutdown: lnwire.DeliveryAddress( + msg.UpfrontShutdownScript, + ), } err = resCtx.reservation.ProcessContribution(remoteContribution) @@ -3255,7 +3261,7 @@ func (f *fundingManager) handleInitFundingMsg(msg *initFundingMsg) { DelayedPaymentPoint: ourContribution.DelayBasePoint.PubKey, FirstCommitmentPoint: ourContribution.FirstCommitmentPoint, ChannelFlags: channelFlags, - UpfrontShutdownScript: shutdown, + UpfrontShutdownScript: lnwire.TypedDeliveryAddress(shutdown), } if err := msg.peer.SendMessage(true, &fundingOpen); err != nil { e := fmt.Errorf("unable to send funding request message: %v", diff --git a/htlcswitch/payment_result.go b/htlcswitch/payment_result.go index e6a1e59f103..cef1d4e8434 100644 --- a/htlcswitch/payment_result.go +++ b/htlcswitch/payment_result.go @@ -76,7 +76,7 @@ func deserializeNetworkResult(r io.Reader) (*networkResult, error) { n := &networkResult{} - n.msg, err = lnwire.ReadMessage(r, 0) + n.msg, err = lnwire.ReadMessage(r, lnwire.ProtocolVersionTLV) if err != nil { return nil, err } diff --git a/htlcswitch/payment_result_test.go b/htlcswitch/payment_result_test.go index 04ff57d8f72..aa7cbc173ec 100644 --- a/htlcswitch/payment_result_test.go +++ b/htlcswitch/payment_result_test.go @@ -39,18 +39,21 @@ func TestNetworkResultSerialization(t *testing.T) { ChanID: chanID, ID: 2, PaymentPreimage: preimage, + ExtraData: make([]byte, 0), } fail := &lnwire.UpdateFailHTLC{ - ChanID: chanID, - ID: 1, - Reason: []byte{}, + ChanID: chanID, + ID: 1, + Reason: []byte{}, + ExtraData: make([]byte, 0), } fail2 := &lnwire.UpdateFailHTLC{ - ChanID: chanID, - ID: 1, - Reason: reason[:], + ChanID: chanID, + ID: 1, + Reason: reason[:], + ExtraData: make([]byte, 0), } testCases := []*networkResult{ diff --git a/lntest/bitcoind_common.go b/lntest/bitcoind_common.go index b59fdac85d5..f673400abc5 100644 --- a/lntest/bitcoind_common.go +++ b/lntest/bitcoind_common.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io/ioutil" - "math/rand" "os" "os/exec" "path/filepath" @@ -16,8 +15,8 @@ import ( "github.com/btcsuite/btcd/rpcclient" ) -// logDir is the name of the temporary log directory. -const logDir = "./.backendlogs" +// logDirPattern is the pattern of the name of the temporary log directory. +const logDirPattern = "%s/.backendlogs" // BitcoindBackendConfig is an implementation of the BackendConfig interface // backed by a Bitcoind node. @@ -74,15 +73,16 @@ func (b BitcoindBackendConfig) Name() string { func newBackend(miner string, netParams *chaincfg.Params, extraArgs []string) ( *BitcoindBackendConfig, func() error, error) { + baseLogDir := fmt.Sprintf(logDirPattern, GetLogDir()) if netParams != &chaincfg.RegressionNetParams { return nil, nil, fmt.Errorf("only regtest supported") } - if err := os.MkdirAll(logDir, 0700); err != nil { + if err := os.MkdirAll(baseLogDir, 0700); err != nil { return nil, nil, err } - logFile, err := filepath.Abs(logDir + "/bitcoind.log") + logFile, err := filepath.Abs(baseLogDir + "/bitcoind.log") if err != nil { return nil, nil, err } @@ -93,10 +93,10 @@ func newBackend(miner string, netParams *chaincfg.Params, extraArgs []string) ( fmt.Errorf("unable to create temp directory: %v", err) } - zmqBlockPath := "ipc:///" + tempBitcoindDir + "/blocks.socket" - zmqTxPath := "ipc:///" + tempBitcoindDir + "/txs.socket" - rpcPort := rand.Int()%(65536-1024) + 1024 - p2pPort := rand.Int()%(65536-1024) + 1024 + zmqBlockAddr := fmt.Sprintf("tcp://127.0.0.1:%d", nextAvailablePort()) + zmqTxAddr := fmt.Sprintf("tcp://127.0.0.1:%d", nextAvailablePort()) + rpcPort := nextAvailablePort() + p2pPort := nextAvailablePort() cmdArgs := []string{ "-datadir=" + tempBitcoindDir, @@ -106,8 +106,8 @@ func newBackend(miner string, netParams *chaincfg.Params, extraArgs []string) ( "220110063096c221be9933c82d38e1", fmt.Sprintf("-rpcport=%d", rpcPort), fmt.Sprintf("-port=%d", p2pPort), - "-zmqpubrawblock=" + zmqBlockPath, - "-zmqpubrawtx=" + zmqTxPath, + "-zmqpubrawblock=" + zmqBlockAddr, + "-zmqpubrawtx=" + zmqTxAddr, "-debuglogfile=" + logFile, } cmdArgs = append(cmdArgs, extraArgs...) @@ -129,13 +129,16 @@ func newBackend(miner string, netParams *chaincfg.Params, extraArgs []string) ( var errStr string // After shutting down the chain backend, we'll make a copy of // the log file before deleting the temporary log dir. - err := CopyFile("./output_bitcoind_chainbackend.log", logFile) + logDestination := fmt.Sprintf( + "%s/output_bitcoind_chainbackend.log", GetLogDir(), + ) + err := CopyFile(logDestination, logFile) if err != nil { errStr += fmt.Sprintf("unable to copy file: %v\n", err) } - if err = os.RemoveAll(logDir); err != nil { + if err = os.RemoveAll(baseLogDir); err != nil { errStr += fmt.Sprintf( - "cannot remove dir %s: %v\n", logDir, err, + "cannot remove dir %s: %v\n", baseLogDir, err, ) } if err := os.RemoveAll(tempBitcoindDir); err != nil { @@ -178,8 +181,8 @@ func newBackend(miner string, netParams *chaincfg.Params, extraArgs []string) ( rpcHost: rpcHost, rpcUser: rpcUser, rpcPass: rpcPass, - zmqBlockPath: zmqBlockPath, - zmqTxPath: zmqTxPath, + zmqBlockPath: zmqBlockAddr, + zmqTxPath: zmqTxAddr, p2pPort: p2pPort, rpcClient: client, minerAddr: miner, diff --git a/lntest/btcd.go b/lntest/btcd.go index 11e19d3fd9f..e8b8cac43cd 100644 --- a/lntest/btcd.go +++ b/lntest/btcd.go @@ -14,8 +14,8 @@ import ( "github.com/btcsuite/btcd/rpcclient" ) -// logDir is the name of the temporary log directory. -const logDir = "./.backendlogs" +// logDirPattern is the pattern of the name of the temporary log directory. +const logDirPattern = "%s/.backendlogs" // temp is used to signal we want to establish a temporary connection using the // btcd Node API. @@ -75,12 +75,13 @@ func (b BtcdBackendConfig) Name() string { func NewBackend(miner string, netParams *chaincfg.Params) ( *BtcdBackendConfig, func() error, error) { + baseLogDir := fmt.Sprintf(logDirPattern, GetLogDir()) args := []string{ "--rejectnonstd", "--txindex", "--trickleinterval=100ms", "--debuglevel=debug", - "--logdir=" + logDir, + "--logdir=" + baseLogDir, "--nowinservice", // The miner will get banned and disconnected from the node if // its requested data are not found. We add a nobanning flag to @@ -110,14 +111,17 @@ func NewBackend(miner string, netParams *chaincfg.Params) ( // After shutting down the chain backend, we'll make a copy of // the log file before deleting the temporary log dir. - logFile := logDir + "/" + netParams.Name + "/btcd.log" - err := CopyFile("./output_btcd_chainbackend.log", logFile) + logFile := baseLogDir + "/" + netParams.Name + "/btcd.log" + logDestination := fmt.Sprintf( + "%s/output_btcd_chainbackend.log", GetLogDir(), + ) + err := CopyFile(logDestination, logFile) if err != nil { errStr += fmt.Sprintf("unable to copy file: %v\n", err) } - if err = os.RemoveAll(logDir); err != nil { + if err = os.RemoveAll(baseLogDir); err != nil { errStr += fmt.Sprintf( - "cannot remove dir %s: %v\n", logDir, err, + "cannot remove dir %s: %v\n", baseLogDir, err, ) } if errStr != "" { diff --git a/lntest/fee_service.go b/lntest/fee_service.go index 68e7d435a41..d71dae7d9f2 100644 --- a/lntest/fee_service.go +++ b/lntest/fee_service.go @@ -16,9 +16,6 @@ const ( // is returned. Requests for higher confirmation targets will fall back // to this. feeServiceTarget = 2 - - // feeServicePort is the tcp port on which the service runs. - feeServicePort = 16534 ) // feeService runs a web service that provides fee estimation information. @@ -40,16 +37,15 @@ type feeEstimates struct { // startFeeService spins up a go-routine to serve fee estimates. func startFeeService() *feeService { + port := nextAvailablePort() f := feeService{ - url: fmt.Sprintf( - "http://localhost:%v/fee-estimates.json", feeServicePort, - ), + url: fmt.Sprintf("http://localhost:%v/fee-estimates.json", port), } // Initialize default fee estimate. f.Fees = map[uint32]uint32{feeServiceTarget: 50000} - listenAddr := fmt.Sprintf(":%v", feeServicePort) + listenAddr := fmt.Sprintf(":%v", port) f.srv = &http.Server{ Addr: listenAddr, } diff --git a/lntest/itest/lnd_channel_backup_test.go b/lntest/itest/lnd_channel_backup_test.go index 5a8bf87dcf7..2a7e0eb6c28 100644 --- a/lntest/itest/lnd_channel_backup_test.go +++ b/lntest/itest/lnd_channel_backup_test.go @@ -1012,6 +1012,10 @@ func testChanRestoreScenario(t *harnessTest, net *lntest.NetworkHarness, require.Contains(t.t, err.Error(), "cannot close channel with state: ") require.Contains(t.t, err.Error(), "ChanStatusRestored") + // Increase the fee estimate so that the following force close tx will + // be cpfp'ed in case of anchor commitments. + net.SetFeeEstimate(30000) + // Now that we have ensured that the channels restored by the backup are // in the correct state even without the remote peer telling us so, // let's start up Carol again. diff --git a/lntest/itest/lnd_multi-hop_htlc_remote_chain_claim_test.go b/lntest/itest/lnd_multi-hop_htlc_remote_chain_claim_test.go index 1aeff01e2a9..72a3c63cfff 100644 --- a/lntest/itest/lnd_multi-hop_htlc_remote_chain_claim_test.go +++ b/lntest/itest/lnd_multi-hop_htlc_remote_chain_claim_test.go @@ -101,15 +101,17 @@ func testMultiHopHtlcRemoteChainClaim(net *lntest.NetworkHarness, t *harnessTest // bob will attempt to redeem his anchor commitment (if the channel // type is of that type). if c == commitTypeAnchors { - _, err = waitForNTxsInMempool(net.Miner.Node, 1, minerMempoolTimeout) + _, err = waitForNTxsInMempool( + net.Miner.Node, 1, minerMempoolTimeout, + ) if err != nil { - t.Fatalf("unable to find bob's anchor commit sweep: %v", err) - + t.Fatalf("unable to find bob's anchor commit sweep: %v", + err) } } // Mine enough blocks for Alice to sweep her funds from the force - // closed channel. closeCHannelAndAssertType() already mined a block + // closed channel. closeChannelAndAssertType() already mined a block // containing the commitment tx and the commit sweep tx will be // broadcast immediately before it can be included in a block, so mine // one less than defaultCSV in order to perform mempool assertions. diff --git a/lntest/itest/lnd_test.go b/lntest/itest/lnd_test.go index 5dc089cbc99..e49c7090c49 100644 --- a/lntest/itest/lnd_test.go +++ b/lntest/itest/lnd_test.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "crypto/sha256" "encoding/hex" + "flag" "fmt" "io" "io/ioutil" @@ -13,7 +14,6 @@ import ( "os" "path/filepath" "reflect" - "runtime" "strings" "sync" "sync/atomic" @@ -53,6 +53,60 @@ import ( "github.com/stretchr/testify/require" ) +const ( + // defaultSplitTranches is the default number of tranches we split the + // test cases into. + defaultSplitTranches uint = 1 + + // defaultRunTranche is the default index of the test cases tranche that + // we run. + defaultRunTranche uint = 0 +) + +var ( + // testCasesSplitParts is the number of tranches the test cases should + // be split into. By default this is set to 1, so no splitting happens. + // If this value is increased, then the -runtranche flag must be + // specified as well to indicate which part should be run in the current + // invocation. + testCasesSplitTranches = flag.Uint( + "splittranches", defaultSplitTranches, "split the test cases "+ + "in this many tranches and run the tranche at "+ + "0-based index specified by the -runtranche flag", + ) + + // testCasesRunTranche is the 0-based index of the split test cases + // tranche to run in the current invocation. + testCasesRunTranche = flag.Uint( + "runtranche", defaultRunTranche, "run the tranche of the "+ + "split test cases with the given (0-based) index", + ) +) + +// getTestCaseSplitTranche returns the sub slice of the test cases that should +// be run as the current split tranche as well as the index and slice offset of +// the tranche. +func getTestCaseSplitTranche() ([]*testCase, uint, uint) { + numTranches := defaultSplitTranches + if testCasesSplitTranches != nil { + numTranches = *testCasesSplitTranches + } + runTranche := defaultRunTranche + if testCasesRunTranche != nil { + runTranche = *testCasesRunTranche + } + + numCases := uint(len(allTestCases)) + testsPerTranche := numCases / numTranches + trancheOffset := runTranche * testsPerTranche + trancheEnd := trancheOffset + testsPerTranche + if trancheEnd > numCases || runTranche == numTranches-1 { + trancheEnd = numCases + } + + return allTestCases[trancheOffset:trancheEnd], runTranche, trancheOffset +} + func rpcPointToWirePoint(t *harnessTest, chanPoint *lnrpc.ChannelPoint) wire.OutPoint { txid, err := lnd.GetChanPointFundingTxid(chanPoint) if err != nil { @@ -2380,7 +2434,7 @@ func testOpenChannelAfterReorg(net *lntest.NetworkHarness, t *harnessTest) { ) // Set up a new miner that we can use to cause a reorg. - tempLogDir := "./.tempminerlogs" + tempLogDir := fmt.Sprintf("%s/.tempminerlogs", lntest.GetLogDir()) logFilename := "output-open_channel_reorg-temp_miner.log" tempMiner, tempMinerCleanUp, err := lntest.NewMiner( tempLogDir, logFilename, @@ -14098,10 +14152,16 @@ func getPaymentResult(stream routerrpc.Router_SendPaymentV2Client) ( // programmatically driven network of lnd nodes. func TestLightningNetworkDaemon(t *testing.T) { // If no tests are registered, then we can exit early. - if len(testsCases) == 0 { + if len(allTestCases) == 0 { t.Skip("integration tests not selected with flag 'rpctest'") } + // Parse testing flags that influence our test execution. + logDir := lntest.GetLogDir() + require.NoError(t, os.MkdirAll(logDir, 0700)) + testCases, trancheIndex, trancheOffset := getTestCaseSplitTranche() + lntest.ApplyPortOffset(uint32(trancheIndex) * 1000) + ht := newHarnessTest(t, nil) // Declare the network harness here to gain access to its @@ -14117,7 +14177,7 @@ func TestLightningNetworkDaemon(t *testing.T) { // guarantees of getting included in to blocks. // // We will also connect it to our chain backend. - minerLogDir := "./.minerlogs" + minerLogDir := fmt.Sprintf("%s/.minerlogs", logDir) miner, minerCleanUp, err := lntest.NewMiner( minerLogDir, "output_btcd_miner.log", harnessNetParams, &rpcclient.NotificationHandlers{}, @@ -14149,27 +14209,12 @@ func TestLightningNetworkDaemon(t *testing.T) { // Connect chainbackend to miner. require.NoError( - t, chainBackend.ConnectMiner(), - "failed to connect to miner", + t, chainBackend.ConnectMiner(), "failed to connect to miner", ) - binary := itestLndBinary - if runtime.GOOS == "windows" { - // Windows (even in a bash like environment like git bash as on - // Travis) doesn't seem to like relative paths to exe files... - currentDir, err := os.Getwd() - if err != nil { - ht.Fatalf("unable to get working directory: %v", err) - } - targetPath := filepath.Join(currentDir, "../../lnd-itest.exe") - binary, err = filepath.Abs(targetPath) - if err != nil { - ht.Fatalf("unable to get absolute path: %v", err) - } - } - // Now we can set up our test harness (LND instance), with the chain // backend we just created. + binary := ht.getLndBinary() lndHarness, err = lntest.NewNetworkHarness(miner, chainBackend, binary) if err != nil { ht.Fatalf("unable to create lightning network harness: %v", err) @@ -14187,7 +14232,8 @@ func TestLightningNetworkDaemon(t *testing.T) { if !more { return } - ht.Logf("lnd finished with error (stderr):\n%v", err) + ht.Logf("lnd finished with error (stderr):\n%v", + err) } } }() @@ -14210,8 +14256,9 @@ func TestLightningNetworkDaemon(t *testing.T) { ht.Fatalf("unable to set up test lightning network: %v", err) } - t.Logf("Running %v integration tests", len(testsCases)) - for _, testCase := range testsCases { + // Run the subset of the test cases selected in this tranche. + for idx, testCase := range testCases { + testCase := testCase logLine := fmt.Sprintf("STARTING ============ %v ============\n", testCase.name) @@ -14232,7 +14279,10 @@ func TestLightningNetworkDaemon(t *testing.T) { // Start every test with the default static fee estimate. lndHarness.SetFeeEstimate(12500) - success := t.Run(testCase.name, func(t1 *testing.T) { + name := fmt.Sprintf("%02d-of-%d/%s/%s", + trancheOffset+uint(idx)+1, len(allTestCases), + chainBackend.Name(), testCase.name) + success := t.Run(name, func(t1 *testing.T) { ht := newHarnessTest(t1, lndHarness) ht.RunTestCase(testCase) }) @@ -14242,8 +14292,9 @@ func TestLightningNetworkDaemon(t *testing.T) { if !success { // Log failure time to help relate the lnd logs to the // failure. - t.Logf("Failure time: %v", - time.Now().Format("2006-01-02 15:04:05.000")) + t.Logf("Failure time: %v", time.Now().Format( + "2006-01-02 15:04:05.000", + )) break } } diff --git a/lntest/itest/lnd_test_list_off_test.go b/lntest/itest/lnd_test_list_off_test.go index ae18d5e0ca3..59795f1d1bb 100644 --- a/lntest/itest/lnd_test_list_off_test.go +++ b/lntest/itest/lnd_test_list_off_test.go @@ -2,4 +2,4 @@ package itest -var testsCases = []*testCase{} +var allTestCases = []*testCase{} diff --git a/lntest/itest/lnd_test_list_on_test.go b/lntest/itest/lnd_test_list_on_test.go index 98910d22b92..3575213cd43 100644 --- a/lntest/itest/lnd_test_list_on_test.go +++ b/lntest/itest/lnd_test_list_on_test.go @@ -2,7 +2,11 @@ package itest -var testsCases = []*testCase{ +var allTestCases = []*testCase{ + { + name: "test multi-hop htlc", + test: testMultiHopHtlcClaims, + }, { name: "sweep coins", test: testSweepAllCoins, @@ -144,10 +148,6 @@ var testsCases = []*testCase{ name: "async bidirectional payments", test: testBidirectionalAsyncPayments, }, - { - name: "test multi-hop htlc", - test: testMultiHopHtlcClaims, - }, { name: "switch circuit persistence", test: testSwitchCircuitPersistence, diff --git a/lntest/itest/test_harness.go b/lntest/itest/test_harness.go index a3c4752893a..45248f60256 100644 --- a/lntest/itest/test_harness.go +++ b/lntest/itest/test_harness.go @@ -3,8 +3,12 @@ package itest import ( "bytes" "context" + "flag" "fmt" "math" + "os" + "path/filepath" + "runtime" "testing" "time" @@ -20,6 +24,11 @@ import ( var ( harnessNetParams = &chaincfg.RegressionNetParams + + // lndExecutable is the full path to the lnd binary. + lndExecutable = flag.String( + "lndexec", itestLndBinary, "full path to lnd binary", + ) ) const ( @@ -111,6 +120,31 @@ func (h *harnessTest) Log(args ...interface{}) { h.t.Log(args...) } +func (h *harnessTest) getLndBinary() string { + binary := itestLndBinary + lndExec := "" + if lndExecutable != nil && *lndExecutable != "" { + lndExec = *lndExecutable + } + if lndExec == "" && runtime.GOOS == "windows" { + // Windows (even in a bash like environment like git bash as on + // Travis) doesn't seem to like relative paths to exe files... + currentDir, err := os.Getwd() + if err != nil { + h.Fatalf("unable to get working directory: %v", err) + } + targetPath := filepath.Join(currentDir, "../../lnd-itest.exe") + binary, err = filepath.Abs(targetPath) + if err != nil { + h.Fatalf("unable to get absolute path: %v", err) + } + } else if lndExec != "" { + binary = lndExec + } + + return binary +} + type testCase struct { name string test func(net *lntest.NetworkHarness, t *harnessTest) diff --git a/lntest/node.go b/lntest/node.go index cdf0be03bb6..cb368e75046 100644 --- a/lntest/node.go +++ b/lntest/node.go @@ -43,7 +43,7 @@ const ( // defaultNodePort is the start of the range for listening ports of // harness nodes. Ports are monotonically increasing starting from this // number and are determined by the results of nextAvailablePort(). - defaultNodePort = 19555 + defaultNodePort = 5555 // logPubKeyBytes is the number of bytes of the node's PubKey that will // be appended to the log file name. The whole PubKey is too long and @@ -70,6 +70,10 @@ var ( logOutput = flag.Bool("logoutput", false, "log output from node n to file output-n.log") + // logSubDir is the default directory where the logs are written to if + // logOutput is true. + logSubDir = flag.String("logdir", ".", "default dir to write logs to") + // goroutineDump is a flag that can be set to dump the active // goroutines of test nodes on failure. goroutineDump = flag.Bool("goroutinedump", false, @@ -104,6 +108,21 @@ func nextAvailablePort() int { panic("no ports available for listening") } +// ApplyPortOffset adds the given offset to the lastPort variable, making it +// possible to run the tests in parallel without colliding on the same ports. +func ApplyPortOffset(offset uint32) { + _ = atomic.AddUint32(&lastPort, offset) +} + +// GetLogDir returns the passed --logdir flag or the default value if it wasn't +// set. +func GetLogDir() string { + if logSubDir != nil && *logSubDir != "" { + return *logSubDir + } + return "." +} + // generateListeningPorts returns four ints representing ports to listen on // designated for the current lightning network test. This returns the next // available ports for the p2p, rpc, rest and profiling services. @@ -386,11 +405,9 @@ func NewMiner(logDir, logFilename string, netParams *chaincfg.Params, // After shutting down the miner, we'll make a copy of the log // file before deleting the temporary log dir. - logFile := fmt.Sprintf( - "%s/%s/btcd.log", logDir, netParams.Name, - ) - copyPath := fmt.Sprintf("./%s", logFilename) - err := CopyFile(copyPath, logFile) + logFile := fmt.Sprintf("%s/%s/btcd.log", logDir, netParams.Name) + copyPath := fmt.Sprintf("%s/../%s", logDir, logFilename) + err := CopyFile(filepath.Clean(copyPath), logFile) if err != nil { return fmt.Errorf("unable to copy file: %v", err) } @@ -475,24 +492,28 @@ func (hn *HarnessNode) start(lndBinary string, lndError chan<- error) error { // If the logoutput flag is passed, redirect output from the nodes to // log files. if *logOutput { - fileName := fmt.Sprintf("output-%d-%s-%s.log", hn.NodeID, + dir := GetLogDir() + fileName := fmt.Sprintf("%s/output-%d-%s-%s.log", dir, hn.NodeID, hn.Cfg.Name, hex.EncodeToString(hn.PubKey[:logPubKeyBytes])) // If the node's PubKey is not yet initialized, create a temporary // file name. Later, after the PubKey has been initialized, the // file can be moved to its final name with the PubKey included. if bytes.Equal(hn.PubKey[:4], []byte{0, 0, 0, 0}) { - fileName = fmt.Sprintf("output-%d-%s-tmp__.log", hn.NodeID, - hn.Cfg.Name) + fileName = fmt.Sprintf("%s/output-%d-%s-tmp__.log", + dir, hn.NodeID, hn.Cfg.Name) // Once the node has done its work, the log file can be renamed. finalizeLogfile = func() { if hn.logFile != nil { hn.logFile.Close() - newFileName := fmt.Sprintf("output-%d-%s-%s.log", - hn.NodeID, hn.Cfg.Name, - hex.EncodeToString(hn.PubKey[:logPubKeyBytes])) + pubKeyHex := hex.EncodeToString( + hn.PubKey[:logPubKeyBytes], + ) + newFileName := fmt.Sprintf("%s/output"+ + "-%d-%s-%s.log", dir, hn.NodeID, + hn.Cfg.Name, pubKeyHex) err := os.Rename(fileName, newFileName) if err != nil { fmt.Printf("could not rename "+ diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 190e31c1255..038d4f16da6 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -3176,6 +3176,7 @@ func TestChanSyncOweCommitment(t *testing.T) { Amount: htlcAmt, Expiry: uint32(10), OnionBlob: fakeOnionBlob, + ExtraData: make([]byte, 0), } htlcIndex, err := bobChannel.AddHTLC(h, nil) @@ -3220,6 +3221,7 @@ func TestChanSyncOweCommitment(t *testing.T) { Amount: htlcAmt, Expiry: uint32(10), OnionBlob: fakeOnionBlob, + ExtraData: make([]byte, 0), } aliceHtlcIndex, err := aliceChannel.AddHTLC(aliceHtlc, nil) if err != nil { diff --git a/lnwire/accept_channel.go b/lnwire/accept_channel.go index da9daa69b32..5cbd8b2fe65 100644 --- a/lnwire/accept_channel.go +++ b/lnwire/accept_channel.go @@ -91,7 +91,12 @@ type AcceptChannel struct { // be paid when mutually closing the channel. This field is optional, and // and has a length prefix, so a zero will be written if it is not set // and its length followed by the script will be written if it is set. - UpfrontShutdownScript DeliveryAddress + UpfrontShutdownScript TypedDeliveryAddress + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // A compile time check to ensure AcceptChannel implements the lnwire.Message @@ -105,6 +110,7 @@ var _ Message = (*AcceptChannel)(nil) // This is part of the lnwire.Message interface. func (a *AcceptChannel) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, a.PendingChannelID[:], a.DustLimit, a.MaxValueInFlight, @@ -120,6 +126,7 @@ func (a *AcceptChannel) Encode(w io.Writer, pver uint32) error { a.HtlcPoint, a.FirstCommitmentPoint, a.UpfrontShutdownScript, + a.ExtraData, ) } @@ -130,7 +137,8 @@ func (a *AcceptChannel) Encode(w io.Writer, pver uint32) error { // This is part of the lnwire.Message interface. func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { // Read all the mandatory fields in the accept message. - err := ReadElements(r, + return ReadElements(r, + pver, a.PendingChannelID[:], &a.DustLimit, &a.MaxValueInFlight, @@ -145,18 +153,9 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { &a.DelayedPaymentPoint, &a.HtlcPoint, &a.FirstCommitmentPoint, + &a.UpfrontShutdownScript, + &a.ExtraData, ) - if err != nil { - return err - } - - // Check for the optional upfront shutdown script field. If it is not there, - // silence the EOF error. - err = ReadElement(r, &a.UpfrontShutdownScript) - if err != nil && err != io.EOF { - return err - } - return nil } // MsgType returns the MessageType code which uniquely identifies this message @@ -172,11 +171,5 @@ func (a *AcceptChannel) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (a *AcceptChannel) MaxPayloadLength(uint32) uint32 { - // 32 + (8 * 4) + (4 * 1) + (2 * 2) + (33 * 6) - var length uint32 = 270 // base length - - // Upfront shutdown script max length. - length += 2 + deliveryAddressMaxSize - - return length + return MaxMsgBody } diff --git a/lnwire/accept_channel_test.go b/lnwire/accept_channel_test.go index a1ab2be48c4..18ad1639ed2 100644 --- a/lnwire/accept_channel_test.go +++ b/lnwire/accept_channel_test.go @@ -12,7 +12,7 @@ import ( func TestDecodeAcceptChannel(t *testing.T) { tests := []struct { name string - shutdownScript DeliveryAddress + shutdownScript TypedDeliveryAddress }{ { name: "no upfront shutdown script", diff --git a/lnwire/announcement_signatures.go b/lnwire/announcement_signatures.go index 0124521354f..cb9fe990b73 100644 --- a/lnwire/announcement_signatures.go +++ b/lnwire/announcement_signatures.go @@ -2,7 +2,6 @@ package lnwire import ( "io" - "io/ioutil" ) // AnnounceSignatures this is a direct message between two endpoints of a @@ -40,7 +39,7 @@ type AnnounceSignatures struct { // properly validate the set of signatures that cover these new fields, // and ensure we're able to make upgrades to the network in a forwards // compatible manner. - ExtraOpaqueData []byte + ExtraOpaqueData ExtraOpaqueData } // A compile time check to ensure AnnounceSignatures implements the @@ -52,29 +51,14 @@ var _ Message = (*AnnounceSignatures)(nil) // // This is part of the lnwire.Message interface. func (a *AnnounceSignatures) Decode(r io.Reader, pver uint32) error { - err := ReadElements(r, + return ReadElements(r, + pver, &a.ChannelID, &a.ShortChannelID, &a.NodeSignature, &a.BitcoinSignature, + &a.ExtraOpaqueData, ) - if err != nil { - return err - } - - // Now that we've read out all the fields that we explicitly know of, - // we'll collect the remainder into the ExtraOpaqueData field. If there - // aren't any bytes, then we'll snip off the slice to avoid carrying - // around excess capacity. - a.ExtraOpaqueData, err = ioutil.ReadAll(r) - if err != nil { - return err - } - if len(a.ExtraOpaqueData) == 0 { - a.ExtraOpaqueData = nil - } - - return nil } // Encode serializes the target AnnounceSignatures into the passed io.Writer @@ -83,6 +67,7 @@ func (a *AnnounceSignatures) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (a *AnnounceSignatures) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, a.ChannelID, a.ShortChannelID, a.NodeSignature, @@ -104,5 +89,5 @@ func (a *AnnounceSignatures) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (a *AnnounceSignatures) MaxPayloadLength(pver uint32) uint32 { - return 65533 + return MaxMsgBody } diff --git a/lnwire/channel_announcement.go b/lnwire/channel_announcement.go index 46efeed8807..92ccb520441 100644 --- a/lnwire/channel_announcement.go +++ b/lnwire/channel_announcement.go @@ -3,7 +3,6 @@ package lnwire import ( "bytes" "io" - "io/ioutil" "github.com/btcsuite/btcd/chaincfg/chainhash" ) @@ -56,7 +55,7 @@ type ChannelAnnouncement struct { // properly validate the set of signatures that cover these new fields, // and ensure we're able to make upgrades to the network in a forwards // compatible manner. - ExtraOpaqueData []byte + ExtraOpaqueData ExtraOpaqueData } // A compile time check to ensure ChannelAnnouncement implements the @@ -68,7 +67,8 @@ var _ Message = (*ChannelAnnouncement)(nil) // // This is part of the lnwire.Message interface. func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error { - err := ReadElements(r, + return ReadElements(r, + pver, &a.NodeSig1, &a.NodeSig2, &a.BitcoinSig1, @@ -80,24 +80,8 @@ func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error { &a.NodeID2, &a.BitcoinKey1, &a.BitcoinKey2, + &a.ExtraOpaqueData, ) - if err != nil { - return err - } - - // Now that we've read out all the fields that we explicitly know of, - // we'll collect the remainder into the ExtraOpaqueData field. If there - // aren't any bytes, then we'll snip off the slice to avoid carrying - // around excess capacity. - a.ExtraOpaqueData, err = ioutil.ReadAll(r) - if err != nil { - return err - } - if len(a.ExtraOpaqueData) == 0 { - a.ExtraOpaqueData = nil - } - - return nil } // Encode serializes the target ChannelAnnouncement into the passed io.Writer @@ -106,6 +90,7 @@ func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (a *ChannelAnnouncement) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, a.NodeSig1, a.NodeSig2, a.BitcoinSig1, @@ -134,7 +119,7 @@ func (a *ChannelAnnouncement) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (a *ChannelAnnouncement) MaxPayloadLength(pver uint32) uint32 { - return 65533 + return MaxMsgBody } // DataToSign is used to retrieve part of the announcement message which should @@ -143,6 +128,10 @@ func (a *ChannelAnnouncement) DataToSign() ([]byte, error) { // We should not include the signatures itself. var w bytes.Buffer err := WriteElements(&w, + // We always use the modern protocol version here as we always + // need to include any optional data in the signature digest + // for forwards compatibility. + ProtocolVersionTLV, a.Features, a.ChainHash[:], a.ShortChannelID, diff --git a/lnwire/channel_reestablish.go b/lnwire/channel_reestablish.go index 42abcf95d7c..aec99d491ec 100644 --- a/lnwire/channel_reestablish.go +++ b/lnwire/channel_reestablish.go @@ -60,6 +60,11 @@ type ChannelReestablish struct { // LocalUnrevokedCommitPoint is the commitment point used in the // current un-revoked commitment transaction of the sending party. LocalUnrevokedCommitPoint *btcec.PublicKey + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // A compile time check to ensure ChannelReestablish implements the @@ -72,6 +77,7 @@ var _ Message = (*ChannelReestablish)(nil) // This is part of the lnwire.Message interface. func (a *ChannelReestablish) Encode(w io.Writer, pver uint32) error { err := WriteElements(w, + pver, a.ChanID, a.NextLocalCommitHeight, a.RemoteCommitTailHeight, @@ -83,12 +89,21 @@ func (a *ChannelReestablish) Encode(w io.Writer, pver uint32) error { // If the commit point wasn't sent, then we won't write out any of the // remaining fields as they're optional. if a.LocalUnrevokedCommitPoint == nil { - return nil + // However, we'll still write out the extra data if it's + // present. + // + // NOTE: This is here primarily for the quickcheck tests, in + // practice, we'll always populate this field. + return WriteElements(w, pver, a.ExtraData) } // Otherwise, we'll write out the remaining elements. - return WriteElements(w, a.LastRemoteCommitSecret[:], - a.LocalUnrevokedCommitPoint) + return WriteElements(w, + pver, + a.LastRemoteCommitSecret[:], + a.LocalUnrevokedCommitPoint, + a.ExtraData, + ) } // Decode deserializes a serialized ChannelReestablish stored in the passed @@ -97,6 +112,7 @@ func (a *ChannelReestablish) Encode(w io.Writer, pver uint32) error { // This is part of the lnwire.Message interface. func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error { err := ReadElements(r, + pver, &a.ChanID, &a.NextLocalCommitHeight, &a.RemoteCommitTailHeight, @@ -118,6 +134,9 @@ func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error { var buf [32]byte _, err = io.ReadFull(r, buf[:32]) if err == io.EOF { + // If there aren't any more bytes, then we'll emplace an empty + // extra data to make our quickcheck tests happy. + a.ExtraData = make([]byte, 0) return nil } else if err != nil { return err @@ -127,9 +146,13 @@ func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error { copy(a.LastRemoteCommitSecret[:], buf[:]) // We'll conclude by parsing out the commitment point. We don't check - // the error in this case, as it hey included the commit secret, then + // the error in this case, as it they included the commit secret, then // they MUST also include the commit point. - return ReadElement(r, &a.LocalUnrevokedCommitPoint) + if err = ReadElement(r, pver, &a.LocalUnrevokedCommitPoint); err != nil { + return err + } + + return a.ExtraData.Decode(r, pver) } // MsgType returns the integer uniquely identifying this message type on the @@ -145,22 +168,5 @@ func (a *ChannelReestablish) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (a *ChannelReestablish) MaxPayloadLength(pver uint32) uint32 { - var length uint32 - - // ChanID - 32 bytes - length += 32 - - // NextLocalCommitHeight - 8 bytes - length += 8 - - // RemoteCommitTailHeight - 8 bytes - length += 8 - - // LastRemoteCommitSecret - 32 bytes - length += 32 - - // LocalUnrevokedCommitPoint - 33 bytes - length += 33 - - return length + return MaxMsgBody } diff --git a/lnwire/channel_update.go b/lnwire/channel_update.go index fd627646b6d..8503959b0f2 100644 --- a/lnwire/channel_update.go +++ b/lnwire/channel_update.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "io" - "io/ioutil" "github.com/btcsuite/btcd/chaincfg/chainhash" ) @@ -110,13 +109,10 @@ type ChannelUpdate struct { // HtlcMaximumMsat is the maximum HTLC value which will be accepted. HtlcMaximumMsat MilliSatoshi - // ExtraOpaqueData is the set of data that was appended to this - // message, some of which we may not actually know how to iterate or - // parse. By holding onto this data, we ensure that we're able to - // properly validate the set of signatures that cover these new fields, - // and ensure we're able to make upgrades to the network in a forwards - // compatible manner. - ExtraOpaqueData []byte + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraOpaqueData ExtraOpaqueData } // A compile time check to ensure ChannelUpdate implements the lnwire.Message @@ -129,6 +125,7 @@ var _ Message = (*ChannelUpdate)(nil) // This is part of the lnwire.Message interface. func (a *ChannelUpdate) Decode(r io.Reader, pver uint32) error { err := ReadElements(r, + pver, &a.Signature, a.ChainHash[:], &a.ShortChannelID, @@ -146,24 +143,12 @@ func (a *ChannelUpdate) Decode(r io.Reader, pver uint32) error { // Now check whether the max HTLC field is present and read it if so. if a.MessageFlags.HasMaxHtlc() { - if err := ReadElements(r, &a.HtlcMaximumMsat); err != nil { + if err := ReadElements(r, pver, &a.HtlcMaximumMsat); err != nil { return err } } - // Now that we've read out all the fields that we explicitly know of, - // we'll collect the remainder into the ExtraOpaqueData field. If there - // aren't any bytes, then we'll snip off the slice to avoid carrying - // around excess capacity. - a.ExtraOpaqueData, err = ioutil.ReadAll(r) - if err != nil { - return err - } - if len(a.ExtraOpaqueData) == 0 { - a.ExtraOpaqueData = nil - } - - return nil + return a.ExtraOpaqueData.Decode(r, pver) } // Encode serializes the target ChannelUpdate into the passed io.Writer @@ -172,6 +157,7 @@ func (a *ChannelUpdate) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (a *ChannelUpdate) Encode(w io.Writer, pver uint32) error { err := WriteElements(w, + pver, a.Signature, a.ChainHash[:], a.ShortChannelID, @@ -190,13 +176,13 @@ func (a *ChannelUpdate) Encode(w io.Writer, pver uint32) error { // Now append optional fields if they are set. Currently, the only // optional field is max HTLC. if a.MessageFlags.HasMaxHtlc() { - if err := WriteElements(w, a.HtlcMaximumMsat); err != nil { + if err := WriteElements(w, pver, a.HtlcMaximumMsat); err != nil { return err } } // Finally, append any extra opaque data. - return WriteElements(w, a.ExtraOpaqueData) + return a.ExtraOpaqueData.Encode(w, pver) } // MsgType returns the integer uniquely identifying this message type on the @@ -212,16 +198,16 @@ func (a *ChannelUpdate) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (a *ChannelUpdate) MaxPayloadLength(pver uint32) uint32 { - return 65533 + return MaxMsgBody } // DataToSign is used to retrieve part of the announcement message which should // be signed. func (a *ChannelUpdate) DataToSign() ([]byte, error) { - // We should not include the signatures itself. var w bytes.Buffer err := WriteElements(&w, + ProtocolVersionTLV, a.ChainHash[:], a.ShortChannelID, a.Timestamp, @@ -239,13 +225,18 @@ func (a *ChannelUpdate) DataToSign() ([]byte, error) { // Now append optional fields if they are set. Currently, the only // optional field is max HTLC. if a.MessageFlags.HasMaxHtlc() { - if err := WriteElements(&w, a.HtlcMaximumMsat); err != nil { + err := WriteElements( + &w, ProtocolVersionTLV, a.HtlcMaximumMsat, + ) + if err != nil { return nil, err } } - // Finally, append any extra opaque data. - if err := WriteElements(&w, a.ExtraOpaqueData); err != nil { + // Finally, append any extra opaque data. We always pass in the modern + // protocol version here as we always need to include any extra bytes + // in the signature digest. + if err := a.ExtraOpaqueData.Encode(&w, ProtocolVersionTLV); err != nil { return nil, err } diff --git a/lnwire/closing_signed.go b/lnwire/closing_signed.go index 91b90646a02..c669b8b8520 100644 --- a/lnwire/closing_signed.go +++ b/lnwire/closing_signed.go @@ -27,6 +27,11 @@ type ClosingSigned struct { // Signature is for the proposed channel close transaction. Signature Sig + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewClosingSigned creates a new empty ClosingSigned message. @@ -49,7 +54,10 @@ var _ Message = (*ClosingSigned)(nil) // // This is part of the lnwire.Message interface. func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, &c.ChannelID, &c.FeeSatoshis, &c.Signature) + return ReadElements( + r, pver, &c.ChannelID, &c.FeeSatoshis, &c.Signature, + &c.ExtraData, + ) } // Encode serializes the target ClosingSigned into the passed io.Writer @@ -57,7 +65,9 @@ func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (c *ClosingSigned) Encode(w io.Writer, pver uint32) error { - return WriteElements(w, c.ChannelID, c.FeeSatoshis, c.Signature) + return WriteElements( + w, pver, c.ChannelID, c.FeeSatoshis, c.Signature, c.ExtraData, + ) } // MsgType returns the integer uniquely identifying this message type on the @@ -73,16 +83,5 @@ func (c *ClosingSigned) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *ClosingSigned) MaxPayloadLength(uint32) uint32 { - var length uint32 - - // ChannelID - 32 bytes - length += 32 - - // FeeSatoshis - 8 bytes - length += 8 - - // Signature - 64 bytes - length += 64 - - return length + return MaxMsgBody } diff --git a/lnwire/commit_sig.go b/lnwire/commit_sig.go index 2455c016570..6b469af763e 100644 --- a/lnwire/commit_sig.go +++ b/lnwire/commit_sig.go @@ -34,11 +34,18 @@ type CommitSig struct { // should be signed, for each incoming HTLC the HTLC timeout // transaction should be signed. HtlcSigs []Sig + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewCommitSig creates a new empty CommitSig message. func NewCommitSig() *CommitSig { - return &CommitSig{} + return &CommitSig{ + ExtraData: make([]byte, 0), + } } // A compile time check to ensure CommitSig implements the lnwire.Message @@ -51,9 +58,11 @@ var _ Message = (*CommitSig)(nil) // This is part of the lnwire.Message interface. func (c *CommitSig) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &c.ChanID, &c.CommitSig, &c.HtlcSigs, + &c.ExtraData, ) } @@ -63,9 +72,11 @@ func (c *CommitSig) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *CommitSig) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChanID, c.CommitSig, c.HtlcSigs, + c.ExtraData, ) } @@ -82,8 +93,7 @@ func (c *CommitSig) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *CommitSig) MaxPayloadLength(uint32) uint32 { - // 32 + 64 + 2 + max_allowed_htlcs - return MaxMessagePayload + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is diff --git a/lnwire/error.go b/lnwire/error.go index c9fa39a8a45..5ee9881f603 100644 --- a/lnwire/error.go +++ b/lnwire/error.go @@ -94,6 +94,7 @@ func (c *Error) Error() string { // This is part of the lnwire.Message interface. func (c *Error) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &c.ChanID, &c.Data, ) @@ -105,6 +106,7 @@ func (c *Error) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *Error) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChanID, c.Data, ) @@ -123,8 +125,7 @@ func (c *Error) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *Error) MaxPayloadLength(uint32) uint32 { - // 32 + 2 + 65501 - return 65535 + return MaxMsgBody } // isASCII is a helper method that checks whether all bytes in `data` would be diff --git a/lnwire/extra_bytes.go b/lnwire/extra_bytes.go new file mode 100644 index 00000000000..8d3b4267cbe --- /dev/null +++ b/lnwire/extra_bytes.go @@ -0,0 +1,96 @@ +package lnwire + +import ( + "bytes" + "io" + "io/ioutil" + + "github.com/lightningnetwork/lnd/tlv" +) + +// ExtraOpaqueData is the set of data that was appended to this message, some +// of which we may not actually know how to iterate or parse. By holding onto +// this data, we ensure that we're able to properly validate the set of +// signatures that cover these new fields, and ensure we're able to make +// upgrades to the network in a forwards compatible manner. +type ExtraOpaqueData []byte + +// Encode attempts to encode the raw extra bytes into the passed io.Writer. +func (e *ExtraOpaqueData) Encode(w io.Writer, pver uint32) error { + // Only write out the extra data if we're using the new modern protocol + // version. + if pver != ProtocolVersionTLV { + return nil + } + + eBytes := []byte((*e)[:]) + if err := WriteElements(w, pver, eBytes); err != nil { + return err + } + + return nil +} + +// Decode attempts to unpack the raw bytes encoded in the passed io.Reader as a +// set of extra opaque data. +func (e *ExtraOpaqueData) Decode(r io.Reader, pver uint32) error { + // Only if we're using the modern protocl version will we attempt to + // keep on decoding past the end of the "main message". + if pver != ProtocolVersionTLV { + return nil + } + + // First, we'll attempt to read a set of bytes contained within the + // passed io.Reader (if any exist). + rawBytes, err := ioutil.ReadAll(r) + if err != nil { + return err + } + + // If we _do_ have some bytes, then we'll swap out our backing pointer. + // This ensures that any struct that embeds this type will properly + // store the bytes once this method exits. + if len(rawBytes) > 0 { + *e = ExtraOpaqueData(rawBytes) + } else { + *e = make([]byte, 0) + } + + return nil +} + +// PackRecords attempts to encode the set of tlv records into the target +// ExtraOpaqueData instance. The records will be encoded as a raw TLV stream +// and stored within the backing slice pointer. +func (e *ExtraOpaqueData) PackRecords(records []tlv.Record) error { + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return err + } + + var extraBytesWriter bytes.Buffer + if err := tlvStream.Encode(&extraBytesWriter); err != nil { + return err + } + + *e = ExtraOpaqueData(extraBytesWriter.Bytes()) + + return nil +} + +// ExtractRecords attempts to decode any types in the internal raw bytes as if +// it were a tlv stream. The set of raw parsed types is returned, and any +// passed records (if found in the stream) will be parsed into the proper +// tlv.Record. +func (e *ExtraOpaqueData) ExtractRecords(records ...tlv.Record) ( + tlv.TypeMap, error) { + + extraBytesReader := bytes.NewReader(*e) + + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return nil, err + } + + return tlvStream.DecodeWithParsedTypes(extraBytesReader) +} diff --git a/lnwire/extra_bytes_test.go b/lnwire/extra_bytes_test.go new file mode 100644 index 00000000000..3d1573d94c6 --- /dev/null +++ b/lnwire/extra_bytes_test.go @@ -0,0 +1,200 @@ +package lnwire + +import ( + "bytes" + "math/rand" + "reflect" + "testing" + "testing/quick" + + "github.com/lightningnetwork/lnd/tlv" +) + +// TestExtraOpaqueDataEncodeDecode tests that we're able to encode/decode +// arbitrary payloads. +func TestExtraOpaqueDataEncodeDecode(t *testing.T) { + t.Parallel() + + type testCase struct { + // emptyBytes indicates if we should try to encode empty bytes + // or not. + emptyBytes bool + + // inputBytes if emptyBytes is false, then we'll read in this + // set of bytes instead. + inputBytes []byte + } + + // We should be able to read in an arbitrary set of bytes as an + // ExtraOpaqueData, then encode those new bytes into a new instance. + // The final two instances should be identical. + scenario := func(test testCase) bool { + var ( + extraData ExtraOpaqueData + b bytes.Buffer + ) + + copy(extraData[:], test.inputBytes) + + if err := extraData.Encode(&b, ProtocolVersionTLV); err != nil { + t.Fatalf("unable to encode extra data: %v", err) + return false + } + + var newBytes ExtraOpaqueData + if err := newBytes.Decode(&b, ProtocolVersionTLV); err != nil { + t.Fatalf("unable to decode extra bytes: %v", err) + return false + } + + if !bytes.Equal(extraData[:], newBytes[:]) { + t.Fatalf("expected %x, got %x", extraData, + newBytes) + return false + } + + return true + } + + // We'll make a function to generate random test data. Half of the + // time, we'll actually feed in blank bytes. + quickCfg := &quick.Config{ + Values: func(v []reflect.Value, r *rand.Rand) { + + var newTestCase testCase + if r.Int31()%2 == 0 { + newTestCase.emptyBytes = true + } + + if !newTestCase.emptyBytes { + numBytes := r.Int31n(1000) + newTestCase.inputBytes = make([]byte, numBytes) + + _, err := r.Read(newTestCase.inputBytes) + if err != nil { + t.Fatalf("unable to gen random bytes: %v", err) + return + } + } + + v[0] = reflect.ValueOf(newTestCase) + }, + } + + if err := quick.Check(scenario, quickCfg); err != nil { + t.Fatalf("encode+decode test failed: %v", err) + } +} + +// TestExtraOpaqueDataPackUnpackRecords tests that we're able to pack a set of +// tlv.Records into a stream, and unpack them on the other side to obtain the +// same set of records. +func TestExtraOpaqueDataPackUnpackRecords(t *testing.T) { + t.Parallel() + + var ( + type1 tlv.Type = 1 + type2 tlv.Type = 2 + + channelType1 uint8 = 2 + channelType2 uint8 + + hop1 uint32 = 99 + hop2 uint32 + ) + testRecords := []tlv.Record{ + tlv.MakePrimitiveRecord(type1, &channelType1), + tlv.MakePrimitiveRecord(type2, &hop1), + } + + // Now that we have our set of sample records and types, we'll encode + // them into the passed ExtraOpaqueData instance. + var extraBytes ExtraOpaqueData + if err := extraBytes.PackRecords(testRecords); err != nil { + t.Fatalf("unable to pack records: %v", err) + } + + // We'll now simulate decoding these types _back_ into records on the + // other side. + newRecords := []tlv.Record{ + tlv.MakePrimitiveRecord(type1, &channelType2), + tlv.MakePrimitiveRecord(type2, &hop2), + } + typeMap, err := extraBytes.ExtractRecords(newRecords...) + if err != nil { + t.Fatalf("unable to extract record: %v", err) + } + + // We should find that the new backing values have been populated with + // the proper value. + switch { + case channelType1 != channelType2: + t.Fatalf("wrong record for channel type: expected %v, got %v", + channelType1, channelType2) + + case hop1 != hop2: + t.Fatalf("wrong record for hop: expected %v, got %v", hop1, + hop2) + } + + // Both types we created above should be found in the type map. + if _, ok := typeMap[type1]; !ok { + t.Fatalf("type1 not found in typeMap") + } + if _, ok := typeMap[type2]; !ok { + t.Fatalf("type2 not found in typeMap") + } +} + +// TestExtraOpaqueDataProtocolVersion tests that the encode/decode methods will +// observe the passed protocol version. +func TestExtraOpaqueDataProtocolVersion(t *testing.T) { + t.Parallel() + + extraData := ExtraOpaqueData([]byte("kek")) + + var b bytes.Buffer + if err := extraData.Encode(&b, ProtocolVersionLegacy); err != nil { + t.Fatalf("unable to encode: %v", err) + } + + // The statement above shouldn't have included the extra data since + // we're using the legacy protocol version. + if len(b.Bytes()) != 0 { + t.Fatalf("bytes were encoded using legacy "+ + "protocol version: %x", b.Bytes()) + } + + // If we encode using the proper version, then we should find the same + // data encoded on the other side. + if err := extraData.Encode(&b, ProtocolVersionTLV); err != nil { + t.Fatalf("unable to encode: %v", err) + } + if !bytes.Equal(b.Bytes(), extraData[:]) { + t.Fatalf("encoding mismatch: expected %x, got %x", + b.Bytes(), extraData[:]) + } + + // Now for the other direction, we'll attempt to decode into a fresh + // buffer, but using the legacy version. In the end, no bytes should be + // decoded. + var newExtraData ExtraOpaqueData + if err := newExtraData.Decode(&b, ProtocolVersionLegacy); err != nil { + t.Fatalf("unable to decode data: %v", err) + } + + if len(newExtraData[:]) != 0 { + t.Fatalf("expected not data to be decoded!") + } + + // Finally, if we decode using the proper protocol version, we should get + // the same bytes out that we put in. + if err := newExtraData.Decode(&b, ProtocolVersionTLV); err != nil { + t.Fatalf("unable to decode data: %v", err) + } + if !bytes.Equal(extraData[:], newExtraData[:]) { + t.Fatalf("encoding mismatch: expected %x, got %x", + extraData, newExtraData) + } + +} diff --git a/lnwire/funding_created.go b/lnwire/funding_created.go index c14321ec8f9..7eb7b2b2abd 100644 --- a/lnwire/funding_created.go +++ b/lnwire/funding_created.go @@ -24,6 +24,11 @@ type FundingCreated struct { // CommitSig is Alice's signature from Bob's version of the commitment // transaction. CommitSig Sig + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // A compile time check to ensure FundingCreated implements the lnwire.Message @@ -36,7 +41,10 @@ var _ Message = (*FundingCreated)(nil) // // This is part of the lnwire.Message interface. func (f *FundingCreated) Encode(w io.Writer, pver uint32) error { - return WriteElements(w, f.PendingChannelID[:], f.FundingPoint, f.CommitSig) + return WriteElements( + w, pver, f.PendingChannelID[:], f.FundingPoint, f.CommitSig, + f.ExtraData, + ) } // Decode deserializes the serialized FundingCreated stored in the passed @@ -45,7 +53,10 @@ func (f *FundingCreated) Encode(w io.Writer, pver uint32) error { // // This is part of the lnwire.Message interface. func (f *FundingCreated) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, f.PendingChannelID[:], &f.FundingPoint, &f.CommitSig) + return ReadElements( + r, pver, f.PendingChannelID[:], &f.FundingPoint, &f.CommitSig, + &f.ExtraData, + ) } // MsgType returns the uint32 code which uniquely identifies this message as a @@ -61,6 +72,5 @@ func (f *FundingCreated) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (f *FundingCreated) MaxPayloadLength(uint32) uint32 { - // 32 + 32 + 2 + 64 - return 130 + return MaxMsgBody } diff --git a/lnwire/funding_locked.go b/lnwire/funding_locked.go index c441b0be621..af857207ba2 100644 --- a/lnwire/funding_locked.go +++ b/lnwire/funding_locked.go @@ -19,6 +19,11 @@ type FundingLocked struct { // NextPerCommitmentPoint is the secret that can be used to revoke the // next commitment transaction for the channel. NextPerCommitmentPoint *btcec.PublicKey + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewFundingLocked creates a new FundingLocked message, populating it with the @@ -27,6 +32,7 @@ func NewFundingLocked(cid ChannelID, npcp *btcec.PublicKey) *FundingLocked { return &FundingLocked{ ChanID: cid, NextPerCommitmentPoint: npcp, + ExtraData: make([]byte, 0), } } @@ -41,8 +47,11 @@ var _ Message = (*FundingLocked)(nil) // This is part of the lnwire.Message interface. func (c *FundingLocked) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &c.ChanID, - &c.NextPerCommitmentPoint) + &c.NextPerCommitmentPoint, + &c.ExtraData, + ) } // Encode serializes the target FundingLocked message into the passed io.Writer @@ -52,8 +61,11 @@ func (c *FundingLocked) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *FundingLocked) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChanID, - c.NextPerCommitmentPoint) + c.NextPerCommitmentPoint, + c.ExtraData, + ) } // MsgType returns the uint32 code which uniquely identifies this message as a @@ -70,14 +82,5 @@ func (c *FundingLocked) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *FundingLocked) MaxPayloadLength(uint32) uint32 { - var length uint32 - - // ChanID - 32 bytes - length += 32 - - // NextPerCommitmentPoint - 33 bytes - length += 33 - - // 65 bytes - return length + return MaxMsgBody } diff --git a/lnwire/funding_signed.go b/lnwire/funding_signed.go index 620f8b37317..844d53a5f61 100644 --- a/lnwire/funding_signed.go +++ b/lnwire/funding_signed.go @@ -13,6 +13,11 @@ type FundingSigned struct { // CommitSig is Bob's signature for Alice's version of the commitment // transaction. CommitSig Sig + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // A compile time check to ensure FundingSigned implements the lnwire.Message @@ -25,7 +30,7 @@ var _ Message = (*FundingSigned)(nil) // // This is part of the lnwire.Message interface. func (f *FundingSigned) Encode(w io.Writer, pver uint32) error { - return WriteElements(w, f.ChanID, f.CommitSig) + return WriteElements(w, pver, f.ChanID, f.CommitSig, f.ExtraData) } // Decode deserializes the serialized FundingSigned stored in the passed @@ -34,7 +39,7 @@ func (f *FundingSigned) Encode(w io.Writer, pver uint32) error { // // This is part of the lnwire.Message interface. func (f *FundingSigned) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, &f.ChanID, &f.CommitSig) + return ReadElements(r, pver, &f.ChanID, &f.CommitSig, &f.ExtraData) } // MsgType returns the uint32 code which uniquely identifies this message as a @@ -50,6 +55,5 @@ func (f *FundingSigned) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (f *FundingSigned) MaxPayloadLength(uint32) uint32 { - // 32 + 64 - return 96 + return MaxMsgBody } diff --git a/lnwire/gossip_timestamp_range.go b/lnwire/gossip_timestamp_range.go index 3c28cd056c2..4d4c834c503 100644 --- a/lnwire/gossip_timestamp_range.go +++ b/lnwire/gossip_timestamp_range.go @@ -24,6 +24,11 @@ type GossipTimestampRange struct { // NOT send any announcements that have a timestamp greater than // FirstTimestamp + TimestampRange. TimestampRange uint32 + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewGossipTimestampRange creates a new empty GossipTimestampRange message. @@ -41,9 +46,11 @@ var _ Message = (*GossipTimestampRange)(nil) // This is part of the lnwire.Message interface. func (g *GossipTimestampRange) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, g.ChainHash[:], &g.FirstTimestamp, &g.TimestampRange, + &g.ExtraData, ) } @@ -53,9 +60,11 @@ func (g *GossipTimestampRange) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (g *GossipTimestampRange) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, g.ChainHash[:], g.FirstTimestamp, g.TimestampRange, + g.ExtraData, ) } @@ -73,8 +82,5 @@ func (g *GossipTimestampRange) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (g *GossipTimestampRange) MaxPayloadLength(uint32) uint32 { - // 32 + 4 + 4 - // - // TODO(roasbeef): update to 8 byte timestmaps? - return 40 + return MaxMsgBody } diff --git a/lnwire/init_message.go b/lnwire/init_message.go index 0236a71f84c..402e6bfabbc 100644 --- a/lnwire/init_message.go +++ b/lnwire/init_message.go @@ -20,6 +20,11 @@ type Init struct { // message, any GlobalFeatures should be merged into the unified // Features field. Features *RawFeatureVector + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewInitMessage creates new instance of init message object. @@ -27,6 +32,7 @@ func NewInitMessage(gf *RawFeatureVector, f *RawFeatureVector) *Init { return &Init{ GlobalFeatures: gf, Features: f, + ExtraData: make([]byte, 0), } } @@ -40,8 +46,10 @@ var _ Message = (*Init)(nil) // This is part of the lnwire.Message interface. func (msg *Init) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &msg.GlobalFeatures, &msg.Features, + &msg.ExtraData, ) } @@ -51,8 +59,10 @@ func (msg *Init) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (msg *Init) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, msg.GlobalFeatures, msg.Features, + msg.ExtraData, ) } @@ -69,5 +79,5 @@ func (msg *Init) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (msg *Init) MaxPayloadLength(uint32) uint32 { - return 2 + 2 + maxAllowedSize + 2 + maxAllowedSize + return MaxMsgBody } diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index ca0e449e5a5..fdbea7a3cbe 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -18,9 +18,16 @@ import ( "github.com/lightningnetwork/lnd/tor" ) -// MaxSliceLength is the maximum allowed length for any opaque byte slices in -// the wire protocol. -const MaxSliceLength = 65535 +const ( + // MaxSliceLength is the maximum allowed length for any opaque byte + // slices in the wire protocol. + MaxSliceLength = 65535 + + // MaxMsgBody is the largest payload any message is allowed to provide. + // This is two less than the MaxSliceLength as each message has a 2 + // byte type that precedes the message body. + MaxMsgBody = 65533 +) // PkScript is simple type definition which represents a raw serialized public // key script. @@ -70,11 +77,12 @@ func (a addressType) AddrLen() uint16 { // WriteElement is a one-stop shop to write the big endian representation of // any element which is to be serialized for the wire protocol. The passed // io.Writer should be backed by an appropriately sized byte slice, or be able -// to dynamically expand to accommodate additional data. +// to dynamically expand to accommodate additional data. The passed protocol +// version may affect how items are encoded. // // TODO(roasbeef): this should eventually draw from a buffer pool for // serialization. -func WriteElement(w io.Writer, element interface{}) error { +func WriteElement(w io.Writer, pver uint32, element interface{}) error { switch e := element.(type) { case NodeAlias: if _, err := w.Write(e[:]); err != nil { @@ -161,7 +169,7 @@ func WriteElement(w io.Writer, element interface{}) error { } for _, sig := range e { - if err := WriteElement(w, sig); err != nil { + if err := WriteElement(w, pver, sig); err != nil { return err } } @@ -262,7 +270,7 @@ func WriteElement(w io.Writer, element interface{}) error { return err } case FailCode: - if err := WriteElement(w, uint16(e)); err != nil { + if err := WriteElement(w, pver, uint16(e)); err != nil { return err } case ShortChannelID: @@ -376,7 +384,8 @@ func WriteElement(w io.Writer, element interface{}) error { // length of the addresses. var addrBuf bytes.Buffer for _, address := range e { - if err := WriteElement(&addrBuf, address); err != nil { + err := WriteElement(&addrBuf, pver, address) + if err != nil { return err } } @@ -384,7 +393,7 @@ func WriteElement(w io.Writer, element interface{}) error { // With the addresses fully encoded, we can now write out the // number of bytes needed to encode them. addrLen := addrBuf.Len() - if err := WriteElement(w, uint16(addrLen)); err != nil { + if err := WriteElement(w, pver, uint16(addrLen)); err != nil { return err } @@ -396,7 +405,7 @@ func WriteElement(w io.Writer, element interface{}) error { } } case color.RGBA: - if err := WriteElements(w, e.R, e.G, e.B); err != nil { + if err := WriteElements(w, pver, e.R, e.G, e.B); err != nil { return err } @@ -418,6 +427,13 @@ func WriteElement(w io.Writer, element interface{}) error { if _, err := w.Write(b[:]); err != nil { return err } + + case ExtraOpaqueData: + return e.Encode(w, pver) + + case TypedDeliveryAddress: + return e.Encode(w) + default: return fmt.Errorf("unknown type in WriteElement: %T", e) } @@ -427,9 +443,9 @@ func WriteElement(w io.Writer, element interface{}) error { // WriteElements is writes each element in the elements slice to the passed // io.Writer using WriteElement. -func WriteElements(w io.Writer, elements ...interface{}) error { +func WriteElements(w io.Writer, pver uint32, elements ...interface{}) error { for _, element := range elements { - err := WriteElement(w, element) + err := WriteElement(w, pver, element) if err != nil { return err } @@ -438,8 +454,9 @@ func WriteElements(w io.Writer, elements ...interface{}) error { } // ReadElement is a one-stop utility function to deserialize any datastructure -// encoded using the serialization format of lnwire. -func ReadElement(r io.Reader, element interface{}) error { +// encoded using the serialization format of lnwire. The passed protocol +// version may affect how items are decoded. +func ReadElement(r io.Reader, pver uint32, element interface{}) error { var err error switch e := element.(type) { case *bool: @@ -555,7 +572,8 @@ func ReadElement(r io.Reader, element interface{}) error { if numSigs > 0 { sigs = make([]Sig, numSigs) for i := 0; i < int(numSigs); i++ { - if err := ReadElement(r, &sigs[i]); err != nil { + err := ReadElement(r, pver, &sigs[i]) + if err != nil { return err } } @@ -647,7 +665,7 @@ func ReadElement(r io.Reader, element interface{}) error { Index: uint32(index), } case *FailCode: - if err := ReadElement(r, (*uint16)(e)); err != nil { + if err := ReadElement(r, pver, (*uint16)(e)); err != nil { return err } case *ChannelID: @@ -802,6 +820,7 @@ func ReadElement(r io.Reader, element interface{}) error { *e = addresses case *color.RGBA: err := ReadElements(r, + pver, &e.R, &e.G, &e.B, @@ -824,6 +843,13 @@ func ReadElement(r io.Reader, element interface{}) error { return err } *e = addrBytes[:length] + + case *ExtraOpaqueData: + return e.Decode(r, pver) + + case *TypedDeliveryAddress: + return e.Decode(r) + default: return fmt.Errorf("unknown type in ReadElement: %T", e) } @@ -833,10 +859,10 @@ func ReadElement(r io.Reader, element interface{}) error { // ReadElements deserializes a variable number of elements into the passed // io.Reader, with each element being deserialized according to the ReadElement -// function. -func ReadElements(r io.Reader, elements ...interface{}) error { +// function. The passed protocol version may affect how the items are encoded. +func ReadElements(r io.Reader, pver uint32, elements ...interface{}) error { for _, element := range elements { - err := ReadElement(r, element) + err := ReadElement(r, pver, element) if err != nil { return err } diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 02023b02317..d6cb76e7d7a 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -67,10 +67,10 @@ func randRawKey() ([33]byte, error) { return n, nil } -func randDeliveryAddress(r *rand.Rand) (DeliveryAddress, error) { +func randDeliveryAddress(r *rand.Rand) (TypedDeliveryAddress, error) { // Generate size minimum one. Empty scripts should be tested specifically. size := r.Intn(deliveryAddressMaxSize) + 1 - da := DeliveryAddress(make([]byte, size)) + da := TypedDeliveryAddress(make([]byte, size)) _, err := r.Read(da) return da, err @@ -232,7 +232,7 @@ func TestMaxOutPointIndex(t *testing.T) { } var b bytes.Buffer - if err := WriteElement(&b, op); err == nil { + if err := WriteElement(&b, ProtocolVersionTLV, op); err == nil { t.Fatalf("write of outPoint should fail, index exceeds 16-bits") } } @@ -265,7 +265,8 @@ func TestLightningWireProtocol(t *testing.T) { // Give a new message, we'll serialize the message into a new // bytes buffer. var b bytes.Buffer - if _, err := WriteMessage(&b, msg, 0); err != nil { + _, err := WriteMessage(&b, msg, ProtocolVersionTLV) + if err != nil { t.Fatalf("unable to write msg: %v", err) return false } @@ -282,7 +283,7 @@ func TestLightningWireProtocol(t *testing.T) { // Finally, we'll deserialize the message from the written // buffer, and finally assert that the messages are equal. - newMsg, err := ReadMessage(&b, 0) + newMsg, err := ReadMessage(&b, ProtocolVersionTLV) if err != nil { t.Fatalf("unable to read msg: %v", err) return false @@ -321,6 +322,7 @@ func TestLightningWireProtocol(t *testing.T) { CsvDelay: uint16(r.Int31()), MaxAcceptedHTLCs: uint16(r.Int31()), ChannelFlags: FundingFlag(uint8(r.Int31())), + ExtraData: make([]byte, 0), } if _, err := r.Read(req.ChainHash[:]); err != nil { @@ -387,6 +389,7 @@ func TestLightningWireProtocol(t *testing.T) { HtlcMinimum: MilliSatoshi(r.Int31()), CsvDelay: uint16(r.Int31()), MaxAcceptedHTLCs: uint16(r.Int31()), + ExtraData: make([]byte, 0), } if _, err := r.Read(req.PendingChannelID[:]); err != nil { @@ -440,7 +443,9 @@ func TestLightningWireProtocol(t *testing.T) { v[0] = reflect.ValueOf(req) }, MsgFundingCreated: func(v []reflect.Value, r *rand.Rand) { - req := FundingCreated{} + req := FundingCreated{ + ExtraData: make([]byte, 0), + } if _, err := r.Read(req.PendingChannelID[:]); err != nil { t.Fatalf("unable to generate pending chan id: %v", err) @@ -471,7 +476,8 @@ func TestLightningWireProtocol(t *testing.T) { } req := FundingSigned{ - ChanID: ChannelID(c), + ChanID: ChannelID(c), + ExtraData: make([]byte, 0), } req.CommitSig, err = NewSigFromSignature(testSig) if err != nil { @@ -502,6 +508,7 @@ func TestLightningWireProtocol(t *testing.T) { MsgClosingSigned: func(v []reflect.Value, r *rand.Rand) { req := ClosingSigned{ FeeSatoshis: btcutil.Amount(r.Int63()), + ExtraData: make([]byte, 0), } var err error req.Signature, err = NewSigFromSignature(testSig) @@ -570,8 +577,9 @@ func TestLightningWireProtocol(t *testing.T) { MsgChannelAnnouncement: func(v []reflect.Value, r *rand.Rand) { var err error req := ChannelAnnouncement{ - ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), - Features: randRawFeatureVector(r), + ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), + Features: randRawFeatureVector(r), + ExtraOpaqueData: make([]byte, 0), } req.NodeSig1, err = NewSigFromSignature(testSig) if err != nil { @@ -643,6 +651,7 @@ func TestLightningWireProtocol(t *testing.T) { G: uint8(r.Int31()), B: uint8(r.Int31()), }, + ExtraOpaqueData: make([]byte, 0), } req.Signature, err = NewSigFromSignature(testSig) if err != nil { @@ -698,6 +707,7 @@ func TestLightningWireProtocol(t *testing.T) { HtlcMaximumMsat: maxHtlc, BaseFee: uint32(r.Int31()), FeeRate: uint32(r.Int31()), + ExtraOpaqueData: make([]byte, 0), } req.Signature, err = NewSigFromSignature(testSig) if err != nil { @@ -726,7 +736,8 @@ func TestLightningWireProtocol(t *testing.T) { MsgAnnounceSignatures: func(v []reflect.Value, r *rand.Rand) { var err error req := AnnounceSignatures{ - ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), + ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), + ExtraOpaqueData: make([]byte, 0), } req.NodeSignature, err = NewSigFromSignature(testSig) @@ -763,6 +774,7 @@ func TestLightningWireProtocol(t *testing.T) { req := ChannelReestablish{ NextLocalCommitHeight: uint64(r.Int63()), RemoteCommitTailHeight: uint64(r.Int63()), + ExtraData: make([]byte, 0), } // With a 50/50 probability, we'll include the @@ -785,7 +797,9 @@ func TestLightningWireProtocol(t *testing.T) { v[0] = reflect.ValueOf(req) }, MsgQueryShortChanIDs: func(v []reflect.Value, r *rand.Rand) { - req := QueryShortChanIDs{} + req := QueryShortChanIDs{ + ExtraData: make([]byte, 0), + } // With a 50/50 change, we'll either use zlib encoding, // or regular encoding. @@ -810,10 +824,9 @@ func TestLightningWireProtocol(t *testing.T) { }, MsgReplyChannelRange: func(v []reflect.Value, r *rand.Rand) { req := ReplyChannelRange{ - QueryChannelRange: QueryChannelRange{ - FirstBlockHeight: uint32(r.Int31()), - NumBlocks: uint32(r.Int31()), - }, + FirstBlockHeight: uint32(r.Int31()), + NumBlocks: uint32(r.Int31()), + ExtraData: make([]byte, 0), } if _, err := rand.Read(req.ChainHash[:]); err != nil { diff --git a/lnwire/message.go b/lnwire/message.go index b5c27339e9e..58c95ca95fc 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -56,6 +56,20 @@ const ( MsgGossipTimestampRange = 265 ) +const ( + // ProtocolVersionLegacy is the legacy protocol version. When reading + // or writing messages using this protocol version, any optional fields + // appended to the end of the message will be ignored. + ProtocolVersionLegacy uint32 = 0 + + // ProtocolVersionTLV is the current modern protocol version. When + // reading/writing messages with this version, decoding will continue + // until the entire payload has been ready. When writing with this + // version, any optional fields appended to the end of the main message + // will also be written out. + ProtocolVersionTLV uint32 = 1 +) + // String return the string representation of message type. func (t MessageType) String() string { switch t { diff --git a/lnwire/node_announcement.go b/lnwire/node_announcement.go index f0d897bc91d..eb257509547 100644 --- a/lnwire/node_announcement.go +++ b/lnwire/node_announcement.go @@ -5,7 +5,6 @@ import ( "fmt" "image/color" "io" - "io/ioutil" "net" "unicode/utf8" ) @@ -98,7 +97,7 @@ type NodeAnnouncement struct { // properly validate the set of signatures that cover these new fields, // and ensure we're able to make upgrades to the network in a forwards // compatible manner. - ExtraOpaqueData []byte + ExtraOpaqueData ExtraOpaqueData } // A compile time check to ensure NodeAnnouncement implements the @@ -110,7 +109,8 @@ var _ Message = (*NodeAnnouncement)(nil) // // This is part of the lnwire.Message interface. func (a *NodeAnnouncement) Decode(r io.Reader, pver uint32) error { - err := ReadElements(r, + return ReadElements(r, + pver, &a.Signature, &a.Features, &a.Timestamp, @@ -118,24 +118,8 @@ func (a *NodeAnnouncement) Decode(r io.Reader, pver uint32) error { &a.RGBColor, &a.Alias, &a.Addresses, + &a.ExtraOpaqueData, ) - if err != nil { - return err - } - - // Now that we've read out all the fields that we explicitly know of, - // we'll collect the remainder into the ExtraOpaqueData field. If there - // aren't any bytes, then we'll snip off the slice to avoid carrying - // around excess capacity. - a.ExtraOpaqueData, err = ioutil.ReadAll(r) - if err != nil { - return err - } - if len(a.ExtraOpaqueData) == 0 { - a.ExtraOpaqueData = nil - } - - return nil } // Encode serializes the target NodeAnnouncement into the passed io.Writer @@ -143,6 +127,7 @@ func (a *NodeAnnouncement) Decode(r io.Reader, pver uint32) error { // func (a *NodeAnnouncement) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, a.Signature, a.Features, a.Timestamp, @@ -167,7 +152,7 @@ func (a *NodeAnnouncement) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (a *NodeAnnouncement) MaxPayloadLength(pver uint32) uint32 { - return 65533 + return MaxMsgBody } // DataToSign returns the part of the message that should be signed. @@ -176,6 +161,9 @@ func (a *NodeAnnouncement) DataToSign() ([]byte, error) { // We should not include the signatures itself. var w bytes.Buffer err := WriteElements(&w, + // We always use the modern protocol version as we need to + // include all data for forawrds compatability. + ProtocolVersionTLV, a.Features, a.Timestamp, a.NodeID, diff --git a/lnwire/onion_error.go b/lnwire/onion_error.go index c6235552e9c..1ab024de629 100644 --- a/lnwire/onion_error.go +++ b/lnwire/onion_error.go @@ -389,7 +389,7 @@ func (f *FailIncorrectDetails) Error() string { // // NOTE: Part of the Serializable interface. func (f *FailIncorrectDetails) Decode(r io.Reader, pver uint32) error { - err := ReadElement(r, &f.amount) + err := ReadElement(r, pver, &f.amount) switch { // This is an optional tack on that was added later in the protocol. As // a result, older nodes may not include this value. We'll account for @@ -404,7 +404,7 @@ func (f *FailIncorrectDetails) Decode(r io.Reader, pver uint32) error { // At a later stage, the height field was also tacked on. We need to // check for io.EOF here as well. - err = ReadElement(r, &f.height) + err = ReadElement(r, pver, &f.height) switch { case err == io.EOF: return nil @@ -420,7 +420,7 @@ func (f *FailIncorrectDetails) Decode(r io.Reader, pver uint32) error { // // NOTE: Part of the Serializable interface. func (f *FailIncorrectDetails) Encode(w io.Writer, pver uint32) error { - return WriteElements(w, f.amount, f.height) + return WriteElements(w, pver, f.amount, f.height) } // FailFinalExpiryTooSoon is returned if the cltv_expiry is too low, the final @@ -479,14 +479,14 @@ func (f *FailInvalidOnionVersion) Code() FailCode { // // NOTE: Part of the Serializable interface. func (f *FailInvalidOnionVersion) Decode(r io.Reader, pver uint32) error { - return ReadElement(r, f.OnionSHA256[:]) + return ReadElement(r, pver, f.OnionSHA256[:]) } // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. func (f *FailInvalidOnionVersion) Encode(w io.Writer, pver uint32) error { - return WriteElement(w, f.OnionSHA256[:]) + return WriteElement(w, pver, f.OnionSHA256[:]) } // FailInvalidOnionHmac is return if the onion HMAC is incorrect. @@ -513,14 +513,14 @@ func (f *FailInvalidOnionHmac) Code() FailCode { // // NOTE: Part of the Serializable interface. func (f *FailInvalidOnionHmac) Decode(r io.Reader, pver uint32) error { - return ReadElement(r, f.OnionSHA256[:]) + return ReadElement(r, pver, f.OnionSHA256[:]) } // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. func (f *FailInvalidOnionHmac) Encode(w io.Writer, pver uint32) error { - return WriteElement(w, f.OnionSHA256[:]) + return WriteElement(w, pver, f.OnionSHA256[:]) } // Returns a human readable string describing the target FailureMessage. @@ -555,14 +555,14 @@ func (f *FailInvalidOnionKey) Code() FailCode { // // NOTE: Part of the Serializable interface. func (f *FailInvalidOnionKey) Decode(r io.Reader, pver uint32) error { - return ReadElement(r, f.OnionSHA256[:]) + return ReadElement(r, pver, f.OnionSHA256[:]) } // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. func (f *FailInvalidOnionKey) Encode(w io.Writer, pver uint32) error { - return WriteElement(w, f.OnionSHA256[:]) + return WriteElement(w, pver, f.OnionSHA256[:]) } // Returns a human readable string describing the target FailureMessage. @@ -652,7 +652,7 @@ func (f *FailTemporaryChannelFailure) Error() string { // NOTE: Part of the Serializable interface. func (f *FailTemporaryChannelFailure) Decode(r io.Reader, pver uint32) error { var length uint16 - err := ReadElement(r, &length) + err := ReadElement(r, pver, &length) if err != nil { return err } @@ -680,7 +680,7 @@ func (f *FailTemporaryChannelFailure) Encode(w io.Writer, pver uint32) error { payload = bw.Bytes() } - if err := WriteElement(w, uint16(len(payload))); err != nil { + if err := WriteElement(w, pver, uint16(len(payload))); err != nil { return err } @@ -731,12 +731,12 @@ func (f *FailAmountBelowMinimum) Error() string { // // NOTE: Part of the Serializable interface. func (f *FailAmountBelowMinimum) Decode(r io.Reader, pver uint32) error { - if err := ReadElement(r, &f.HtlcMsat); err != nil { + if err := ReadElement(r, pver, &f.HtlcMsat); err != nil { return err } var length uint16 - if err := ReadElement(r, &length); err != nil { + if err := ReadElement(r, pver, &length); err != nil { return err } @@ -750,7 +750,7 @@ func (f *FailAmountBelowMinimum) Decode(r io.Reader, pver uint32) error { // // NOTE: Part of the Serializable interface. func (f *FailAmountBelowMinimum) Encode(w io.Writer, pver uint32) error { - if err := WriteElement(w, f.HtlcMsat); err != nil { + if err := WriteElement(w, pver, f.HtlcMsat); err != nil { return err } @@ -799,12 +799,12 @@ func (f *FailFeeInsufficient) Error() string { // // NOTE: Part of the Serializable interface. func (f *FailFeeInsufficient) Decode(r io.Reader, pver uint32) error { - if err := ReadElement(r, &f.HtlcMsat); err != nil { + if err := ReadElement(r, pver, &f.HtlcMsat); err != nil { return err } var length uint16 - if err := ReadElement(r, &length); err != nil { + if err := ReadElement(r, pver, &length); err != nil { return err } @@ -818,7 +818,7 @@ func (f *FailFeeInsufficient) Decode(r io.Reader, pver uint32) error { // // NOTE: Part of the Serializable interface. func (f *FailFeeInsufficient) Encode(w io.Writer, pver uint32) error { - if err := WriteElement(w, f.HtlcMsat); err != nil { + if err := WriteElement(w, pver, f.HtlcMsat); err != nil { return err } @@ -867,12 +867,12 @@ func (f *FailIncorrectCltvExpiry) Error() string { // // NOTE: Part of the Serializable interface. func (f *FailIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error { - if err := ReadElement(r, &f.CltvExpiry); err != nil { + if err := ReadElement(r, pver, &f.CltvExpiry); err != nil { return err } var length uint16 - if err := ReadElement(r, &length); err != nil { + if err := ReadElement(r, pver, &length); err != nil { return err } @@ -886,7 +886,7 @@ func (f *FailIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error { // // NOTE: Part of the Serializable interface. func (f *FailIncorrectCltvExpiry) Encode(w io.Writer, pver uint32) error { - if err := WriteElement(w, f.CltvExpiry); err != nil { + if err := WriteElement(w, pver, f.CltvExpiry); err != nil { return err } @@ -929,7 +929,7 @@ func (f *FailExpiryTooSoon) Error() string { // NOTE: Part of the Serializable interface. func (f *FailExpiryTooSoon) Decode(r io.Reader, pver uint32) error { var length uint16 - if err := ReadElement(r, &length); err != nil { + if err := ReadElement(r, pver, &length); err != nil { return err } @@ -988,12 +988,12 @@ func (f *FailChannelDisabled) Error() string { // // NOTE: Part of the Serializable interface. func (f *FailChannelDisabled) Decode(r io.Reader, pver uint32) error { - if err := ReadElement(r, &f.Flags); err != nil { + if err := ReadElement(r, pver, &f.Flags); err != nil { return err } var length uint16 - if err := ReadElement(r, &length); err != nil { + if err := ReadElement(r, pver, &length); err != nil { return err } @@ -1007,7 +1007,7 @@ func (f *FailChannelDisabled) Decode(r io.Reader, pver uint32) error { // // NOTE: Part of the Serializable interface. func (f *FailChannelDisabled) Encode(w io.Writer, pver uint32) error { - if err := WriteElement(w, f.Flags); err != nil { + if err := WriteElement(w, pver, f.Flags); err != nil { return err } @@ -1050,14 +1050,14 @@ func (f *FailFinalIncorrectCltvExpiry) Code() FailCode { // // NOTE: Part of the Serializable interface. func (f *FailFinalIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error { - return ReadElement(r, &f.CltvExpiry) + return ReadElement(r, pver, &f.CltvExpiry) } // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. func (f *FailFinalIncorrectCltvExpiry) Encode(w io.Writer, pver uint32) error { - return WriteElement(w, f.CltvExpiry) + return WriteElement(w, pver, f.CltvExpiry) } // FailFinalIncorrectHtlcAmount is returned if the amt_to_forward is higher @@ -1096,14 +1096,14 @@ func (f *FailFinalIncorrectHtlcAmount) Code() FailCode { // // NOTE: Part of the Serializable interface. func (f *FailFinalIncorrectHtlcAmount) Decode(r io.Reader, pver uint32) error { - return ReadElement(r, &f.IncomingHTLCAmount) + return ReadElement(r, pver, &f.IncomingHTLCAmount) } // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. func (f *FailFinalIncorrectHtlcAmount) Encode(w io.Writer, pver uint32) error { - return WriteElement(w, f.IncomingHTLCAmount) + return WriteElement(w, pver, f.IncomingHTLCAmount) } // FailExpiryTooFar is returned if the CLTV expiry in the HTLC is too far in the @@ -1171,7 +1171,7 @@ func (f *InvalidOnionPayload) Decode(r io.Reader, pver uint32) error { } f.Type = typ - return ReadElements(r, &f.Offset) + return ReadElements(r, pver, &f.Offset) } // Encode writes the failure in bytes stream. @@ -1183,7 +1183,7 @@ func (f *InvalidOnionPayload) Encode(w io.Writer, pver uint32) error { return err } - return WriteElements(w, f.Offset) + return WriteElements(w, pver, f.Offset) } // FailMPPTimeout is returned if the complete amount for a multi part payment @@ -1212,7 +1212,7 @@ func DecodeFailure(r io.Reader, pver uint32) (FailureMessage, error) { // First, we'll parse out the encapsulated failure message itself. This // is a 2 byte length followed by the payload itself. var failureLength uint16 - if err := ReadElement(r, &failureLength); err != nil { + if err := ReadElement(r, pver, &failureLength); err != nil { return nil, fmt.Errorf("unable to read error len: %v", err) } if failureLength > FailureMessageLength { @@ -1284,6 +1284,7 @@ func EncodeFailure(w io.Writer, failure FailureMessage, pver uint32) error { pad := make([]byte, FailureMessageLength-len(failureMessage)) return WriteElements(w, + pver, uint16(len(failureMessage)), failureMessage, uint16(len(pad)), @@ -1414,7 +1415,7 @@ func writeOnionErrorChanUpdate(w io.Writer, chanUpdate *ChannelUpdate, // Now that we know the size, we can write the length out in the main // writer. updateLen := b.Len() - if err := WriteElement(w, uint16(updateLen)); err != nil { + if err := WriteElement(w, pver, uint16(updateLen)); err != nil { return err } diff --git a/lnwire/onion_error_test.go b/lnwire/onion_error_test.go index 3ec147d1ddb..a1ed4fbe1d6 100644 --- a/lnwire/onion_error_test.go +++ b/lnwire/onion_error_test.go @@ -20,11 +20,12 @@ var ( testOffset = uint16(24) sig, _ = NewSigFromSignature(testSig) testChannelUpdate = ChannelUpdate{ - Signature: sig, - ShortChannelID: NewShortChanIDFromInt(1), - Timestamp: 1, - MessageFlags: 0, - ChannelFlags: 1, + Signature: sig, + ShortChannelID: NewShortChanIDFromInt(1), + Timestamp: 1, + MessageFlags: 0, + ChannelFlags: 1, + ExtraOpaqueData: make([]byte, 0), } ) @@ -62,12 +63,13 @@ func TestEncodeDecodeCode(t *testing.T) { for _, failure1 := range onionFailures { var b bytes.Buffer - if err := EncodeFailure(&b, failure1, 0); err != nil { + err := EncodeFailure(&b, failure1, ProtocolVersionTLV) + if err != nil { t.Fatalf("unable to encode failure code(%v): %v", failure1.Code(), err) } - failure2, err := DecodeFailure(&b, 0) + failure2, err := DecodeFailure(&b, ProtocolVersionTLV) if err != nil { t.Fatalf("unable to decode failure code(%v): %v", failure1.Code(), err) @@ -89,7 +91,7 @@ func TestChannelUpdateCompatabilityParsing(t *testing.T) { // We'll start by taking out test channel update, and encoding it into // a set of raw bytes. var b bytes.Buffer - if err := testChannelUpdate.Encode(&b, 0); err != nil { + if err := testChannelUpdate.Encode(&b, ProtocolVersionTLV); err != nil { t.Fatalf("unable to encode chan update: %v", err) } @@ -98,7 +100,7 @@ func TestChannelUpdateCompatabilityParsing(t *testing.T) { // encoded channel update message. var newChanUpdate ChannelUpdate err := parseChannelUpdateCompatabilityMode( - bufio.NewReader(&b), &newChanUpdate, 0, + bufio.NewReader(&b), &newChanUpdate, ProtocolVersionTLV, ) if err != nil { t.Fatalf("unable to parse channel update: %v", err) @@ -119,7 +121,7 @@ func TestChannelUpdateCompatabilityParsing(t *testing.T) { var tByte [2]byte binary.BigEndian.PutUint16(tByte[:], MsgChannelUpdate) b.Write(tByte[:]) - if err := testChannelUpdate.Encode(&b, 0); err != nil { + if err := testChannelUpdate.Encode(&b, ProtocolVersionTLV); err != nil { t.Fatalf("unable to encode chan update: %v", err) } @@ -127,7 +129,7 @@ func TestChannelUpdateCompatabilityParsing(t *testing.T) { // message even with the extra two bytes. var newChanUpdate2 ChannelUpdate err = parseChannelUpdateCompatabilityMode( - bufio.NewReader(&b), &newChanUpdate2, 0, + bufio.NewReader(&b), &newChanUpdate2, ProtocolVersionTLV, ) if err != nil { t.Fatalf("unable to parse channel update: %v", err) @@ -164,7 +166,7 @@ func TestWriteOnionErrorChanUpdate(t *testing.T) { // Finally, read the length encoded and ensure that it matches the raw // length. var encodedLen uint16 - if err := ReadElement(&errorBuf, &encodedLen); err != nil { + if err := ReadElement(&errorBuf, ProtocolVersionTLV, &encodedLen); err != nil { t.Fatalf("unable to read len: %v", err) } if uint16(trueUpdateLength) != encodedLen { @@ -275,5 +277,5 @@ func (f *mockFailIncorrectDetailsNoHeight) Decode(r io.Reader, pver uint32) erro } func (f *mockFailIncorrectDetailsNoHeight) Encode(w io.Writer, pver uint32) error { - return WriteElement(w, f.amount) + return WriteElement(w, ProtocolVersionTLV, f.amount) } diff --git a/lnwire/open_channel.go b/lnwire/open_channel.go index f78cc26eff5..bfe25599b2a 100644 --- a/lnwire/open_channel.go +++ b/lnwire/open_channel.go @@ -127,7 +127,12 @@ type OpenChannel struct { // be paid when mutually closing the channel. This field is optional, and // and has a length prefix, so a zero will be written if it is not set // and its length followed by the script will be written if it is set. - UpfrontShutdownScript DeliveryAddress + UpfrontShutdownScript TypedDeliveryAddress + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // A compile time check to ensure OpenChannel implements the lnwire.Message @@ -141,6 +146,7 @@ var _ Message = (*OpenChannel)(nil) // This is part of the lnwire.Message interface. func (o *OpenChannel) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, o.ChainHash[:], o.PendingChannelID[:], o.FundingAmount, @@ -160,6 +166,7 @@ func (o *OpenChannel) Encode(w io.Writer, pver uint32) error { o.FirstCommitmentPoint, o.ChannelFlags, o.UpfrontShutdownScript, + o.ExtraData, ) } @@ -169,7 +176,8 @@ func (o *OpenChannel) Encode(w io.Writer, pver uint32) error { // // This is part of the lnwire.Message interface. func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { - if err := ReadElements(r, + return ReadElements(r, + pver, o.ChainHash[:], o.PendingChannelID[:], &o.FundingAmount, @@ -188,18 +196,9 @@ func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { &o.HtlcPoint, &o.FirstCommitmentPoint, &o.ChannelFlags, - ); err != nil { - return err - } - - // Check for the optional upfront shutdown script field. If it is not there, - // silence the EOF error. - err := ReadElement(r, &o.UpfrontShutdownScript) - if err != nil && err != io.EOF { - return err - } - - return nil + &o.UpfrontShutdownScript, + &o.ExtraData, + ) } // MsgType returns the MessageType code which uniquely identifies this message @@ -215,11 +214,5 @@ func (o *OpenChannel) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (o *OpenChannel) MaxPayloadLength(uint32) uint32 { - // (32 * 2) + (8 * 6) + (4 * 1) + (2 * 2) + (33 * 6) + 1 - var length uint32 = 319 // base length - - // Upfront shutdown script max length. - length += 2 + deliveryAddressMaxSize - - return length + return MaxMsgBody } diff --git a/lnwire/ping.go b/lnwire/ping.go index cf9a83b78ce..cc75c276eca 100644 --- a/lnwire/ping.go +++ b/lnwire/ping.go @@ -36,6 +36,7 @@ var _ Message = (*Ping)(nil) // This is part of the lnwire.Message interface. func (p *Ping) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &p.NumPongBytes, &p.PaddingBytes) } @@ -46,6 +47,7 @@ func (p *Ping) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (p *Ping) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, p.NumPongBytes, p.PaddingBytes) } diff --git a/lnwire/pong.go b/lnwire/pong.go index c3166aaf6d0..3057cd57953 100644 --- a/lnwire/pong.go +++ b/lnwire/pong.go @@ -32,7 +32,7 @@ var _ Message = (*Pong)(nil) // This is part of the lnwire.Message interface. func (p *Pong) Decode(r io.Reader, pver uint32) error { return ReadElements(r, - &p.PongBytes, + pver, &p.PongBytes, ) } @@ -42,7 +42,7 @@ func (p *Pong) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (p *Pong) Encode(w io.Writer, pver uint32) error { return WriteElements(w, - p.PongBytes, + pver, p.PongBytes, ) } diff --git a/lnwire/query_channel_range.go b/lnwire/query_channel_range.go index 9546fcd32a1..6d07763c399 100644 --- a/lnwire/query_channel_range.go +++ b/lnwire/query_channel_range.go @@ -25,6 +25,11 @@ type QueryChannelRange struct { // NumBlocks is the number of blocks beyond the first block that short // channel ID's should be sent for. NumBlocks uint32 + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewQueryChannelRange creates a new empty QueryChannelRange message. @@ -42,9 +47,11 @@ var _ Message = (*QueryChannelRange)(nil) // This is part of the lnwire.Message interface. func (q *QueryChannelRange) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, q.ChainHash[:], &q.FirstBlockHeight, &q.NumBlocks, + &q.ExtraData, ) } @@ -54,9 +61,11 @@ func (q *QueryChannelRange) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (q *QueryChannelRange) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, q.ChainHash[:], q.FirstBlockHeight, q.NumBlocks, + q.ExtraData, ) } @@ -73,8 +82,7 @@ func (q *QueryChannelRange) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (q *QueryChannelRange) MaxPayloadLength(uint32) uint32 { - // 32 + 4 + 4 - return 40 + return MaxMsgBody } // LastBlockHeight returns the last block height covered by the range of a diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index cb24178b39e..c6dfab2de37 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -81,6 +81,11 @@ type QueryShortChanIDs struct { // ShortChanIDs is a slice of decoded short channel ID's. ShortChanIDs []ShortChannelID + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData + // noSort indicates whether or not to sort the short channel ids before // writing them out. // @@ -108,25 +113,30 @@ var _ Message = (*QueryShortChanIDs)(nil) // // This is part of the lnwire.Message interface. func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error { - err := ReadElements(r, q.ChainHash[:]) + err := ReadElements(r, pver, q.ChainHash[:]) if err != nil { return err } - q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r) + q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r, pver) + if err != nil { + return err + } - return err + return q.ExtraData.Decode(r, pver) } // decodeShortChanIDs decodes a set of short channel ID's that have been // encoded. The first byte of the body details how the short chan ID's were // encoded. We'll use this type to govern exactly how we go about encoding the -// set of short channel ID's. -func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, error) { +// set of short channel ID's. The protocol version may affect how the IDs are +// decoded. +func decodeShortChanIDs(r io.Reader, + pver uint32) (ShortChanIDEncoding, []ShortChannelID, error) { // First, we'll attempt to read the number of bytes in the body of the // set of encoded short channel ID's. var numBytesResp uint16 - err := ReadElements(r, &numBytesResp) + err := ReadElements(r, pver, &numBytesResp) if err != nil { return 0, nil, err } @@ -179,7 +189,8 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err bodyReader := bytes.NewReader(queryBody) var lastChanID ShortChannelID for i := 0; i < numShortChanIDs; i++ { - if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil { + err := ReadElements(bodyReader, pver, &shortChanIDs[i]) + if err != nil { return 0, nil, fmt.Errorf("unable to parse "+ "short chan ID: %v", err) } @@ -235,7 +246,7 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err // We'll now attempt to read the next short channel ID // encoded in the payload. var cid ShortChannelID - err := ReadElements(limitedDecompressor, &cid) + err := ReadElements(limitedDecompressor, pver, &cid) switch { // If we get an EOF error, then that either means we've @@ -285,20 +296,28 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err // This is part of the lnwire.Message interface. func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error { // First, we'll write out the chain hash. - err := WriteElements(w, q.ChainHash[:]) + err := WriteElements(w, pver, q.ChainHash[:]) if err != nil { return err } // Base on our encoding type, we'll write out the set of short channel // ID's. - return encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort) + err = encodeShortChanIDs( + w, q.EncodingType, q.ShortChanIDs, q.noSort, pver, + ) + if err != nil { + return err + } + + return q.ExtraData.Encode(w, pver) } // encodeShortChanIDs encodes the passed short channel ID's into the passed -// io.Writer, respecting the specified encoding type. +// io.Writer, respecting the specified encoding type. The protocol version may +// affect how the items are encoded. func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, - shortChanIDs []ShortChannelID, noSort bool) error { + shortChanIDs []ShortChannelID, noSort bool, pver uint32) error { // For both of the current encoding types, the channel ID's are to be // sorted in place, so we'll do that now. The sorting is applied unless @@ -319,20 +338,20 @@ func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, // body. We add 1 as the response will have the encoding type // prepended to it. numBytesBody := uint16(len(shortChanIDs)*8) + 1 - if err := WriteElements(w, numBytesBody); err != nil { + if err := WriteElements(w, pver, numBytesBody); err != nil { return err } // We'll then write out the encoding that that follows the // actual encoded short channel ID's. - if err := WriteElements(w, encodingType); err != nil { + if err := WriteElements(w, pver, encodingType); err != nil { return err } // Now that we know they're sorted, we can write out each short // channel ID to the buffer. for _, chanID := range shortChanIDs { - if err := WriteElements(w, chanID); err != nil { + if err := WriteElements(w, pver, chanID); err != nil { return fmt.Errorf("unable to write short chan "+ "ID: %v", err) } @@ -363,7 +382,7 @@ func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, // into the zlib writer, which will do compressing on // the fly. for _, chanID := range shortChanIDs { - err := WriteElements(zlibWriter, chanID) + err := WriteElements(zlibWriter, pver, chanID) if err != nil { return fmt.Errorf("unable to write short chan "+ "ID: %v", err) @@ -394,10 +413,10 @@ func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, // Finally, we can write out the number of bytes, the // compression type, and finally the buffer itself. - if err := WriteElements(w, uint16(numBytesBody)); err != nil { + if err := WriteElements(w, pver, uint16(numBytesBody)); err != nil { return err } - if err := WriteElements(w, encodingType); err != nil { + if err := WriteElements(w, pver, encodingType); err != nil { return err } @@ -425,5 +444,5 @@ func (q *QueryShortChanIDs) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (q *QueryShortChanIDs) MaxPayloadLength(uint32) uint32 { - return MaxMessagePayload + return MaxMsgBody } diff --git a/lnwire/reply_channel_range.go b/lnwire/reply_channel_range.go index 430606025c2..2c49c0b5a82 100644 --- a/lnwire/reply_channel_range.go +++ b/lnwire/reply_channel_range.go @@ -1,14 +1,29 @@ package lnwire -import "io" +import ( + "io" + "math" + + "github.com/btcsuite/btcd/chaincfg/chainhash" +) // ReplyChannelRange is the response to the QueryChannelRange message. It // includes the original query, and the next streaming chunk of encoded short // channel ID's as the response. We'll also include a byte that indicates if // this is the last query in the message. type ReplyChannelRange struct { - // QueryChannelRange is the corresponding query to this response. - QueryChannelRange + // ChainHash denotes the target chain that we're trying to synchronize + // channel graph state for. + ChainHash chainhash.Hash + + // FirstBlockHeight is the first block in the query range. The + // responder should send all new short channel IDs from this block + // until this block plus the specified number of blocks. + FirstBlockHeight uint32 + + // NumBlocks is the number of blocks beyond the first block that short + // channel ID's should be sent for. + NumBlocks uint32 // Complete denotes if this is the conclusion of the set of streaming // responses to the original query. @@ -22,6 +37,11 @@ type ReplyChannelRange struct { // ShortChanIDs is a slice of decoded short channel ID's. ShortChanIDs []ShortChannelID + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData + // noSort indicates whether or not to sort the short channel ids before // writing them out. // @@ -43,18 +63,23 @@ var _ Message = (*ReplyChannelRange)(nil) // // This is part of the lnwire.Message interface. func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error { - err := c.QueryChannelRange.Decode(r, pver) + err := ReadElements(r, + pver, + c.ChainHash[:], + &c.FirstBlockHeight, + &c.NumBlocks, + &c.Complete, + ) if err != nil { return err } - if err := ReadElements(r, &c.Complete); err != nil { + c.EncodingType, c.ShortChanIDs, err = decodeShortChanIDs(r, pver) + if err != nil { return err } - c.EncodingType, c.ShortChanIDs, err = decodeShortChanIDs(r) - - return err + return c.ExtraData.Decode(r, pver) } // Encode serializes the target ReplyChannelRange into the passed io.Writer @@ -62,15 +87,25 @@ func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (c *ReplyChannelRange) Encode(w io.Writer, pver uint32) error { - if err := c.QueryChannelRange.Encode(w, pver); err != nil { + err := WriteElements(w, + pver, + c.ChainHash[:], + c.FirstBlockHeight, + c.NumBlocks, + c.Complete, + ) + if err != nil { return err } - if err := WriteElements(w, c.Complete); err != nil { + err = encodeShortChanIDs( + w, c.EncodingType, c.ShortChanIDs, c.noSort, pver, + ) + if err != nil { return err } - return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort) + return c.ExtraData.Encode(w, pver) } // MsgType returns the integer uniquely identifying this message type on the @@ -86,5 +121,16 @@ func (c *ReplyChannelRange) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *ReplyChannelRange) MaxPayloadLength(uint32) uint32 { - return MaxMessagePayload + return MaxMsgBody +} + +// LastBlockHeight returns the last block height covered by the range of a +// QueryChannelRange message. +func (c *ReplyChannelRange) LastBlockHeight() uint32 { + // Handle overflows by casting to uint64. + lastBlockHeight := uint64(c.FirstBlockHeight) + uint64(c.NumBlocks) - 1 + if lastBlockHeight > math.MaxUint32 { + return math.MaxUint32 + } + return uint32(lastBlockHeight) } diff --git a/lnwire/reply_channel_range_test.go b/lnwire/reply_channel_range_test.go index d2c8df68c68..a40356040d3 100644 --- a/lnwire/reply_channel_range_test.go +++ b/lnwire/reply_channel_range_test.go @@ -30,7 +30,7 @@ func TestReplyChannelRangeUnsorted(t *testing.T) { var req2 ReplyChannelRange err = req2.Decode(bytes.NewReader(b.Bytes()), 0) if _, ok := err.(ErrUnsortedSIDs); !ok { - t.Fatalf("expected ErrUnsortedSIDs, got: %T", + t.Fatalf("expected ErrUnsortedSIDs, got: %v", err) } }) @@ -67,13 +67,12 @@ func TestReplyChannelRangeEmpty(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { req := ReplyChannelRange{ - QueryChannelRange: QueryChannelRange{ - FirstBlockHeight: 1, - NumBlocks: 2, - }, - Complete: 1, - EncodingType: test.encType, - ShortChanIDs: nil, + FirstBlockHeight: 1, + NumBlocks: 2, + Complete: 1, + EncodingType: test.encType, + ShortChanIDs: nil, + ExtraData: make([]byte, 0), } // First decode the hex string in the test case into a @@ -81,7 +80,9 @@ func TestReplyChannelRangeEmpty(t *testing.T) { // identical to the one created above. var req2 ReplyChannelRange b, _ := hex.DecodeString(test.encodedHex) - err := req2.Decode(bytes.NewReader(b), 0) + err := req2.Decode( + bytes.NewReader(b), ProtocolVersionTLV, + ) if err != nil { t.Fatalf("unable to decode req: %v", err) } @@ -94,7 +95,7 @@ func TestReplyChannelRangeEmpty(t *testing.T) { // request created above, and assert that it matches // the raw byte encoding. var b2 bytes.Buffer - err = req.Encode(&b2, 0) + err = req.Encode(&b2, ProtocolVersionTLV) if err != nil { t.Fatalf("unable to encode req: %v", err) } diff --git a/lnwire/reply_short_chan_ids_end.go b/lnwire/reply_short_chan_ids_end.go index ce5f8f740bd..64341ba2e8c 100644 --- a/lnwire/reply_short_chan_ids_end.go +++ b/lnwire/reply_short_chan_ids_end.go @@ -22,6 +22,11 @@ type ReplyShortChanIDsEnd struct { // set of short chan ID's in the corresponding QueryShortChanIDs // message. Complete uint8 + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewReplyShortChanIDsEnd creates a new empty ReplyShortChanIDsEnd message. @@ -39,8 +44,10 @@ var _ Message = (*ReplyShortChanIDsEnd)(nil) // This is part of the lnwire.Message interface. func (c *ReplyShortChanIDsEnd) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, c.ChainHash[:], &c.Complete, + &c.ExtraData, ) } @@ -50,8 +57,10 @@ func (c *ReplyShortChanIDsEnd) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *ReplyShortChanIDsEnd) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChainHash[:], c.Complete, + c.ExtraData, ) } @@ -69,6 +78,5 @@ func (c *ReplyShortChanIDsEnd) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *ReplyShortChanIDsEnd) MaxPayloadLength(uint32) uint32 { - // 32 (chain hash) + 1 (complete) - return 33 + return MaxMsgBody } diff --git a/lnwire/revoke_and_ack.go b/lnwire/revoke_and_ack.go index d685f0f3256..1e288877dcc 100644 --- a/lnwire/revoke_and_ack.go +++ b/lnwire/revoke_and_ack.go @@ -30,11 +30,18 @@ type RevokeAndAck struct { // create the proper revocation key used within the commitment // transaction. NextRevocationKey *btcec.PublicKey + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewRevokeAndAck creates a new RevokeAndAck message. func NewRevokeAndAck() *RevokeAndAck { - return &RevokeAndAck{} + return &RevokeAndAck{ + ExtraData: make([]byte, 0), + } } // A compile time check to ensure RevokeAndAck implements the lnwire.Message @@ -47,9 +54,11 @@ var _ Message = (*RevokeAndAck)(nil) // This is part of the lnwire.Message interface. func (c *RevokeAndAck) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &c.ChanID, c.Revocation[:], &c.NextRevocationKey, + &c.ExtraData, ) } @@ -59,9 +68,11 @@ func (c *RevokeAndAck) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *RevokeAndAck) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChanID, c.Revocation[:], c.NextRevocationKey, + c.ExtraData, ) } @@ -78,8 +89,7 @@ func (c *RevokeAndAck) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *RevokeAndAck) MaxPayloadLength(uint32) uint32 { - // 32 + 32 + 33 - return 97 + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is diff --git a/lnwire/shutdown.go b/lnwire/shutdown.go index 94d10a9080c..07270f8ddea 100644 --- a/lnwire/shutdown.go +++ b/lnwire/shutdown.go @@ -15,6 +15,11 @@ type Shutdown struct { // Address is the script to which the channel funds will be paid. Address DeliveryAddress + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // DeliveryAddress is used to communicate the address to which funds from a @@ -48,7 +53,7 @@ var _ Message = (*Shutdown)(nil) // // This is part of the lnwire.Message interface. func (s *Shutdown) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, &s.ChannelID, &s.Address) + return ReadElements(r, pver, &s.ChannelID, &s.Address, &s.ExtraData) } // Encode serializes the target Shutdown into the passed io.Writer observing @@ -56,7 +61,7 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (s *Shutdown) Encode(w io.Writer, pver uint32) error { - return WriteElements(w, s.ChannelID, s.Address) + return WriteElements(w, pver, s.ChannelID, s.Address, s.ExtraData) } // MsgType returns the integer uniquely identifying this message type on the @@ -72,16 +77,5 @@ func (s *Shutdown) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (s *Shutdown) MaxPayloadLength(pver uint32) uint32 { - var length uint32 - - // ChannelID - 32bytes - length += 32 - - // Len - 2 bytes - length += 2 - - // ScriptPubKey - maximum delivery address size. - length += deliveryAddressMaxSize - - return length + return MaxMsgBody } diff --git a/lnwire/typed_delivery_addr.go b/lnwire/typed_delivery_addr.go new file mode 100644 index 00000000000..fd546afd096 --- /dev/null +++ b/lnwire/typed_delivery_addr.go @@ -0,0 +1,65 @@ +package lnwire + +import ( + "io" + + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // DeliveryAddrType is the TLV record type for delivery addreses within + // the name space of the OpenChannel and AcceptChannel messages. + DeliveryAddrType = 0 +) + +// TypedDeliveryAddress is similar to the DeliveryAddrType type, but it's +// encoded using a mini TLV stream. This tyupe was intorudced in order to allow +// the OpenChannel/AcceptChannel messages to properly be extended with TLV types. +type TypedDeliveryAddress []byte + +// Encode encodes the target TypedDeliveryAddress into the target io.Writer +// using a TLV stream. +func (t *TypedDeliveryAddress) Encode(w io.Writer) error { + addrBytes := []byte((*t)[:]) + + records := []tlv.Record{ + tlv.MakeDynamicRecord( + DeliveryAddrType, &addrBytes, + func() uint64 { + return uint64(len(addrBytes)) + }, + tlv.EVarBytes, tlv.DVarBytes, + ), + } + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return err + } + + return tlvStream.Encode(w) +} + +// Decode decodes a set of bytes from the targer io.Reader into the target +// TypedDeliveryAddress. +func (t *TypedDeliveryAddress) Decode(r io.Reader) error { + addrBytes := []byte((*t)[:]) + + records := []tlv.Record{ + tlv.MakeDynamicRecord( + DeliveryAddrType, &addrBytes, nil, + tlv.EVarBytes, tlv.DVarBytes, + ), + } + + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return err + } + if err := tlvStream.Decode(r); err != nil { + return err + } + + *t = addrBytes + + return nil +} diff --git a/lnwire/typed_delivery_addr_test.go b/lnwire/typed_delivery_addr_test.go new file mode 100644 index 00000000000..c84915a6beb --- /dev/null +++ b/lnwire/typed_delivery_addr_test.go @@ -0,0 +1,31 @@ +package lnwire + +import ( + "bytes" + "testing" +) + +// TestTypedDeliveryAddressEncodeDecode tests that we're able to properly +// encode and decode typed delivery addresses. +func TestTypedDeliveryAddressEncodeDecode(t *testing.T) { + t.Parallel() + + addr := TypedDeliveryAddress( + bytes.Repeat([]byte("a"), deliveryAddressMaxSize), + ) + + var b bytes.Buffer + if err := addr.Encode(&b); err != nil { + t.Fatalf("unable to encode addr: %v", err) + } + + var addr2 TypedDeliveryAddress + if err := addr2.Decode(&b); err != nil { + t.Fatalf("unable to decode addr: %v", err) + } + + if !bytes.Equal(addr, addr2) { + t.Fatalf("addr mismatch: expected %x, got %x", addr[:], + addr2[:]) + } +} diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go index 028c6320d72..691a071d91f 100644 --- a/lnwire/update_add_htlc.go +++ b/lnwire/update_add_htlc.go @@ -52,6 +52,11 @@ type UpdateAddHTLC struct { // should strip off a layer of encryption, exposing the next hop to be // used in the subsequent UpdateAddHTLC message. OnionBlob [OnionPacketSize]byte + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewUpdateAddHTLC returns a new empty UpdateAddHTLC message. @@ -69,12 +74,14 @@ var _ Message = (*UpdateAddHTLC)(nil) // This is part of the lnwire.Message interface. func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &c.ChanID, &c.ID, &c.Amount, c.PaymentHash[:], &c.Expiry, c.OnionBlob[:], + &c.ExtraData, ) } @@ -84,12 +91,14 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *UpdateAddHTLC) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChanID, c.ID, c.Amount, c.PaymentHash[:], c.Expiry, c.OnionBlob[:], + c.ExtraData, ) } @@ -106,8 +115,7 @@ func (c *UpdateAddHTLC) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *UpdateAddHTLC) MaxPayloadLength(uint32) uint32 { - // 1450 - return 32 + 8 + 4 + 8 + 32 + 1366 + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is diff --git a/lnwire/update_fail_htlc.go b/lnwire/update_fail_htlc.go index 194f2ecd000..54592636f51 100644 --- a/lnwire/update_fail_htlc.go +++ b/lnwire/update_fail_htlc.go @@ -26,6 +26,11 @@ type UpdateFailHTLC struct { // failed. This blob is only fully decryptable by the initiator of the // HTLC message. Reason OpaqueReason + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // A compile time check to ensure UpdateFailHTLC implements the lnwire.Message @@ -38,9 +43,11 @@ var _ Message = (*UpdateFailHTLC)(nil) // This is part of the lnwire.Message interface. func (c *UpdateFailHTLC) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &c.ChanID, &c.ID, &c.Reason, + &c.ExtraData, ) } @@ -50,9 +57,11 @@ func (c *UpdateFailHTLC) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *UpdateFailHTLC) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChanID, c.ID, c.Reason, + c.ExtraData, ) } @@ -69,21 +78,7 @@ func (c *UpdateFailHTLC) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *UpdateFailHTLC) MaxPayloadLength(uint32) uint32 { - var length uint32 - - // Length of the ChanID - length += 32 - - // Length of the ID - length += 8 - - // Length of the length opaque reason - length += 2 - - // Length of the Reason - length += 292 - - return length + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is diff --git a/lnwire/update_fail_malformed_htlc.go b/lnwire/update_fail_malformed_htlc.go index 39d4b8709e2..60d04c5136f 100644 --- a/lnwire/update_fail_malformed_htlc.go +++ b/lnwire/update_fail_malformed_htlc.go @@ -24,6 +24,11 @@ type UpdateFailMalformedHTLC struct { // FailureCode the exact reason why onion blob haven't been parsed. FailureCode FailCode + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // A compile time check to ensure UpdateFailMalformedHTLC implements the @@ -36,10 +41,12 @@ var _ Message = (*UpdateFailMalformedHTLC)(nil) // This is part of the lnwire.Message interface. func (c *UpdateFailMalformedHTLC) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &c.ChanID, &c.ID, c.ShaOnionBlob[:], &c.FailureCode, + &c.ExtraData, ) } @@ -49,10 +56,12 @@ func (c *UpdateFailMalformedHTLC) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *UpdateFailMalformedHTLC) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChanID, c.ID, c.ShaOnionBlob[:], c.FailureCode, + c.ExtraData, ) } @@ -70,8 +79,7 @@ func (c *UpdateFailMalformedHTLC) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *UpdateFailMalformedHTLC) MaxPayloadLength(uint32) uint32 { - // 32 + 8 + 32 + 2 - return 74 + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is diff --git a/lnwire/update_fee.go b/lnwire/update_fee.go index 2d27c3772f7..4953cf6a4fc 100644 --- a/lnwire/update_fee.go +++ b/lnwire/update_fee.go @@ -16,6 +16,11 @@ type UpdateFee struct { // TODO(halseth): make SatPerKWeight when fee estimation is moved to // own package. Currently this will cause an import cycle. FeePerKw uint32 + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewUpdateFee creates a new UpdateFee message. @@ -36,8 +41,10 @@ var _ Message = (*UpdateFee)(nil) // This is part of the lnwire.Message interface. func (c *UpdateFee) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &c.ChanID, &c.FeePerKw, + &c.ExtraData, ) } @@ -47,8 +54,10 @@ func (c *UpdateFee) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *UpdateFee) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChanID, c.FeePerKw, + c.ExtraData, ) } @@ -65,8 +74,7 @@ func (c *UpdateFee) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *UpdateFee) MaxPayloadLength(uint32) uint32 { - // 32 + 4 - return 36 + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is diff --git a/lnwire/update_fulfill_htlc.go b/lnwire/update_fulfill_htlc.go index 6c0e6339ff6..150ddf3595f 100644 --- a/lnwire/update_fulfill_htlc.go +++ b/lnwire/update_fulfill_htlc.go @@ -21,6 +21,11 @@ type UpdateFulfillHTLC struct { // PaymentPreimage is the R-value preimage required to fully settle an // HTLC. PaymentPreimage [32]byte + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewUpdateFulfillHTLC returns a new empty UpdateFulfillHTLC. @@ -44,9 +49,11 @@ var _ Message = (*UpdateFulfillHTLC)(nil) // This is part of the lnwire.Message interface. func (c *UpdateFulfillHTLC) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &c.ChanID, &c.ID, c.PaymentPreimage[:], + &c.ExtraData, ) } @@ -56,9 +63,11 @@ func (c *UpdateFulfillHTLC) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *UpdateFulfillHTLC) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChanID, c.ID, c.PaymentPreimage[:], + c.ExtraData, ) } @@ -75,8 +84,7 @@ func (c *UpdateFulfillHTLC) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *UpdateFulfillHTLC) MaxPayloadLength(uint32) uint32 { - // 32 + 8 + 32 - return 72 + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is diff --git a/make/testing_flags.mk b/make/testing_flags.mk index 1443ab5b225..f64d859c896 100644 --- a/make/testing_flags.mk +++ b/make/testing_flags.mk @@ -3,12 +3,18 @@ RPC_TAGS = autopilotrpc chainrpc invoicesrpc routerrpc signrpc verrpc walletrpc LOG_TAGS = TEST_FLAGS = COVER_PKG = $$(go list -deps ./... | grep '$(PKG)' | grep -v lnrpc) +NUM_ITEST_TRANCHES = 6 # If rpc option is set also add all extra RPC tags to DEV_TAGS ifneq ($(with-rpc),) DEV_TAGS += $(RPC_TAGS) endif +# Scale the number of parallel running itest tranches. +ifneq ($(tranches),) +NUM_ITEST_TRANCHES = $(tranches) +endif + # If specific package is being unit tested, construct the full name of the # subpackage. ifneq ($(pkg),) @@ -25,7 +31,7 @@ endif # Define the integration test.run filter if the icase argument was provided. ifneq ($(icase),) -TEST_FLAGS += -test.run=TestLightningNetworkDaemon/$(icase) +TEST_FLAGS += -test.run="TestLightningNetworkDaemon/.*-of-.*/.*/$(icase)" endif ifneq ($(tags),) diff --git a/peer/brontide.go b/peer/brontide.go index c50f4ebae0a..85f32ae0ee0 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -919,7 +919,7 @@ func (p *Brontide) readNextMessage() (lnwire.Message, error) { // Next, create a new io.Reader implementation from the raw message, // and use this to decode the message directly from. msgReader := bytes.NewReader(rawMsg) - nextMsg, err := lnwire.ReadMessage(msgReader, 0) + nextMsg, err := lnwire.ReadMessage(msgReader, lnwire.ProtocolVersionTLV) if err != nil { return nil, err } diff --git a/routing/ann_validation.go b/routing/ann_validation.go index cc8530bb165..76d58cb85a8 100644 --- a/routing/ann_validation.go +++ b/routing/ann_validation.go @@ -110,7 +110,10 @@ func ValidateNodeAnn(a *lnwire.NodeAnnouncement) error { dataHash := chainhash.DoubleHashB(data) if !nodeSig.Verify(dataHash, nodeKey) { var msgBuf bytes.Buffer - if _, err := lnwire.WriteMessage(&msgBuf, a, 0); err != nil { + _, err := lnwire.WriteMessage( + &msgBuf, a, lnwire.ProtocolVersionTLV, + ) + if err != nil { return err } diff --git a/scripts/itest_part.sh b/scripts/itest_part.sh new file mode 100755 index 00000000000..52c3481c8c7 --- /dev/null +++ b/scripts/itest_part.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +# Let's work with absolute paths only, we run in the itest directory itself. +WORKDIR=$(pwd)/lntest/itest + +TRANCHE=$1 +NUM_TRANCHES=$2 + +# Shift the passed parameters by two, giving us all remaining testing flags in +# the $@ special variable. +shift +shift + +# Windows insists on having the .exe suffix for an executable, we need to add +# that here if necessary. +EXEC="$WORKDIR"/itest.test"$EXEC_SUFFIX" +LND_EXEC="$WORKDIR"/lnd-itest"$EXEC_SUFFIX" +echo $EXEC -test.v "$@" -logoutput -goroutinedump -logdir=.logs-tranche$TRANCHE -lndexec=$LND_EXEC -splittranches=$NUM_TRANCHES -runtranche=$TRANCHE + +# Exit code 255 causes the parallel jobs to abort, so if one part fails the +# other is aborted too. +cd "$WORKDIR" || exit 255 +$EXEC -test.v "$@" -logoutput -goroutinedump -logdir=.logs-tranche$TRANCHE -lndexec=$LND_EXEC -splittranches=$NUM_TRANCHES -runtranche=$TRANCHE || exit 255