From 916952416ca87c38d4dce6b5962892762b269838 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Wed, 4 Nov 2020 11:03:25 +0100 Subject: [PATCH 01/43] lntest: lower initial port, add ApplyPortOffset function To allow running multiple test tranches in parallel, we need a way to make sure the TCP ports don't collide. We'll work with offsets for the ports, using a different offset for each tranche. --- lntest/node.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/lntest/node.go b/lntest/node.go index cdf0be03bb6..8c72a1edb57 100644 --- a/lntest/node.go +++ b/lntest/node.go @@ -43,7 +43,7 @@ const ( // defaultNodePort is the start of the range for listening ports of // harness nodes. Ports are monotonically increasing starting from this // number and are determined by the results of nextAvailablePort(). - defaultNodePort = 19555 + defaultNodePort = 5555 // logPubKeyBytes is the number of bytes of the node's PubKey that will // be appended to the log file name. The whole PubKey is too long and @@ -104,6 +104,12 @@ func nextAvailablePort() int { panic("no ports available for listening") } +// ApplyPortOffset adds the given offset to the lastPort variable, making it +// possible to run the tests in parallel without colliding on the same ports. +func ApplyPortOffset(offset uint32) { + _ = atomic.AddUint32(&lastPort, offset) +} + // generateListeningPorts returns four ints representing ports to listen on // designated for the current lightning network test. This returns the next // available ports for the p2p, rpc, rest and profiling services. From b4c57eb7b2b5ea0fd4a0874fa9eb845e36588c83 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Wed, 4 Nov 2020 11:03:26 +0100 Subject: [PATCH 02/43] lntest: use nextAvailablePort for fee service --- lntest/fee_service.go | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/lntest/fee_service.go b/lntest/fee_service.go index 68e7d435a41..d71dae7d9f2 100644 --- a/lntest/fee_service.go +++ b/lntest/fee_service.go @@ -16,9 +16,6 @@ const ( // is returned. Requests for higher confirmation targets will fall back // to this. feeServiceTarget = 2 - - // feeServicePort is the tcp port on which the service runs. - feeServicePort = 16534 ) // feeService runs a web service that provides fee estimation information. @@ -40,16 +37,15 @@ type feeEstimates struct { // startFeeService spins up a go-routine to serve fee estimates. func startFeeService() *feeService { + port := nextAvailablePort() f := feeService{ - url: fmt.Sprintf( - "http://localhost:%v/fee-estimates.json", feeServicePort, - ), + url: fmt.Sprintf("http://localhost:%v/fee-estimates.json", port), } // Initialize default fee estimate. f.Fees = map[uint32]uint32{feeServiceTarget: 50000} - listenAddr := fmt.Sprintf(":%v", feeServicePort) + listenAddr := fmt.Sprintf(":%v", port) f.srv = &http.Server{ Addr: listenAddr, } From 05ac6aaca54e3915570c7c93599f9542402048b4 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Wed, 4 Nov 2020 11:03:27 +0100 Subject: [PATCH 03/43] lntest: use nextAvailablePort for bitcoind --- lntest/bitcoind_common.go | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/lntest/bitcoind_common.go b/lntest/bitcoind_common.go index b59fdac85d5..019ac268ee9 100644 --- a/lntest/bitcoind_common.go +++ b/lntest/bitcoind_common.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io/ioutil" - "math/rand" "os" "os/exec" "path/filepath" @@ -93,10 +92,10 @@ func newBackend(miner string, netParams *chaincfg.Params, extraArgs []string) ( fmt.Errorf("unable to create temp directory: %v", err) } - zmqBlockPath := "ipc:///" + tempBitcoindDir + "/blocks.socket" - zmqTxPath := "ipc:///" + tempBitcoindDir + "/txs.socket" - rpcPort := rand.Int()%(65536-1024) + 1024 - p2pPort := rand.Int()%(65536-1024) + 1024 + zmqBlockAddr := fmt.Sprintf("tcp://127.0.0.1:%d", nextAvailablePort()) + zmqTxAddr := fmt.Sprintf("tcp://127.0.0.1:%d", nextAvailablePort()) + rpcPort := nextAvailablePort() + p2pPort := nextAvailablePort() cmdArgs := []string{ "-datadir=" + tempBitcoindDir, @@ -106,8 +105,8 @@ func newBackend(miner string, netParams *chaincfg.Params, extraArgs []string) ( "220110063096c221be9933c82d38e1", fmt.Sprintf("-rpcport=%d", rpcPort), fmt.Sprintf("-port=%d", p2pPort), - "-zmqpubrawblock=" + zmqBlockPath, - "-zmqpubrawtx=" + zmqTxPath, + "-zmqpubrawblock=" + zmqBlockAddr, + "-zmqpubrawtx=" + zmqTxAddr, "-debuglogfile=" + logFile, } cmdArgs = append(cmdArgs, extraArgs...) @@ -178,8 +177,8 @@ func newBackend(miner string, netParams *chaincfg.Params, extraArgs []string) ( rpcHost: rpcHost, rpcUser: rpcUser, rpcPass: rpcPass, - zmqBlockPath: zmqBlockPath, - zmqTxPath: zmqTxPath, + zmqBlockPath: zmqBlockAddr, + zmqTxPath: zmqTxAddr, p2pPort: p2pPort, rpcClient: client, minerAddr: miner, From 7326c1549fba2bad2527992d65894992b9e99491 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Wed, 4 Nov 2020 11:03:29 +0100 Subject: [PATCH 04/43] itest: split tests into dynamic tranches --- lntest/itest/lnd_test.go | 82 +++++++++++++++++++++++--- lntest/itest/lnd_test_list_off_test.go | 2 +- lntest/itest/lnd_test_list_on_test.go | 2 +- 3 files changed, 75 insertions(+), 11 deletions(-) diff --git a/lntest/itest/lnd_test.go b/lntest/itest/lnd_test.go index 5dc089cbc99..1b1533fbe3b 100644 --- a/lntest/itest/lnd_test.go +++ b/lntest/itest/lnd_test.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "crypto/sha256" "encoding/hex" + "flag" "fmt" "io" "io/ioutil" @@ -53,6 +54,60 @@ import ( "github.com/stretchr/testify/require" ) +const ( + // defaultSplitTranches is the default number of tranches we split the + // test cases into. + defaultSplitTranches uint = 1 + + // defaultRunTranche is the default index of the test cases tranche that + // we run. + defaultRunTranche uint = 0 +) + +var ( + // testCasesSplitParts is the number of tranches the test cases should + // be split into. By default this is set to 1, so no splitting happens. + // If this value is increased, then the -runtranche flag must be + // specified as well to indicate which part should be run in the current + // invocation. + testCasesSplitTranches = flag.Uint( + "splittranches", defaultSplitTranches, "split the test cases "+ + "in this many tranches and run the tranche at "+ + "0-based index specified by the -runtranche flag", + ) + + // testCasesRunTranche is the 0-based index of the split test cases + // tranche to run in the current invocation. + testCasesRunTranche = flag.Uint( + "runtranche", defaultRunTranche, "run the tranche of the "+ + "split test cases with the given (0-based) index", + ) +) + +// getTestCaseSplitTranche returns the sub slice of the test cases that should +// be run as the current split tranche as well as the index and slice offset of +// the tranche. +func getTestCaseSplitTranche() ([]*testCase, uint, uint) { + numTranches := defaultSplitTranches + if testCasesSplitTranches != nil { + numTranches = *testCasesSplitTranches + } + runTranche := defaultRunTranche + if testCasesRunTranche != nil { + runTranche = *testCasesRunTranche + } + + numCases := uint(len(allTestCases)) + testsPerTranche := numCases / numTranches + trancheOffset := runTranche * testsPerTranche + trancheEnd := trancheOffset + testsPerTranche + if trancheEnd > numCases || runTranche == numTranches-1 { + trancheEnd = numCases + } + + return allTestCases[trancheOffset:trancheEnd], runTranche, trancheOffset +} + func rpcPointToWirePoint(t *harnessTest, chanPoint *lnrpc.ChannelPoint) wire.OutPoint { txid, err := lnd.GetChanPointFundingTxid(chanPoint) if err != nil { @@ -14098,10 +14153,14 @@ func getPaymentResult(stream routerrpc.Router_SendPaymentV2Client) ( // programmatically driven network of lnd nodes. func TestLightningNetworkDaemon(t *testing.T) { // If no tests are registered, then we can exit early. - if len(testsCases) == 0 { + if len(allTestCases) == 0 { t.Skip("integration tests not selected with flag 'rpctest'") } + // Parse testing flags that influence our test execution. + testCases, trancheIndex, trancheOffset := getTestCaseSplitTranche() + lntest.ApplyPortOffset(uint32(trancheIndex) * 1000) + ht := newHarnessTest(t, nil) // Declare the network harness here to gain access to its @@ -14149,8 +14208,7 @@ func TestLightningNetworkDaemon(t *testing.T) { // Connect chainbackend to miner. require.NoError( - t, chainBackend.ConnectMiner(), - "failed to connect to miner", + t, chainBackend.ConnectMiner(), "failed to connect to miner", ) binary := itestLndBinary @@ -14187,7 +14245,8 @@ func TestLightningNetworkDaemon(t *testing.T) { if !more { return } - ht.Logf("lnd finished with error (stderr):\n%v", err) + ht.Logf("lnd finished with error (stderr):\n%v", + err) } } }() @@ -14210,8 +14269,9 @@ func TestLightningNetworkDaemon(t *testing.T) { ht.Fatalf("unable to set up test lightning network: %v", err) } - t.Logf("Running %v integration tests", len(testsCases)) - for _, testCase := range testsCases { + // Run the subset of the test cases selected in this tranche. + for idx, testCase := range testCases { + testCase := testCase logLine := fmt.Sprintf("STARTING ============ %v ============\n", testCase.name) @@ -14232,7 +14292,10 @@ func TestLightningNetworkDaemon(t *testing.T) { // Start every test with the default static fee estimate. lndHarness.SetFeeEstimate(12500) - success := t.Run(testCase.name, func(t1 *testing.T) { + name := fmt.Sprintf("%02d-of-%d/%s/%s", + trancheOffset+uint(idx)+1, len(allTestCases), + chainBackend.Name(), testCase.name) + success := t.Run(name, func(t1 *testing.T) { ht := newHarnessTest(t1, lndHarness) ht.RunTestCase(testCase) }) @@ -14242,8 +14305,9 @@ func TestLightningNetworkDaemon(t *testing.T) { if !success { // Log failure time to help relate the lnd logs to the // failure. - t.Logf("Failure time: %v", - time.Now().Format("2006-01-02 15:04:05.000")) + t.Logf("Failure time: %v", time.Now().Format( + "2006-01-02 15:04:05.000", + )) break } } diff --git a/lntest/itest/lnd_test_list_off_test.go b/lntest/itest/lnd_test_list_off_test.go index ae18d5e0ca3..59795f1d1bb 100644 --- a/lntest/itest/lnd_test_list_off_test.go +++ b/lntest/itest/lnd_test_list_off_test.go @@ -2,4 +2,4 @@ package itest -var testsCases = []*testCase{} +var allTestCases = []*testCase{} diff --git a/lntest/itest/lnd_test_list_on_test.go b/lntest/itest/lnd_test_list_on_test.go index 98910d22b92..420331c7e5b 100644 --- a/lntest/itest/lnd_test_list_on_test.go +++ b/lntest/itest/lnd_test_list_on_test.go @@ -2,7 +2,7 @@ package itest -var testsCases = []*testCase{ +var allTestCases = []*testCase{ { name: "sweep coins", test: testSweepAllCoins, From a0483733a417734c1e84d5a6297761377d11d60f Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Wed, 4 Nov 2020 11:03:30 +0100 Subject: [PATCH 05/43] lntest: add log dir flag --- lntest/bitcoind_common.go | 18 +++++++++++------- lntest/btcd.go | 18 +++++++++++------- lntest/itest/lnd_test.go | 6 ++++-- lntest/node.go | 37 ++++++++++++++++++++++++++----------- 4 files changed, 52 insertions(+), 27 deletions(-) diff --git a/lntest/bitcoind_common.go b/lntest/bitcoind_common.go index 019ac268ee9..f673400abc5 100644 --- a/lntest/bitcoind_common.go +++ b/lntest/bitcoind_common.go @@ -15,8 +15,8 @@ import ( "github.com/btcsuite/btcd/rpcclient" ) -// logDir is the name of the temporary log directory. -const logDir = "./.backendlogs" +// logDirPattern is the pattern of the name of the temporary log directory. +const logDirPattern = "%s/.backendlogs" // BitcoindBackendConfig is an implementation of the BackendConfig interface // backed by a Bitcoind node. @@ -73,15 +73,16 @@ func (b BitcoindBackendConfig) Name() string { func newBackend(miner string, netParams *chaincfg.Params, extraArgs []string) ( *BitcoindBackendConfig, func() error, error) { + baseLogDir := fmt.Sprintf(logDirPattern, GetLogDir()) if netParams != &chaincfg.RegressionNetParams { return nil, nil, fmt.Errorf("only regtest supported") } - if err := os.MkdirAll(logDir, 0700); err != nil { + if err := os.MkdirAll(baseLogDir, 0700); err != nil { return nil, nil, err } - logFile, err := filepath.Abs(logDir + "/bitcoind.log") + logFile, err := filepath.Abs(baseLogDir + "/bitcoind.log") if err != nil { return nil, nil, err } @@ -128,13 +129,16 @@ func newBackend(miner string, netParams *chaincfg.Params, extraArgs []string) ( var errStr string // After shutting down the chain backend, we'll make a copy of // the log file before deleting the temporary log dir. - err := CopyFile("./output_bitcoind_chainbackend.log", logFile) + logDestination := fmt.Sprintf( + "%s/output_bitcoind_chainbackend.log", GetLogDir(), + ) + err := CopyFile(logDestination, logFile) if err != nil { errStr += fmt.Sprintf("unable to copy file: %v\n", err) } - if err = os.RemoveAll(logDir); err != nil { + if err = os.RemoveAll(baseLogDir); err != nil { errStr += fmt.Sprintf( - "cannot remove dir %s: %v\n", logDir, err, + "cannot remove dir %s: %v\n", baseLogDir, err, ) } if err := os.RemoveAll(tempBitcoindDir); err != nil { diff --git a/lntest/btcd.go b/lntest/btcd.go index 11e19d3fd9f..e8b8cac43cd 100644 --- a/lntest/btcd.go +++ b/lntest/btcd.go @@ -14,8 +14,8 @@ import ( "github.com/btcsuite/btcd/rpcclient" ) -// logDir is the name of the temporary log directory. -const logDir = "./.backendlogs" +// logDirPattern is the pattern of the name of the temporary log directory. +const logDirPattern = "%s/.backendlogs" // temp is used to signal we want to establish a temporary connection using the // btcd Node API. @@ -75,12 +75,13 @@ func (b BtcdBackendConfig) Name() string { func NewBackend(miner string, netParams *chaincfg.Params) ( *BtcdBackendConfig, func() error, error) { + baseLogDir := fmt.Sprintf(logDirPattern, GetLogDir()) args := []string{ "--rejectnonstd", "--txindex", "--trickleinterval=100ms", "--debuglevel=debug", - "--logdir=" + logDir, + "--logdir=" + baseLogDir, "--nowinservice", // The miner will get banned and disconnected from the node if // its requested data are not found. We add a nobanning flag to @@ -110,14 +111,17 @@ func NewBackend(miner string, netParams *chaincfg.Params) ( // After shutting down the chain backend, we'll make a copy of // the log file before deleting the temporary log dir. - logFile := logDir + "/" + netParams.Name + "/btcd.log" - err := CopyFile("./output_btcd_chainbackend.log", logFile) + logFile := baseLogDir + "/" + netParams.Name + "/btcd.log" + logDestination := fmt.Sprintf( + "%s/output_btcd_chainbackend.log", GetLogDir(), + ) + err := CopyFile(logDestination, logFile) if err != nil { errStr += fmt.Sprintf("unable to copy file: %v\n", err) } - if err = os.RemoveAll(logDir); err != nil { + if err = os.RemoveAll(baseLogDir); err != nil { errStr += fmt.Sprintf( - "cannot remove dir %s: %v\n", logDir, err, + "cannot remove dir %s: %v\n", baseLogDir, err, ) } if errStr != "" { diff --git a/lntest/itest/lnd_test.go b/lntest/itest/lnd_test.go index 1b1533fbe3b..c2ceac58c8b 100644 --- a/lntest/itest/lnd_test.go +++ b/lntest/itest/lnd_test.go @@ -2435,7 +2435,7 @@ func testOpenChannelAfterReorg(net *lntest.NetworkHarness, t *harnessTest) { ) // Set up a new miner that we can use to cause a reorg. - tempLogDir := "./.tempminerlogs" + tempLogDir := fmt.Sprintf("%s/.tempminerlogs", lntest.GetLogDir()) logFilename := "output-open_channel_reorg-temp_miner.log" tempMiner, tempMinerCleanUp, err := lntest.NewMiner( tempLogDir, logFilename, @@ -14158,6 +14158,8 @@ func TestLightningNetworkDaemon(t *testing.T) { } // Parse testing flags that influence our test execution. + logDir := lntest.GetLogDir() + require.NoError(t, os.MkdirAll(logDir, 0700)) testCases, trancheIndex, trancheOffset := getTestCaseSplitTranche() lntest.ApplyPortOffset(uint32(trancheIndex) * 1000) @@ -14176,7 +14178,7 @@ func TestLightningNetworkDaemon(t *testing.T) { // guarantees of getting included in to blocks. // // We will also connect it to our chain backend. - minerLogDir := "./.minerlogs" + minerLogDir := fmt.Sprintf("%s/.minerlogs", logDir) miner, minerCleanUp, err := lntest.NewMiner( minerLogDir, "output_btcd_miner.log", harnessNetParams, &rpcclient.NotificationHandlers{}, diff --git a/lntest/node.go b/lntest/node.go index 8c72a1edb57..cb368e75046 100644 --- a/lntest/node.go +++ b/lntest/node.go @@ -70,6 +70,10 @@ var ( logOutput = flag.Bool("logoutput", false, "log output from node n to file output-n.log") + // logSubDir is the default directory where the logs are written to if + // logOutput is true. + logSubDir = flag.String("logdir", ".", "default dir to write logs to") + // goroutineDump is a flag that can be set to dump the active // goroutines of test nodes on failure. goroutineDump = flag.Bool("goroutinedump", false, @@ -110,6 +114,15 @@ func ApplyPortOffset(offset uint32) { _ = atomic.AddUint32(&lastPort, offset) } +// GetLogDir returns the passed --logdir flag or the default value if it wasn't +// set. +func GetLogDir() string { + if logSubDir != nil && *logSubDir != "" { + return *logSubDir + } + return "." +} + // generateListeningPorts returns four ints representing ports to listen on // designated for the current lightning network test. This returns the next // available ports for the p2p, rpc, rest and profiling services. @@ -392,11 +405,9 @@ func NewMiner(logDir, logFilename string, netParams *chaincfg.Params, // After shutting down the miner, we'll make a copy of the log // file before deleting the temporary log dir. - logFile := fmt.Sprintf( - "%s/%s/btcd.log", logDir, netParams.Name, - ) - copyPath := fmt.Sprintf("./%s", logFilename) - err := CopyFile(copyPath, logFile) + logFile := fmt.Sprintf("%s/%s/btcd.log", logDir, netParams.Name) + copyPath := fmt.Sprintf("%s/../%s", logDir, logFilename) + err := CopyFile(filepath.Clean(copyPath), logFile) if err != nil { return fmt.Errorf("unable to copy file: %v", err) } @@ -481,24 +492,28 @@ func (hn *HarnessNode) start(lndBinary string, lndError chan<- error) error { // If the logoutput flag is passed, redirect output from the nodes to // log files. if *logOutput { - fileName := fmt.Sprintf("output-%d-%s-%s.log", hn.NodeID, + dir := GetLogDir() + fileName := fmt.Sprintf("%s/output-%d-%s-%s.log", dir, hn.NodeID, hn.Cfg.Name, hex.EncodeToString(hn.PubKey[:logPubKeyBytes])) // If the node's PubKey is not yet initialized, create a temporary // file name. Later, after the PubKey has been initialized, the // file can be moved to its final name with the PubKey included. if bytes.Equal(hn.PubKey[:4], []byte{0, 0, 0, 0}) { - fileName = fmt.Sprintf("output-%d-%s-tmp__.log", hn.NodeID, - hn.Cfg.Name) + fileName = fmt.Sprintf("%s/output-%d-%s-tmp__.log", + dir, hn.NodeID, hn.Cfg.Name) // Once the node has done its work, the log file can be renamed. finalizeLogfile = func() { if hn.logFile != nil { hn.logFile.Close() - newFileName := fmt.Sprintf("output-%d-%s-%s.log", - hn.NodeID, hn.Cfg.Name, - hex.EncodeToString(hn.PubKey[:logPubKeyBytes])) + pubKeyHex := hex.EncodeToString( + hn.PubKey[:logPubKeyBytes], + ) + newFileName := fmt.Sprintf("%s/output"+ + "-%d-%s-%s.log", dir, hn.NodeID, + hn.Cfg.Name, pubKeyHex) err := os.Rename(fileName, newFileName) if err != nil { fmt.Printf("could not rename "+ From 6a1c02665608f9bade3ff0d766e2efce26ae0e5c Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Wed, 4 Nov 2020 11:03:31 +0100 Subject: [PATCH 06/43] itest: add flags for lnd executable --- lntest/itest/lnd_test.go | 17 +---------------- lntest/itest/test_harness.go | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/lntest/itest/lnd_test.go b/lntest/itest/lnd_test.go index c2ceac58c8b..e49c7090c49 100644 --- a/lntest/itest/lnd_test.go +++ b/lntest/itest/lnd_test.go @@ -14,7 +14,6 @@ import ( "os" "path/filepath" "reflect" - "runtime" "strings" "sync" "sync/atomic" @@ -14213,23 +14212,9 @@ func TestLightningNetworkDaemon(t *testing.T) { t, chainBackend.ConnectMiner(), "failed to connect to miner", ) - binary := itestLndBinary - if runtime.GOOS == "windows" { - // Windows (even in a bash like environment like git bash as on - // Travis) doesn't seem to like relative paths to exe files... - currentDir, err := os.Getwd() - if err != nil { - ht.Fatalf("unable to get working directory: %v", err) - } - targetPath := filepath.Join(currentDir, "../../lnd-itest.exe") - binary, err = filepath.Abs(targetPath) - if err != nil { - ht.Fatalf("unable to get absolute path: %v", err) - } - } - // Now we can set up our test harness (LND instance), with the chain // backend we just created. + binary := ht.getLndBinary() lndHarness, err = lntest.NewNetworkHarness(miner, chainBackend, binary) if err != nil { ht.Fatalf("unable to create lightning network harness: %v", err) diff --git a/lntest/itest/test_harness.go b/lntest/itest/test_harness.go index a3c4752893a..45248f60256 100644 --- a/lntest/itest/test_harness.go +++ b/lntest/itest/test_harness.go @@ -3,8 +3,12 @@ package itest import ( "bytes" "context" + "flag" "fmt" "math" + "os" + "path/filepath" + "runtime" "testing" "time" @@ -20,6 +24,11 @@ import ( var ( harnessNetParams = &chaincfg.RegressionNetParams + + // lndExecutable is the full path to the lnd binary. + lndExecutable = flag.String( + "lndexec", itestLndBinary, "full path to lnd binary", + ) ) const ( @@ -111,6 +120,31 @@ func (h *harnessTest) Log(args ...interface{}) { h.t.Log(args...) } +func (h *harnessTest) getLndBinary() string { + binary := itestLndBinary + lndExec := "" + if lndExecutable != nil && *lndExecutable != "" { + lndExec = *lndExecutable + } + if lndExec == "" && runtime.GOOS == "windows" { + // Windows (even in a bash like environment like git bash as on + // Travis) doesn't seem to like relative paths to exe files... + currentDir, err := os.Getwd() + if err != nil { + h.Fatalf("unable to get working directory: %v", err) + } + targetPath := filepath.Join(currentDir, "../../lnd-itest.exe") + binary, err = filepath.Abs(targetPath) + if err != nil { + h.Fatalf("unable to get absolute path: %v", err) + } + } else if lndExec != "" { + binary = lndExec + } + + return binary +} + type testCase struct { name string test func(net *lntest.NetworkHarness, t *harnessTest) From 42f7ea51c504e231e1eea1d54f267980ac38eb77 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Wed, 4 Nov 2020 11:03:33 +0100 Subject: [PATCH 07/43] travis+make: execute test groups in parallel --- .gitignore | 2 ++ .travis.yml | 21 ++++++++++----------- Makefile | 21 +++++++++++++++++++++ make/testing_flags.mk | 8 +++++++- scripts/itest_part.sh | 23 +++++++++++++++++++++++ 5 files changed, 63 insertions(+), 12 deletions(-) create mode 100755 scripts/itest_part.sh diff --git a/.gitignore b/.gitignore index f23749bd118..371b57b6815 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,8 @@ lntest/itest/output*.log lntest/itest/pprof*.log lntest/itest/.backendlogs lntest/itest/.minerlogs +lntest/itest/lnd-itest +lntest/itest/.logs-* cmd/cmd *.key diff --git a/.travis.yml b/.travis.yml index feed17cc11e..51cf9402c17 100644 --- a/.travis.yml +++ b/.travis.yml @@ -50,32 +50,30 @@ jobs: - stage: Integration Test name: Btcd Integration script: - - make itest + - make itest-parallel - name: Bitcoind Integration (txindex enabled) script: - bash ./scripts/install_bitcoind.sh - - make itest backend=bitcoind + - make itest-parallel backend=bitcoind - name: Bitcoind Integration (txindex disabled) script: - bash ./scripts/install_bitcoind.sh - - make itest backend="bitcoind notxindex" + - make itest-parallel backend="bitcoind notxindex" - name: Neutrino Integration script: - - make itest backend=neutrino + - make itest-parallel backend=neutrino - name: Btcd Integration ARM script: - - GOARM=7 GOARCH=arm GOOS=linux CGO_ENABLED=0 make btcd build-itest - - file lnd-itest - - GOARM=7 GOARCH=arm GOOS=linux CGO_ENABLED=0 make itest-only + - GOARM=7 GOARCH=arm GOOS=linux make itest-parallel arch: arm64 - name: Btcd Integration Windows script: - - make itest-windows + - make itest-parallel-windows os: windows before_install: - choco upgrade --no-progress -y make netcat curl findutils @@ -85,7 +83,8 @@ jobs: case $TRAVIS_OS_NAME in windows) echo "Uploading to termbin.com..." - for f in ./lntest/itest/*.log; do cat $f | nc termbin.com 9999 | xargs -r0 printf "$f"' uploaded to %s'; done + LOG_FILES=$(find ./lntest/itest -name '*.log') + for f in $LOG_FILES; do echo -n $f; cat $f | nc termbin.com 9999 | xargs -r0 printf ' uploaded to %s'; done ;; esac @@ -97,8 +96,8 @@ after_failure: ;; *) - LOG_FILES=./lntest/itest/*.log - echo "Uploading to termbin.com..." && find $LOG_FILES | xargs -I{} sh -c "cat {} | nc termbin.com 9999 | xargs -r0 printf '{} uploaded to %s'" + LOG_FILES=$(find ./lntest/itest -name '*.log') + echo "Uploading to termbin.com..." && for f in $LOG_FILES; do echo -n $f; cat $f | nc termbin.com 9999 | xargs -r0 printf ' uploaded to %s'; done echo "Uploading to file.io..." && tar -zcvO $LOG_FILES | curl -s -F 'file=@-;filename=logs.tar.gz' https://file.io | xargs -r0 printf 'logs.tar.gz uploaded to %s\n' ;; esac diff --git a/Makefile b/Makefile index 5f55f0bc75d..ec0e322751a 100644 --- a/Makefile +++ b/Makefile @@ -175,6 +175,27 @@ itest-only: itest: btcd build-itest itest-only +itest-parallel: btcd + @$(call print, "Building lnd binary") + CGO_ENABLED=0 $(GOBUILD) -tags="$(ITEST_TAGS)" -o lntest/itest/lnd-itest $(ITEST_LDFLAGS) $(PKG)/cmd/lnd + + @$(call print, "Building itest binary for $(backend) backend") + CGO_ENABLED=0 $(GOTEST) -v ./lntest/itest -tags="$(DEV_TAGS) $(RPC_TAGS) rpctest $(backend)" -logoutput -goroutinedump -c -o lntest/itest/itest.test + + @$(call print, "Running tests") + rm -rf lntest/itest/*.log lntest/itest/.logs-* + echo -n "$$(seq 0 $$(expr $(NUM_ITEST_TRANCHES) - 1))" | xargs -P $(NUM_ITEST_TRANCHES) -n 1 -I {} scripts/itest_part.sh {} $(NUM_ITEST_TRANCHES) $(TEST_FLAGS) + +itest-parallel-windows: btcd + @$(call print, "Building lnd binary") + CGO_ENABLED=0 $(GOBUILD) -tags="$(ITEST_TAGS)" -o lntest/itest/lnd-itest.exe $(ITEST_LDFLAGS) $(PKG)/cmd/lnd + + @$(call print, "Building itest binary for $(backend) backend") + CGO_ENABLED=0 $(GOTEST) -v ./lntest/itest -tags="$(DEV_TAGS) $(RPC_TAGS) rpctest $(backend)" -logoutput -goroutinedump -c -o lntest/itest/itest.test.exe + + @$(call print, "Running tests") + EXEC_SUFFIX=".exe" echo -n "$$(seq 0 $$(expr $(NUM_ITEST_TRANCHES) - 1))" | xargs -P $(NUM_ITEST_TRANCHES) -n 1 -I {} scripts/itest_part.sh {} $(NUM_ITEST_TRANCHES) $(TEST_FLAGS) + itest-windows: btcd build-itest-windows itest-only unit: btcd diff --git a/make/testing_flags.mk b/make/testing_flags.mk index 1443ab5b225..f64d859c896 100644 --- a/make/testing_flags.mk +++ b/make/testing_flags.mk @@ -3,12 +3,18 @@ RPC_TAGS = autopilotrpc chainrpc invoicesrpc routerrpc signrpc verrpc walletrpc LOG_TAGS = TEST_FLAGS = COVER_PKG = $$(go list -deps ./... | grep '$(PKG)' | grep -v lnrpc) +NUM_ITEST_TRANCHES = 6 # If rpc option is set also add all extra RPC tags to DEV_TAGS ifneq ($(with-rpc),) DEV_TAGS += $(RPC_TAGS) endif +# Scale the number of parallel running itest tranches. +ifneq ($(tranches),) +NUM_ITEST_TRANCHES = $(tranches) +endif + # If specific package is being unit tested, construct the full name of the # subpackage. ifneq ($(pkg),) @@ -25,7 +31,7 @@ endif # Define the integration test.run filter if the icase argument was provided. ifneq ($(icase),) -TEST_FLAGS += -test.run=TestLightningNetworkDaemon/$(icase) +TEST_FLAGS += -test.run="TestLightningNetworkDaemon/.*-of-.*/.*/$(icase)" endif ifneq ($(tags),) diff --git a/scripts/itest_part.sh b/scripts/itest_part.sh new file mode 100755 index 00000000000..52c3481c8c7 --- /dev/null +++ b/scripts/itest_part.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +# Let's work with absolute paths only, we run in the itest directory itself. +WORKDIR=$(pwd)/lntest/itest + +TRANCHE=$1 +NUM_TRANCHES=$2 + +# Shift the passed parameters by two, giving us all remaining testing flags in +# the $@ special variable. +shift +shift + +# Windows insists on having the .exe suffix for an executable, we need to add +# that here if necessary. +EXEC="$WORKDIR"/itest.test"$EXEC_SUFFIX" +LND_EXEC="$WORKDIR"/lnd-itest"$EXEC_SUFFIX" +echo $EXEC -test.v "$@" -logoutput -goroutinedump -logdir=.logs-tranche$TRANCHE -lndexec=$LND_EXEC -splittranches=$NUM_TRANCHES -runtranche=$TRANCHE + +# Exit code 255 causes the parallel jobs to abort, so if one part fails the +# other is aborted too. +cd "$WORKDIR" || exit 255 +$EXEC -test.v "$@" -logoutput -goroutinedump -logdir=.logs-tranche$TRANCHE -lndexec=$LND_EXEC -splittranches=$NUM_TRANCHES -runtranche=$TRANCHE || exit 255 From 72ea3e323204e5241d3f71f75f1b17f671dddc73 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Wed, 4 Nov 2020 11:03:34 +0100 Subject: [PATCH 08/43] itest: fix chanbackup restore flake Updating the fee of the mock estimator _after_ starting carol turned out to be flaky and could lead to the new fee not being picked up in time for the force close. That lead to carol not cpfp'ing the force closed transaction. --- lntest/itest/lnd_channel_backup_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lntest/itest/lnd_channel_backup_test.go b/lntest/itest/lnd_channel_backup_test.go index 5a8bf87dcf7..2a7e0eb6c28 100644 --- a/lntest/itest/lnd_channel_backup_test.go +++ b/lntest/itest/lnd_channel_backup_test.go @@ -1012,6 +1012,10 @@ func testChanRestoreScenario(t *harnessTest, net *lntest.NetworkHarness, require.Contains(t.t, err.Error(), "cannot close channel with state: ") require.Contains(t.t, err.Error(), "ChanStatusRestored") + // Increase the fee estimate so that the following force close tx will + // be cpfp'ed in case of anchor commitments. + net.SetFeeEstimate(30000) + // Now that we have ensured that the channels restored by the backup are // in the correct state even without the remote peer telling us so, // let's start up Carol again. From 2e62b5c0fc62030f7b342b0321ed3f4946d94a27 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Wed, 4 Nov 2020 11:03:36 +0100 Subject: [PATCH 09/43] itest: fix typo and formatting --- .../lnd_multi-hop_htlc_remote_chain_claim_test.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lntest/itest/lnd_multi-hop_htlc_remote_chain_claim_test.go b/lntest/itest/lnd_multi-hop_htlc_remote_chain_claim_test.go index 1aeff01e2a9..72a3c63cfff 100644 --- a/lntest/itest/lnd_multi-hop_htlc_remote_chain_claim_test.go +++ b/lntest/itest/lnd_multi-hop_htlc_remote_chain_claim_test.go @@ -101,15 +101,17 @@ func testMultiHopHtlcRemoteChainClaim(net *lntest.NetworkHarness, t *harnessTest // bob will attempt to redeem his anchor commitment (if the channel // type is of that type). if c == commitTypeAnchors { - _, err = waitForNTxsInMempool(net.Miner.Node, 1, minerMempoolTimeout) + _, err = waitForNTxsInMempool( + net.Miner.Node, 1, minerMempoolTimeout, + ) if err != nil { - t.Fatalf("unable to find bob's anchor commit sweep: %v", err) - + t.Fatalf("unable to find bob's anchor commit sweep: %v", + err) } } // Mine enough blocks for Alice to sweep her funds from the force - // closed channel. closeCHannelAndAssertType() already mined a block + // closed channel. closeChannelAndAssertType() already mined a block // containing the commitment tx and the commit sweep tx will be // broadcast immediately before it can be included in a block, so mine // one less than defaultCSV in order to perform mempool assertions. From 463373138be9b0e68a8c5b67b1b95a95c4404ea6 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Wed, 4 Nov 2020 11:03:37 +0100 Subject: [PATCH 10/43] itest: move longest test to beginning To make sure the test that takes the longest overall time is always started first, independent of the number of test tranches we run, we move it to the beginning of the list. Because that test involves a lot of waiting, it allows us to play around with the number of tranches more efficiently. --- lntest/itest/lnd_test_list_on_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lntest/itest/lnd_test_list_on_test.go b/lntest/itest/lnd_test_list_on_test.go index 420331c7e5b..3575213cd43 100644 --- a/lntest/itest/lnd_test_list_on_test.go +++ b/lntest/itest/lnd_test_list_on_test.go @@ -3,6 +3,10 @@ package itest var allTestCases = []*testCase{ + { + name: "test multi-hop htlc", + test: testMultiHopHtlcClaims, + }, { name: "sweep coins", test: testSweepAllCoins, @@ -144,10 +148,6 @@ var allTestCases = []*testCase{ name: "async bidirectional payments", test: testBidirectionalAsyncPayments, }, - { - name: "test multi-hop htlc", - test: testMultiHopHtlcClaims, - }, { name: "switch circuit persistence", test: testSwitchCircuitPersistence, From 222e3c933d2594947d7603611d86349d2c2cd710 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:25:01 -0800 Subject: [PATCH 11/43] lnwire: create new ExtraOpaqueData type for parsing TLV extensions In this commit, we create a new `ExtraOpaqueData` based on the field with the same name that's present in all the announcement related messages. In later commits, we'll embed this new type in each message, so we'll have a generic way to add/parse TLV extensions from messages. --- lnwire/extra_bytes.go | 84 +++++++++++++++++++++ lnwire/extra_bytes_test.go | 147 +++++++++++++++++++++++++++++++++++++ lnwire/lnwire.go | 21 +++++- 3 files changed, 249 insertions(+), 3 deletions(-) create mode 100644 lnwire/extra_bytes.go create mode 100644 lnwire/extra_bytes_test.go diff --git a/lnwire/extra_bytes.go b/lnwire/extra_bytes.go new file mode 100644 index 00000000000..a94a948ee2c --- /dev/null +++ b/lnwire/extra_bytes.go @@ -0,0 +1,84 @@ +package lnwire + +import ( + "bytes" + "io" + "io/ioutil" + + "github.com/lightningnetwork/lnd/tlv" +) + +// ExtraOpaqueData is the set of data that was appended to this message, some +// of which we may not actually know how to iterate or parse. By holding onto +// this data, we ensure that we're able to properly validate the set of +// signatures that cover these new fields, and ensure we're able to make +// upgrades to the network in a forwards compatible manner. +type ExtraOpaqueData []byte + +// Encode attempts to encode the raw extra bytes into the passed io.Writer. +func (e *ExtraOpaqueData) Encode(w io.Writer) error { + eBytes := []byte((*e)[:]) + if err := WriteElements(w, eBytes); err != nil { + return err + } + + return nil +} + +// Decode attempts to unpack the raw bytes encoded in the passed io.Reader as a +// set of extra opaque data. +func (e *ExtraOpaqueData) Decode(r io.Reader) error { + // First, we'll attempt to read a set of bytes contained within the + // passed io.Reader (if any exist). + rawBytes, err := ioutil.ReadAll(r) + if err != nil { + return err + } + + // If we _do_ have some bytes, then we'll swap out our backing pointer. + // This ensures that any struct that embeds this type will properly + // store the bytes once this method exits. + if len(rawBytes) > 0 { + *e = ExtraOpaqueData(rawBytes) + } else { + *e = make([]byte, 0) + } + + return nil +} + +// PackRecords attempts to encode the set of tlv records into the target +// ExtraOpaqueData instance. The records will be encoded as a raw TLV stream +// and stored within the backing slice pointer. +func (e *ExtraOpaqueData) PackRecords(records []tlv.Record) error { + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return err + } + + var extraBytesWriter bytes.Buffer + if err := tlvStream.Encode(&extraBytesWriter); err != nil { + return err + } + + *e = ExtraOpaqueData(extraBytesWriter.Bytes()) + + return nil +} + +// ExtractRecords attempts to decode any types in the internal raw bytes as if +// it were a tlv stream. The set of raw parsed types is returned, and any +// passed records (if found in the stream) will be parsed into the proper +// tlv.Record. +func (e *ExtraOpaqueData) ExtractRecords(records ...tlv.Record) ( + tlv.TypeMap, error) { + + extraBytesReader := bytes.NewReader(*e) + + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return nil, err + } + + return tlvStream.DecodeWithParsedTypes(extraBytesReader) +} diff --git a/lnwire/extra_bytes_test.go b/lnwire/extra_bytes_test.go new file mode 100644 index 00000000000..55acfee61fe --- /dev/null +++ b/lnwire/extra_bytes_test.go @@ -0,0 +1,147 @@ +package lnwire + +import ( + "bytes" + "math/rand" + "reflect" + "testing" + "testing/quick" + + "github.com/lightningnetwork/lnd/tlv" +) + +// TestExtraOpaqueDataEncodeDecode tests that we're able to encode/decode +// arbitrary payloads. +func TestExtraOpaqueDataEncodeDecode(t *testing.T) { + t.Parallel() + + type testCase struct { + // emptyBytes indicates if we should try to encode empty bytes + // or not. + emptyBytes bool + + // inputBytes if emptyBytes is false, then we'll read in this + // set of bytes instead. + inputBytes []byte + } + + // We should be able to read in an arbitrary set of bytes as an + // ExtraOpaqueData, then encode those new bytes into a new instance. + // The final two instances should be identical. + scenario := func(test testCase) bool { + var ( + extraData ExtraOpaqueData + b bytes.Buffer + ) + + copy(extraData[:], test.inputBytes) + + if err := extraData.Encode(&b); err != nil { + t.Fatalf("unable to encode extra data: %v", err) + return false + } + + var newBytes ExtraOpaqueData + if err := newBytes.Decode(&b); err != nil { + t.Fatalf("unable to decode extra bytes: %v", err) + return false + } + + if !bytes.Equal(extraData[:], newBytes[:]) { + t.Fatalf("expected %x, got %x", extraData, + newBytes) + return false + } + + return true + } + + // We'll make a function to generate random test data. Half of the + // time, we'll actually feed in blank bytes. + quickCfg := &quick.Config{ + Values: func(v []reflect.Value, r *rand.Rand) { + + var newTestCase testCase + if r.Int31()%2 == 0 { + newTestCase.emptyBytes = true + } + + if !newTestCase.emptyBytes { + numBytes := r.Int31n(1000) + newTestCase.inputBytes = make([]byte, numBytes) + + _, err := r.Read(newTestCase.inputBytes) + if err != nil { + t.Fatalf("unable to gen random bytes: %v", err) + return + } + } + + v[0] = reflect.ValueOf(newTestCase) + }, + } + + if err := quick.Check(scenario, quickCfg); err != nil { + t.Fatalf("encode+decode test failed: %v", err) + } +} + +// TestExtraOpaqueDataPackUnpackRecords tests that we're able to pack a set of +// tlv.Records into a stream, and unpack them on the other side to obtain the +// same set of records. +func TestExtraOpaqueDataPackUnpackRecords(t *testing.T) { + t.Parallel() + + var ( + type1 tlv.Type = 1 + type2 tlv.Type = 2 + + channelType1 uint8 = 2 + channelType2 uint8 + + hop1 uint32 = 99 + hop2 uint32 + ) + testRecords := []tlv.Record{ + tlv.MakePrimitiveRecord(type1, &channelType1), + tlv.MakePrimitiveRecord(type2, &hop1), + } + + // Now that we have our set of sample records and types, we'll encode + // them into the passed ExtraOpaqueData instance. + var extraBytes ExtraOpaqueData + if err := extraBytes.PackRecords(testRecords); err != nil { + t.Fatalf("unable to pack records: %v", err) + } + + // We'll now simulate decoding these types _back_ into records on the + // other side. + newRecords := []tlv.Record{ + tlv.MakePrimitiveRecord(type1, &channelType2), + tlv.MakePrimitiveRecord(type2, &hop2), + } + typeMap, err := extraBytes.ExtractRecords(newRecords...) + if err != nil { + t.Fatalf("unable to extract record: %v", err) + } + + // We should find that the new backing values have been populated with + // the proper value. + switch { + case channelType1 != channelType2: + t.Fatalf("wrong record for channel type: expected %v, got %v", + channelType1, channelType2) + + case hop1 != hop2: + t.Fatalf("wrong record for hop: expected %v, got %v", hop1, + hop2) + } + + // Both types we created above should be found in the type map. + if _, ok := typeMap[type1]; !ok { + t.Fatalf("type1 not found in typeMap") + } + if _, ok := typeMap[type2]; !ok { + t.Fatalf("type2 not found in typeMap") + } +} diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index ca0e449e5a5..c180cad3883 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -18,9 +18,16 @@ import ( "github.com/lightningnetwork/lnd/tor" ) -// MaxSliceLength is the maximum allowed length for any opaque byte slices in -// the wire protocol. -const MaxSliceLength = 65535 +const ( + // MaxSliceLength is the maximum allowed length for any opaque byte + // slices in the wire protocol. + MaxSliceLength = 65535 + + // MaxMsgBody is the largest payload any message is allowed to provide. + // This is two less than the MaxSliceLength as each message has a 2 + // byte type that precedes the message body. + MaxMsgBody = 65533 +) // PkScript is simple type definition which represents a raw serialized public // key script. @@ -418,6 +425,10 @@ func WriteElement(w io.Writer, element interface{}) error { if _, err := w.Write(b[:]); err != nil { return err } + + case ExtraOpaqueData: + return e.Encode(w) + default: return fmt.Errorf("unknown type in WriteElement: %T", e) } @@ -824,6 +835,10 @@ func ReadElement(r io.Reader, element interface{}) error { return err } *e = addrBytes[:length] + + case *ExtraOpaqueData: + return e.Decode(r) + default: return fmt.Errorf("unknown type in ReadElement: %T", e) } From 574bff8c07dd326f7f3d2c946eba08a3a1df1f63 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:25:36 -0800 Subject: [PATCH 12/43] lnwire: prep UpdateFulfillHTLC for TLV extensions --- lnwire/update_fulfill_htlc.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/lnwire/update_fulfill_htlc.go b/lnwire/update_fulfill_htlc.go index 6c0e6339ff6..36977b1e928 100644 --- a/lnwire/update_fulfill_htlc.go +++ b/lnwire/update_fulfill_htlc.go @@ -21,6 +21,11 @@ type UpdateFulfillHTLC struct { // PaymentPreimage is the R-value preimage required to fully settle an // HTLC. PaymentPreimage [32]byte + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewUpdateFulfillHTLC returns a new empty UpdateFulfillHTLC. @@ -47,6 +52,7 @@ func (c *UpdateFulfillHTLC) Decode(r io.Reader, pver uint32) error { &c.ChanID, &c.ID, c.PaymentPreimage[:], + &c.ExtraData, ) } @@ -59,6 +65,7 @@ func (c *UpdateFulfillHTLC) Encode(w io.Writer, pver uint32) error { c.ChanID, c.ID, c.PaymentPreimage[:], + c.ExtraData, ) } @@ -75,8 +82,7 @@ func (c *UpdateFulfillHTLC) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *UpdateFulfillHTLC) MaxPayloadLength(uint32) uint32 { - // 32 + 8 + 32 - return 72 + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is From b4acf79fd40b82588e110e60024547493a9f3a0d Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:25:46 -0800 Subject: [PATCH 13/43] lnwire: prep UpdateFee for TLV extensions --- lnwire/update_fee.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/lnwire/update_fee.go b/lnwire/update_fee.go index 2d27c3772f7..25ab180c2df 100644 --- a/lnwire/update_fee.go +++ b/lnwire/update_fee.go @@ -16,6 +16,11 @@ type UpdateFee struct { // TODO(halseth): make SatPerKWeight when fee estimation is moved to // own package. Currently this will cause an import cycle. FeePerKw uint32 + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewUpdateFee creates a new UpdateFee message. @@ -38,6 +43,7 @@ func (c *UpdateFee) Decode(r io.Reader, pver uint32) error { return ReadElements(r, &c.ChanID, &c.FeePerKw, + &c.ExtraData, ) } @@ -49,6 +55,7 @@ func (c *UpdateFee) Encode(w io.Writer, pver uint32) error { return WriteElements(w, c.ChanID, c.FeePerKw, + c.ExtraData, ) } @@ -65,8 +72,7 @@ func (c *UpdateFee) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *UpdateFee) MaxPayloadLength(uint32) uint32 { - // 32 + 4 - return 36 + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is From 6002919438b17ea40b90939e74c08a2204abc1d8 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:25:55 -0800 Subject: [PATCH 14/43] lnwire: prep UpdateFailMalformedHTLC for TLV extensions --- lnwire/update_fail_malformed_htlc.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/lnwire/update_fail_malformed_htlc.go b/lnwire/update_fail_malformed_htlc.go index 39d4b8709e2..b28ec29ff4e 100644 --- a/lnwire/update_fail_malformed_htlc.go +++ b/lnwire/update_fail_malformed_htlc.go @@ -24,6 +24,11 @@ type UpdateFailMalformedHTLC struct { // FailureCode the exact reason why onion blob haven't been parsed. FailureCode FailCode + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // A compile time check to ensure UpdateFailMalformedHTLC implements the @@ -40,6 +45,7 @@ func (c *UpdateFailMalformedHTLC) Decode(r io.Reader, pver uint32) error { &c.ID, c.ShaOnionBlob[:], &c.FailureCode, + &c.ExtraData, ) } @@ -53,6 +59,7 @@ func (c *UpdateFailMalformedHTLC) Encode(w io.Writer, pver uint32) error { c.ID, c.ShaOnionBlob[:], c.FailureCode, + c.ExtraData, ) } @@ -70,8 +77,7 @@ func (c *UpdateFailMalformedHTLC) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *UpdateFailMalformedHTLC) MaxPayloadLength(uint32) uint32 { - // 32 + 8 + 32 + 2 - return 74 + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is From e652f1dbc2b74817957e1f965b360f4a467cf5a9 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:26:07 -0800 Subject: [PATCH 15/43] lnwire: prep UpdateFailHTLC for TLV extensions --- lnwire/update_fail_htlc.go | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/lnwire/update_fail_htlc.go b/lnwire/update_fail_htlc.go index 194f2ecd000..09666ac25ff 100644 --- a/lnwire/update_fail_htlc.go +++ b/lnwire/update_fail_htlc.go @@ -26,6 +26,11 @@ type UpdateFailHTLC struct { // failed. This blob is only fully decryptable by the initiator of the // HTLC message. Reason OpaqueReason + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // A compile time check to ensure UpdateFailHTLC implements the lnwire.Message @@ -41,6 +46,7 @@ func (c *UpdateFailHTLC) Decode(r io.Reader, pver uint32) error { &c.ChanID, &c.ID, &c.Reason, + &c.ExtraData, ) } @@ -53,6 +59,7 @@ func (c *UpdateFailHTLC) Encode(w io.Writer, pver uint32) error { c.ChanID, c.ID, c.Reason, + c.ExtraData, ) } @@ -69,21 +76,7 @@ func (c *UpdateFailHTLC) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *UpdateFailHTLC) MaxPayloadLength(uint32) uint32 { - var length uint32 - - // Length of the ChanID - length += 32 - - // Length of the ID - length += 8 - - // Length of the length opaque reason - length += 2 - - // Length of the Reason - length += 292 - - return length + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is From 9f356b6dad03c09e008840fe0b94bd526d1ffa88 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:26:23 -0800 Subject: [PATCH 16/43] lnwire: prep UpdateAddHTLC for TLV extensions --- lnwire/update_add_htlc.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go index 028c6320d72..9211d39ffb0 100644 --- a/lnwire/update_add_htlc.go +++ b/lnwire/update_add_htlc.go @@ -52,6 +52,11 @@ type UpdateAddHTLC struct { // should strip off a layer of encryption, exposing the next hop to be // used in the subsequent UpdateAddHTLC message. OnionBlob [OnionPacketSize]byte + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewUpdateAddHTLC returns a new empty UpdateAddHTLC message. @@ -75,6 +80,7 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { c.PaymentHash[:], &c.Expiry, c.OnionBlob[:], + &c.ExtraData, ) } @@ -90,6 +96,7 @@ func (c *UpdateAddHTLC) Encode(w io.Writer, pver uint32) error { c.PaymentHash[:], c.Expiry, c.OnionBlob[:], + c.ExtraData, ) } @@ -106,8 +113,7 @@ func (c *UpdateAddHTLC) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *UpdateAddHTLC) MaxPayloadLength(uint32) uint32 { - // 1450 - return 32 + 8 + 4 + 8 + 32 + 1366 + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is From 2bc6393734d014b96d7e3478b7fe4757ad5e407f Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:26:33 -0800 Subject: [PATCH 17/43] lnwire: prep Shutdown for TLV extensions --- lnwire/shutdown.go | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/lnwire/shutdown.go b/lnwire/shutdown.go index 94d10a9080c..e27681e4e60 100644 --- a/lnwire/shutdown.go +++ b/lnwire/shutdown.go @@ -15,6 +15,11 @@ type Shutdown struct { // Address is the script to which the channel funds will be paid. Address DeliveryAddress + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // DeliveryAddress is used to communicate the address to which funds from a @@ -48,7 +53,7 @@ var _ Message = (*Shutdown)(nil) // // This is part of the lnwire.Message interface. func (s *Shutdown) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, &s.ChannelID, &s.Address) + return ReadElements(r, &s.ChannelID, &s.Address, &s.ExtraData) } // Encode serializes the target Shutdown into the passed io.Writer observing @@ -56,7 +61,7 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (s *Shutdown) Encode(w io.Writer, pver uint32) error { - return WriteElements(w, s.ChannelID, s.Address) + return WriteElements(w, s.ChannelID, s.Address, s.ExtraData) } // MsgType returns the integer uniquely identifying this message type on the @@ -72,16 +77,5 @@ func (s *Shutdown) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (s *Shutdown) MaxPayloadLength(pver uint32) uint32 { - var length uint32 - - // ChannelID - 32bytes - length += 32 - - // Len - 2 bytes - length += 2 - - // ScriptPubKey - maximum delivery address size. - length += deliveryAddressMaxSize - - return length + return MaxMsgBody } From fe168320783e20c6eb86aa4bcecf17e7ab47bfef Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:26:41 -0800 Subject: [PATCH 18/43] lnwire: prep RevokeAndAck for TLV extensions --- lnwire/revoke_and_ack.go | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/lnwire/revoke_and_ack.go b/lnwire/revoke_and_ack.go index d685f0f3256..6eaf5cafd69 100644 --- a/lnwire/revoke_and_ack.go +++ b/lnwire/revoke_and_ack.go @@ -30,11 +30,18 @@ type RevokeAndAck struct { // create the proper revocation key used within the commitment // transaction. NextRevocationKey *btcec.PublicKey + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewRevokeAndAck creates a new RevokeAndAck message. func NewRevokeAndAck() *RevokeAndAck { - return &RevokeAndAck{} + return &RevokeAndAck{ + ExtraData: make([]byte, 0), + } } // A compile time check to ensure RevokeAndAck implements the lnwire.Message @@ -50,6 +57,7 @@ func (c *RevokeAndAck) Decode(r io.Reader, pver uint32) error { &c.ChanID, c.Revocation[:], &c.NextRevocationKey, + &c.ExtraData, ) } @@ -62,6 +70,7 @@ func (c *RevokeAndAck) Encode(w io.Writer, pver uint32) error { c.ChanID, c.Revocation[:], c.NextRevocationKey, + c.ExtraData, ) } @@ -78,8 +87,7 @@ func (c *RevokeAndAck) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *RevokeAndAck) MaxPayloadLength(uint32) uint32 { - // 32 + 32 + 33 - return 97 + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is From bc73120a66218389adac8e250c7666e6d6e12e42 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:26:56 -0800 Subject: [PATCH 19/43] lnwire: prep ReplyShortChanIDsEnd for TLV extensions --- lnwire/reply_short_chan_ids_end.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/lnwire/reply_short_chan_ids_end.go b/lnwire/reply_short_chan_ids_end.go index ce5f8f740bd..1412b50f960 100644 --- a/lnwire/reply_short_chan_ids_end.go +++ b/lnwire/reply_short_chan_ids_end.go @@ -22,6 +22,11 @@ type ReplyShortChanIDsEnd struct { // set of short chan ID's in the corresponding QueryShortChanIDs // message. Complete uint8 + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewReplyShortChanIDsEnd creates a new empty ReplyShortChanIDsEnd message. @@ -41,6 +46,7 @@ func (c *ReplyShortChanIDsEnd) Decode(r io.Reader, pver uint32) error { return ReadElements(r, c.ChainHash[:], &c.Complete, + &c.ExtraData, ) } @@ -52,6 +58,7 @@ func (c *ReplyShortChanIDsEnd) Encode(w io.Writer, pver uint32) error { return WriteElements(w, c.ChainHash[:], c.Complete, + c.ExtraData, ) } @@ -69,6 +76,5 @@ func (c *ReplyShortChanIDsEnd) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *ReplyShortChanIDsEnd) MaxPayloadLength(uint32) uint32 { - // 32 (chain hash) + 1 (complete) - return 33 + return MaxMsgBody } From 68f8fc781c4d6715e4e0ba282e96da615dbc1cef Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:30:54 -0800 Subject: [PATCH 20/43] discovery+lnwire: remove embedding within ReplyChannelRange In order to prep for allowing TLV extensions for the `ReplyChannelRange` and `QueryChannelRange` messages, we'll need to remove the struct embedding as is. If we don't remove this, then we'll attempt to decode TLV extensions from both the embedded and outer struct. All relevant call sites have been updated to reflect this minor change. --- discovery/sync_manager_test.go | 5 ++- discovery/syncer.go | 32 ++++++++-------- discovery/syncer_test.go | 59 +++++++++++++----------------- lnwire/lnwire_test.go | 6 +-- lnwire/reply_channel_range.go | 55 +++++++++++++++++++++++----- lnwire/reply_channel_range_test.go | 14 +++---- 6 files changed, 99 insertions(+), 72 deletions(-) diff --git a/discovery/sync_manager_test.go b/discovery/sync_manager_test.go index c7a228f8cf6..b8b309249b9 100644 --- a/discovery/sync_manager_test.go +++ b/discovery/sync_manager_test.go @@ -536,8 +536,9 @@ func assertTransitionToChansSynced(t *testing.T, s *GossipSyncer, peer *mockPeer assertMsgSent(t, peer, query) s.ProcessQueryMsg(&lnwire.ReplyChannelRange{ - QueryChannelRange: *query, - Complete: 1, + FirstBlockHeight: 0, + NumBlocks: math.MaxUint32, + Complete: 1, }, nil) chanSeries := s.cfg.channelSeries.(*mockChannelGraphTimeSeries) diff --git a/discovery/syncer.go b/discovery/syncer.go index 10b6d4205da..8417fda5049 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -685,7 +685,9 @@ func (g *GossipSyncer) synchronizeChanIDs() (bool, error) { func isLegacyReplyChannelRange(query *lnwire.QueryChannelRange, reply *lnwire.ReplyChannelRange) bool { - return reply.QueryChannelRange == *query + return (reply.ChainHash == query.ChainHash && + reply.FirstBlockHeight == query.FirstBlockHeight && + reply.NumBlocks == query.NumBlocks) } // processChanRangeReply is called each time the GossipSyncer receives a new @@ -705,7 +707,7 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro // The last block should also be. We don't need to check the // intermediate ones because they should already be in sorted // order. - replyLastHeight := msg.QueryChannelRange.LastBlockHeight() + replyLastHeight := msg.LastBlockHeight() queryLastHeight := g.curQueryRangeMsg.LastBlockHeight() if replyLastHeight > queryLastHeight { return fmt.Errorf("reply includes channels for height "+ @@ -754,7 +756,7 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro // Otherwise, we'll look at the reply's height range. default: - replyLastHeight := msg.QueryChannelRange.LastBlockHeight() + replyLastHeight := msg.LastBlockHeight() queryLastHeight := g.curQueryRangeMsg.LastBlockHeight() // TODO(wilmer): This might require some padding if the remote @@ -904,10 +906,12 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro g.cfg.chainHash) return g.cfg.sendToPeerSync(&lnwire.ReplyChannelRange{ - QueryChannelRange: *query, - Complete: 0, - EncodingType: g.cfg.encodingType, - ShortChanIDs: nil, + ChainHash: query.ChainHash, + FirstBlockHeight: query.FirstBlockHeight, + NumBlocks: query.NumBlocks, + Complete: 0, + EncodingType: g.cfg.encodingType, + ShortChanIDs: nil, }) } @@ -1001,14 +1005,12 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro // With our chunk assembled, we'll now send to the remote peer // the current chunk. replyChunk := lnwire.ReplyChannelRange{ - QueryChannelRange: lnwire.QueryChannelRange{ - ChainHash: query.ChainHash, - NumBlocks: numBlocksInResp, - FirstBlockHeight: firstBlockHeight, - }, - Complete: 0, - EncodingType: g.cfg.encodingType, - ShortChanIDs: channelChunk, + ChainHash: query.ChainHash, + NumBlocks: numBlocksInResp, + FirstBlockHeight: firstBlockHeight, + Complete: 0, + EncodingType: g.cfg.encodingType, + ShortChanIDs: channelChunk, } if isFinalChunk { replyChunk.Complete = 1 diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index 8e99fa49efa..b0d649de896 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -576,10 +576,9 @@ func TestGossipSyncerQueryChannelRangeWrongChainHash(t *testing.T) { t.Fatalf("expected lnwire.ReplyChannelRange, got %T", msg) } - if msg.QueryChannelRange != *query { - t.Fatalf("wrong query channel range in reply: "+ - "expected: %v\ngot: %v", spew.Sdump(*query), - spew.Sdump(msg.QueryChannelRange)) + if msg.ChainHash != query.ChainHash { + t.Fatalf("wrong chain hash: expected %v got %v", + query.ChainHash, msg.ChainHash) } if msg.Complete != 0 { t.Fatalf("expected complete set to 0, got %v", @@ -1192,34 +1191,13 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { t.Fatalf("unable to generate channel range query: %v", err) } - var replyQueries []*lnwire.QueryChannelRange - if legacy { - // Each reply query is the same as the original query in the - // legacy mode. - replyQueries = []*lnwire.QueryChannelRange{query, query, query} - } else { - // When interpreting block ranges, the first reply should start - // from our requested first block, and the last should end at - // our requested last block. - replyQueries = []*lnwire.QueryChannelRange{ - { - FirstBlockHeight: 0, - NumBlocks: 11, - }, - { - FirstBlockHeight: 11, - NumBlocks: 1, - }, - { - FirstBlockHeight: 12, - NumBlocks: query.NumBlocks - 12, - }, - } - } - + // When interpreting block ranges, the first reply should start from + // our requested first block, and the last should end at our requested + // last block. replies := []*lnwire.ReplyChannelRange{ { - QueryChannelRange: *replyQueries[0], + FirstBlockHeight: 0, + NumBlocks: 11, ShortChanIDs: []lnwire.ShortChannelID{ { BlockHeight: 10, @@ -1227,7 +1205,8 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { }, }, { - QueryChannelRange: *replyQueries[1], + FirstBlockHeight: 11, + NumBlocks: 1, ShortChanIDs: []lnwire.ShortChannelID{ { BlockHeight: 11, @@ -1235,8 +1214,9 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { }, }, { - QueryChannelRange: *replyQueries[2], - Complete: 1, + FirstBlockHeight: 12, + NumBlocks: query.NumBlocks - 12, + Complete: 1, ShortChanIDs: []lnwire.ShortChannelID{ { BlockHeight: 12, @@ -1245,6 +1225,19 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { }, } + // Each reply query is the same as the original query in the legacy + // mode. + if legacy { + replies[0].FirstBlockHeight = query.FirstBlockHeight + replies[0].NumBlocks = query.NumBlocks + + replies[1].FirstBlockHeight = query.FirstBlockHeight + replies[1].NumBlocks = query.NumBlocks + + replies[2].FirstBlockHeight = query.FirstBlockHeight + replies[2].NumBlocks = query.NumBlocks + } + // We'll begin by sending the syncer a set of non-complete channel // range replies. if err := syncer.processChanRangeReply(replies[0]); err != nil { diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 02023b02317..ea90f3c0cfa 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -810,10 +810,8 @@ func TestLightningWireProtocol(t *testing.T) { }, MsgReplyChannelRange: func(v []reflect.Value, r *rand.Rand) { req := ReplyChannelRange{ - QueryChannelRange: QueryChannelRange{ - FirstBlockHeight: uint32(r.Int31()), - NumBlocks: uint32(r.Int31()), - }, + FirstBlockHeight: uint32(r.Int31()), + NumBlocks: uint32(r.Int31()), } if _, err := rand.Read(req.ChainHash[:]); err != nil { diff --git a/lnwire/reply_channel_range.go b/lnwire/reply_channel_range.go index 430606025c2..0c97c211e7a 100644 --- a/lnwire/reply_channel_range.go +++ b/lnwire/reply_channel_range.go @@ -1,14 +1,29 @@ package lnwire -import "io" +import ( + "io" + "math" + + "github.com/btcsuite/btcd/chaincfg/chainhash" +) // ReplyChannelRange is the response to the QueryChannelRange message. It // includes the original query, and the next streaming chunk of encoded short // channel ID's as the response. We'll also include a byte that indicates if // this is the last query in the message. type ReplyChannelRange struct { - // QueryChannelRange is the corresponding query to this response. - QueryChannelRange + // ChainHash denotes the target chain that we're trying to synchronize + // channel graph state for. + ChainHash chainhash.Hash + + // FirstBlockHeight is the first block in the query range. The + // responder should send all new short channel IDs from this block + // until this block plus the specified number of blocks. + FirstBlockHeight uint32 + + // NumBlocks is the number of blocks beyond the first block that short + // channel ID's should be sent for. + NumBlocks uint32 // Complete denotes if this is the conclusion of the set of streaming // responses to the original query. @@ -43,17 +58,21 @@ var _ Message = (*ReplyChannelRange)(nil) // // This is part of the lnwire.Message interface. func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error { - err := c.QueryChannelRange.Decode(r, pver) + err := ReadElements(r, + c.ChainHash[:], + &c.FirstBlockHeight, + &c.NumBlocks, + &c.Complete, + ) if err != nil { return err } - if err := ReadElements(r, &c.Complete); err != nil { + c.EncodingType, c.ShortChanIDs, err = decodeShortChanIDs(r) + if err != nil { return err } - c.EncodingType, c.ShortChanIDs, err = decodeShortChanIDs(r) - return err } @@ -62,15 +81,21 @@ func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (c *ReplyChannelRange) Encode(w io.Writer, pver uint32) error { - if err := c.QueryChannelRange.Encode(w, pver); err != nil { + err := WriteElements(w, + c.ChainHash[:], + c.FirstBlockHeight, + c.NumBlocks, + c.Complete, + ) + if err != nil { return err } - if err := WriteElements(w, c.Complete); err != nil { + err = encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort) + if err != nil { return err } - return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort) } // MsgType returns the integer uniquely identifying this message type on the @@ -87,4 +112,14 @@ func (c *ReplyChannelRange) MsgType() MessageType { // This is part of the lnwire.Message interface. func (c *ReplyChannelRange) MaxPayloadLength(uint32) uint32 { return MaxMessagePayload + +// LastBlockHeight returns the last block height covered by the range of a +// QueryChannelRange message. +func (c *ReplyChannelRange) LastBlockHeight() uint32 { + // Handle overflows by casting to uint64. + lastBlockHeight := uint64(c.FirstBlockHeight) + uint64(c.NumBlocks) - 1 + if lastBlockHeight > math.MaxUint32 { + return math.MaxUint32 + } + return uint32(lastBlockHeight) } diff --git a/lnwire/reply_channel_range_test.go b/lnwire/reply_channel_range_test.go index d2c8df68c68..d656db55d0c 100644 --- a/lnwire/reply_channel_range_test.go +++ b/lnwire/reply_channel_range_test.go @@ -30,7 +30,7 @@ func TestReplyChannelRangeUnsorted(t *testing.T) { var req2 ReplyChannelRange err = req2.Decode(bytes.NewReader(b.Bytes()), 0) if _, ok := err.(ErrUnsortedSIDs); !ok { - t.Fatalf("expected ErrUnsortedSIDs, got: %T", + t.Fatalf("expected ErrUnsortedSIDs, got: %v", err) } }) @@ -67,13 +67,11 @@ func TestReplyChannelRangeEmpty(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { req := ReplyChannelRange{ - QueryChannelRange: QueryChannelRange{ - FirstBlockHeight: 1, - NumBlocks: 2, - }, - Complete: 1, - EncodingType: test.encType, - ShortChanIDs: nil, + FirstBlockHeight: 1, + NumBlocks: 2, + Complete: 1, + EncodingType: test.encType, + ShortChanIDs: nil, } // First decode the hex string in the test case into a From eab6a63ac5957e02229c2abffed344a6ff566b98 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:31:09 -0800 Subject: [PATCH 21/43] lnwire: prep ReplyChannelRange for TLV extensions --- lnwire/reply_channel_range.go | 11 +++++++++-- lnwire/reply_channel_range_test.go | 1 + 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/lnwire/reply_channel_range.go b/lnwire/reply_channel_range.go index 0c97c211e7a..5167cc5a51c 100644 --- a/lnwire/reply_channel_range.go +++ b/lnwire/reply_channel_range.go @@ -37,6 +37,11 @@ type ReplyChannelRange struct { // ShortChanIDs is a slice of decoded short channel ID's. ShortChanIDs []ShortChannelID + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData + // noSort indicates whether or not to sort the short channel ids before // writing them out. // @@ -73,7 +78,7 @@ func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error { return err } - return err + return c.ExtraData.Decode(r) } // Encode serializes the target ReplyChannelRange into the passed io.Writer @@ -96,6 +101,7 @@ func (c *ReplyChannelRange) Encode(w io.Writer, pver uint32) error { return err } + return c.ExtraData.Encode(w) } // MsgType returns the integer uniquely identifying this message type on the @@ -111,7 +117,8 @@ func (c *ReplyChannelRange) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *ReplyChannelRange) MaxPayloadLength(uint32) uint32 { - return MaxMessagePayload + return MaxMsgBody +} // LastBlockHeight returns the last block height covered by the range of a // QueryChannelRange message. diff --git a/lnwire/reply_channel_range_test.go b/lnwire/reply_channel_range_test.go index d656db55d0c..ff3414958e3 100644 --- a/lnwire/reply_channel_range_test.go +++ b/lnwire/reply_channel_range_test.go @@ -72,6 +72,7 @@ func TestReplyChannelRangeEmpty(t *testing.T) { Complete: 1, EncodingType: test.encType, ShortChanIDs: nil, + ExtraData: make([]byte, 0), } // First decode the hex string in the test case into a From e33a371383dee4389baa55a493a59d66e2205f8c Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:31:16 -0800 Subject: [PATCH 22/43] lnwire: prep QueryShortChanIDs for TLV extensions --- lnwire/query_short_chan_ids.go | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index cb24178b39e..43a271333f1 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -81,6 +81,11 @@ type QueryShortChanIDs struct { // ShortChanIDs is a slice of decoded short channel ID's. ShortChanIDs []ShortChannelID + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData + // noSort indicates whether or not to sort the short channel ids before // writing them out. // @@ -114,8 +119,11 @@ func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error { } q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r) + if err != nil { + return err + } - return err + return q.ExtraData.Decode(r) } // decodeShortChanIDs decodes a set of short channel ID's that have been @@ -292,7 +300,12 @@ func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error { // Base on our encoding type, we'll write out the set of short channel // ID's. - return encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort) + err = encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort) + if err != nil { + return err + } + + return q.ExtraData.Encode(w) } // encodeShortChanIDs encodes the passed short channel ID's into the passed @@ -425,5 +438,5 @@ func (q *QueryShortChanIDs) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (q *QueryShortChanIDs) MaxPayloadLength(uint32) uint32 { - return MaxMessagePayload + return MaxMsgBody } From 0186026c0e191e59699011a9630a941dab8e4103 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:31:24 -0800 Subject: [PATCH 23/43] lnwire: prep QueryChannelRange for TLV extensions --- lnwire/query_channel_range.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/lnwire/query_channel_range.go b/lnwire/query_channel_range.go index 9546fcd32a1..3bdb30e5eca 100644 --- a/lnwire/query_channel_range.go +++ b/lnwire/query_channel_range.go @@ -25,6 +25,11 @@ type QueryChannelRange struct { // NumBlocks is the number of blocks beyond the first block that short // channel ID's should be sent for. NumBlocks uint32 + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewQueryChannelRange creates a new empty QueryChannelRange message. @@ -45,6 +50,7 @@ func (q *QueryChannelRange) Decode(r io.Reader, pver uint32) error { q.ChainHash[:], &q.FirstBlockHeight, &q.NumBlocks, + &q.ExtraData, ) } @@ -57,6 +63,7 @@ func (q *QueryChannelRange) Encode(w io.Writer, pver uint32) error { q.ChainHash[:], q.FirstBlockHeight, q.NumBlocks, + q.ExtraData, ) } @@ -73,8 +80,7 @@ func (q *QueryChannelRange) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (q *QueryChannelRange) MaxPayloadLength(uint32) uint32 { - // 32 + 4 + 4 - return 40 + return MaxMsgBody } // LastBlockHeight returns the last block height covered by the range of a From 707782d35af8ee9c924a2e387dda1342a603a8c4 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:31:44 -0800 Subject: [PATCH 24/43] lnwire: prep NodeAnnouncement for TLV extensions --- lnwire/node_announcement.go | 25 ++++--------------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/lnwire/node_announcement.go b/lnwire/node_announcement.go index f0d897bc91d..c794e5b52ef 100644 --- a/lnwire/node_announcement.go +++ b/lnwire/node_announcement.go @@ -5,7 +5,6 @@ import ( "fmt" "image/color" "io" - "io/ioutil" "net" "unicode/utf8" ) @@ -98,7 +97,7 @@ type NodeAnnouncement struct { // properly validate the set of signatures that cover these new fields, // and ensure we're able to make upgrades to the network in a forwards // compatible manner. - ExtraOpaqueData []byte + ExtraOpaqueData ExtraOpaqueData } // A compile time check to ensure NodeAnnouncement implements the @@ -110,7 +109,7 @@ var _ Message = (*NodeAnnouncement)(nil) // // This is part of the lnwire.Message interface. func (a *NodeAnnouncement) Decode(r io.Reader, pver uint32) error { - err := ReadElements(r, + return ReadElements(r, &a.Signature, &a.Features, &a.Timestamp, @@ -118,24 +117,8 @@ func (a *NodeAnnouncement) Decode(r io.Reader, pver uint32) error { &a.RGBColor, &a.Alias, &a.Addresses, + &a.ExtraOpaqueData, ) - if err != nil { - return err - } - - // Now that we've read out all the fields that we explicitly know of, - // we'll collect the remainder into the ExtraOpaqueData field. If there - // aren't any bytes, then we'll snip off the slice to avoid carrying - // around excess capacity. - a.ExtraOpaqueData, err = ioutil.ReadAll(r) - if err != nil { - return err - } - if len(a.ExtraOpaqueData) == 0 { - a.ExtraOpaqueData = nil - } - - return nil } // Encode serializes the target NodeAnnouncement into the passed io.Writer @@ -167,7 +150,7 @@ func (a *NodeAnnouncement) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (a *NodeAnnouncement) MaxPayloadLength(pver uint32) uint32 { - return 65533 + return MaxMsgBody } // DataToSign returns the part of the message that should be signed. From fdac8297771acc3f69c4762493199e356bc44b1b Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:31:52 -0800 Subject: [PATCH 25/43] lnwire: prep Init for TLV extensions --- lnwire/init_message.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/lnwire/init_message.go b/lnwire/init_message.go index 0236a71f84c..18af1d7da9a 100644 --- a/lnwire/init_message.go +++ b/lnwire/init_message.go @@ -20,6 +20,11 @@ type Init struct { // message, any GlobalFeatures should be merged into the unified // Features field. Features *RawFeatureVector + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewInitMessage creates new instance of init message object. @@ -27,6 +32,7 @@ func NewInitMessage(gf *RawFeatureVector, f *RawFeatureVector) *Init { return &Init{ GlobalFeatures: gf, Features: f, + ExtraData: make([]byte, 0), } } @@ -42,6 +48,7 @@ func (msg *Init) Decode(r io.Reader, pver uint32) error { return ReadElements(r, &msg.GlobalFeatures, &msg.Features, + &msg.ExtraData, ) } @@ -53,6 +60,7 @@ func (msg *Init) Encode(w io.Writer, pver uint32) error { return WriteElements(w, msg.GlobalFeatures, msg.Features, + msg.ExtraData, ) } @@ -69,5 +77,5 @@ func (msg *Init) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (msg *Init) MaxPayloadLength(uint32) uint32 { - return 2 + 2 + maxAllowedSize + 2 + maxAllowedSize + return MaxMsgBody } From 484686ddd1ae42e3152e1571124a42d4ab1d5c3f Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:31:59 -0800 Subject: [PATCH 26/43] lnwire: prep GossipTimestampRange for TLV extensions --- lnwire/gossip_timestamp_range.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/lnwire/gossip_timestamp_range.go b/lnwire/gossip_timestamp_range.go index 3c28cd056c2..fb62e27213f 100644 --- a/lnwire/gossip_timestamp_range.go +++ b/lnwire/gossip_timestamp_range.go @@ -24,6 +24,11 @@ type GossipTimestampRange struct { // NOT send any announcements that have a timestamp greater than // FirstTimestamp + TimestampRange. TimestampRange uint32 + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewGossipTimestampRange creates a new empty GossipTimestampRange message. @@ -44,6 +49,7 @@ func (g *GossipTimestampRange) Decode(r io.Reader, pver uint32) error { g.ChainHash[:], &g.FirstTimestamp, &g.TimestampRange, + &g.ExtraData, ) } @@ -56,6 +62,7 @@ func (g *GossipTimestampRange) Encode(w io.Writer, pver uint32) error { g.ChainHash[:], g.FirstTimestamp, g.TimestampRange, + g.ExtraData, ) } @@ -73,8 +80,5 @@ func (g *GossipTimestampRange) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (g *GossipTimestampRange) MaxPayloadLength(uint32) uint32 { - // 32 + 4 + 4 - // - // TODO(roasbeef): update to 8 byte timestmaps? - return 40 + return MaxMsgBody } From a000aa27922d2b61cafdfd72c614dd8416c6174e Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:32:11 -0800 Subject: [PATCH 27/43] lnwire: prep FundingSigned for TLV extensions --- lnwire/funding_signed.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/lnwire/funding_signed.go b/lnwire/funding_signed.go index 620f8b37317..1ef1556802e 100644 --- a/lnwire/funding_signed.go +++ b/lnwire/funding_signed.go @@ -13,6 +13,11 @@ type FundingSigned struct { // CommitSig is Bob's signature for Alice's version of the commitment // transaction. CommitSig Sig + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // A compile time check to ensure FundingSigned implements the lnwire.Message @@ -25,7 +30,7 @@ var _ Message = (*FundingSigned)(nil) // // This is part of the lnwire.Message interface. func (f *FundingSigned) Encode(w io.Writer, pver uint32) error { - return WriteElements(w, f.ChanID, f.CommitSig) + return WriteElements(w, f.ChanID, f.CommitSig, f.ExtraData) } // Decode deserializes the serialized FundingSigned stored in the passed @@ -34,7 +39,7 @@ func (f *FundingSigned) Encode(w io.Writer, pver uint32) error { // // This is part of the lnwire.Message interface. func (f *FundingSigned) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, &f.ChanID, &f.CommitSig) + return ReadElements(r, &f.ChanID, &f.CommitSig, &f.ExtraData) } // MsgType returns the uint32 code which uniquely identifies this message as a @@ -50,6 +55,5 @@ func (f *FundingSigned) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (f *FundingSigned) MaxPayloadLength(uint32) uint32 { - // 32 + 64 - return 96 + return MaxMsgBody } From 8041476e9f1a1f0de6ecf0c6e6a26df6b21a18bb Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:32:21 -0800 Subject: [PATCH 28/43] lnwire: prep FundingLocked for TLV extensions --- lnwire/funding_locked.go | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/lnwire/funding_locked.go b/lnwire/funding_locked.go index c441b0be621..1eeddfb6cc6 100644 --- a/lnwire/funding_locked.go +++ b/lnwire/funding_locked.go @@ -19,6 +19,11 @@ type FundingLocked struct { // NextPerCommitmentPoint is the secret that can be used to revoke the // next commitment transaction for the channel. NextPerCommitmentPoint *btcec.PublicKey + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewFundingLocked creates a new FundingLocked message, populating it with the @@ -27,6 +32,7 @@ func NewFundingLocked(cid ChannelID, npcp *btcec.PublicKey) *FundingLocked { return &FundingLocked{ ChanID: cid, NextPerCommitmentPoint: npcp, + ExtraData: make([]byte, 0), } } @@ -42,7 +48,9 @@ var _ Message = (*FundingLocked)(nil) func (c *FundingLocked) Decode(r io.Reader, pver uint32) error { return ReadElements(r, &c.ChanID, - &c.NextPerCommitmentPoint) + &c.NextPerCommitmentPoint, + &c.ExtraData, + ) } // Encode serializes the target FundingLocked message into the passed io.Writer @@ -53,7 +61,9 @@ func (c *FundingLocked) Decode(r io.Reader, pver uint32) error { func (c *FundingLocked) Encode(w io.Writer, pver uint32) error { return WriteElements(w, c.ChanID, - c.NextPerCommitmentPoint) + c.NextPerCommitmentPoint, + c.ExtraData, + ) } // MsgType returns the uint32 code which uniquely identifies this message as a @@ -70,14 +80,5 @@ func (c *FundingLocked) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *FundingLocked) MaxPayloadLength(uint32) uint32 { - var length uint32 - - // ChanID - 32 bytes - length += 32 - - // NextPerCommitmentPoint - 33 bytes - length += 33 - - // 65 bytes - return length + return MaxMsgBody } From ba8c1d1904d296744e88dfe45f8f1aeb8aba8fa0 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:32:29 -0800 Subject: [PATCH 29/43] lnwire: prep FundingCreated for TLV extensions --- lnwire/funding_created.go | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/lnwire/funding_created.go b/lnwire/funding_created.go index c14321ec8f9..437b1b6a8c5 100644 --- a/lnwire/funding_created.go +++ b/lnwire/funding_created.go @@ -24,6 +24,11 @@ type FundingCreated struct { // CommitSig is Alice's signature from Bob's version of the commitment // transaction. CommitSig Sig + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // A compile time check to ensure FundingCreated implements the lnwire.Message @@ -36,7 +41,10 @@ var _ Message = (*FundingCreated)(nil) // // This is part of the lnwire.Message interface. func (f *FundingCreated) Encode(w io.Writer, pver uint32) error { - return WriteElements(w, f.PendingChannelID[:], f.FundingPoint, f.CommitSig) + return WriteElements( + w, f.PendingChannelID[:], f.FundingPoint, f.CommitSig, + f.ExtraData, + ) } // Decode deserializes the serialized FundingCreated stored in the passed @@ -45,7 +53,10 @@ func (f *FundingCreated) Encode(w io.Writer, pver uint32) error { // // This is part of the lnwire.Message interface. func (f *FundingCreated) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, f.PendingChannelID[:], &f.FundingPoint, &f.CommitSig) + return ReadElements( + r, f.PendingChannelID[:], &f.FundingPoint, &f.CommitSig, + &f.ExtraData, + ) } // MsgType returns the uint32 code which uniquely identifies this message as a @@ -61,6 +72,5 @@ func (f *FundingCreated) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (f *FundingCreated) MaxPayloadLength(uint32) uint32 { - // 32 + 32 + 2 + 64 - return 130 + return MaxMsgBody } From eec879fc9c39fb38d7ad8be6f02e6412edd2b5f6 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:32:37 -0800 Subject: [PATCH 30/43] lnwire: prep CommitSig for TLV extensions --- lnwire/commit_sig.go | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/lnwire/commit_sig.go b/lnwire/commit_sig.go index 2455c016570..2ac71ddc99e 100644 --- a/lnwire/commit_sig.go +++ b/lnwire/commit_sig.go @@ -34,11 +34,18 @@ type CommitSig struct { // should be signed, for each incoming HTLC the HTLC timeout // transaction should be signed. HtlcSigs []Sig + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewCommitSig creates a new empty CommitSig message. func NewCommitSig() *CommitSig { - return &CommitSig{} + return &CommitSig{ + ExtraData: make([]byte, 0), + } } // A compile time check to ensure CommitSig implements the lnwire.Message @@ -54,6 +61,7 @@ func (c *CommitSig) Decode(r io.Reader, pver uint32) error { &c.ChanID, &c.CommitSig, &c.HtlcSigs, + &c.ExtraData, ) } @@ -66,6 +74,7 @@ func (c *CommitSig) Encode(w io.Writer, pver uint32) error { c.ChanID, c.CommitSig, c.HtlcSigs, + c.ExtraData, ) } @@ -82,8 +91,7 @@ func (c *CommitSig) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *CommitSig) MaxPayloadLength(uint32) uint32 { - // 32 + 64 + 2 + max_allowed_htlcs - return MaxMessagePayload + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is From 2137c0afb10441d693a11b116bae6a8f5a2e5040 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:32:45 -0800 Subject: [PATCH 31/43] lnwire: prep ClosingSigned for TLV extensions --- lnwire/closing_signed.go | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/lnwire/closing_signed.go b/lnwire/closing_signed.go index 91b90646a02..7732715bbf9 100644 --- a/lnwire/closing_signed.go +++ b/lnwire/closing_signed.go @@ -27,6 +27,11 @@ type ClosingSigned struct { // Signature is for the proposed channel close transaction. Signature Sig + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewClosingSigned creates a new empty ClosingSigned message. @@ -49,7 +54,9 @@ var _ Message = (*ClosingSigned)(nil) // // This is part of the lnwire.Message interface. func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, &c.ChannelID, &c.FeeSatoshis, &c.Signature) + return ReadElements( + r, &c.ChannelID, &c.FeeSatoshis, &c.Signature, &c.ExtraData, + ) } // Encode serializes the target ClosingSigned into the passed io.Writer @@ -57,7 +64,9 @@ func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (c *ClosingSigned) Encode(w io.Writer, pver uint32) error { - return WriteElements(w, c.ChannelID, c.FeeSatoshis, c.Signature) + return WriteElements( + w, c.ChannelID, c.FeeSatoshis, c.Signature, c.ExtraData, + ) } // MsgType returns the integer uniquely identifying this message type on the @@ -73,16 +82,5 @@ func (c *ClosingSigned) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *ClosingSigned) MaxPayloadLength(uint32) uint32 { - var length uint32 - - // ChannelID - 32 bytes - length += 32 - - // FeeSatoshis - 8 bytes - length += 8 - - // Signature - 64 bytes - length += 64 - - return length + return MaxMsgBody } From c518bec0d9742939f28d3e456decd9a95ca3835d Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:33:01 -0800 Subject: [PATCH 32/43] lnwire: prep ChannelUpdate for TLV extensions --- lnwire/channel_update.go | 32 ++++++++------------------------ lnwire/onion_error_test.go | 11 ++++++----- 2 files changed, 14 insertions(+), 29 deletions(-) diff --git a/lnwire/channel_update.go b/lnwire/channel_update.go index fd627646b6d..e44cde7a62b 100644 --- a/lnwire/channel_update.go +++ b/lnwire/channel_update.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "io" - "io/ioutil" "github.com/btcsuite/btcd/chaincfg/chainhash" ) @@ -110,13 +109,10 @@ type ChannelUpdate struct { // HtlcMaximumMsat is the maximum HTLC value which will be accepted. HtlcMaximumMsat MilliSatoshi - // ExtraOpaqueData is the set of data that was appended to this - // message, some of which we may not actually know how to iterate or - // parse. By holding onto this data, we ensure that we're able to - // properly validate the set of signatures that cover these new fields, - // and ensure we're able to make upgrades to the network in a forwards - // compatible manner. - ExtraOpaqueData []byte + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraOpaqueData ExtraOpaqueData } // A compile time check to ensure ChannelUpdate implements the lnwire.Message @@ -151,19 +147,7 @@ func (a *ChannelUpdate) Decode(r io.Reader, pver uint32) error { } } - // Now that we've read out all the fields that we explicitly know of, - // we'll collect the remainder into the ExtraOpaqueData field. If there - // aren't any bytes, then we'll snip off the slice to avoid carrying - // around excess capacity. - a.ExtraOpaqueData, err = ioutil.ReadAll(r) - if err != nil { - return err - } - if len(a.ExtraOpaqueData) == 0 { - a.ExtraOpaqueData = nil - } - - return nil + return a.ExtraOpaqueData.Decode(r) } // Encode serializes the target ChannelUpdate into the passed io.Writer @@ -196,7 +180,7 @@ func (a *ChannelUpdate) Encode(w io.Writer, pver uint32) error { } // Finally, append any extra opaque data. - return WriteElements(w, a.ExtraOpaqueData) + return a.ExtraOpaqueData.Encode(w) } // MsgType returns the integer uniquely identifying this message type on the @@ -212,7 +196,7 @@ func (a *ChannelUpdate) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (a *ChannelUpdate) MaxPayloadLength(pver uint32) uint32 { - return 65533 + return MaxMsgBody } // DataToSign is used to retrieve part of the announcement message which should @@ -245,7 +229,7 @@ func (a *ChannelUpdate) DataToSign() ([]byte, error) { } // Finally, append any extra opaque data. - if err := WriteElements(&w, a.ExtraOpaqueData); err != nil { + if err := a.ExtraOpaqueData.Encode(&w); err != nil { return nil, err } diff --git a/lnwire/onion_error_test.go b/lnwire/onion_error_test.go index 3ec147d1ddb..8c4c131c66a 100644 --- a/lnwire/onion_error_test.go +++ b/lnwire/onion_error_test.go @@ -20,11 +20,12 @@ var ( testOffset = uint16(24) sig, _ = NewSigFromSignature(testSig) testChannelUpdate = ChannelUpdate{ - Signature: sig, - ShortChannelID: NewShortChanIDFromInt(1), - Timestamp: 1, - MessageFlags: 0, - ChannelFlags: 1, + Signature: sig, + ShortChannelID: NewShortChanIDFromInt(1), + Timestamp: 1, + MessageFlags: 0, + ChannelFlags: 1, + ExtraOpaqueData: make([]byte, 0), } ) From ec85a9b81c5f908f8051377f47f68b22684c944a Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:33:07 -0800 Subject: [PATCH 33/43] lnwire: prep ChannelReestablish for TLV extensions --- lnwire/channel_reestablish.go | 49 +++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/lnwire/channel_reestablish.go b/lnwire/channel_reestablish.go index 42abcf95d7c..9a689ad4302 100644 --- a/lnwire/channel_reestablish.go +++ b/lnwire/channel_reestablish.go @@ -60,6 +60,11 @@ type ChannelReestablish struct { // LocalUnrevokedCommitPoint is the commitment point used in the // current un-revoked commitment transaction of the sending party. LocalUnrevokedCommitPoint *btcec.PublicKey + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // A compile time check to ensure ChannelReestablish implements the @@ -83,12 +88,20 @@ func (a *ChannelReestablish) Encode(w io.Writer, pver uint32) error { // If the commit point wasn't sent, then we won't write out any of the // remaining fields as they're optional. if a.LocalUnrevokedCommitPoint == nil { - return nil + // However, we'll still write out the extra data if it's + // present. + // + // NOTE: This is here primarily for the quickcheck tests, in + // practice, we'll always populate this field. + return WriteElements(w, a.ExtraData) } // Otherwise, we'll write out the remaining elements. - return WriteElements(w, a.LastRemoteCommitSecret[:], - a.LocalUnrevokedCommitPoint) + return WriteElements(w, + a.LastRemoteCommitSecret[:], + a.LocalUnrevokedCommitPoint, + a.ExtraData, + ) } // Decode deserializes a serialized ChannelReestablish stored in the passed @@ -118,6 +131,9 @@ func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error { var buf [32]byte _, err = io.ReadFull(r, buf[:32]) if err == io.EOF { + // If there aren't any more bytes, then we'll emplace an empty + // extra data to make our quickcheck tests happy. + a.ExtraData = make([]byte, 0) return nil } else if err != nil { return err @@ -127,9 +143,13 @@ func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error { copy(a.LastRemoteCommitSecret[:], buf[:]) // We'll conclude by parsing out the commitment point. We don't check - // the error in this case, as it hey included the commit secret, then + // the error in this case, as it they included the commit secret, then // they MUST also include the commit point. - return ReadElement(r, &a.LocalUnrevokedCommitPoint) + if err = ReadElement(r, &a.LocalUnrevokedCommitPoint); err != nil { + return err + } + + return a.ExtraData.Decode(r) } // MsgType returns the integer uniquely identifying this message type on the @@ -145,22 +165,5 @@ func (a *ChannelReestablish) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (a *ChannelReestablish) MaxPayloadLength(pver uint32) uint32 { - var length uint32 - - // ChanID - 32 bytes - length += 32 - - // NextLocalCommitHeight - 8 bytes - length += 8 - - // RemoteCommitTailHeight - 8 bytes - length += 8 - - // LastRemoteCommitSecret - 32 bytes - length += 32 - - // LocalUnrevokedCommitPoint - 33 bytes - length += 33 - - return length + return MaxMsgBody } From 0055691404e3c6f9c4944c9c62bd1b4fc9e07f61 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:33:23 -0800 Subject: [PATCH 34/43] lnwire: prep ChannelAnnouncement for TLV extensions --- lnwire/channel_announcement.go | 25 ++++--------------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/lnwire/channel_announcement.go b/lnwire/channel_announcement.go index 46efeed8807..de4b72b3155 100644 --- a/lnwire/channel_announcement.go +++ b/lnwire/channel_announcement.go @@ -3,7 +3,6 @@ package lnwire import ( "bytes" "io" - "io/ioutil" "github.com/btcsuite/btcd/chaincfg/chainhash" ) @@ -56,7 +55,7 @@ type ChannelAnnouncement struct { // properly validate the set of signatures that cover these new fields, // and ensure we're able to make upgrades to the network in a forwards // compatible manner. - ExtraOpaqueData []byte + ExtraOpaqueData ExtraOpaqueData } // A compile time check to ensure ChannelAnnouncement implements the @@ -68,7 +67,7 @@ var _ Message = (*ChannelAnnouncement)(nil) // // This is part of the lnwire.Message interface. func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error { - err := ReadElements(r, + return ReadElements(r, &a.NodeSig1, &a.NodeSig2, &a.BitcoinSig1, @@ -80,24 +79,8 @@ func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error { &a.NodeID2, &a.BitcoinKey1, &a.BitcoinKey2, + &a.ExtraOpaqueData, ) - if err != nil { - return err - } - - // Now that we've read out all the fields that we explicitly know of, - // we'll collect the remainder into the ExtraOpaqueData field. If there - // aren't any bytes, then we'll snip off the slice to avoid carrying - // around excess capacity. - a.ExtraOpaqueData, err = ioutil.ReadAll(r) - if err != nil { - return err - } - if len(a.ExtraOpaqueData) == 0 { - a.ExtraOpaqueData = nil - } - - return nil } // Encode serializes the target ChannelAnnouncement into the passed io.Writer @@ -134,7 +117,7 @@ func (a *ChannelAnnouncement) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (a *ChannelAnnouncement) MaxPayloadLength(pver uint32) uint32 { - return 65533 + return MaxMsgBody } // DataToSign is used to retrieve part of the announcement message which should From 3741b78de72a50dbd684af51fc9a0468324933a0 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:33:31 -0800 Subject: [PATCH 35/43] lnwire: prep AnnounceSignatures for TLV extensions --- lnwire/announcement_signatures.go | 25 ++++--------------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/lnwire/announcement_signatures.go b/lnwire/announcement_signatures.go index 0124521354f..9e0d99089b3 100644 --- a/lnwire/announcement_signatures.go +++ b/lnwire/announcement_signatures.go @@ -2,7 +2,6 @@ package lnwire import ( "io" - "io/ioutil" ) // AnnounceSignatures this is a direct message between two endpoints of a @@ -40,7 +39,7 @@ type AnnounceSignatures struct { // properly validate the set of signatures that cover these new fields, // and ensure we're able to make upgrades to the network in a forwards // compatible manner. - ExtraOpaqueData []byte + ExtraOpaqueData ExtraOpaqueData } // A compile time check to ensure AnnounceSignatures implements the @@ -52,29 +51,13 @@ var _ Message = (*AnnounceSignatures)(nil) // // This is part of the lnwire.Message interface. func (a *AnnounceSignatures) Decode(r io.Reader, pver uint32) error { - err := ReadElements(r, + return ReadElements(r, &a.ChannelID, &a.ShortChannelID, &a.NodeSignature, &a.BitcoinSignature, + &a.ExtraOpaqueData, ) - if err != nil { - return err - } - - // Now that we've read out all the fields that we explicitly know of, - // we'll collect the remainder into the ExtraOpaqueData field. If there - // aren't any bytes, then we'll snip off the slice to avoid carrying - // around excess capacity. - a.ExtraOpaqueData, err = ioutil.ReadAll(r) - if err != nil { - return err - } - if len(a.ExtraOpaqueData) == 0 { - a.ExtraOpaqueData = nil - } - - return nil } // Encode serializes the target AnnounceSignatures into the passed io.Writer @@ -104,5 +87,5 @@ func (a *AnnounceSignatures) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (a *AnnounceSignatures) MaxPayloadLength(pver uint32) uint32 { - return 65533 + return MaxMsgBody } From 092f48f3dcb30d90cf9a515887c3ead0582196c6 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 27 Jan 2020 17:33:45 -0800 Subject: [PATCH 36/43] lnwire: prep AcceptChannel for TLV extensions --- lnwire/accept_channel.go | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/lnwire/accept_channel.go b/lnwire/accept_channel.go index da9daa69b32..dee2c5ef46e 100644 --- a/lnwire/accept_channel.go +++ b/lnwire/accept_channel.go @@ -92,6 +92,11 @@ type AcceptChannel struct { // and has a length prefix, so a zero will be written if it is not set // and its length followed by the script will be written if it is set. UpfrontShutdownScript DeliveryAddress + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // A compile time check to ensure AcceptChannel implements the lnwire.Message @@ -104,7 +109,7 @@ var _ Message = (*AcceptChannel)(nil) // // This is part of the lnwire.Message interface. func (a *AcceptChannel) Encode(w io.Writer, pver uint32) error { - return WriteElements(w, + err := WriteElements(w, a.PendingChannelID[:], a.DustLimit, a.MaxValueInFlight, @@ -121,6 +126,11 @@ func (a *AcceptChannel) Encode(w io.Writer, pver uint32) error { a.FirstCommitmentPoint, a.UpfrontShutdownScript, ) + if err != nil { + return err + } + + return a.ExtraData.Encode(w) } // Decode deserializes the serialized AcceptChannel stored in the passed @@ -156,7 +166,8 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { if err != nil && err != io.EOF { return err } - return nil + + return a.ExtraData.Decode(r) } // MsgType returns the MessageType code which uniquely identifies this message @@ -172,11 +183,5 @@ func (a *AcceptChannel) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (a *AcceptChannel) MaxPayloadLength(uint32) uint32 { - // 32 + (8 * 4) + (4 * 1) + (2 * 2) + (33 * 6) - var length uint32 = 270 // base length - - // Upfront shutdown script max length. - length += 2 + deliveryAddressMaxSize - - return length + return MaxMsgBody } From 6de68ce0b7f938b9a02b20ba2ecd67c12d95be45 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 11 Mar 2020 17:13:37 -0700 Subject: [PATCH 37/43] lnwire: prep OpenChannel for TLV extensions --- lnwire/open_channel.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/lnwire/open_channel.go b/lnwire/open_channel.go index f78cc26eff5..06ac0e31e39 100644 --- a/lnwire/open_channel.go +++ b/lnwire/open_channel.go @@ -128,6 +128,11 @@ type OpenChannel struct { // and has a length prefix, so a zero will be written if it is not set // and its length followed by the script will be written if it is set. UpfrontShutdownScript DeliveryAddress + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // A compile time check to ensure OpenChannel implements the lnwire.Message @@ -160,6 +165,7 @@ func (o *OpenChannel) Encode(w io.Writer, pver uint32) error { o.FirstCommitmentPoint, o.ChannelFlags, o.UpfrontShutdownScript, + o.ExtraData, ) } @@ -199,7 +205,7 @@ func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { return err } - return nil + return ReadElement(r, &o.ExtraData) } // MsgType returns the MessageType code which uniquely identifies this message @@ -215,11 +221,5 @@ func (o *OpenChannel) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (o *OpenChannel) MaxPayloadLength(uint32) uint32 { - // (32 * 2) + (8 * 6) + (4 * 1) + (2 * 2) + (33 * 6) + 1 - var length uint32 = 319 // base length - - // Upfront shutdown script max length. - length += 2 + deliveryAddressMaxSize - - return length + return MaxMsgBody } From 509f1ef88fe13ac09e6180fc671e370c8f4f0e36 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 11 Mar 2020 17:14:17 -0700 Subject: [PATCH 38/43] lnwire: update quickcheck tests, use constant for Error --- lnwire/error.go | 3 +-- lnwire/lnwire_test.go | 26 ++++++++++++++++++++------ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/lnwire/error.go b/lnwire/error.go index c9fa39a8a45..4f23d1ef8e6 100644 --- a/lnwire/error.go +++ b/lnwire/error.go @@ -123,8 +123,7 @@ func (c *Error) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *Error) MaxPayloadLength(uint32) uint32 { - // 32 + 2 + 65501 - return 65535 + return MaxMsgBody } // isASCII is a helper method that checks whether all bytes in `data` would be diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index ea90f3c0cfa..6f19414960d 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -321,6 +321,7 @@ func TestLightningWireProtocol(t *testing.T) { CsvDelay: uint16(r.Int31()), MaxAcceptedHTLCs: uint16(r.Int31()), ChannelFlags: FundingFlag(uint8(r.Int31())), + ExtraData: make([]byte, 0), } if _, err := r.Read(req.ChainHash[:]); err != nil { @@ -387,6 +388,7 @@ func TestLightningWireProtocol(t *testing.T) { HtlcMinimum: MilliSatoshi(r.Int31()), CsvDelay: uint16(r.Int31()), MaxAcceptedHTLCs: uint16(r.Int31()), + ExtraData: make([]byte, 0), } if _, err := r.Read(req.PendingChannelID[:]); err != nil { @@ -440,7 +442,9 @@ func TestLightningWireProtocol(t *testing.T) { v[0] = reflect.ValueOf(req) }, MsgFundingCreated: func(v []reflect.Value, r *rand.Rand) { - req := FundingCreated{} + req := FundingCreated{ + ExtraData: make([]byte, 0), + } if _, err := r.Read(req.PendingChannelID[:]); err != nil { t.Fatalf("unable to generate pending chan id: %v", err) @@ -471,7 +475,8 @@ func TestLightningWireProtocol(t *testing.T) { } req := FundingSigned{ - ChanID: ChannelID(c), + ChanID: ChannelID(c), + ExtraData: make([]byte, 0), } req.CommitSig, err = NewSigFromSignature(testSig) if err != nil { @@ -502,6 +507,7 @@ func TestLightningWireProtocol(t *testing.T) { MsgClosingSigned: func(v []reflect.Value, r *rand.Rand) { req := ClosingSigned{ FeeSatoshis: btcutil.Amount(r.Int63()), + ExtraData: make([]byte, 0), } var err error req.Signature, err = NewSigFromSignature(testSig) @@ -570,8 +576,9 @@ func TestLightningWireProtocol(t *testing.T) { MsgChannelAnnouncement: func(v []reflect.Value, r *rand.Rand) { var err error req := ChannelAnnouncement{ - ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), - Features: randRawFeatureVector(r), + ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), + Features: randRawFeatureVector(r), + ExtraOpaqueData: make([]byte, 0), } req.NodeSig1, err = NewSigFromSignature(testSig) if err != nil { @@ -643,6 +650,7 @@ func TestLightningWireProtocol(t *testing.T) { G: uint8(r.Int31()), B: uint8(r.Int31()), }, + ExtraOpaqueData: make([]byte, 0), } req.Signature, err = NewSigFromSignature(testSig) if err != nil { @@ -698,6 +706,7 @@ func TestLightningWireProtocol(t *testing.T) { HtlcMaximumMsat: maxHtlc, BaseFee: uint32(r.Int31()), FeeRate: uint32(r.Int31()), + ExtraOpaqueData: make([]byte, 0), } req.Signature, err = NewSigFromSignature(testSig) if err != nil { @@ -726,7 +735,8 @@ func TestLightningWireProtocol(t *testing.T) { MsgAnnounceSignatures: func(v []reflect.Value, r *rand.Rand) { var err error req := AnnounceSignatures{ - ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), + ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), + ExtraOpaqueData: make([]byte, 0), } req.NodeSignature, err = NewSigFromSignature(testSig) @@ -763,6 +773,7 @@ func TestLightningWireProtocol(t *testing.T) { req := ChannelReestablish{ NextLocalCommitHeight: uint64(r.Int63()), RemoteCommitTailHeight: uint64(r.Int63()), + ExtraData: make([]byte, 0), } // With a 50/50 probability, we'll include the @@ -785,7 +796,9 @@ func TestLightningWireProtocol(t *testing.T) { v[0] = reflect.ValueOf(req) }, MsgQueryShortChanIDs: func(v []reflect.Value, r *rand.Rand) { - req := QueryShortChanIDs{} + req := QueryShortChanIDs{ + ExtraData: make([]byte, 0), + } // With a 50/50 change, we'll either use zlib encoding, // or regular encoding. @@ -812,6 +825,7 @@ func TestLightningWireProtocol(t *testing.T) { req := ReplyChannelRange{ FirstBlockHeight: uint32(r.Int31()), NumBlocks: uint32(r.Int31()), + ExtraData: make([]byte, 0), } if _, err := rand.Read(req.ChainHash[:]); err != nil { From 2bcdbb0be1480393485c613623182beff37b9f05 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 11 Mar 2020 18:30:41 -0700 Subject: [PATCH 39/43] lnwire: convert the delivery addr in [Open+Accept]Channel to a TLV type In this commit, we convert the delivery address in the open and accept channel methods to be a TLV type. This works as an "empty" delivery address is encoded using a two zero bytes (uint16 length zero), and a tlv type of 0 is encoded in the same manner (byte for type, byte for zero length). This change allows us to easily extend these messages in the future, in a uniform manner. --- fundingmanager.go | 42 ++++++++++--------- lnwire/accept_channel.go | 26 +++--------- lnwire/accept_channel_test.go | 2 +- lnwire/lnwire.go | 6 +++ lnwire/lnwire_test.go | 4 +- lnwire/open_channel.go | 19 +++------ lnwire/typed_delivery_addr.go | 65 ++++++++++++++++++++++++++++++ lnwire/typed_delivery_addr_test.go | 31 ++++++++++++++ 8 files changed, 140 insertions(+), 55 deletions(-) create mode 100644 lnwire/typed_delivery_addr.go create mode 100644 lnwire/typed_delivery_addr_test.go diff --git a/fundingmanager.go b/fundingmanager.go index 43e3c29d3db..49e6edc5c4a 100644 --- a/fundingmanager.go +++ b/fundingmanager.go @@ -1437,7 +1437,9 @@ func (f *fundingManager) handleFundingOpen(peer lnpeer.Peer, PubKey: copyPubKey(msg.HtlcPoint), }, }, - UpfrontShutdown: msg.UpfrontShutdownScript, + UpfrontShutdown: lnwire.DeliveryAddress( + msg.UpfrontShutdownScript, + ), } err = reservation.ProcessSingleContribution(remoteContribution) if err != nil { @@ -1455,21 +1457,23 @@ func (f *fundingManager) handleFundingOpen(peer lnpeer.Peer, // contribution in the next message of the workflow. ourContribution := reservation.OurContribution() fundingAccept := lnwire.AcceptChannel{ - PendingChannelID: msg.PendingChannelID, - DustLimit: ourContribution.DustLimit, - MaxValueInFlight: remoteMaxValue, - ChannelReserve: chanReserve, - MinAcceptDepth: uint32(numConfsReq), - HtlcMinimum: minHtlc, - CsvDelay: remoteCsvDelay, - MaxAcceptedHTLCs: maxHtlcs, - FundingKey: ourContribution.MultiSigKey.PubKey, - RevocationPoint: ourContribution.RevocationBasePoint.PubKey, - PaymentPoint: ourContribution.PaymentBasePoint.PubKey, - DelayedPaymentPoint: ourContribution.DelayBasePoint.PubKey, - HtlcPoint: ourContribution.HtlcBasePoint.PubKey, - FirstCommitmentPoint: ourContribution.FirstCommitmentPoint, - UpfrontShutdownScript: ourContribution.UpfrontShutdown, + PendingChannelID: msg.PendingChannelID, + DustLimit: ourContribution.DustLimit, + MaxValueInFlight: remoteMaxValue, + ChannelReserve: chanReserve, + MinAcceptDepth: uint32(numConfsReq), + HtlcMinimum: minHtlc, + CsvDelay: remoteCsvDelay, + MaxAcceptedHTLCs: maxHtlcs, + FundingKey: ourContribution.MultiSigKey.PubKey, + RevocationPoint: ourContribution.RevocationBasePoint.PubKey, + PaymentPoint: ourContribution.PaymentBasePoint.PubKey, + DelayedPaymentPoint: ourContribution.DelayBasePoint.PubKey, + HtlcPoint: ourContribution.HtlcBasePoint.PubKey, + FirstCommitmentPoint: ourContribution.FirstCommitmentPoint, + UpfrontShutdownScript: lnwire.TypedDeliveryAddress( + ourContribution.UpfrontShutdown, + ), } if err := peer.SendMessage(true, &fundingAccept); err != nil { @@ -1568,7 +1572,9 @@ func (f *fundingManager) handleFundingAccept(peer lnpeer.Peer, PubKey: copyPubKey(msg.HtlcPoint), }, }, - UpfrontShutdown: msg.UpfrontShutdownScript, + UpfrontShutdown: lnwire.DeliveryAddress( + msg.UpfrontShutdownScript, + ), } err = resCtx.reservation.ProcessContribution(remoteContribution) @@ -3255,7 +3261,7 @@ func (f *fundingManager) handleInitFundingMsg(msg *initFundingMsg) { DelayedPaymentPoint: ourContribution.DelayBasePoint.PubKey, FirstCommitmentPoint: ourContribution.FirstCommitmentPoint, ChannelFlags: channelFlags, - UpfrontShutdownScript: shutdown, + UpfrontShutdownScript: lnwire.TypedDeliveryAddress(shutdown), } if err := msg.peer.SendMessage(true, &fundingOpen); err != nil { e := fmt.Errorf("unable to send funding request message: %v", diff --git a/lnwire/accept_channel.go b/lnwire/accept_channel.go index dee2c5ef46e..0b74be88e40 100644 --- a/lnwire/accept_channel.go +++ b/lnwire/accept_channel.go @@ -91,7 +91,7 @@ type AcceptChannel struct { // be paid when mutually closing the channel. This field is optional, and // and has a length prefix, so a zero will be written if it is not set // and its length followed by the script will be written if it is set. - UpfrontShutdownScript DeliveryAddress + UpfrontShutdownScript TypedDeliveryAddress // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can @@ -109,7 +109,7 @@ var _ Message = (*AcceptChannel)(nil) // // This is part of the lnwire.Message interface. func (a *AcceptChannel) Encode(w io.Writer, pver uint32) error { - err := WriteElements(w, + return WriteElements(w, a.PendingChannelID[:], a.DustLimit, a.MaxValueInFlight, @@ -125,12 +125,8 @@ func (a *AcceptChannel) Encode(w io.Writer, pver uint32) error { a.HtlcPoint, a.FirstCommitmentPoint, a.UpfrontShutdownScript, + a.ExtraData, ) - if err != nil { - return err - } - - return a.ExtraData.Encode(w) } // Decode deserializes the serialized AcceptChannel stored in the passed @@ -140,7 +136,7 @@ func (a *AcceptChannel) Encode(w io.Writer, pver uint32) error { // This is part of the lnwire.Message interface. func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { // Read all the mandatory fields in the accept message. - err := ReadElements(r, + return ReadElements(r, a.PendingChannelID[:], &a.DustLimit, &a.MaxValueInFlight, @@ -155,19 +151,9 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { &a.DelayedPaymentPoint, &a.HtlcPoint, &a.FirstCommitmentPoint, + &a.UpfrontShutdownScript, + &a.ExtraData, ) - if err != nil { - return err - } - - // Check for the optional upfront shutdown script field. If it is not there, - // silence the EOF error. - err = ReadElement(r, &a.UpfrontShutdownScript) - if err != nil && err != io.EOF { - return err - } - - return a.ExtraData.Decode(r) } // MsgType returns the MessageType code which uniquely identifies this message diff --git a/lnwire/accept_channel_test.go b/lnwire/accept_channel_test.go index a1ab2be48c4..18ad1639ed2 100644 --- a/lnwire/accept_channel_test.go +++ b/lnwire/accept_channel_test.go @@ -12,7 +12,7 @@ import ( func TestDecodeAcceptChannel(t *testing.T) { tests := []struct { name string - shutdownScript DeliveryAddress + shutdownScript TypedDeliveryAddress }{ { name: "no upfront shutdown script", diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index c180cad3883..b6e70501031 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -429,6 +429,9 @@ func WriteElement(w io.Writer, element interface{}) error { case ExtraOpaqueData: return e.Encode(w) + case TypedDeliveryAddress: + return e.Encode(w) + default: return fmt.Errorf("unknown type in WriteElement: %T", e) } @@ -839,6 +842,9 @@ func ReadElement(r io.Reader, element interface{}) error { case *ExtraOpaqueData: return e.Decode(r) + case *TypedDeliveryAddress: + return e.Decode(r) + default: return fmt.Errorf("unknown type in ReadElement: %T", e) } diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 6f19414960d..a44726a7cd7 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -67,10 +67,10 @@ func randRawKey() ([33]byte, error) { return n, nil } -func randDeliveryAddress(r *rand.Rand) (DeliveryAddress, error) { +func randDeliveryAddress(r *rand.Rand) (TypedDeliveryAddress, error) { // Generate size minimum one. Empty scripts should be tested specifically. size := r.Intn(deliveryAddressMaxSize) + 1 - da := DeliveryAddress(make([]byte, size)) + da := TypedDeliveryAddress(make([]byte, size)) _, err := r.Read(da) return da, err diff --git a/lnwire/open_channel.go b/lnwire/open_channel.go index 06ac0e31e39..25ea42668cc 100644 --- a/lnwire/open_channel.go +++ b/lnwire/open_channel.go @@ -127,7 +127,7 @@ type OpenChannel struct { // be paid when mutually closing the channel. This field is optional, and // and has a length prefix, so a zero will be written if it is not set // and its length followed by the script will be written if it is set. - UpfrontShutdownScript DeliveryAddress + UpfrontShutdownScript TypedDeliveryAddress // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can @@ -175,7 +175,7 @@ func (o *OpenChannel) Encode(w io.Writer, pver uint32) error { // // This is part of the lnwire.Message interface. func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { - if err := ReadElements(r, + return ReadElements(r, o.ChainHash[:], o.PendingChannelID[:], &o.FundingAmount, @@ -194,18 +194,9 @@ func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { &o.HtlcPoint, &o.FirstCommitmentPoint, &o.ChannelFlags, - ); err != nil { - return err - } - - // Check for the optional upfront shutdown script field. If it is not there, - // silence the EOF error. - err := ReadElement(r, &o.UpfrontShutdownScript) - if err != nil && err != io.EOF { - return err - } - - return ReadElement(r, &o.ExtraData) + &o.UpfrontShutdownScript, + &o.ExtraData, + ) } // MsgType returns the MessageType code which uniquely identifies this message diff --git a/lnwire/typed_delivery_addr.go b/lnwire/typed_delivery_addr.go new file mode 100644 index 00000000000..fd546afd096 --- /dev/null +++ b/lnwire/typed_delivery_addr.go @@ -0,0 +1,65 @@ +package lnwire + +import ( + "io" + + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // DeliveryAddrType is the TLV record type for delivery addreses within + // the name space of the OpenChannel and AcceptChannel messages. + DeliveryAddrType = 0 +) + +// TypedDeliveryAddress is similar to the DeliveryAddrType type, but it's +// encoded using a mini TLV stream. This tyupe was intorudced in order to allow +// the OpenChannel/AcceptChannel messages to properly be extended with TLV types. +type TypedDeliveryAddress []byte + +// Encode encodes the target TypedDeliveryAddress into the target io.Writer +// using a TLV stream. +func (t *TypedDeliveryAddress) Encode(w io.Writer) error { + addrBytes := []byte((*t)[:]) + + records := []tlv.Record{ + tlv.MakeDynamicRecord( + DeliveryAddrType, &addrBytes, + func() uint64 { + return uint64(len(addrBytes)) + }, + tlv.EVarBytes, tlv.DVarBytes, + ), + } + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return err + } + + return tlvStream.Encode(w) +} + +// Decode decodes a set of bytes from the targer io.Reader into the target +// TypedDeliveryAddress. +func (t *TypedDeliveryAddress) Decode(r io.Reader) error { + addrBytes := []byte((*t)[:]) + + records := []tlv.Record{ + tlv.MakeDynamicRecord( + DeliveryAddrType, &addrBytes, nil, + tlv.EVarBytes, tlv.DVarBytes, + ), + } + + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return err + } + if err := tlvStream.Decode(r); err != nil { + return err + } + + *t = addrBytes + + return nil +} diff --git a/lnwire/typed_delivery_addr_test.go b/lnwire/typed_delivery_addr_test.go new file mode 100644 index 00000000000..c84915a6beb --- /dev/null +++ b/lnwire/typed_delivery_addr_test.go @@ -0,0 +1,31 @@ +package lnwire + +import ( + "bytes" + "testing" +) + +// TestTypedDeliveryAddressEncodeDecode tests that we're able to properly +// encode and decode typed delivery addresses. +func TestTypedDeliveryAddressEncodeDecode(t *testing.T) { + t.Parallel() + + addr := TypedDeliveryAddress( + bytes.Repeat([]byte("a"), deliveryAddressMaxSize), + ) + + var b bytes.Buffer + if err := addr.Encode(&b); err != nil { + t.Fatalf("unable to encode addr: %v", err) + } + + var addr2 TypedDeliveryAddress + if err := addr2.Decode(&b); err != nil { + t.Fatalf("unable to decode addr: %v", err) + } + + if !bytes.Equal(addr, addr2) { + t.Fatalf("addr mismatch: expected %x, got %x", addr[:], + addr2[:]) + } +} From 66df61718697c6499ed9bee997a828de63c1b979 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 25 Mar 2020 17:44:07 -0700 Subject: [PATCH 40/43] multi: update unit tests to pass deep equal assertions with messages In this commit, we update a series of unit tests in the code base to now pass due to the new wire message encode/decode logic. In many instances, we'll now manually set the extra bytes to an empty byte slice to avoid comparisons that fail due to one message having an empty byte slice and the other having a nil pointer. --- channeldb/channel_test.go | 22 ++++++++++++------- .../migration_01_to_11/migrations_test.go | 5 ++++- channeldb/waitingproof_test.go | 5 ++++- discovery/message_store_test.go | 6 +++-- htlcswitch/payment_result_test.go | 15 ++++++++----- lnwallet/channel_test.go | 2 ++ 6 files changed, 37 insertions(+), 18 deletions(-) diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 656a885bf55..6f319003dc7 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -639,7 +639,8 @@ func TestChannelStateTransition(t *testing.T) { { LogIndex: 2, UpdateMsg: &lnwire.UpdateAddHTLC{ - ChanID: lnwire.ChannelID{1, 2, 3}, + ChanID: lnwire.ChannelID{1, 2, 3}, + ExtraData: make([]byte, 0), }, }, } @@ -660,7 +661,9 @@ func TestChannelStateTransition(t *testing.T) { if !reflect.DeepEqual( dbUnsignedAckedUpdates[0], unsignedAckedUpdates[0], ) { - t.Fatalf("unexpected update") + t.Fatalf("unexpected update: expected %v, got %v", + spew.Sdump(unsignedAckedUpdates[0]), + spew.Sdump(dbUnsignedAckedUpdates)) } // The balances, new update, the HTLCs and the changes to the fake @@ -702,22 +705,25 @@ func TestChannelStateTransition(t *testing.T) { wireSig, wireSig, }, + ExtraData: make([]byte, 0), }, LogUpdates: []LogUpdate{ { LogIndex: 1, UpdateMsg: &lnwire.UpdateAddHTLC{ - ID: 1, - Amount: lnwire.NewMSatFromSatoshis(100), - Expiry: 25, + ID: 1, + Amount: lnwire.NewMSatFromSatoshis(100), + Expiry: 25, + ExtraData: make([]byte, 0), }, }, { LogIndex: 2, UpdateMsg: &lnwire.UpdateAddHTLC{ - ID: 2, - Amount: lnwire.NewMSatFromSatoshis(200), - Expiry: 50, + ID: 2, + Amount: lnwire.NewMSatFromSatoshis(200), + Expiry: 50, + ExtraData: make([]byte, 0), }, }, }, diff --git a/channeldb/migration_01_to_11/migrations_test.go b/channeldb/migration_01_to_11/migrations_test.go index 6cd855e85dd..7fc90855da3 100644 --- a/channeldb/migration_01_to_11/migrations_test.go +++ b/channeldb/migration_01_to_11/migrations_test.go @@ -464,7 +464,10 @@ func TestMigrateGossipMessageStoreKeys(t *testing.T) { // Construct the message which we'll use to test the migration, along // with its old and new key formats. shortChanID := lnwire.ShortChannelID{BlockHeight: 10} - msg := &lnwire.AnnounceSignatures{ShortChannelID: shortChanID} + msg := &lnwire.AnnounceSignatures{ + ShortChannelID: shortChanID, + ExtraOpaqueData: make([]byte, 0), + } var oldMsgKey [33 + 8]byte copy(oldMsgKey[:33], pubKey.SerializeCompressed()) diff --git a/channeldb/waitingproof_test.go b/channeldb/waitingproof_test.go index 12679b69f4a..24207da813e 100644 --- a/channeldb/waitingproof_test.go +++ b/channeldb/waitingproof_test.go @@ -5,6 +5,7 @@ import ( "reflect" + "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/lnwire" ) @@ -23,6 +24,7 @@ func TestWaitingProofStore(t *testing.T) { proof1 := NewWaitingProof(true, &lnwire.AnnounceSignatures{ NodeSignature: wireSig, BitcoinSignature: wireSig, + ExtraOpaqueData: make([]byte, 0), }) store, err := NewWaitingProofStore(db) @@ -40,7 +42,8 @@ func TestWaitingProofStore(t *testing.T) { t.Fatalf("unable retrieve proof from storage: %v", err) } if !reflect.DeepEqual(proof1, proof2) { - t.Fatal("wrong proof retrieved") + t.Fatalf("wrong proof retrieved: expected %v, got %v", + spew.Sdump(proof1), spew.Sdump(proof2)) } if _, err := store.Get(proof1.OppositeKey()); err != ErrWaitingProofNotFound { diff --git a/discovery/message_store_test.go b/discovery/message_store_test.go index fc7ba3360e6..d62dd16050c 100644 --- a/discovery/message_store_test.go +++ b/discovery/message_store_test.go @@ -64,13 +64,15 @@ func randCompressedPubKey(t *testing.T) [33]byte { func randAnnounceSignatures() *lnwire.AnnounceSignatures { return &lnwire.AnnounceSignatures{ - ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()), + ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()), + ExtraOpaqueData: make([]byte, 0), } } func randChannelUpdate() *lnwire.ChannelUpdate { return &lnwire.ChannelUpdate{ - ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()), + ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()), + ExtraOpaqueData: make([]byte, 0), } } diff --git a/htlcswitch/payment_result_test.go b/htlcswitch/payment_result_test.go index 04ff57d8f72..aa7cbc173ec 100644 --- a/htlcswitch/payment_result_test.go +++ b/htlcswitch/payment_result_test.go @@ -39,18 +39,21 @@ func TestNetworkResultSerialization(t *testing.T) { ChanID: chanID, ID: 2, PaymentPreimage: preimage, + ExtraData: make([]byte, 0), } fail := &lnwire.UpdateFailHTLC{ - ChanID: chanID, - ID: 1, - Reason: []byte{}, + ChanID: chanID, + ID: 1, + Reason: []byte{}, + ExtraData: make([]byte, 0), } fail2 := &lnwire.UpdateFailHTLC{ - ChanID: chanID, - ID: 1, - Reason: reason[:], + ChanID: chanID, + ID: 1, + Reason: reason[:], + ExtraData: make([]byte, 0), } testCases := []*networkResult{ diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 190e31c1255..038d4f16da6 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -3176,6 +3176,7 @@ func TestChanSyncOweCommitment(t *testing.T) { Amount: htlcAmt, Expiry: uint32(10), OnionBlob: fakeOnionBlob, + ExtraData: make([]byte, 0), } htlcIndex, err := bobChannel.AddHTLC(h, nil) @@ -3220,6 +3221,7 @@ func TestChanSyncOweCommitment(t *testing.T) { Amount: htlcAmt, Expiry: uint32(10), OnionBlob: fakeOnionBlob, + ExtraData: make([]byte, 0), } aliceHtlcIndex, err := aliceChannel.AddHTLC(aliceHtlc, nil) if err != nil { From d76bcc3abfd316deb5b07cf6317446fa5436aee5 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Fri, 24 Jul 2020 18:38:10 -0700 Subject: [PATCH 41/43] lnwire: propagate protocol version through entire package, add new version for TLV In this commit, we update all the existing methods for encoding and decoding in the package to also observe the currently unused `pver` (protocol version) argument. With this change, we can enforce the new "TLV everywhere" rules for all the new code, while leaving the existing migrations untouched (using the legacy version). Otherwise, a legacy migration would possibly read too many bytes (as the ExtraOpaqueBytes reads until EOF), causing it to corrupt existing data on disk. The Encode/Decode methods of the ExtraOpaqueData struct will now observe the new `pver` value, and return early (not reading/writing) the extra bytes based on this value. --- lnwire/accept_channel.go | 2 + lnwire/announcement_signatures.go | 2 + lnwire/channel_announcement.go | 6 +++ lnwire/channel_reestablish.go | 9 ++-- lnwire/channel_update.go | 23 ++++++---- lnwire/closing_signed.go | 5 ++- lnwire/commit_sig.go | 2 + lnwire/error.go | 2 + lnwire/extra_bytes.go | 18 ++++++-- lnwire/extra_bytes_test.go | 57 +++++++++++++++++++++++- lnwire/funding_created.go | 4 +- lnwire/funding_locked.go | 2 + lnwire/funding_signed.go | 4 +- lnwire/gossip_timestamp_range.go | 2 + lnwire/init_message.go | 2 + lnwire/lnwire.go | 41 ++++++++++-------- lnwire/lnwire_test.go | 7 +-- lnwire/message.go | 14 ++++++ lnwire/node_announcement.go | 5 +++ lnwire/onion_error.go | 65 ++++++++++++++-------------- lnwire/onion_error_test.go | 17 ++++---- lnwire/open_channel.go | 2 + lnwire/ping.go | 2 + lnwire/pong.go | 4 +- lnwire/query_channel_range.go | 2 + lnwire/query_short_chan_ids.go | 44 +++++++++++-------- lnwire/reply_channel_range.go | 12 +++-- lnwire/reply_channel_range_test.go | 6 ++- lnwire/reply_short_chan_ids_end.go | 2 + lnwire/revoke_and_ack.go | 2 + lnwire/shutdown.go | 4 +- lnwire/update_add_htlc.go | 2 + lnwire/update_fail_htlc.go | 2 + lnwire/update_fail_malformed_htlc.go | 2 + lnwire/update_fee.go | 2 + lnwire/update_fulfill_htlc.go | 2 + 36 files changed, 267 insertions(+), 112 deletions(-) diff --git a/lnwire/accept_channel.go b/lnwire/accept_channel.go index 0b74be88e40..5cbd8b2fe65 100644 --- a/lnwire/accept_channel.go +++ b/lnwire/accept_channel.go @@ -110,6 +110,7 @@ var _ Message = (*AcceptChannel)(nil) // This is part of the lnwire.Message interface. func (a *AcceptChannel) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, a.PendingChannelID[:], a.DustLimit, a.MaxValueInFlight, @@ -137,6 +138,7 @@ func (a *AcceptChannel) Encode(w io.Writer, pver uint32) error { func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { // Read all the mandatory fields in the accept message. return ReadElements(r, + pver, a.PendingChannelID[:], &a.DustLimit, &a.MaxValueInFlight, diff --git a/lnwire/announcement_signatures.go b/lnwire/announcement_signatures.go index 9e0d99089b3..cb9fe990b73 100644 --- a/lnwire/announcement_signatures.go +++ b/lnwire/announcement_signatures.go @@ -52,6 +52,7 @@ var _ Message = (*AnnounceSignatures)(nil) // This is part of the lnwire.Message interface. func (a *AnnounceSignatures) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &a.ChannelID, &a.ShortChannelID, &a.NodeSignature, @@ -66,6 +67,7 @@ func (a *AnnounceSignatures) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (a *AnnounceSignatures) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, a.ChannelID, a.ShortChannelID, a.NodeSignature, diff --git a/lnwire/channel_announcement.go b/lnwire/channel_announcement.go index de4b72b3155..92ccb520441 100644 --- a/lnwire/channel_announcement.go +++ b/lnwire/channel_announcement.go @@ -68,6 +68,7 @@ var _ Message = (*ChannelAnnouncement)(nil) // This is part of the lnwire.Message interface. func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &a.NodeSig1, &a.NodeSig2, &a.BitcoinSig1, @@ -89,6 +90,7 @@ func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (a *ChannelAnnouncement) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, a.NodeSig1, a.NodeSig2, a.BitcoinSig1, @@ -126,6 +128,10 @@ func (a *ChannelAnnouncement) DataToSign() ([]byte, error) { // We should not include the signatures itself. var w bytes.Buffer err := WriteElements(&w, + // We always use the modern protocol version here as we always + // need to include any optional data in the signature digest + // for forwards compatibility. + ProtocolVersionTLV, a.Features, a.ChainHash[:], a.ShortChannelID, diff --git a/lnwire/channel_reestablish.go b/lnwire/channel_reestablish.go index 9a689ad4302..aec99d491ec 100644 --- a/lnwire/channel_reestablish.go +++ b/lnwire/channel_reestablish.go @@ -77,6 +77,7 @@ var _ Message = (*ChannelReestablish)(nil) // This is part of the lnwire.Message interface. func (a *ChannelReestablish) Encode(w io.Writer, pver uint32) error { err := WriteElements(w, + pver, a.ChanID, a.NextLocalCommitHeight, a.RemoteCommitTailHeight, @@ -93,11 +94,12 @@ func (a *ChannelReestablish) Encode(w io.Writer, pver uint32) error { // // NOTE: This is here primarily for the quickcheck tests, in // practice, we'll always populate this field. - return WriteElements(w, a.ExtraData) + return WriteElements(w, pver, a.ExtraData) } // Otherwise, we'll write out the remaining elements. return WriteElements(w, + pver, a.LastRemoteCommitSecret[:], a.LocalUnrevokedCommitPoint, a.ExtraData, @@ -110,6 +112,7 @@ func (a *ChannelReestablish) Encode(w io.Writer, pver uint32) error { // This is part of the lnwire.Message interface. func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error { err := ReadElements(r, + pver, &a.ChanID, &a.NextLocalCommitHeight, &a.RemoteCommitTailHeight, @@ -145,11 +148,11 @@ func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error { // We'll conclude by parsing out the commitment point. We don't check // the error in this case, as it they included the commit secret, then // they MUST also include the commit point. - if err = ReadElement(r, &a.LocalUnrevokedCommitPoint); err != nil { + if err = ReadElement(r, pver, &a.LocalUnrevokedCommitPoint); err != nil { return err } - return a.ExtraData.Decode(r) + return a.ExtraData.Decode(r, pver) } // MsgType returns the integer uniquely identifying this message type on the diff --git a/lnwire/channel_update.go b/lnwire/channel_update.go index e44cde7a62b..8503959b0f2 100644 --- a/lnwire/channel_update.go +++ b/lnwire/channel_update.go @@ -125,6 +125,7 @@ var _ Message = (*ChannelUpdate)(nil) // This is part of the lnwire.Message interface. func (a *ChannelUpdate) Decode(r io.Reader, pver uint32) error { err := ReadElements(r, + pver, &a.Signature, a.ChainHash[:], &a.ShortChannelID, @@ -142,12 +143,12 @@ func (a *ChannelUpdate) Decode(r io.Reader, pver uint32) error { // Now check whether the max HTLC field is present and read it if so. if a.MessageFlags.HasMaxHtlc() { - if err := ReadElements(r, &a.HtlcMaximumMsat); err != nil { + if err := ReadElements(r, pver, &a.HtlcMaximumMsat); err != nil { return err } } - return a.ExtraOpaqueData.Decode(r) + return a.ExtraOpaqueData.Decode(r, pver) } // Encode serializes the target ChannelUpdate into the passed io.Writer @@ -156,6 +157,7 @@ func (a *ChannelUpdate) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (a *ChannelUpdate) Encode(w io.Writer, pver uint32) error { err := WriteElements(w, + pver, a.Signature, a.ChainHash[:], a.ShortChannelID, @@ -174,13 +176,13 @@ func (a *ChannelUpdate) Encode(w io.Writer, pver uint32) error { // Now append optional fields if they are set. Currently, the only // optional field is max HTLC. if a.MessageFlags.HasMaxHtlc() { - if err := WriteElements(w, a.HtlcMaximumMsat); err != nil { + if err := WriteElements(w, pver, a.HtlcMaximumMsat); err != nil { return err } } // Finally, append any extra opaque data. - return a.ExtraOpaqueData.Encode(w) + return a.ExtraOpaqueData.Encode(w, pver) } // MsgType returns the integer uniquely identifying this message type on the @@ -202,10 +204,10 @@ func (a *ChannelUpdate) MaxPayloadLength(pver uint32) uint32 { // DataToSign is used to retrieve part of the announcement message which should // be signed. func (a *ChannelUpdate) DataToSign() ([]byte, error) { - // We should not include the signatures itself. var w bytes.Buffer err := WriteElements(&w, + ProtocolVersionTLV, a.ChainHash[:], a.ShortChannelID, a.Timestamp, @@ -223,13 +225,18 @@ func (a *ChannelUpdate) DataToSign() ([]byte, error) { // Now append optional fields if they are set. Currently, the only // optional field is max HTLC. if a.MessageFlags.HasMaxHtlc() { - if err := WriteElements(&w, a.HtlcMaximumMsat); err != nil { + err := WriteElements( + &w, ProtocolVersionTLV, a.HtlcMaximumMsat, + ) + if err != nil { return nil, err } } - // Finally, append any extra opaque data. - if err := a.ExtraOpaqueData.Encode(&w); err != nil { + // Finally, append any extra opaque data. We always pass in the modern + // protocol version here as we always need to include any extra bytes + // in the signature digest. + if err := a.ExtraOpaqueData.Encode(&w, ProtocolVersionTLV); err != nil { return nil, err } diff --git a/lnwire/closing_signed.go b/lnwire/closing_signed.go index 7732715bbf9..c669b8b8520 100644 --- a/lnwire/closing_signed.go +++ b/lnwire/closing_signed.go @@ -55,7 +55,8 @@ var _ Message = (*ClosingSigned)(nil) // This is part of the lnwire.Message interface. func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error { return ReadElements( - r, &c.ChannelID, &c.FeeSatoshis, &c.Signature, &c.ExtraData, + r, pver, &c.ChannelID, &c.FeeSatoshis, &c.Signature, + &c.ExtraData, ) } @@ -65,7 +66,7 @@ func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *ClosingSigned) Encode(w io.Writer, pver uint32) error { return WriteElements( - w, c.ChannelID, c.FeeSatoshis, c.Signature, c.ExtraData, + w, pver, c.ChannelID, c.FeeSatoshis, c.Signature, c.ExtraData, ) } diff --git a/lnwire/commit_sig.go b/lnwire/commit_sig.go index 2ac71ddc99e..6b469af763e 100644 --- a/lnwire/commit_sig.go +++ b/lnwire/commit_sig.go @@ -58,6 +58,7 @@ var _ Message = (*CommitSig)(nil) // This is part of the lnwire.Message interface. func (c *CommitSig) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &c.ChanID, &c.CommitSig, &c.HtlcSigs, @@ -71,6 +72,7 @@ func (c *CommitSig) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *CommitSig) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChanID, c.CommitSig, c.HtlcSigs, diff --git a/lnwire/error.go b/lnwire/error.go index 4f23d1ef8e6..5ee9881f603 100644 --- a/lnwire/error.go +++ b/lnwire/error.go @@ -94,6 +94,7 @@ func (c *Error) Error() string { // This is part of the lnwire.Message interface. func (c *Error) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &c.ChanID, &c.Data, ) @@ -105,6 +106,7 @@ func (c *Error) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *Error) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChanID, c.Data, ) diff --git a/lnwire/extra_bytes.go b/lnwire/extra_bytes.go index a94a948ee2c..8d3b4267cbe 100644 --- a/lnwire/extra_bytes.go +++ b/lnwire/extra_bytes.go @@ -16,9 +16,15 @@ import ( type ExtraOpaqueData []byte // Encode attempts to encode the raw extra bytes into the passed io.Writer. -func (e *ExtraOpaqueData) Encode(w io.Writer) error { +func (e *ExtraOpaqueData) Encode(w io.Writer, pver uint32) error { + // Only write out the extra data if we're using the new modern protocol + // version. + if pver != ProtocolVersionTLV { + return nil + } + eBytes := []byte((*e)[:]) - if err := WriteElements(w, eBytes); err != nil { + if err := WriteElements(w, pver, eBytes); err != nil { return err } @@ -27,7 +33,13 @@ func (e *ExtraOpaqueData) Encode(w io.Writer) error { // Decode attempts to unpack the raw bytes encoded in the passed io.Reader as a // set of extra opaque data. -func (e *ExtraOpaqueData) Decode(r io.Reader) error { +func (e *ExtraOpaqueData) Decode(r io.Reader, pver uint32) error { + // Only if we're using the modern protocl version will we attempt to + // keep on decoding past the end of the "main message". + if pver != ProtocolVersionTLV { + return nil + } + // First, we'll attempt to read a set of bytes contained within the // passed io.Reader (if any exist). rawBytes, err := ioutil.ReadAll(r) diff --git a/lnwire/extra_bytes_test.go b/lnwire/extra_bytes_test.go index 55acfee61fe..3d1573d94c6 100644 --- a/lnwire/extra_bytes_test.go +++ b/lnwire/extra_bytes_test.go @@ -36,13 +36,13 @@ func TestExtraOpaqueDataEncodeDecode(t *testing.T) { copy(extraData[:], test.inputBytes) - if err := extraData.Encode(&b); err != nil { + if err := extraData.Encode(&b, ProtocolVersionTLV); err != nil { t.Fatalf("unable to encode extra data: %v", err) return false } var newBytes ExtraOpaqueData - if err := newBytes.Decode(&b); err != nil { + if err := newBytes.Decode(&b, ProtocolVersionTLV); err != nil { t.Fatalf("unable to decode extra bytes: %v", err) return false } @@ -145,3 +145,56 @@ func TestExtraOpaqueDataPackUnpackRecords(t *testing.T) { t.Fatalf("type2 not found in typeMap") } } + +// TestExtraOpaqueDataProtocolVersion tests that the encode/decode methods will +// observe the passed protocol version. +func TestExtraOpaqueDataProtocolVersion(t *testing.T) { + t.Parallel() + + extraData := ExtraOpaqueData([]byte("kek")) + + var b bytes.Buffer + if err := extraData.Encode(&b, ProtocolVersionLegacy); err != nil { + t.Fatalf("unable to encode: %v", err) + } + + // The statement above shouldn't have included the extra data since + // we're using the legacy protocol version. + if len(b.Bytes()) != 0 { + t.Fatalf("bytes were encoded using legacy "+ + "protocol version: %x", b.Bytes()) + } + + // If we encode using the proper version, then we should find the same + // data encoded on the other side. + if err := extraData.Encode(&b, ProtocolVersionTLV); err != nil { + t.Fatalf("unable to encode: %v", err) + } + if !bytes.Equal(b.Bytes(), extraData[:]) { + t.Fatalf("encoding mismatch: expected %x, got %x", + b.Bytes(), extraData[:]) + } + + // Now for the other direction, we'll attempt to decode into a fresh + // buffer, but using the legacy version. In the end, no bytes should be + // decoded. + var newExtraData ExtraOpaqueData + if err := newExtraData.Decode(&b, ProtocolVersionLegacy); err != nil { + t.Fatalf("unable to decode data: %v", err) + } + + if len(newExtraData[:]) != 0 { + t.Fatalf("expected not data to be decoded!") + } + + // Finally, if we decode using the proper protocol version, we should get + // the same bytes out that we put in. + if err := newExtraData.Decode(&b, ProtocolVersionTLV); err != nil { + t.Fatalf("unable to decode data: %v", err) + } + if !bytes.Equal(extraData[:], newExtraData[:]) { + t.Fatalf("encoding mismatch: expected %x, got %x", + extraData, newExtraData) + } + +} diff --git a/lnwire/funding_created.go b/lnwire/funding_created.go index 437b1b6a8c5..7eb7b2b2abd 100644 --- a/lnwire/funding_created.go +++ b/lnwire/funding_created.go @@ -42,7 +42,7 @@ var _ Message = (*FundingCreated)(nil) // This is part of the lnwire.Message interface. func (f *FundingCreated) Encode(w io.Writer, pver uint32) error { return WriteElements( - w, f.PendingChannelID[:], f.FundingPoint, f.CommitSig, + w, pver, f.PendingChannelID[:], f.FundingPoint, f.CommitSig, f.ExtraData, ) } @@ -54,7 +54,7 @@ func (f *FundingCreated) Encode(w io.Writer, pver uint32) error { // This is part of the lnwire.Message interface. func (f *FundingCreated) Decode(r io.Reader, pver uint32) error { return ReadElements( - r, f.PendingChannelID[:], &f.FundingPoint, &f.CommitSig, + r, pver, f.PendingChannelID[:], &f.FundingPoint, &f.CommitSig, &f.ExtraData, ) } diff --git a/lnwire/funding_locked.go b/lnwire/funding_locked.go index 1eeddfb6cc6..af857207ba2 100644 --- a/lnwire/funding_locked.go +++ b/lnwire/funding_locked.go @@ -47,6 +47,7 @@ var _ Message = (*FundingLocked)(nil) // This is part of the lnwire.Message interface. func (c *FundingLocked) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &c.ChanID, &c.NextPerCommitmentPoint, &c.ExtraData, @@ -60,6 +61,7 @@ func (c *FundingLocked) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *FundingLocked) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChanID, c.NextPerCommitmentPoint, c.ExtraData, diff --git a/lnwire/funding_signed.go b/lnwire/funding_signed.go index 1ef1556802e..844d53a5f61 100644 --- a/lnwire/funding_signed.go +++ b/lnwire/funding_signed.go @@ -30,7 +30,7 @@ var _ Message = (*FundingSigned)(nil) // // This is part of the lnwire.Message interface. func (f *FundingSigned) Encode(w io.Writer, pver uint32) error { - return WriteElements(w, f.ChanID, f.CommitSig, f.ExtraData) + return WriteElements(w, pver, f.ChanID, f.CommitSig, f.ExtraData) } // Decode deserializes the serialized FundingSigned stored in the passed @@ -39,7 +39,7 @@ func (f *FundingSigned) Encode(w io.Writer, pver uint32) error { // // This is part of the lnwire.Message interface. func (f *FundingSigned) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, &f.ChanID, &f.CommitSig, &f.ExtraData) + return ReadElements(r, pver, &f.ChanID, &f.CommitSig, &f.ExtraData) } // MsgType returns the uint32 code which uniquely identifies this message as a diff --git a/lnwire/gossip_timestamp_range.go b/lnwire/gossip_timestamp_range.go index fb62e27213f..4d4c834c503 100644 --- a/lnwire/gossip_timestamp_range.go +++ b/lnwire/gossip_timestamp_range.go @@ -46,6 +46,7 @@ var _ Message = (*GossipTimestampRange)(nil) // This is part of the lnwire.Message interface. func (g *GossipTimestampRange) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, g.ChainHash[:], &g.FirstTimestamp, &g.TimestampRange, @@ -59,6 +60,7 @@ func (g *GossipTimestampRange) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (g *GossipTimestampRange) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, g.ChainHash[:], g.FirstTimestamp, g.TimestampRange, diff --git a/lnwire/init_message.go b/lnwire/init_message.go index 18af1d7da9a..402e6bfabbc 100644 --- a/lnwire/init_message.go +++ b/lnwire/init_message.go @@ -46,6 +46,7 @@ var _ Message = (*Init)(nil) // This is part of the lnwire.Message interface. func (msg *Init) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &msg.GlobalFeatures, &msg.Features, &msg.ExtraData, @@ -58,6 +59,7 @@ func (msg *Init) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (msg *Init) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, msg.GlobalFeatures, msg.Features, msg.ExtraData, diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index b6e70501031..fdbea7a3cbe 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -77,11 +77,12 @@ func (a addressType) AddrLen() uint16 { // WriteElement is a one-stop shop to write the big endian representation of // any element which is to be serialized for the wire protocol. The passed // io.Writer should be backed by an appropriately sized byte slice, or be able -// to dynamically expand to accommodate additional data. +// to dynamically expand to accommodate additional data. The passed protocol +// version may affect how items are encoded. // // TODO(roasbeef): this should eventually draw from a buffer pool for // serialization. -func WriteElement(w io.Writer, element interface{}) error { +func WriteElement(w io.Writer, pver uint32, element interface{}) error { switch e := element.(type) { case NodeAlias: if _, err := w.Write(e[:]); err != nil { @@ -168,7 +169,7 @@ func WriteElement(w io.Writer, element interface{}) error { } for _, sig := range e { - if err := WriteElement(w, sig); err != nil { + if err := WriteElement(w, pver, sig); err != nil { return err } } @@ -269,7 +270,7 @@ func WriteElement(w io.Writer, element interface{}) error { return err } case FailCode: - if err := WriteElement(w, uint16(e)); err != nil { + if err := WriteElement(w, pver, uint16(e)); err != nil { return err } case ShortChannelID: @@ -383,7 +384,8 @@ func WriteElement(w io.Writer, element interface{}) error { // length of the addresses. var addrBuf bytes.Buffer for _, address := range e { - if err := WriteElement(&addrBuf, address); err != nil { + err := WriteElement(&addrBuf, pver, address) + if err != nil { return err } } @@ -391,7 +393,7 @@ func WriteElement(w io.Writer, element interface{}) error { // With the addresses fully encoded, we can now write out the // number of bytes needed to encode them. addrLen := addrBuf.Len() - if err := WriteElement(w, uint16(addrLen)); err != nil { + if err := WriteElement(w, pver, uint16(addrLen)); err != nil { return err } @@ -403,7 +405,7 @@ func WriteElement(w io.Writer, element interface{}) error { } } case color.RGBA: - if err := WriteElements(w, e.R, e.G, e.B); err != nil { + if err := WriteElements(w, pver, e.R, e.G, e.B); err != nil { return err } @@ -427,7 +429,7 @@ func WriteElement(w io.Writer, element interface{}) error { } case ExtraOpaqueData: - return e.Encode(w) + return e.Encode(w, pver) case TypedDeliveryAddress: return e.Encode(w) @@ -441,9 +443,9 @@ func WriteElement(w io.Writer, element interface{}) error { // WriteElements is writes each element in the elements slice to the passed // io.Writer using WriteElement. -func WriteElements(w io.Writer, elements ...interface{}) error { +func WriteElements(w io.Writer, pver uint32, elements ...interface{}) error { for _, element := range elements { - err := WriteElement(w, element) + err := WriteElement(w, pver, element) if err != nil { return err } @@ -452,8 +454,9 @@ func WriteElements(w io.Writer, elements ...interface{}) error { } // ReadElement is a one-stop utility function to deserialize any datastructure -// encoded using the serialization format of lnwire. -func ReadElement(r io.Reader, element interface{}) error { +// encoded using the serialization format of lnwire. The passed protocol +// version may affect how items are decoded. +func ReadElement(r io.Reader, pver uint32, element interface{}) error { var err error switch e := element.(type) { case *bool: @@ -569,7 +572,8 @@ func ReadElement(r io.Reader, element interface{}) error { if numSigs > 0 { sigs = make([]Sig, numSigs) for i := 0; i < int(numSigs); i++ { - if err := ReadElement(r, &sigs[i]); err != nil { + err := ReadElement(r, pver, &sigs[i]) + if err != nil { return err } } @@ -661,7 +665,7 @@ func ReadElement(r io.Reader, element interface{}) error { Index: uint32(index), } case *FailCode: - if err := ReadElement(r, (*uint16)(e)); err != nil { + if err := ReadElement(r, pver, (*uint16)(e)); err != nil { return err } case *ChannelID: @@ -816,6 +820,7 @@ func ReadElement(r io.Reader, element interface{}) error { *e = addresses case *color.RGBA: err := ReadElements(r, + pver, &e.R, &e.G, &e.B, @@ -840,7 +845,7 @@ func ReadElement(r io.Reader, element interface{}) error { *e = addrBytes[:length] case *ExtraOpaqueData: - return e.Decode(r) + return e.Decode(r, pver) case *TypedDeliveryAddress: return e.Decode(r) @@ -854,10 +859,10 @@ func ReadElement(r io.Reader, element interface{}) error { // ReadElements deserializes a variable number of elements into the passed // io.Reader, with each element being deserialized according to the ReadElement -// function. -func ReadElements(r io.Reader, elements ...interface{}) error { +// function. The passed protocol version may affect how the items are encoded. +func ReadElements(r io.Reader, pver uint32, elements ...interface{}) error { for _, element := range elements { - err := ReadElement(r, element) + err := ReadElement(r, pver, element) if err != nil { return err } diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index a44726a7cd7..d6cb76e7d7a 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -232,7 +232,7 @@ func TestMaxOutPointIndex(t *testing.T) { } var b bytes.Buffer - if err := WriteElement(&b, op); err == nil { + if err := WriteElement(&b, ProtocolVersionTLV, op); err == nil { t.Fatalf("write of outPoint should fail, index exceeds 16-bits") } } @@ -265,7 +265,8 @@ func TestLightningWireProtocol(t *testing.T) { // Give a new message, we'll serialize the message into a new // bytes buffer. var b bytes.Buffer - if _, err := WriteMessage(&b, msg, 0); err != nil { + _, err := WriteMessage(&b, msg, ProtocolVersionTLV) + if err != nil { t.Fatalf("unable to write msg: %v", err) return false } @@ -282,7 +283,7 @@ func TestLightningWireProtocol(t *testing.T) { // Finally, we'll deserialize the message from the written // buffer, and finally assert that the messages are equal. - newMsg, err := ReadMessage(&b, 0) + newMsg, err := ReadMessage(&b, ProtocolVersionTLV) if err != nil { t.Fatalf("unable to read msg: %v", err) return false diff --git a/lnwire/message.go b/lnwire/message.go index b5c27339e9e..58c95ca95fc 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -56,6 +56,20 @@ const ( MsgGossipTimestampRange = 265 ) +const ( + // ProtocolVersionLegacy is the legacy protocol version. When reading + // or writing messages using this protocol version, any optional fields + // appended to the end of the message will be ignored. + ProtocolVersionLegacy uint32 = 0 + + // ProtocolVersionTLV is the current modern protocol version. When + // reading/writing messages with this version, decoding will continue + // until the entire payload has been ready. When writing with this + // version, any optional fields appended to the end of the main message + // will also be written out. + ProtocolVersionTLV uint32 = 1 +) + // String return the string representation of message type. func (t MessageType) String() string { switch t { diff --git a/lnwire/node_announcement.go b/lnwire/node_announcement.go index c794e5b52ef..eb257509547 100644 --- a/lnwire/node_announcement.go +++ b/lnwire/node_announcement.go @@ -110,6 +110,7 @@ var _ Message = (*NodeAnnouncement)(nil) // This is part of the lnwire.Message interface. func (a *NodeAnnouncement) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &a.Signature, &a.Features, &a.Timestamp, @@ -126,6 +127,7 @@ func (a *NodeAnnouncement) Decode(r io.Reader, pver uint32) error { // func (a *NodeAnnouncement) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, a.Signature, a.Features, a.Timestamp, @@ -159,6 +161,9 @@ func (a *NodeAnnouncement) DataToSign() ([]byte, error) { // We should not include the signatures itself. var w bytes.Buffer err := WriteElements(&w, + // We always use the modern protocol version as we need to + // include all data for forawrds compatability. + ProtocolVersionTLV, a.Features, a.Timestamp, a.NodeID, diff --git a/lnwire/onion_error.go b/lnwire/onion_error.go index c6235552e9c..1ab024de629 100644 --- a/lnwire/onion_error.go +++ b/lnwire/onion_error.go @@ -389,7 +389,7 @@ func (f *FailIncorrectDetails) Error() string { // // NOTE: Part of the Serializable interface. func (f *FailIncorrectDetails) Decode(r io.Reader, pver uint32) error { - err := ReadElement(r, &f.amount) + err := ReadElement(r, pver, &f.amount) switch { // This is an optional tack on that was added later in the protocol. As // a result, older nodes may not include this value. We'll account for @@ -404,7 +404,7 @@ func (f *FailIncorrectDetails) Decode(r io.Reader, pver uint32) error { // At a later stage, the height field was also tacked on. We need to // check for io.EOF here as well. - err = ReadElement(r, &f.height) + err = ReadElement(r, pver, &f.height) switch { case err == io.EOF: return nil @@ -420,7 +420,7 @@ func (f *FailIncorrectDetails) Decode(r io.Reader, pver uint32) error { // // NOTE: Part of the Serializable interface. func (f *FailIncorrectDetails) Encode(w io.Writer, pver uint32) error { - return WriteElements(w, f.amount, f.height) + return WriteElements(w, pver, f.amount, f.height) } // FailFinalExpiryTooSoon is returned if the cltv_expiry is too low, the final @@ -479,14 +479,14 @@ func (f *FailInvalidOnionVersion) Code() FailCode { // // NOTE: Part of the Serializable interface. func (f *FailInvalidOnionVersion) Decode(r io.Reader, pver uint32) error { - return ReadElement(r, f.OnionSHA256[:]) + return ReadElement(r, pver, f.OnionSHA256[:]) } // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. func (f *FailInvalidOnionVersion) Encode(w io.Writer, pver uint32) error { - return WriteElement(w, f.OnionSHA256[:]) + return WriteElement(w, pver, f.OnionSHA256[:]) } // FailInvalidOnionHmac is return if the onion HMAC is incorrect. @@ -513,14 +513,14 @@ func (f *FailInvalidOnionHmac) Code() FailCode { // // NOTE: Part of the Serializable interface. func (f *FailInvalidOnionHmac) Decode(r io.Reader, pver uint32) error { - return ReadElement(r, f.OnionSHA256[:]) + return ReadElement(r, pver, f.OnionSHA256[:]) } // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. func (f *FailInvalidOnionHmac) Encode(w io.Writer, pver uint32) error { - return WriteElement(w, f.OnionSHA256[:]) + return WriteElement(w, pver, f.OnionSHA256[:]) } // Returns a human readable string describing the target FailureMessage. @@ -555,14 +555,14 @@ func (f *FailInvalidOnionKey) Code() FailCode { // // NOTE: Part of the Serializable interface. func (f *FailInvalidOnionKey) Decode(r io.Reader, pver uint32) error { - return ReadElement(r, f.OnionSHA256[:]) + return ReadElement(r, pver, f.OnionSHA256[:]) } // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. func (f *FailInvalidOnionKey) Encode(w io.Writer, pver uint32) error { - return WriteElement(w, f.OnionSHA256[:]) + return WriteElement(w, pver, f.OnionSHA256[:]) } // Returns a human readable string describing the target FailureMessage. @@ -652,7 +652,7 @@ func (f *FailTemporaryChannelFailure) Error() string { // NOTE: Part of the Serializable interface. func (f *FailTemporaryChannelFailure) Decode(r io.Reader, pver uint32) error { var length uint16 - err := ReadElement(r, &length) + err := ReadElement(r, pver, &length) if err != nil { return err } @@ -680,7 +680,7 @@ func (f *FailTemporaryChannelFailure) Encode(w io.Writer, pver uint32) error { payload = bw.Bytes() } - if err := WriteElement(w, uint16(len(payload))); err != nil { + if err := WriteElement(w, pver, uint16(len(payload))); err != nil { return err } @@ -731,12 +731,12 @@ func (f *FailAmountBelowMinimum) Error() string { // // NOTE: Part of the Serializable interface. func (f *FailAmountBelowMinimum) Decode(r io.Reader, pver uint32) error { - if err := ReadElement(r, &f.HtlcMsat); err != nil { + if err := ReadElement(r, pver, &f.HtlcMsat); err != nil { return err } var length uint16 - if err := ReadElement(r, &length); err != nil { + if err := ReadElement(r, pver, &length); err != nil { return err } @@ -750,7 +750,7 @@ func (f *FailAmountBelowMinimum) Decode(r io.Reader, pver uint32) error { // // NOTE: Part of the Serializable interface. func (f *FailAmountBelowMinimum) Encode(w io.Writer, pver uint32) error { - if err := WriteElement(w, f.HtlcMsat); err != nil { + if err := WriteElement(w, pver, f.HtlcMsat); err != nil { return err } @@ -799,12 +799,12 @@ func (f *FailFeeInsufficient) Error() string { // // NOTE: Part of the Serializable interface. func (f *FailFeeInsufficient) Decode(r io.Reader, pver uint32) error { - if err := ReadElement(r, &f.HtlcMsat); err != nil { + if err := ReadElement(r, pver, &f.HtlcMsat); err != nil { return err } var length uint16 - if err := ReadElement(r, &length); err != nil { + if err := ReadElement(r, pver, &length); err != nil { return err } @@ -818,7 +818,7 @@ func (f *FailFeeInsufficient) Decode(r io.Reader, pver uint32) error { // // NOTE: Part of the Serializable interface. func (f *FailFeeInsufficient) Encode(w io.Writer, pver uint32) error { - if err := WriteElement(w, f.HtlcMsat); err != nil { + if err := WriteElement(w, pver, f.HtlcMsat); err != nil { return err } @@ -867,12 +867,12 @@ func (f *FailIncorrectCltvExpiry) Error() string { // // NOTE: Part of the Serializable interface. func (f *FailIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error { - if err := ReadElement(r, &f.CltvExpiry); err != nil { + if err := ReadElement(r, pver, &f.CltvExpiry); err != nil { return err } var length uint16 - if err := ReadElement(r, &length); err != nil { + if err := ReadElement(r, pver, &length); err != nil { return err } @@ -886,7 +886,7 @@ func (f *FailIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error { // // NOTE: Part of the Serializable interface. func (f *FailIncorrectCltvExpiry) Encode(w io.Writer, pver uint32) error { - if err := WriteElement(w, f.CltvExpiry); err != nil { + if err := WriteElement(w, pver, f.CltvExpiry); err != nil { return err } @@ -929,7 +929,7 @@ func (f *FailExpiryTooSoon) Error() string { // NOTE: Part of the Serializable interface. func (f *FailExpiryTooSoon) Decode(r io.Reader, pver uint32) error { var length uint16 - if err := ReadElement(r, &length); err != nil { + if err := ReadElement(r, pver, &length); err != nil { return err } @@ -988,12 +988,12 @@ func (f *FailChannelDisabled) Error() string { // // NOTE: Part of the Serializable interface. func (f *FailChannelDisabled) Decode(r io.Reader, pver uint32) error { - if err := ReadElement(r, &f.Flags); err != nil { + if err := ReadElement(r, pver, &f.Flags); err != nil { return err } var length uint16 - if err := ReadElement(r, &length); err != nil { + if err := ReadElement(r, pver, &length); err != nil { return err } @@ -1007,7 +1007,7 @@ func (f *FailChannelDisabled) Decode(r io.Reader, pver uint32) error { // // NOTE: Part of the Serializable interface. func (f *FailChannelDisabled) Encode(w io.Writer, pver uint32) error { - if err := WriteElement(w, f.Flags); err != nil { + if err := WriteElement(w, pver, f.Flags); err != nil { return err } @@ -1050,14 +1050,14 @@ func (f *FailFinalIncorrectCltvExpiry) Code() FailCode { // // NOTE: Part of the Serializable interface. func (f *FailFinalIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error { - return ReadElement(r, &f.CltvExpiry) + return ReadElement(r, pver, &f.CltvExpiry) } // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. func (f *FailFinalIncorrectCltvExpiry) Encode(w io.Writer, pver uint32) error { - return WriteElement(w, f.CltvExpiry) + return WriteElement(w, pver, f.CltvExpiry) } // FailFinalIncorrectHtlcAmount is returned if the amt_to_forward is higher @@ -1096,14 +1096,14 @@ func (f *FailFinalIncorrectHtlcAmount) Code() FailCode { // // NOTE: Part of the Serializable interface. func (f *FailFinalIncorrectHtlcAmount) Decode(r io.Reader, pver uint32) error { - return ReadElement(r, &f.IncomingHTLCAmount) + return ReadElement(r, pver, &f.IncomingHTLCAmount) } // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. func (f *FailFinalIncorrectHtlcAmount) Encode(w io.Writer, pver uint32) error { - return WriteElement(w, f.IncomingHTLCAmount) + return WriteElement(w, pver, f.IncomingHTLCAmount) } // FailExpiryTooFar is returned if the CLTV expiry in the HTLC is too far in the @@ -1171,7 +1171,7 @@ func (f *InvalidOnionPayload) Decode(r io.Reader, pver uint32) error { } f.Type = typ - return ReadElements(r, &f.Offset) + return ReadElements(r, pver, &f.Offset) } // Encode writes the failure in bytes stream. @@ -1183,7 +1183,7 @@ func (f *InvalidOnionPayload) Encode(w io.Writer, pver uint32) error { return err } - return WriteElements(w, f.Offset) + return WriteElements(w, pver, f.Offset) } // FailMPPTimeout is returned if the complete amount for a multi part payment @@ -1212,7 +1212,7 @@ func DecodeFailure(r io.Reader, pver uint32) (FailureMessage, error) { // First, we'll parse out the encapsulated failure message itself. This // is a 2 byte length followed by the payload itself. var failureLength uint16 - if err := ReadElement(r, &failureLength); err != nil { + if err := ReadElement(r, pver, &failureLength); err != nil { return nil, fmt.Errorf("unable to read error len: %v", err) } if failureLength > FailureMessageLength { @@ -1284,6 +1284,7 @@ func EncodeFailure(w io.Writer, failure FailureMessage, pver uint32) error { pad := make([]byte, FailureMessageLength-len(failureMessage)) return WriteElements(w, + pver, uint16(len(failureMessage)), failureMessage, uint16(len(pad)), @@ -1414,7 +1415,7 @@ func writeOnionErrorChanUpdate(w io.Writer, chanUpdate *ChannelUpdate, // Now that we know the size, we can write the length out in the main // writer. updateLen := b.Len() - if err := WriteElement(w, uint16(updateLen)); err != nil { + if err := WriteElement(w, pver, uint16(updateLen)); err != nil { return err } diff --git a/lnwire/onion_error_test.go b/lnwire/onion_error_test.go index 8c4c131c66a..a1ed4fbe1d6 100644 --- a/lnwire/onion_error_test.go +++ b/lnwire/onion_error_test.go @@ -63,12 +63,13 @@ func TestEncodeDecodeCode(t *testing.T) { for _, failure1 := range onionFailures { var b bytes.Buffer - if err := EncodeFailure(&b, failure1, 0); err != nil { + err := EncodeFailure(&b, failure1, ProtocolVersionTLV) + if err != nil { t.Fatalf("unable to encode failure code(%v): %v", failure1.Code(), err) } - failure2, err := DecodeFailure(&b, 0) + failure2, err := DecodeFailure(&b, ProtocolVersionTLV) if err != nil { t.Fatalf("unable to decode failure code(%v): %v", failure1.Code(), err) @@ -90,7 +91,7 @@ func TestChannelUpdateCompatabilityParsing(t *testing.T) { // We'll start by taking out test channel update, and encoding it into // a set of raw bytes. var b bytes.Buffer - if err := testChannelUpdate.Encode(&b, 0); err != nil { + if err := testChannelUpdate.Encode(&b, ProtocolVersionTLV); err != nil { t.Fatalf("unable to encode chan update: %v", err) } @@ -99,7 +100,7 @@ func TestChannelUpdateCompatabilityParsing(t *testing.T) { // encoded channel update message. var newChanUpdate ChannelUpdate err := parseChannelUpdateCompatabilityMode( - bufio.NewReader(&b), &newChanUpdate, 0, + bufio.NewReader(&b), &newChanUpdate, ProtocolVersionTLV, ) if err != nil { t.Fatalf("unable to parse channel update: %v", err) @@ -120,7 +121,7 @@ func TestChannelUpdateCompatabilityParsing(t *testing.T) { var tByte [2]byte binary.BigEndian.PutUint16(tByte[:], MsgChannelUpdate) b.Write(tByte[:]) - if err := testChannelUpdate.Encode(&b, 0); err != nil { + if err := testChannelUpdate.Encode(&b, ProtocolVersionTLV); err != nil { t.Fatalf("unable to encode chan update: %v", err) } @@ -128,7 +129,7 @@ func TestChannelUpdateCompatabilityParsing(t *testing.T) { // message even with the extra two bytes. var newChanUpdate2 ChannelUpdate err = parseChannelUpdateCompatabilityMode( - bufio.NewReader(&b), &newChanUpdate2, 0, + bufio.NewReader(&b), &newChanUpdate2, ProtocolVersionTLV, ) if err != nil { t.Fatalf("unable to parse channel update: %v", err) @@ -165,7 +166,7 @@ func TestWriteOnionErrorChanUpdate(t *testing.T) { // Finally, read the length encoded and ensure that it matches the raw // length. var encodedLen uint16 - if err := ReadElement(&errorBuf, &encodedLen); err != nil { + if err := ReadElement(&errorBuf, ProtocolVersionTLV, &encodedLen); err != nil { t.Fatalf("unable to read len: %v", err) } if uint16(trueUpdateLength) != encodedLen { @@ -276,5 +277,5 @@ func (f *mockFailIncorrectDetailsNoHeight) Decode(r io.Reader, pver uint32) erro } func (f *mockFailIncorrectDetailsNoHeight) Encode(w io.Writer, pver uint32) error { - return WriteElement(w, f.amount) + return WriteElement(w, ProtocolVersionTLV, f.amount) } diff --git a/lnwire/open_channel.go b/lnwire/open_channel.go index 25ea42668cc..bfe25599b2a 100644 --- a/lnwire/open_channel.go +++ b/lnwire/open_channel.go @@ -146,6 +146,7 @@ var _ Message = (*OpenChannel)(nil) // This is part of the lnwire.Message interface. func (o *OpenChannel) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, o.ChainHash[:], o.PendingChannelID[:], o.FundingAmount, @@ -176,6 +177,7 @@ func (o *OpenChannel) Encode(w io.Writer, pver uint32) error { // This is part of the lnwire.Message interface. func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, o.ChainHash[:], o.PendingChannelID[:], &o.FundingAmount, diff --git a/lnwire/ping.go b/lnwire/ping.go index cf9a83b78ce..cc75c276eca 100644 --- a/lnwire/ping.go +++ b/lnwire/ping.go @@ -36,6 +36,7 @@ var _ Message = (*Ping)(nil) // This is part of the lnwire.Message interface. func (p *Ping) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &p.NumPongBytes, &p.PaddingBytes) } @@ -46,6 +47,7 @@ func (p *Ping) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (p *Ping) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, p.NumPongBytes, p.PaddingBytes) } diff --git a/lnwire/pong.go b/lnwire/pong.go index c3166aaf6d0..3057cd57953 100644 --- a/lnwire/pong.go +++ b/lnwire/pong.go @@ -32,7 +32,7 @@ var _ Message = (*Pong)(nil) // This is part of the lnwire.Message interface. func (p *Pong) Decode(r io.Reader, pver uint32) error { return ReadElements(r, - &p.PongBytes, + pver, &p.PongBytes, ) } @@ -42,7 +42,7 @@ func (p *Pong) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (p *Pong) Encode(w io.Writer, pver uint32) error { return WriteElements(w, - p.PongBytes, + pver, p.PongBytes, ) } diff --git a/lnwire/query_channel_range.go b/lnwire/query_channel_range.go index 3bdb30e5eca..6d07763c399 100644 --- a/lnwire/query_channel_range.go +++ b/lnwire/query_channel_range.go @@ -47,6 +47,7 @@ var _ Message = (*QueryChannelRange)(nil) // This is part of the lnwire.Message interface. func (q *QueryChannelRange) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, q.ChainHash[:], &q.FirstBlockHeight, &q.NumBlocks, @@ -60,6 +61,7 @@ func (q *QueryChannelRange) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (q *QueryChannelRange) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, q.ChainHash[:], q.FirstBlockHeight, q.NumBlocks, diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index 43a271333f1..c6dfab2de37 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -113,28 +113,30 @@ var _ Message = (*QueryShortChanIDs)(nil) // // This is part of the lnwire.Message interface. func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error { - err := ReadElements(r, q.ChainHash[:]) + err := ReadElements(r, pver, q.ChainHash[:]) if err != nil { return err } - q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r) + q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r, pver) if err != nil { return err } - return q.ExtraData.Decode(r) + return q.ExtraData.Decode(r, pver) } // decodeShortChanIDs decodes a set of short channel ID's that have been // encoded. The first byte of the body details how the short chan ID's were // encoded. We'll use this type to govern exactly how we go about encoding the -// set of short channel ID's. -func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, error) { +// set of short channel ID's. The protocol version may affect how the IDs are +// decoded. +func decodeShortChanIDs(r io.Reader, + pver uint32) (ShortChanIDEncoding, []ShortChannelID, error) { // First, we'll attempt to read the number of bytes in the body of the // set of encoded short channel ID's. var numBytesResp uint16 - err := ReadElements(r, &numBytesResp) + err := ReadElements(r, pver, &numBytesResp) if err != nil { return 0, nil, err } @@ -187,7 +189,8 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err bodyReader := bytes.NewReader(queryBody) var lastChanID ShortChannelID for i := 0; i < numShortChanIDs; i++ { - if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil { + err := ReadElements(bodyReader, pver, &shortChanIDs[i]) + if err != nil { return 0, nil, fmt.Errorf("unable to parse "+ "short chan ID: %v", err) } @@ -243,7 +246,7 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err // We'll now attempt to read the next short channel ID // encoded in the payload. var cid ShortChannelID - err := ReadElements(limitedDecompressor, &cid) + err := ReadElements(limitedDecompressor, pver, &cid) switch { // If we get an EOF error, then that either means we've @@ -293,25 +296,28 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err // This is part of the lnwire.Message interface. func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error { // First, we'll write out the chain hash. - err := WriteElements(w, q.ChainHash[:]) + err := WriteElements(w, pver, q.ChainHash[:]) if err != nil { return err } // Base on our encoding type, we'll write out the set of short channel // ID's. - err = encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort) + err = encodeShortChanIDs( + w, q.EncodingType, q.ShortChanIDs, q.noSort, pver, + ) if err != nil { return err } - return q.ExtraData.Encode(w) + return q.ExtraData.Encode(w, pver) } // encodeShortChanIDs encodes the passed short channel ID's into the passed -// io.Writer, respecting the specified encoding type. +// io.Writer, respecting the specified encoding type. The protocol version may +// affect how the items are encoded. func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, - shortChanIDs []ShortChannelID, noSort bool) error { + shortChanIDs []ShortChannelID, noSort bool, pver uint32) error { // For both of the current encoding types, the channel ID's are to be // sorted in place, so we'll do that now. The sorting is applied unless @@ -332,20 +338,20 @@ func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, // body. We add 1 as the response will have the encoding type // prepended to it. numBytesBody := uint16(len(shortChanIDs)*8) + 1 - if err := WriteElements(w, numBytesBody); err != nil { + if err := WriteElements(w, pver, numBytesBody); err != nil { return err } // We'll then write out the encoding that that follows the // actual encoded short channel ID's. - if err := WriteElements(w, encodingType); err != nil { + if err := WriteElements(w, pver, encodingType); err != nil { return err } // Now that we know they're sorted, we can write out each short // channel ID to the buffer. for _, chanID := range shortChanIDs { - if err := WriteElements(w, chanID); err != nil { + if err := WriteElements(w, pver, chanID); err != nil { return fmt.Errorf("unable to write short chan "+ "ID: %v", err) } @@ -376,7 +382,7 @@ func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, // into the zlib writer, which will do compressing on // the fly. for _, chanID := range shortChanIDs { - err := WriteElements(zlibWriter, chanID) + err := WriteElements(zlibWriter, pver, chanID) if err != nil { return fmt.Errorf("unable to write short chan "+ "ID: %v", err) @@ -407,10 +413,10 @@ func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, // Finally, we can write out the number of bytes, the // compression type, and finally the buffer itself. - if err := WriteElements(w, uint16(numBytesBody)); err != nil { + if err := WriteElements(w, pver, uint16(numBytesBody)); err != nil { return err } - if err := WriteElements(w, encodingType); err != nil { + if err := WriteElements(w, pver, encodingType); err != nil { return err } diff --git a/lnwire/reply_channel_range.go b/lnwire/reply_channel_range.go index 5167cc5a51c..2c49c0b5a82 100644 --- a/lnwire/reply_channel_range.go +++ b/lnwire/reply_channel_range.go @@ -64,6 +64,7 @@ var _ Message = (*ReplyChannelRange)(nil) // This is part of the lnwire.Message interface. func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error { err := ReadElements(r, + pver, c.ChainHash[:], &c.FirstBlockHeight, &c.NumBlocks, @@ -73,12 +74,12 @@ func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error { return err } - c.EncodingType, c.ShortChanIDs, err = decodeShortChanIDs(r) + c.EncodingType, c.ShortChanIDs, err = decodeShortChanIDs(r, pver) if err != nil { return err } - return c.ExtraData.Decode(r) + return c.ExtraData.Decode(r, pver) } // Encode serializes the target ReplyChannelRange into the passed io.Writer @@ -87,6 +88,7 @@ func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *ReplyChannelRange) Encode(w io.Writer, pver uint32) error { err := WriteElements(w, + pver, c.ChainHash[:], c.FirstBlockHeight, c.NumBlocks, @@ -96,12 +98,14 @@ func (c *ReplyChannelRange) Encode(w io.Writer, pver uint32) error { return err } - err = encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort) + err = encodeShortChanIDs( + w, c.EncodingType, c.ShortChanIDs, c.noSort, pver, + ) if err != nil { return err } - return c.ExtraData.Encode(w) + return c.ExtraData.Encode(w, pver) } // MsgType returns the integer uniquely identifying this message type on the diff --git a/lnwire/reply_channel_range_test.go b/lnwire/reply_channel_range_test.go index ff3414958e3..a40356040d3 100644 --- a/lnwire/reply_channel_range_test.go +++ b/lnwire/reply_channel_range_test.go @@ -80,7 +80,9 @@ func TestReplyChannelRangeEmpty(t *testing.T) { // identical to the one created above. var req2 ReplyChannelRange b, _ := hex.DecodeString(test.encodedHex) - err := req2.Decode(bytes.NewReader(b), 0) + err := req2.Decode( + bytes.NewReader(b), ProtocolVersionTLV, + ) if err != nil { t.Fatalf("unable to decode req: %v", err) } @@ -93,7 +95,7 @@ func TestReplyChannelRangeEmpty(t *testing.T) { // request created above, and assert that it matches // the raw byte encoding. var b2 bytes.Buffer - err = req.Encode(&b2, 0) + err = req.Encode(&b2, ProtocolVersionTLV) if err != nil { t.Fatalf("unable to encode req: %v", err) } diff --git a/lnwire/reply_short_chan_ids_end.go b/lnwire/reply_short_chan_ids_end.go index 1412b50f960..64341ba2e8c 100644 --- a/lnwire/reply_short_chan_ids_end.go +++ b/lnwire/reply_short_chan_ids_end.go @@ -44,6 +44,7 @@ var _ Message = (*ReplyShortChanIDsEnd)(nil) // This is part of the lnwire.Message interface. func (c *ReplyShortChanIDsEnd) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, c.ChainHash[:], &c.Complete, &c.ExtraData, @@ -56,6 +57,7 @@ func (c *ReplyShortChanIDsEnd) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *ReplyShortChanIDsEnd) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChainHash[:], c.Complete, c.ExtraData, diff --git a/lnwire/revoke_and_ack.go b/lnwire/revoke_and_ack.go index 6eaf5cafd69..1e288877dcc 100644 --- a/lnwire/revoke_and_ack.go +++ b/lnwire/revoke_and_ack.go @@ -54,6 +54,7 @@ var _ Message = (*RevokeAndAck)(nil) // This is part of the lnwire.Message interface. func (c *RevokeAndAck) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &c.ChanID, c.Revocation[:], &c.NextRevocationKey, @@ -67,6 +68,7 @@ func (c *RevokeAndAck) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *RevokeAndAck) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChanID, c.Revocation[:], c.NextRevocationKey, diff --git a/lnwire/shutdown.go b/lnwire/shutdown.go index e27681e4e60..07270f8ddea 100644 --- a/lnwire/shutdown.go +++ b/lnwire/shutdown.go @@ -53,7 +53,7 @@ var _ Message = (*Shutdown)(nil) // // This is part of the lnwire.Message interface. func (s *Shutdown) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, &s.ChannelID, &s.Address, &s.ExtraData) + return ReadElements(r, pver, &s.ChannelID, &s.Address, &s.ExtraData) } // Encode serializes the target Shutdown into the passed io.Writer observing @@ -61,7 +61,7 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (s *Shutdown) Encode(w io.Writer, pver uint32) error { - return WriteElements(w, s.ChannelID, s.Address, s.ExtraData) + return WriteElements(w, pver, s.ChannelID, s.Address, s.ExtraData) } // MsgType returns the integer uniquely identifying this message type on the diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go index 9211d39ffb0..691a071d91f 100644 --- a/lnwire/update_add_htlc.go +++ b/lnwire/update_add_htlc.go @@ -74,6 +74,7 @@ var _ Message = (*UpdateAddHTLC)(nil) // This is part of the lnwire.Message interface. func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &c.ChanID, &c.ID, &c.Amount, @@ -90,6 +91,7 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *UpdateAddHTLC) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChanID, c.ID, c.Amount, diff --git a/lnwire/update_fail_htlc.go b/lnwire/update_fail_htlc.go index 09666ac25ff..54592636f51 100644 --- a/lnwire/update_fail_htlc.go +++ b/lnwire/update_fail_htlc.go @@ -43,6 +43,7 @@ var _ Message = (*UpdateFailHTLC)(nil) // This is part of the lnwire.Message interface. func (c *UpdateFailHTLC) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &c.ChanID, &c.ID, &c.Reason, @@ -56,6 +57,7 @@ func (c *UpdateFailHTLC) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *UpdateFailHTLC) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChanID, c.ID, c.Reason, diff --git a/lnwire/update_fail_malformed_htlc.go b/lnwire/update_fail_malformed_htlc.go index b28ec29ff4e..60d04c5136f 100644 --- a/lnwire/update_fail_malformed_htlc.go +++ b/lnwire/update_fail_malformed_htlc.go @@ -41,6 +41,7 @@ var _ Message = (*UpdateFailMalformedHTLC)(nil) // This is part of the lnwire.Message interface. func (c *UpdateFailMalformedHTLC) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &c.ChanID, &c.ID, c.ShaOnionBlob[:], @@ -55,6 +56,7 @@ func (c *UpdateFailMalformedHTLC) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *UpdateFailMalformedHTLC) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChanID, c.ID, c.ShaOnionBlob[:], diff --git a/lnwire/update_fee.go b/lnwire/update_fee.go index 25ab180c2df..4953cf6a4fc 100644 --- a/lnwire/update_fee.go +++ b/lnwire/update_fee.go @@ -41,6 +41,7 @@ var _ Message = (*UpdateFee)(nil) // This is part of the lnwire.Message interface. func (c *UpdateFee) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &c.ChanID, &c.FeePerKw, &c.ExtraData, @@ -53,6 +54,7 @@ func (c *UpdateFee) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *UpdateFee) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChanID, c.FeePerKw, c.ExtraData, diff --git a/lnwire/update_fulfill_htlc.go b/lnwire/update_fulfill_htlc.go index 36977b1e928..150ddf3595f 100644 --- a/lnwire/update_fulfill_htlc.go +++ b/lnwire/update_fulfill_htlc.go @@ -49,6 +49,7 @@ var _ Message = (*UpdateFulfillHTLC)(nil) // This is part of the lnwire.Message interface. func (c *UpdateFulfillHTLC) Decode(r io.Reader, pver uint32) error { return ReadElements(r, + pver, &c.ChanID, &c.ID, c.PaymentPreimage[:], @@ -62,6 +63,7 @@ func (c *UpdateFulfillHTLC) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *UpdateFulfillHTLC) Encode(w io.Writer, pver uint32) error { return WriteElements(w, + pver, c.ChanID, c.ID, c.PaymentPreimage[:], From 79767eb208f01d9327ee4f55b058389fdad7df16 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Fri, 24 Jul 2020 18:39:18 -0700 Subject: [PATCH 42/43] channeldb/migrations: use lnwire.ProtocolVersionTLV In this commit, update the encode/decode code concerning wire messages in the existing migrations to force them to specify that they want the legacy protocol version which doesn't attempt to read the extra bytes for all messages other than gossip messages. --- channeldb/migration_01_to_11/codec.go | 5 +++-- channeldb/migration_01_to_11/migrations.go | 3 ++- channeldb/migration_01_to_11/migrations_test.go | 4 +++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/channeldb/migration_01_to_11/codec.go b/channeldb/migration_01_to_11/codec.go index 1727c8c997d..cf590cb2890 100644 --- a/channeldb/migration_01_to_11/codec.go +++ b/channeldb/migration_01_to_11/codec.go @@ -172,7 +172,8 @@ func WriteElement(w io.Writer, element interface{}) error { } case lnwire.Message: - if _, err := lnwire.WriteMessage(w, e, 0); err != nil { + _, err := lnwire.WriteMessage(w, e, lnwire.ProtocolVersionLegacy) + if err != nil { return err } @@ -383,7 +384,7 @@ func ReadElement(r io.Reader, element interface{}) error { *e = bytes case *lnwire.Message: - msg, err := lnwire.ReadMessage(r, 0) + msg, err := lnwire.ReadMessage(r, lnwire.ProtocolVersionLegacy) if err != nil { return err } diff --git a/channeldb/migration_01_to_11/migrations.go b/channeldb/migration_01_to_11/migrations.go index 35be510e996..d46711b0791 100644 --- a/channeldb/migration_01_to_11/migrations.go +++ b/channeldb/migration_01_to_11/migrations.go @@ -724,7 +724,8 @@ func MigrateGossipMessageStoreKeys(tx kvdb.RwTx) error { // Serialize the message with its wire encoding. var b bytes.Buffer - if _, err := lnwire.WriteMessage(&b, msg, 0); err != nil { + _, err := lnwire.WriteMessage(&b, msg, lnwire.ProtocolVersionTLV) + if err != nil { return err } diff --git a/channeldb/migration_01_to_11/migrations_test.go b/channeldb/migration_01_to_11/migrations_test.go index 7fc90855da3..3692af93232 100644 --- a/channeldb/migration_01_to_11/migrations_test.go +++ b/channeldb/migration_01_to_11/migrations_test.go @@ -529,7 +529,9 @@ func TestMigrateGossipMessageStoreKeys(t *testing.T) { t.Fatal(err) } - gotMsg, err := lnwire.ReadMessage(bytes.NewReader(rawMsg), 0) + gotMsg, err := lnwire.ReadMessage( + bytes.NewReader(rawMsg), lnwire.ProtocolVersionLegacy, + ) if err != nil { t.Fatalf("unable to deserialize raw message: %v", err) } From afabedd77abb8281aac8559e95caed473329d969 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Fri, 24 Jul 2020 18:39:51 -0700 Subject: [PATCH 43/43] multi: use lnwire.ProtocolVersionTLV for encoding/decoding wire messages --- chanbackup/multi.go | 12 +++++++++--- chanbackup/single.go | 38 +++++++++++++++++++++++++----------- channeldb/channel.go | 5 +++-- channeldb/codec.go | 5 +++-- channeldb/waitingproof.go | 5 +++-- discovery/message_store.go | 11 ++++++++--- htlcswitch/payment_result.go | 2 +- peer/brontide.go | 2 +- routing/ann_validation.go | 5 ++++- 9 files changed, 59 insertions(+), 26 deletions(-) diff --git a/chanbackup/multi.go b/chanbackup/multi.go index e90bd613e41..78e6197f03d 100644 --- a/chanbackup/multi.go +++ b/chanbackup/multi.go @@ -63,7 +63,9 @@ func (m Multi) PackToWriter(w io.Writer, keyRing keychain.KeyRing) error { var multiBackupBuffer bytes.Buffer // First, we'll write out the version of this multi channel baackup. - err := lnwire.WriteElements(&multiBackupBuffer, byte(m.Version)) + err := lnwire.WriteElements( + &multiBackupBuffer, lnwire.ProtocolVersionTLV, byte(m.Version), + ) if err != nil { return err } @@ -111,7 +113,9 @@ func (m *Multi) UnpackFromReader(r io.Reader, keyRing keychain.KeyRing) error { // First, we'll need to read the version of this multi-back up so we // can know how to unpack each of the individual SCB's. var multiVersion byte - err = lnwire.ReadElements(backupReader, &multiVersion) + err = lnwire.ReadElements( + backupReader, lnwire.ProtocolVersionTLV, &multiVersion, + ) if err != nil { return err } @@ -127,7 +131,9 @@ func (m *Multi) UnpackFromReader(r io.Reader, keyRing keychain.KeyRing) error { // backup is the same size, so we can continue until we've // parsed out everything. var numBackups uint32 - err = lnwire.ReadElements(backupReader, &numBackups) + err = lnwire.ReadElements( + backupReader, lnwire.ProtocolVersionTLV, &numBackups, + ) if err != nil { return err } diff --git a/chanbackup/single.go b/chanbackup/single.go index 490657b90dd..5d4c9a20eef 100644 --- a/chanbackup/single.go +++ b/chanbackup/single.go @@ -207,6 +207,7 @@ func (s *Single) Serialize(w io.Writer) error { var singleBytes bytes.Buffer if err := lnwire.WriteElements( &singleBytes, + lnwire.ProtocolVersionTLV, s.IsInitiator, s.ChainHash[:], s.FundingOutpoint, @@ -249,6 +250,7 @@ func (s *Single) Serialize(w io.Writer) error { return lnwire.WriteElements( w, + lnwire.ProtocolVersionTLV, byte(s.Version), uint16(len(singleBytes.Bytes())), singleBytes.Bytes(), @@ -290,12 +292,14 @@ func readLocalKeyDesc(r io.Reader) (keychain.KeyDescriptor, error) { var keyDesc keychain.KeyDescriptor var keyFam uint32 - if err := lnwire.ReadElements(r, &keyFam); err != nil { + err := lnwire.ReadElements(r, lnwire.ProtocolVersionTLV, &keyFam) + if err != nil { return keyDesc, err } keyDesc.Family = keychain.KeyFamily(keyFam) - if err := lnwire.ReadElements(r, &keyDesc.Index); err != nil { + err = lnwire.ReadElements(r, lnwire.ProtocolVersionTLV, &keyDesc.Index) + if err != nil { return keyDesc, err } @@ -333,7 +337,7 @@ func (s *Single) Deserialize(r io.Reader) error { // First, we'll need to read the version of this single-back up so we // can know how to unpack each of the SCB. var version byte - err := lnwire.ReadElements(r, &version) + err := lnwire.ReadElements(r, lnwire.ProtocolVersionTLV, &version) if err != nil { return err } @@ -350,19 +354,23 @@ func (s *Single) Deserialize(r io.Reader) error { } var length uint16 - if err := lnwire.ReadElements(r, &length); err != nil { + err = lnwire.ReadElements(r, lnwire.ProtocolVersionTLV, &length) + if err != nil { return err } err = lnwire.ReadElements( - r, &s.IsInitiator, s.ChainHash[:], &s.FundingOutpoint, - &s.ShortChannelID, &s.RemoteNodePub, &s.Addresses, &s.Capacity, + r, lnwire.ProtocolVersionTLV, &s.IsInitiator, s.ChainHash[:], + &s.FundingOutpoint, &s.ShortChannelID, &s.RemoteNodePub, + &s.Addresses, &s.Capacity, ) if err != nil { return err } - err = lnwire.ReadElements(r, &s.LocalChanCfg.CsvDelay) + err = lnwire.ReadElements( + r, lnwire.ProtocolVersionTLV, &s.LocalChanCfg.CsvDelay, + ) if err != nil { return err } @@ -387,7 +395,9 @@ func (s *Single) Deserialize(r io.Reader) error { return err } - err = lnwire.ReadElements(r, &s.RemoteChanCfg.CsvDelay) + err = lnwire.ReadElements( + r, lnwire.ProtocolVersionTLV, &s.RemoteChanCfg.CsvDelay, + ) if err != nil { return err } @@ -417,7 +427,8 @@ func (s *Single) Deserialize(r io.Reader) error { shaChainPub [33]byte zeroPub [33]byte ) - if err := lnwire.ReadElements(r, shaChainPub[:]); err != nil { + err = lnwire.ReadElements(r, lnwire.ProtocolVersionTLV, shaChainPub[:]) + if err != nil { return err } @@ -433,12 +444,17 @@ func (s *Single) Deserialize(r io.Reader) error { } var shaKeyFam uint32 - if err := lnwire.ReadElements(r, &shaKeyFam); err != nil { + err = lnwire.ReadElements( + r, lnwire.ProtocolVersionTLV, &shaKeyFam, + ) + if err != nil { return err } s.ShaChainRootDesc.KeyLocator.Family = keychain.KeyFamily(shaKeyFam) - return lnwire.ReadElements(r, &s.ShaChainRootDesc.KeyLocator.Index) + return lnwire.ReadElements( + r, lnwire.ProtocolVersionTLV, &s.ShaChainRootDesc.KeyLocator.Index, + ) } // UnpackFromReader is similar to Deserialize method, but it expects the passed diff --git a/channeldb/channel.go b/channeldb/channel.go index 35a0700d472..f34cd43e847 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -1871,7 +1871,8 @@ func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { return err } - if err := diff.CommitSig.Encode(w, 0); err != nil { + err := diff.CommitSig.Encode(w, lnwire.ProtocolVersionTLV) + if err != nil { return err } @@ -1918,7 +1919,7 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) { } d.CommitSig = &lnwire.CommitSig{} - if err := d.CommitSig.Decode(r, 0); err != nil { + if err := d.CommitSig.Decode(r, lnwire.ProtocolVersionTLV); err != nil { return nil, err } diff --git a/channeldb/codec.go b/channeldb/codec.go index f6903175f8d..deb47017886 100644 --- a/channeldb/codec.go +++ b/channeldb/codec.go @@ -178,7 +178,8 @@ func WriteElement(w io.Writer, element interface{}) error { } case lnwire.Message: - if _, err := lnwire.WriteMessage(w, e, 0); err != nil { + _, err := lnwire.WriteMessage(w, e, lnwire.ProtocolVersionTLV) + if err != nil { return err } @@ -394,7 +395,7 @@ func ReadElement(r io.Reader, element interface{}) error { *e = bytes case *lnwire.Message: - msg, err := lnwire.ReadMessage(r, 0) + msg, err := lnwire.ReadMessage(r, lnwire.ProtocolVersionTLV) if err != nil { return err } diff --git a/channeldb/waitingproof.go b/channeldb/waitingproof.go index 2ea706c8405..e69afc05679 100644 --- a/channeldb/waitingproof.go +++ b/channeldb/waitingproof.go @@ -227,7 +227,8 @@ func (p *WaitingProof) Encode(w io.Writer) error { return err } - if err := p.AnnounceSignatures.Encode(w, 0); err != nil { + err := p.AnnounceSignatures.Encode(w, lnwire.ProtocolVersionTLV) + if err != nil { return err } @@ -242,7 +243,7 @@ func (p *WaitingProof) Decode(r io.Reader) error { } msg := &lnwire.AnnounceSignatures{} - if err := msg.Decode(r, 0); err != nil { + if err := msg.Decode(r, lnwire.ProtocolVersionTLV); err != nil { return err } diff --git a/discovery/message_store.go b/discovery/message_store.go index f86ede20860..8399dd49f61 100644 --- a/discovery/message_store.go +++ b/discovery/message_store.go @@ -120,7 +120,8 @@ func (s *MessageStore) AddMessage(msg lnwire.Message, peerPubKey [33]byte) error // Serialize the message with its wire encoding. var b bytes.Buffer - if _, err := lnwire.WriteMessage(&b, msg, 0); err != nil { + _, err = lnwire.WriteMessage(&b, msg, lnwire.ProtocolVersionTLV) + if err != nil { return err } @@ -163,7 +164,9 @@ func (s *MessageStore) DeleteMessage(msg lnwire.Message, return nil } - dbMsg, err := lnwire.ReadMessage(bytes.NewReader(v), 0) + dbMsg, err := lnwire.ReadMessage( + bytes.NewReader(v), lnwire.ProtocolVersionTLV, + ) if err != nil { return err } @@ -182,7 +185,9 @@ func (s *MessageStore) DeleteMessage(msg lnwire.Message, // readMessage reads a message from its serialized form and ensures its // supported by the current version of the message store. func readMessage(msgBytes []byte) (lnwire.Message, error) { - msg, err := lnwire.ReadMessage(bytes.NewReader(msgBytes), 0) + msg, err := lnwire.ReadMessage( + bytes.NewReader(msgBytes), lnwire.ProtocolVersionTLV, + ) if err != nil { return nil, err } diff --git a/htlcswitch/payment_result.go b/htlcswitch/payment_result.go index e6a1e59f103..cef1d4e8434 100644 --- a/htlcswitch/payment_result.go +++ b/htlcswitch/payment_result.go @@ -76,7 +76,7 @@ func deserializeNetworkResult(r io.Reader) (*networkResult, error) { n := &networkResult{} - n.msg, err = lnwire.ReadMessage(r, 0) + n.msg, err = lnwire.ReadMessage(r, lnwire.ProtocolVersionTLV) if err != nil { return nil, err } diff --git a/peer/brontide.go b/peer/brontide.go index c50f4ebae0a..85f32ae0ee0 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -919,7 +919,7 @@ func (p *Brontide) readNextMessage() (lnwire.Message, error) { // Next, create a new io.Reader implementation from the raw message, // and use this to decode the message directly from. msgReader := bytes.NewReader(rawMsg) - nextMsg, err := lnwire.ReadMessage(msgReader, 0) + nextMsg, err := lnwire.ReadMessage(msgReader, lnwire.ProtocolVersionTLV) if err != nil { return nil, err } diff --git a/routing/ann_validation.go b/routing/ann_validation.go index cc8530bb165..76d58cb85a8 100644 --- a/routing/ann_validation.go +++ b/routing/ann_validation.go @@ -110,7 +110,10 @@ func ValidateNodeAnn(a *lnwire.NodeAnnouncement) error { dataHash := chainhash.DoubleHashB(data) if !nodeSig.Verify(dataHash, nodeKey) { var msgBuf bytes.Buffer - if _, err := lnwire.WriteMessage(&msgBuf, a, 0); err != nil { + _, err := lnwire.WriteMessage( + &msgBuf, a, lnwire.ProtocolVersionTLV, + ) + if err != nil { return err }