diff --git a/channeldb/channel.go b/channeldb/channel.go index 6221208a8c3..4a035032988 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -20,6 +20,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/walletdb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -121,6 +122,12 @@ var ( // broadcasted when moving the channel to state CoopBroadcasted. coopCloseTxKey = []byte("coop-closing-tx-key") + // shutdownInfoKey points to the serialised shutdown info that has been + // persisted for a channel. The existence of this info means that we + // have sent the Shutdown message before and so should re-initiate the + // shutdown on re-establish. + shutdownInfoKey = []byte("shutdown-info-key") + // commitDiffKey stores the current pending commitment state we've // extended to the remote party (if any). Each time we propose a new // state, we store the information necessary to reconstruct this state @@ -188,6 +195,10 @@ var ( // in the state CommitBroadcasted. ErrNoCloseTx = fmt.Errorf("no closing tx found") + // ErrNoShutdownInfo is returned when no shutdown info has been + // persisted for a channel. + ErrNoShutdownInfo = errors.New("no shutdown info") + // ErrNoRestoredChannelMutation is returned when a caller attempts to // mutate a channel that's been recovered. ErrNoRestoredChannelMutation = fmt.Errorf("cannot mutate restored " + @@ -1575,6 +1586,79 @@ func (c *OpenChannel) ChanSyncMsg() (*lnwire.ChannelReestablish, error) { }, nil } +// MarkShutdownSent serialises and persist the given ShutdownInfo for this +// channel. Persisting this info represents the fact that we have sent the +// Shutdown message to the remote side and hence that we should re-transmit the +// same Shutdown message on re-establish. +func (c *OpenChannel) MarkShutdownSent(info *ShutdownInfo) error { + c.Lock() + defer c.Unlock() + + return c.storeShutdownInfo(info) +} + +// storeShutdownInfo serialises the ShutdownInfo and persists it under the +// shutdownInfoKey. +func (c *OpenChannel) storeShutdownInfo(info *ShutdownInfo) error { + var b bytes.Buffer + err := info.encode(&b) + if err != nil { + return err + } + + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { + chanBucket, err := fetchChanBucketRw( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + return chanBucket.Put(shutdownInfoKey, b.Bytes()) + }, func() {}) +} + +// ShutdownInfo decodes the shutdown info stored for this channel and returns +// the result. If no shutdown info has been persisted for this channel then the +// ErrNoShutdownInfo error is returned. +func (c *OpenChannel) ShutdownInfo() (fn.Option[ShutdownInfo], error) { + c.RLock() + defer c.RUnlock() + + var shutdownInfo *ShutdownInfo + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + switch { + case err == nil: + case errors.Is(err, ErrNoChanDBExists), + errors.Is(err, ErrNoActiveChannels), + errors.Is(err, ErrChannelNotFound): + + return ErrNoShutdownInfo + default: + return err + } + + shutdownInfoBytes := chanBucket.Get(shutdownInfoKey) + if shutdownInfoBytes == nil { + return ErrNoShutdownInfo + } + + shutdownInfo, err = decodeShutdownInfo(shutdownInfoBytes) + + return err + }, func() { + shutdownInfo = nil + }) + if err != nil { + return fn.None[ShutdownInfo](), err + } + + return fn.Some[ShutdownInfo](*shutdownInfo), nil +} + // isBorked returns true if the channel has been marked as borked in the // database. This requires an existing database transaction to already be // active. @@ -4294,3 +4378,59 @@ func MakeScidRecord(typ tlv.Type, scid *lnwire.ShortChannelID) tlv.Record { typ, scid, 8, lnwire.EShortChannelID, lnwire.DShortChannelID, ) } + +// ShutdownInfo contains various info about the shutdown initiation of a +// channel. +type ShutdownInfo struct { + // DeliveryScript is the address that we have included in any previous + // Shutdown message for a particular channel and so should include in + // any future re-sends of the Shutdown message. + DeliveryScript tlv.RecordT[tlv.TlvType0, lnwire.DeliveryAddress] + + // LocalInitiator is true if we sent a Shutdown message before ever + // receiving a Shutdown message from the remote peer. + LocalInitiator tlv.RecordT[tlv.TlvType1, bool] +} + +// NewShutdownInfo constructs a new ShutdownInfo object. +func NewShutdownInfo(deliveryScript lnwire.DeliveryAddress, + locallyInitiated bool) *ShutdownInfo { + + return &ShutdownInfo{ + DeliveryScript: tlv.NewRecordT[tlv.TlvType0](deliveryScript), + LocalInitiator: tlv.NewPrimitiveRecord[tlv.TlvType1]( + locallyInitiated, + ), + } +} + +// encode serialises the ShutdownInfo to the given io.Writer. +func (s *ShutdownInfo) encode(w io.Writer) error { + records := []tlv.Record{ + s.DeliveryScript.Record(), + s.LocalInitiator.Record(), + } + + stream, err := tlv.NewStream(records...) + if err != nil { + return err + } + + return stream.Encode(w) +} + +// decodeShutdownInfo constructs a ShutdownInfo struct by decoding the given +// byte slice. +func decodeShutdownInfo(b []byte) (*ShutdownInfo, error) { + tlvStream := lnwire.ExtraOpaqueData(b) + + var info ShutdownInfo + records := []tlv.RecordProducer{ + &info.DeliveryScript, + &info.LocalInitiator, + } + + _, err := tlvStream.ExtractRecords(records...) + + return &info, err +} diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index bfebad824bc..6047a1e67ed 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -1158,6 +1158,70 @@ func TestFetchWaitingCloseChannels(t *testing.T) { } } +// TestShutdownInfo tests that a channel's shutdown info can correctly be +// persisted and retrieved. +func TestShutdownInfo(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + localInit bool + }{ + { + name: "local node initiated", + localInit: true, + }, + { + name: "remote node initiated", + localInit: false, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + testShutdownInfo(t, test.localInit) + }) + } +} + +func testShutdownInfo(t *testing.T, locallyInitiated bool) { + fullDB, err := MakeTestDB(t) + require.NoError(t, err, "unable to make test database") + + cdb := fullDB.ChannelStateDB() + + // First a test channel. + channel := createTestChannel(t, cdb) + + // We haven't persisted any shutdown info for this channel yet. + _, err = channel.ShutdownInfo() + require.Error(t, err, ErrNoShutdownInfo) + + // Construct a new delivery script and create a new ShutdownInfo object. + script := []byte{1, 3, 4, 5} + + // Create a ShutdownInfo struct. + shutdownInfo := NewShutdownInfo(script, locallyInitiated) + + // Persist the shutdown info. + require.NoError(t, channel.MarkShutdownSent(shutdownInfo)) + + // We should now be able to retrieve the shutdown info. + info, err := channel.ShutdownInfo() + require.NoError(t, err) + require.True(t, info.IsSome()) + + // Assert that the decoded values of the shutdown info are correct. + info.WhenSome(func(info ShutdownInfo) { + require.EqualValues(t, script, info.DeliveryScript.Val) + require.Equal(t, locallyInitiated, info.LocalInitiator.Val) + }) +} + // TestRefresh asserts that Refresh updates the in-memory state of another // OpenChannel to reflect a preceding call to MarkOpen on a different // OpenChannel. diff --git a/docs/release-notes/release-notes-0.18.0.md b/docs/release-notes/release-notes-0.18.0.md index 5ccc18f3d45..b38ae042c05 100644 --- a/docs/release-notes/release-notes-0.18.0.md +++ b/docs/release-notes/release-notes-0.18.0.md @@ -73,6 +73,11 @@ a `shutdown` message if there were currently HTLCs on the channel. After this change, the shutdown procedure should be compliant with BOLT2 requirements. +* If HTLCs are in-flight at the same time that a `shutdown` is sent and then + a re-connect happens before the coop-close is completed we now [ensure that + we re-init the `shutdown` + exchange](https://github.com/lightningnetwork/lnd/pull/8464) + * The AMP struct in payment hops will [now be populated](https://github.com/lightningnetwork/lnd/pull/7976) when the AMP TLV is set. * [Add Taproot witness types diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 4b3b5d9b415..75866bebadd 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -135,14 +135,16 @@ type ChannelUpdateHandler interface { MayAddOutgoingHtlc(lnwire.MilliSatoshi) error // EnableAdds sets the ChannelUpdateHandler state to allow - // UpdateAddHtlc's in the specified direction. It returns an error if - // the state already allowed those adds. - EnableAdds(direction LinkDirection) error - - // DiableAdds sets the ChannelUpdateHandler state to allow - // UpdateAddHtlc's in the specified direction. It returns an error if - // the state already disallowed those adds. - DisableAdds(direction LinkDirection) error + // UpdateAddHtlc's in the specified direction. It returns true if the + // state was changed and false if the desired state was already set + // before the method was called. + EnableAdds(direction LinkDirection) bool + + // DisableAdds sets the ChannelUpdateHandler state to allow + // UpdateAddHtlc's in the specified direction. It returns true if the + // state was changed and false if the desired state was already set + // before the method was called. + DisableAdds(direction LinkDirection) bool // IsFlushing returns true when UpdateAddHtlc's are disabled in the // direction of the argument. diff --git a/htlcswitch/link.go b/htlcswitch/link.go index d8d947952c2..a700c9714b6 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -19,6 +19,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/contractcourt" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/invoices" @@ -271,6 +272,14 @@ type ChannelLinkConfig struct { // GetAliases is used by the link and switch to fetch the set of // aliases for a given link. GetAliases func(base lnwire.ShortChannelID) []lnwire.ShortChannelID + + // PreviouslySentShutdown is an optional value that is set if, at the + // time of the link being started, persisted shutdown info was found for + // the channel. This value being set means that we previously sent a + // Shutdown message to our peer, and so we should do so again on + // re-establish and should not allow anymore HTLC adds on the outgoing + // direction of the link. + PreviouslySentShutdown fn.Option[lnwire.Shutdown] } // channelLink is the service which drives a channel's commitment update @@ -618,41 +627,25 @@ func (l *channelLink) EligibleToUpdate() bool { } // EnableAdds sets the ChannelUpdateHandler state to allow UpdateAddHtlc's in -// the specified direction. It returns an error if the state already allowed -// those adds. -func (l *channelLink) EnableAdds(linkDirection LinkDirection) error { +// the specified direction. It returns true if the state was changed and false +// if the desired state was already set before the method was called. +func (l *channelLink) EnableAdds(linkDirection LinkDirection) bool { if linkDirection == Outgoing { - if !l.isOutgoingAddBlocked.Swap(false) { - return errors.New("outgoing adds already enabled") - } - } - - if linkDirection == Incoming { - if !l.isIncomingAddBlocked.Swap(false) { - return errors.New("incoming adds already enabled") - } + return l.isOutgoingAddBlocked.Swap(false) } - return nil + return l.isIncomingAddBlocked.Swap(false) } -// DiableAdds sets the ChannelUpdateHandler state to allow UpdateAddHtlc's in -// the specified direction. It returns an error if the state already disallowed -// those adds. -func (l *channelLink) DisableAdds(linkDirection LinkDirection) error { +// DisableAdds sets the ChannelUpdateHandler state to allow UpdateAddHtlc's in +// the specified direction. It returns true if the state was changed and false +// if the desired state was already set before the method was called. +func (l *channelLink) DisableAdds(linkDirection LinkDirection) bool { if linkDirection == Outgoing { - if l.isOutgoingAddBlocked.Swap(true) { - return errors.New("outgoing adds already disabled") - } + return !l.isOutgoingAddBlocked.Swap(true) } - if linkDirection == Incoming { - if l.isIncomingAddBlocked.Swap(true) { - return errors.New("incoming adds already disabled") - } - } - - return nil + return !l.isIncomingAddBlocked.Swap(true) } // IsFlushing returns true when UpdateAddHtlc's are disabled in the direction of @@ -1206,6 +1199,25 @@ func (l *channelLink) htlcManager() { } } + // If a shutdown message has previously been sent on this link, then we + // need to make sure that we have disabled any HTLC adds on the outgoing + // direction of the link and that we re-resend the same shutdown message + // that we previously sent. + l.cfg.PreviouslySentShutdown.WhenSome(func(shutdown lnwire.Shutdown) { + // Immediately disallow any new outgoing HTLCs. + if !l.DisableAdds(Outgoing) { + l.log.Warnf("Outgoing link adds already disabled") + } + + // Re-send the shutdown message the peer. Since syncChanStates + // would have sent any outstanding CommitSig, it is fine for us + // to immediately queue the shutdown message now. + err := l.cfg.Peer.SendMessage(false, &shutdown) + if err != nil { + l.log.Warnf("Error sending shutdown message: %v", err) + } + }) + // We've successfully reestablished the channel, mark it as such to // allow the switch to forward HTLCs in the outbound direction. l.markReestablished() diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index cac26a6de86..d9e583876b3 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -6969,27 +6969,22 @@ func TestLinkFlushApiDirectionIsolation(t *testing.T) { for i := 0; i < 10; i++ { if prand.Uint64()%2 == 0 { - //nolint:errcheck aliceLink.EnableAdds(Outgoing) require.False(t, aliceLink.IsFlushing(Outgoing)) } else { - //nolint:errcheck aliceLink.DisableAdds(Outgoing) require.True(t, aliceLink.IsFlushing(Outgoing)) } require.False(t, aliceLink.IsFlushing(Incoming)) } - //nolint:errcheck aliceLink.EnableAdds(Outgoing) for i := 0; i < 10; i++ { if prand.Uint64()%2 == 0 { - //nolint:errcheck aliceLink.EnableAdds(Incoming) require.False(t, aliceLink.IsFlushing(Incoming)) } else { - //nolint:errcheck aliceLink.DisableAdds(Incoming) require.True(t, aliceLink.IsFlushing(Incoming)) } @@ -7010,16 +7005,16 @@ func TestLinkFlushApiGateStateIdempotence(t *testing.T) { ) for _, dir := range []LinkDirection{Incoming, Outgoing} { - require.Nil(t, aliceLink.DisableAdds(dir)) + require.True(t, aliceLink.DisableAdds(dir)) require.True(t, aliceLink.IsFlushing(dir)) - require.NotNil(t, aliceLink.DisableAdds(dir)) + require.False(t, aliceLink.DisableAdds(dir)) require.True(t, aliceLink.IsFlushing(dir)) - require.Nil(t, aliceLink.EnableAdds(dir)) + require.True(t, aliceLink.EnableAdds(dir)) require.False(t, aliceLink.IsFlushing(dir)) - require.NotNil(t, aliceLink.EnableAdds(dir)) + require.False(t, aliceLink.EnableAdds(dir)) require.False(t, aliceLink.IsFlushing(dir)) } } diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 5c7722a54a2..ab6fbe76af6 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -906,13 +906,14 @@ func (f *mockChannelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) { return f.shortChanID, nil } -func (f *mockChannelLink) EnableAdds(linkDirection LinkDirection) error { +func (f *mockChannelLink) EnableAdds(linkDirection LinkDirection) bool { // TODO(proofofkeags): Implement - return nil + return true } -func (f *mockChannelLink) DisableAdds(linkDirection LinkDirection) error { + +func (f *mockChannelLink) DisableAdds(linkDirection LinkDirection) bool { // TODO(proofofkeags): Implement - return nil + return true } func (f *mockChannelLink) IsFlushing(linkDirection LinkDirection) bool { // TODO(proofofkeags): Implement diff --git a/itest/lnd_coop_close_with_htlcs_test.go b/itest/lnd_coop_close_with_htlcs_test.go index 3978e119268..4c437cd32cc 100644 --- a/itest/lnd_coop_close_with_htlcs_test.go +++ b/itest/lnd_coop_close_with_htlcs_test.go @@ -1,24 +1,44 @@ package itest import ( + "testing" + "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc" "github.com/lightningnetwork/lnd/lnrpc/routerrpc" + "github.com/lightningnetwork/lnd/lnrpc/walletrpc" "github.com/lightningnetwork/lnd/lntest" + "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lntypes" "github.com/stretchr/testify/require" ) -// testCoopCloseWithHtlcs tests whether or not we can successfully issue a coop -// close request whilt there are still active htlcs on the link. Here we will -// set up an HODL invoice to suspend settlement. Then we will attempt to close -// the channel which should appear as a noop for the time being. Then we will -// have the receiver settle the invoice and observe that the channel gets torn -// down after settlement. +// testCoopCloseWithHtlcs tests whether we can successfully issue a coop close +// request while there are still active htlcs on the link. In all the tests, we +// will set up an HODL invoice to suspend settlement. Then we will attempt to +// close the channel which should appear as a noop for the time being. Then we +// will have the receiver settle the invoice and observe that the channel gets +// torn down after settlement. func testCoopCloseWithHtlcs(ht *lntest.HarnessTest) { + ht.Run("no restart", func(t *testing.T) { + tt := ht.Subtest(t) + coopCloseWithHTLCs(tt) + }) + + ht.Run("with restart", func(t *testing.T) { + tt := ht.Subtest(t) + coopCloseWithHTLCsWithRestart(tt) + }) +} + +// coopCloseWithHTLCs tests the basic coop close scenario which occurs when one +// channel party initiates a channel shutdown while an HTLC is still pending on +// the channel. +func coopCloseWithHTLCs(ht *lntest.HarnessTest) { alice, bob := ht.Alice, ht.Bob + ht.ConnectNodes(alice, bob) // Here we set up a channel between Alice and Bob, beginning with a // balance on Bob's side. @@ -101,3 +121,123 @@ func testCoopCloseWithHtlcs(ht *lntest.HarnessTest) { // Wait for it to get mined and finish tearing down. ht.AssertStreamChannelCoopClosed(alice, chanPoint, false, closeClient) } + +// coopCloseWithHTLCsWithRestart also tests the coop close flow when an HTLC +// is still pending on the channel but this time it ensures that the shutdown +// process continues as expected even if a channel re-establish happens after +// one party has already initiated the shutdown. +func coopCloseWithHTLCsWithRestart(ht *lntest.HarnessTest) { + alice, bob := ht.Alice, ht.Bob + ht.ConnectNodes(alice, bob) + + // Open a channel between Alice and Bob with the balance split equally. + // We do this to ensure that the close transaction will have 2 outputs + // so that we can assert that the correct delivery address gets used by + // the channel close initiator. + chanPoint := ht.OpenChannel(bob, alice, lntest.OpenChannelParams{ + Amt: btcutil.Amount(1000000), + PushAmt: btcutil.Amount(1000000 / 2), + }) + + // Wait for Bob to understand that the channel is ready to use. + ht.AssertTopologyChannelOpen(bob, chanPoint) + + // Set up a HODL invoice so that we can be sure that an HTLC is pending + // on the channel at the time that shutdown is requested. + var preimage lntypes.Preimage + copy(preimage[:], ht.Random32Bytes()) + payHash := preimage.Hash() + + invoiceReq := &invoicesrpc.AddHoldInvoiceRequest{ + Memo: "testing close", + Value: 400, + Hash: payHash[:], + } + resp := alice.RPC.AddHoldInvoice(invoiceReq) + invoiceStream := alice.RPC.SubscribeSingleInvoice(payHash[:]) + + // Wait for the invoice to be ready and payable. + ht.AssertInvoiceState(invoiceStream, lnrpc.Invoice_OPEN) + + // Now that the invoice is ready to be paid, let's have Bob open an HTLC + // for it. + req := &routerrpc.SendPaymentRequest{ + PaymentRequest: resp.PaymentRequest, + TimeoutSeconds: 60, + FeeLimitSat: 1000000, + } + ht.SendPaymentAndAssertStatus(bob, req, lnrpc.Payment_IN_FLIGHT) + ht.AssertNumActiveHtlcs(bob, 1) + + // Assert at this point that the HTLC is open but not yet settled. + ht.AssertInvoiceState(invoiceStream, lnrpc.Invoice_ACCEPTED) + + // We will now let Alice initiate the closure of the channel. We will + // also let her specify a specific delivery address to be used since we + // want to test that this same address is used in the Shutdown message + // on reconnection. + newAddr := alice.RPC.NewAddress(&lnrpc.NewAddressRequest{ + Type: AddrTypeWitnessPubkeyHash, + }) + + _ = alice.RPC.CloseChannel(&lnrpc.CloseChannelRequest{ + ChannelPoint: chanPoint, + NoWait: true, + DeliveryAddress: newAddr.Address, + }) + + // Assert that both nodes see the channel as waiting for close. + ht.AssertChannelInactive(bob, chanPoint) + ht.AssertChannelInactive(alice, chanPoint) + + // Now restart Alice and Bob. + ht.RestartNode(alice) + ht.RestartNode(bob) + + ht.AssertConnected(alice, bob) + + // Show that both nodes still see the channel as waiting for close after + // the restart. + ht.AssertChannelInactive(bob, chanPoint) + ht.AssertChannelInactive(alice, chanPoint) + + // Settle the invoice. + alice.RPC.SettleInvoice(preimage[:]) + + // Wait for the channel to appear in the waiting closed list. + err := wait.Predicate(func() bool { + pendingChansResp := alice.RPC.PendingChannels() + waitingClosed := pendingChansResp.WaitingCloseChannels + + return len(waitingClosed) == 1 + }, defaultTimeout) + require.NoError(ht, err) + + // Wait for the close tx to be in the Mempool and then mine 6 blocks + // to confirm the close. + closingTx := ht.AssertClosingTxInMempool( + chanPoint, lnrpc.CommitmentType_LEGACY, + ) + ht.MineBlocksAndAssertNumTxes(6, 1) + + // Finally, we inspect the closing transaction here to show that the + // delivery address that Alice specified in her original close request + // is the one that ended up being used in the final closing transaction. + tx := alice.RPC.GetTransaction(&walletrpc.GetTransactionRequest{ + Txid: closingTx.TxHash().String(), + }) + require.Len(ht, tx.OutputDetails, 2) + + // Find Alice's output in the coop-close transaction. + var outputDetail *lnrpc.OutputDetail + for _, output := range tx.OutputDetails { + if output.IsOurAddress { + outputDetail = output + break + } + } + require.NotNil(ht, outputDetail) + + // Show that the address used is the one she requested. + require.Equal(ht, outputDetail.Address, newAddr.Address) +} diff --git a/lnwallet/chancloser/chancloser.go b/lnwallet/chancloser/chancloser.go index d33bfc4f0b9..97f5f4a63a8 100644 --- a/lnwallet/chancloser/chancloser.go +++ b/lnwallet/chancloser/chancloser.go @@ -21,13 +21,13 @@ import ( ) var ( - // ErrChanAlreadyClosing is returned when a channel shutdown is attempted - // more than once. + // ErrChanAlreadyClosing is returned when a channel shutdown is + // attempted more than once. ErrChanAlreadyClosing = fmt.Errorf("channel shutdown already initiated") // ErrChanCloseNotFinished is returned when a caller attempts to access - // a field or function that is contingent on the channel closure negotiation - // already being completed. + // a field or function that is contingent on the channel closure + // negotiation already being completed. ErrChanCloseNotFinished = fmt.Errorf("close negotiation not finished") // ErrInvalidState is returned when the closing state machine receives a @@ -79,16 +79,16 @@ const ( // closeFeeNegotiation is the third, and most persistent state. Both // parties enter this state after they've sent and received a shutdown // message. During this phase, both sides will send monotonically - // increasing fee requests until one side accepts the last fee rate offered - // by the other party. In this case, the party will broadcast the closing - // transaction, and send the accepted fee to the remote party. This then - // causes a shift into the closeFinished state. + // increasing fee requests until one side accepts the last fee rate + // offered by the other party. In this case, the party will broadcast + // the closing transaction, and send the accepted fee to the remote + // party. This then causes a shift into the closeFinished state. closeFeeNegotiation - // closeFinished is the final state of the state machine. In this state, a - // side has accepted a fee offer and has broadcast the valid closing - // transaction to the network. During this phase, the closing transaction - // becomes available for examination. + // closeFinished is the final state of the state machine. In this state, + // a side has accepted a fee offer and has broadcast the valid closing + // transaction to the network. During this phase, the closing + // transaction becomes available for examination. closeFinished ) @@ -156,8 +156,9 @@ type ChanCloser struct { // negotiationHeight is the height that the fee negotiation begun at. negotiationHeight uint32 - // closingTx is the final, fully signed closing transaction. This will only - // be populated once the state machine shifts to the closeFinished state. + // closingTx is the final, fully signed closing transaction. This will + // only be populated once the state machine shifts to the closeFinished + // state. closingTx *wire.MsgTx // idealFeeSat is the ideal fee that the state machine should initially @@ -173,22 +174,22 @@ type ChanCloser struct { idealFeeRate chainfee.SatPerKWeight // lastFeeProposal is the last fee that we proposed to the remote party. - // We'll use this as a pivot point to ratchet our next offer up, down, or - // simply accept the remote party's prior offer. + // We'll use this as a pivot point to ratchet our next offer up, down, + // or simply accept the remote party's prior offer. lastFeeProposal btcutil.Amount - // priorFeeOffers is a map that keeps track of all the proposed fees that - // we've offered during the fee negotiation. We use this map to cut the - // negotiation early if the remote party ever sends an offer that we've - // sent in the past. Once negotiation terminates, we can extract the prior - // signature of our accepted offer from this map. + // priorFeeOffers is a map that keeps track of all the proposed fees + // that we've offered during the fee negotiation. We use this map to cut + // the negotiation early if the remote party ever sends an offer that + // we've sent in the past. Once negotiation terminates, we can extract + // the prior signature of our accepted offer from this map. // // TODO(roasbeef): need to ensure if they broadcast w/ any of our prior // sigs, we are aware of priorFeeOffers map[btcutil.Amount]*lnwire.ClosingSigned - // closeReq is the initial closing request. This will only be populated if - // we're the initiator of this closing negotiation. + // closeReq is the initial closing request. This will only be populated + // if we're the initiator of this closing negotiation. // // TODO(roasbeef): abstract away closeReq *htlcswitch.ChanClose @@ -273,8 +274,10 @@ func NewChanCloser(cfg ChanCloseCfg, deliveryScript []byte, negotiationHeight: negotiationHeight, idealFeeRate: idealFeePerKw, localDeliveryScript: deliveryScript, - priorFeeOffers: make(map[btcutil.Amount]*lnwire.ClosingSigned), - locallyInitiated: locallyInitiated, + priorFeeOffers: make( + map[btcutil.Amount]*lnwire.ClosingSigned, + ), + locallyInitiated: locallyInitiated, } } @@ -321,9 +324,9 @@ func (c *ChanCloser) initFeeBaseline() { // initChanShutdown begins the shutdown process by un-registering the channel, // and creating a valid shutdown message to our target delivery address. func (c *ChanCloser) initChanShutdown() (*lnwire.Shutdown, error) { - // With both items constructed we'll now send the shutdown message for this - // particular channel, advertising a shutdown request to our desired - // closing script. + // With both items constructed we'll now send the shutdown message for + // this particular channel, advertising a shutdown request to our + // desired closing script. shutdown := lnwire.NewShutdown(c.cid, c.localDeliveryScript) // If this is a taproot channel, then we'll need to also generate a @@ -353,6 +356,17 @@ func (c *ChanCloser) initChanShutdown() (*lnwire.Shutdown, error) { chancloserLog.Infof("ChannelPoint(%v): sending shutdown message", c.chanPoint) + // At this point, we persist any relevant info regarding the Shutdown + // message we are about to send in order to ensure that if a + // re-establish occurs then we will re-send the same Shutdown message. + shutdownInfo := channeldb.NewShutdownInfo( + c.localDeliveryScript, c.locallyInitiated, + ) + err := c.cfg.Channel.MarkShutdownSent(shutdownInfo) + if err != nil { + return nil, err + } + return shutdown, nil } @@ -375,12 +389,12 @@ func (c *ChanCloser) ShutdownChan() (*lnwire.Shutdown, error) { } // With the opening steps complete, we'll transition into the - // closeShutdownInitiated state. In this state, we'll wait until the other - // party sends their version of the shutdown message. + // closeShutdownInitiated state. In this state, we'll wait until the + // other party sends their version of the shutdown message. c.state = closeShutdownInitiated - // Finally, we'll return the shutdown message to the caller so it can send - // it to the remote peer. + // Finally, we'll return the shutdown message to the caller so it can + // send it to the remote peer. return shutdownMsg, nil } @@ -476,9 +490,8 @@ func validateShutdownScript(disconnect func() error, upfrontScript, // If appropriate, it will also generate a Shutdown message of its own to send // out to the peer. It is possible for this method to return None when no error // occurred. -func (c *ChanCloser) ReceiveShutdown( - msg lnwire.Shutdown, -) (fn.Option[lnwire.Shutdown], error) { +func (c *ChanCloser) ReceiveShutdown(msg lnwire.Shutdown) ( + fn.Option[lnwire.Shutdown], error) { noShutdown := fn.None[lnwire.Shutdown]() @@ -610,9 +623,8 @@ func (c *ChanCloser) ReceiveShutdown( // it will not. In either case it will transition the ChanCloser state machine // to the negotiation phase wherein ClosingSigned messages are exchanged until // a mutually agreeable result is achieved. -func (c *ChanCloser) BeginNegotiation() ( - fn.Option[lnwire.ClosingSigned], error, -) { +func (c *ChanCloser) BeginNegotiation() (fn.Option[lnwire.ClosingSigned], + error) { noClosingSigned := fn.None[lnwire.ClosingSigned]() @@ -673,11 +685,8 @@ func (c *ChanCloser) BeginNegotiation() ( // ReceiveClosingSigned is a method that should be called whenever we receive a // ClosingSigned message from the wire. It may or may not return a ClosingSigned // of our own to send back to the remote. -// -//nolint:funlen -func (c *ChanCloser) ReceiveClosingSigned( - msg lnwire.ClosingSigned, -) (fn.Option[lnwire.ClosingSigned], error) { +func (c *ChanCloser) ReceiveClosingSigned(msg lnwire.ClosingSigned) ( + fn.Option[lnwire.ClosingSigned], error) { noClosing := fn.None[lnwire.ClosingSigned]() @@ -882,7 +891,9 @@ func (c *ChanCloser) ReceiveClosingSigned( // proposeCloseSigned attempts to propose a new signature for the closing // transaction for a channel based on the prior fee negotiations and our current // compromise fee. -func (c *ChanCloser) proposeCloseSigned(fee btcutil.Amount) (*lnwire.ClosingSigned, error) { +func (c *ChanCloser) proposeCloseSigned(fee btcutil.Amount) ( + *lnwire.ClosingSigned, error) { + var ( closeOpts []lnwallet.ChanCloseOpt err error @@ -956,8 +967,8 @@ func (c *ChanCloser) proposeCloseSigned(fee btcutil.Amount) (*lnwire.ClosingSign // compromise and to ensure that the fee negotiation has a stopping point. We // consider their fee acceptable if it's within 30% of our fee. func feeInAcceptableRange(localFee, remoteFee btcutil.Amount) bool { - // If our offer is lower than theirs, then we'll accept their offer if it's - // no more than 30% *greater* than our current offer. + // If our offer is lower than theirs, then we'll accept their offer if + // it's no more than 30% *greater* than our current offer. if localFee < remoteFee { acceptableRange := localFee + ((localFee * 3) / 10) return remoteFee <= acceptableRange @@ -991,51 +1002,59 @@ func calcCompromiseFee(chanPoint wire.OutPoint, ourIdealFee, lastSentFee, // TODO(roasbeef): take in number of rounds as well? - chancloserLog.Infof("ChannelPoint(%v): computing fee compromise, ideal="+ - "%v, last_sent=%v, remote_offer=%v", chanPoint, int64(ourIdealFee), - int64(lastSentFee), int64(remoteFee)) + chancloserLog.Infof("ChannelPoint(%v): computing fee compromise, "+ + "ideal=%v, last_sent=%v, remote_offer=%v", chanPoint, + int64(ourIdealFee), int64(lastSentFee), int64(remoteFee)) - // Otherwise, we'll need to attempt to make a fee compromise if this is the - // second round, and neither side has agreed on fees. + // Otherwise, we'll need to attempt to make a fee compromise if this is + // the second round, and neither side has agreed on fees. switch { - // If their proposed fee is identical to our ideal fee, then we'll go with - // that as we can short circuit the fee negotiation. Similarly, if we - // haven't sent an offer yet, we'll default to our ideal fee. + // If their proposed fee is identical to our ideal fee, then we'll go + // with that as we can short circuit the fee negotiation. Similarly, if + // we haven't sent an offer yet, we'll default to our ideal fee. case ourIdealFee == remoteFee || lastSentFee == 0: return ourIdealFee // If the last fee we sent, is equal to the fee the remote party is - // offering, then we can simply return this fee as the negotiation is over. + // offering, then we can simply return this fee as the negotiation is + // over. case remoteFee == lastSentFee: return lastSentFee // If the fee the remote party is offering is less than the last one we - // sent, then we'll need to ratchet down in order to move our offer closer - // to theirs. + // sent, then we'll need to ratchet down in order to move our offer + // closer to theirs. case remoteFee < lastSentFee: - // If the fee is lower, but still acceptable, then we'll just return - // this fee and end the negotiation. + // If the fee is lower, but still acceptable, then we'll just + // return this fee and end the negotiation. if feeInAcceptableRange(lastSentFee, remoteFee) { - chancloserLog.Infof("ChannelPoint(%v): proposed remote fee is "+ - "close enough, capitulating", chanPoint) + chancloserLog.Infof("ChannelPoint(%v): proposed "+ + "remote fee is close enough, capitulating", + chanPoint) + return remoteFee } - // Otherwise, we'll ratchet the fee *down* using our current algorithm. + // Otherwise, we'll ratchet the fee *down* using our current + // algorithm. return ratchetFee(lastSentFee, false) - // If the fee the remote party is offering is greater than the last one we - // sent, then we'll ratchet up in order to ensure we terminate eventually. + // If the fee the remote party is offering is greater than the last one + // we sent, then we'll ratchet up in order to ensure we terminate + // eventually. case remoteFee > lastSentFee: - // If the fee is greater, but still acceptable, then we'll just return - // this fee in order to put an end to the negotiation. + // If the fee is greater, but still acceptable, then we'll just + // return this fee in order to put an end to the negotiation. if feeInAcceptableRange(lastSentFee, remoteFee) { - chancloserLog.Infof("ChannelPoint(%v): proposed remote fee is "+ - "close enough, capitulating", chanPoint) + chancloserLog.Infof("ChannelPoint(%v): proposed "+ + "remote fee is close enough, capitulating", + chanPoint) + return remoteFee } - // Otherwise, we'll ratchet the fee up using our current algorithm. + // Otherwise, we'll ratchet the fee up using our current + // algorithm. return ratchetFee(lastSentFee, true) default: diff --git a/lnwallet/chancloser/chancloser_test.go b/lnwallet/chancloser/chancloser_test.go index 807668a1956..53c0fb6baf2 100644 --- a/lnwallet/chancloser/chancloser_test.go +++ b/lnwallet/chancloser/chancloser_test.go @@ -154,6 +154,10 @@ func (m *mockChannel) MarkCoopBroadcasted(*wire.MsgTx, bool) error { return nil } +func (m *mockChannel) MarkShutdownSent(*channeldb.ShutdownInfo) error { + return nil +} + func (m *mockChannel) IsInitiator() bool { return m.initiator } diff --git a/lnwallet/chancloser/interface.go b/lnwallet/chancloser/interface.go index 9d588d521ad..4daf5ac34ec 100644 --- a/lnwallet/chancloser/interface.go +++ b/lnwallet/chancloser/interface.go @@ -35,6 +35,11 @@ type Channel interface { //nolint:interfacebloat // transaction has been broadcast. MarkCoopBroadcasted(*wire.MsgTx, bool) error + // MarkShutdownSent persists the given ShutdownInfo. The existence of + // the ShutdownInfo represents the fact that the Shutdown message has + // been sent by us and so should be re-sent on re-establish. + MarkShutdownSent(info *channeldb.ShutdownInfo) error + // IsInitiator returns true we are the initiator of the channel. IsInitiator() bool diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 028b06e46e3..3871ec780e4 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -8823,6 +8823,18 @@ func (lc *LightningChannel) MarkCoopBroadcasted(tx *wire.MsgTx, return lc.channelState.MarkCoopBroadcasted(tx, localInitiated) } +// MarkShutdownSent persists the given ShutdownInfo. The existence of the +// ShutdownInfo represents the fact that the Shutdown message has been sent by +// us and so should be re-sent on re-establish. +func (lc *LightningChannel) MarkShutdownSent( + info *channeldb.ShutdownInfo) error { + + lc.Lock() + defer lc.Unlock() + + return lc.channelState.MarkShutdownSent(info) +} + // MarkDataLoss marks sets the channel status to LocalDataLoss and stores the // passed commitPoint for use to retrieve funds in case the remote force closes // the channel. diff --git a/peer/brontide.go b/peer/brontide.go index 4d633479c74..6692ee3c14d 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -27,6 +27,7 @@ import ( "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/feature" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/funding" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch/hodl" @@ -975,17 +976,70 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) ( spew.Sdump(forwardingPolicy)) // If the channel is pending, set the value to nil in the - // activeChannels map. This is done to signify that the channel is - // pending. We don't add the link to the switch here - it's the funding - // manager's responsibility to spin up pending channels. Adding them - // here would just be extra work as we'll tear them down when creating - // + adding the final link. + // activeChannels map. This is done to signify that the channel + // is pending. We don't add the link to the switch here - it's + // the funding manager's responsibility to spin up pending + // channels. Adding them here would just be extra work as we'll + // tear them down when creating + adding the final link. if lnChan.IsPending() { p.activeChannels.Store(chanID, nil) continue } + shutdownInfo, err := lnChan.State().ShutdownInfo() + if err != nil && !errors.Is(err, channeldb.ErrNoShutdownInfo) { + return nil, err + } + + var ( + shutdownMsg fn.Option[lnwire.Shutdown] + shutdownInfoErr error + ) + shutdownInfo.WhenSome(func(info channeldb.ShutdownInfo) { + // Compute an ideal fee. + feePerKw, err := p.cfg.FeeEstimator.EstimateFeePerKW( + p.cfg.CoopCloseTargetConfs, + ) + if err != nil { + shutdownInfoErr = fmt.Errorf("unable to "+ + "estimate fee: %w", err) + + return + } + + chanCloser, err := p.createChanCloser( + lnChan, info.DeliveryScript.Val, feePerKw, nil, + info.LocalInitiator.Val, + ) + if err != nil { + shutdownInfoErr = fmt.Errorf("unable to "+ + "create chan closer: %w", err) + + return + } + + chanID := lnwire.NewChanIDFromOutPoint( + &lnChan.State().FundingOutpoint, + ) + + p.activeChanCloses[chanID] = chanCloser + + // Create the Shutdown message. + shutdown, err := chanCloser.ShutdownChan() + if err != nil { + delete(p.activeChanCloses, chanID) + shutdownInfoErr = err + + return + } + + shutdownMsg = fn.Some[lnwire.Shutdown](*shutdown) + }) + if shutdownInfoErr != nil { + return nil, shutdownInfoErr + } + // Subscribe to the set of on-chain events for this channel. chainEvents, err := p.cfg.ChainArb.SubscribeChannelEvents( *chanPoint, @@ -996,7 +1050,7 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) ( err = p.addLink( chanPoint, lnChan, forwardingPolicy, chainEvents, - true, + true, shutdownMsg, ) if err != nil { return nil, fmt.Errorf("unable to add link %v to "+ @@ -1014,7 +1068,7 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint, lnChan *lnwallet.LightningChannel, forwardingPolicy *models.ForwardingPolicy, chainEvents *contractcourt.ChainEventSubscription, - syncStates bool) error { + syncStates bool, shutdownMsg fn.Option[lnwire.Shutdown]) error { // onChannelFailure will be called by the link in case the channel // fails for some reason. @@ -1083,6 +1137,7 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint, NotifyInactiveLinkEvent: p.cfg.ChannelNotifier.NotifyInactiveLinkEvent, HtlcNotifier: p.cfg.HtlcNotifier, GetAliases: p.cfg.GetAliases, + PreviouslySentShutdown: shutdownMsg, } // Before adding our new link, purge the switch of any pending or live @@ -2802,15 +2857,32 @@ func (p *Brontide) restartCoopClose(lnChan *lnwallet.LightningChannel) ( return nil, nil } - // As mentioned above, we don't re-create the delivery script. - deliveryScript := c.LocalShutdownScript - if len(deliveryScript) == 0 { - var err error - deliveryScript, err = p.genDeliveryScript() - if err != nil { - p.log.Errorf("unable to gen delivery script: %v", - err) - return nil, fmt.Errorf("close addr unavailable") + var deliveryScript []byte + + shutdownInfo, err := c.ShutdownInfo() + switch { + // We have previously stored the delivery script that we need to use + // in the shutdown message. Re-use this script. + case err == nil: + shutdownInfo.WhenSome(func(info channeldb.ShutdownInfo) { + deliveryScript = info.DeliveryScript.Val + }) + + // An error other than ErrNoShutdownInfo was returned + case err != nil && !errors.Is(err, channeldb.ErrNoShutdownInfo): + return nil, err + + case errors.Is(err, channeldb.ErrNoShutdownInfo): + deliveryScript = c.LocalShutdownScript + if len(deliveryScript) == 0 { + var err error + deliveryScript, err = p.genDeliveryScript() + if err != nil { + p.log.Errorf("unable to gen delivery script: "+ + "%v", err) + + return nil, fmt.Errorf("close addr unavailable") + } } } @@ -2990,13 +3062,12 @@ func (p *Brontide) handleLocalCloseReq(req *htlcswitch.ChanClose) { return } - link.OnCommitOnce(htlcswitch.Outgoing, func() { - err := link.DisableAdds(htlcswitch.Outgoing) - if err != nil { - p.log.Warnf("outgoing link adds already "+ - "disabled: %v", link.ChanID()) - } + if !link.DisableAdds(htlcswitch.Outgoing) { + p.log.Warnf("Outgoing link adds already "+ + "disabled: %v", link.ChanID()) + } + link.OnCommitOnce(htlcswitch.Outgoing, func() { p.queueMsg(shutdownMsg, nil) }) @@ -3619,12 +3690,9 @@ func (p *Brontide) handleCloseMsg(msg *closeMsg) { switch typed := msg.msg.(type) { case *lnwire.Shutdown: // Disable incoming adds immediately. - if link != nil { - err := link.DisableAdds(htlcswitch.Incoming) - if err != nil { - p.log.Warnf("incoming link adds already "+ - "disabled: %v", link.ChanID()) - } + if link != nil && !link.DisableAdds(htlcswitch.Incoming) { + p.log.Warnf("Incoming link adds already disabled: %v", + link.ChanID()) } oShutdown, err := chanCloser.ReceiveShutdown(*typed) @@ -3634,7 +3702,7 @@ func (p *Brontide) handleCloseMsg(msg *closeMsg) { } oShutdown.WhenSome(func(msg lnwire.Shutdown) { - // if the link is nil it means we can immediately queue + // If the link is nil it means we can immediately queue // the Shutdown message since we don't have to wait for // commitment transaction synchronization. if link == nil { @@ -3642,14 +3710,17 @@ func (p *Brontide) handleCloseMsg(msg *closeMsg) { return } + // Immediately disallow any new HTLC's from being added + // in the outgoing direction. + if !link.DisableAdds(htlcswitch.Outgoing) { + p.log.Warnf("Outgoing link adds already "+ + "disabled: %v", link.ChanID()) + } + // When we have a Shutdown to send, we defer it till the // next time we send a CommitSig to remain spec // compliant. link.OnCommitOnce(htlcswitch.Outgoing, func() { - err := link.DisableAdds(htlcswitch.Outgoing) - if err != nil { - p.log.Warn(err.Error()) - } p.queueMsg(&msg, nil) }) }) @@ -3906,7 +3977,7 @@ func (p *Brontide) addActiveChannel(c *lnpeer.NewChannel) error { // Create the link and add it to the switch. err = p.addLink( chanPoint, lnChan, initialPolicy, chainEvents, - shouldReestablish, + shouldReestablish, fn.None[lnwire.Shutdown](), ) if err != nil { return fmt.Errorf("can't register new channel link(%v) with "+ diff --git a/peer/test_utils.go b/peer/test_utils.go index e4b5d6086df..05bfe6ad4cc 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -4,7 +4,6 @@ import ( "bytes" crand "crypto/rand" "encoding/binary" - "fmt" "io" "math/rand" "net" @@ -510,34 +509,20 @@ type mockMessageConn struct { readRaceDetectingCounter int } -func (m *mockUpdateHandler) EnableAdds(dir htlcswitch.LinkDirection) error { - switch dir { - case htlcswitch.Outgoing: - if !m.isOutgoingAddBlocked.Swap(false) { - return fmt.Errorf("%v adds already enabled", dir) - } - case htlcswitch.Incoming: - if !m.isIncomingAddBlocked.Swap(false) { - return fmt.Errorf("%v adds already enabled", dir) - } +func (m *mockUpdateHandler) EnableAdds(dir htlcswitch.LinkDirection) bool { + if dir == htlcswitch.Outgoing { + return m.isOutgoingAddBlocked.Swap(false) } - return nil + return m.isIncomingAddBlocked.Swap(false) } -func (m *mockUpdateHandler) DisableAdds(dir htlcswitch.LinkDirection) error { - switch dir { - case htlcswitch.Outgoing: - if m.isOutgoingAddBlocked.Swap(true) { - return fmt.Errorf("%v adds already disabled", dir) - } - case htlcswitch.Incoming: - if m.isIncomingAddBlocked.Swap(true) { - return fmt.Errorf("%v adds already disabled", dir) - } +func (m *mockUpdateHandler) DisableAdds(dir htlcswitch.LinkDirection) bool { + if dir == htlcswitch.Outgoing { + return !m.isOutgoingAddBlocked.Swap(true) } - return nil + return !m.isIncomingAddBlocked.Swap(true) } func (m *mockUpdateHandler) IsFlushing(dir htlcswitch.LinkDirection) bool {