From a517cfb73461edca94d689e0a0bdc52a5431e595 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Sat, 20 Apr 2024 05:48:22 +0800 Subject: [PATCH 01/19] chainntnfs+lnd: replace `BlockHeader` with `Block` So the block data can be used by subsystems without calling `GetBlock`. --- chainntnfs/best_block_view.go | 2 +- chainntnfs/best_block_view_test.go | 19 +++---- chainntnfs/bitcoindnotify/bitcoind.go | 73 ++++++++++----------------- chainntnfs/btcdnotify/btcd.go | 53 +++++++------------ chainntnfs/interface.go | 25 +++++---- chainntnfs/neutrinonotify/neutrino.go | 51 +++++++++---------- chainntnfs/test/test_interface.go | 19 +++---- config_builder.go | 3 +- 8 files changed, 102 insertions(+), 143 deletions(-) diff --git a/chainntnfs/best_block_view.go b/chainntnfs/best_block_view.go index c043e68e7e0..ba9295e0530 100644 --- a/chainntnfs/best_block_view.go +++ b/chainntnfs/best_block_view.go @@ -68,7 +68,7 @@ func (t *BestBlockTracker) BestBlockHeader() (*wire.BlockHeader, error) { return nil, errors.New("best block header not yet known") } - return epoch.BlockHeader, nil + return &epoch.Block.Header, nil } // updateLoop is a helper that subscribes to the underlying BlockEpochEvent diff --git a/chainntnfs/best_block_view_test.go b/chainntnfs/best_block_view_test.go index 2a55bc8b032..338a8630aad 100644 --- a/chainntnfs/best_block_view_test.go +++ b/chainntnfs/best_block_view_test.go @@ -26,14 +26,15 @@ func (blockEpoch) Generate(r *rand.Rand, size int) reflect.Value { return reflect.ValueOf(blockEpoch(chainntnfs.BlockEpoch{ Hash: &chainHash, Height: r.Int31n(1000000), - BlockHeader: &wire.BlockHeader{ - Version: 2, - PrevBlock: prevBlockHash, - MerkleRoot: merkleRootHash, - Timestamp: time.Now(), - Bits: r.Uint32(), - Nonce: r.Uint32(), - }, + Block: wire.MsgBlock{ + Header: wire.BlockHeader{ + Version: 2, + PrevBlock: prevBlockHash, + MerkleRoot: merkleRootHash, + Timestamp: time.Now(), + Bits: r.Uint32(), + Nonce: r.Uint32(), + }}, })) } @@ -77,7 +78,7 @@ func TestBestBlockTracker(t *testing.T) { header, _ := tracker.BestBlockHeader() return height == uint32(epoch.Height) && - header == epoch.BlockHeader + *header == epoch.Block.Header } idempotence := func(epochRand blockEpoch) bool { epoch := chainntnfs.BlockEpoch(epochRand) diff --git a/chainntnfs/bitcoindnotify/bitcoind.go b/chainntnfs/bitcoindnotify/bitcoind.go index 2bffefdbefd..8addd799367 100644 --- a/chainntnfs/bitcoindnotify/bitcoind.go +++ b/chainntnfs/bitcoindnotify/bitcoind.go @@ -178,7 +178,8 @@ func (b *BitcoindNotifier) startNotifier() error { if err != nil { return err } - blockHeader, err := b.chainConn.GetBlockHeader(currentHash) + + block, err := b.GetBlock(currentHash) if err != nil { return err } @@ -189,9 +190,9 @@ func (b *BitcoindNotifier) startNotifier() error { ) b.bestBlock = chainntnfs.BlockEpoch{ - Height: currentHeight, - Hash: currentHash, - BlockHeader: blockHeader, + Height: currentHeight, + Hash: currentHash, + Block: *block, } b.wg.Add(1) @@ -350,9 +351,7 @@ out: // a notification for the current tip. if msg.bestBlock == nil { b.notifyBlockEpochClient( - msg, b.bestBlock.Height, - b.bestBlock.Hash, - b.bestBlock.BlockHeader, + msg, b.bestBlock, ) msg.errorChan <- nil @@ -372,10 +371,7 @@ out: } for _, block := range missedBlocks { - b.notifyBlockEpochClient( - msg, block.Height, block.Hash, - block.BlockHeader, - ) + b.notifyBlockEpochClient(msg, block) } msg.errorChan <- nil @@ -384,15 +380,14 @@ out: case ntfn := <-b.chainConn.Notifications(): switch item := ntfn.(type) { case chain.BlockConnected: - blockHeader, err := - b.chainConn.GetBlockHeader(&item.Hash) + block, err := b.GetBlock(&item.Hash) if err != nil { chainntnfs.Log.Errorf("Unable to fetch "+ "block header: %v", err) continue } - if blockHeader.PrevBlock != *b.bestBlock.Hash { + if block.Header.PrevBlock != *b.bestBlock.Hash { // Handle the case where the notifier // missed some blocks from its chain // backend. @@ -424,9 +419,9 @@ out: } newBlock := chainntnfs.BlockEpoch{ - Height: item.Height, - Hash: &item.Hash, - BlockHeader: blockHeader, + Height: item.Height, + Hash: &item.Hash, + Block: *block, } if err := b.handleBlockConnected(newBlock); err != nil { chainntnfs.Log.Error(err) @@ -631,61 +626,49 @@ func (b *BitcoindNotifier) confDetailsManually(confRequest chainntnfs.ConfReques // handleBlockConnected applies a chain update for a new block. Any watched // transactions included this block will processed to either send notifications // now or after numConfirmations confs. -func (b *BitcoindNotifier) handleBlockConnected(block chainntnfs.BlockEpoch) error { - // First, we'll fetch the raw block as we'll need to gather all the - // transactions to determine whether any are relevant to our registered - // clients. - rawBlock, err := b.GetBlock(block.Hash) - if err != nil { - return fmt.Errorf("unable to get block: %w", err) - } - utilBlock := btcutil.NewBlock(rawBlock) +func (b *BitcoindNotifier) handleBlockConnected( + epoch chainntnfs.BlockEpoch) error { + + utilBlock := btcutil.NewBlock(&epoch.Block) // We'll then extend the txNotifier's height with the information of // this new block, which will handle all of the notification logic for // us. - err = b.txNotifier.ConnectTip(utilBlock, uint32(block.Height)) + err := b.txNotifier.ConnectTip(utilBlock, uint32(epoch.Height)) if err != nil { return fmt.Errorf("unable to connect tip: %w", err) } - chainntnfs.Log.Infof("New block: height=%v, sha=%v", block.Height, - block.Hash) + chainntnfs.Log.Infof("New block: height=%v, sha=%v", epoch.Height, + epoch.Hash) // Now that we've guaranteed the new block extends the txNotifier's // current tip, we'll proceed to dispatch notifications to all of our // registered clients whom have had notifications fulfilled. Before // doing so, we'll make sure update our in memory state in order to // satisfy any client requests based upon the new block. - b.bestBlock = block + b.bestBlock = epoch + + b.notifyBlockEpochs(epoch) - b.notifyBlockEpochs(block.Height, block.Hash, block.BlockHeader) - return b.txNotifier.NotifyHeight(uint32(block.Height)) + return b.txNotifier.NotifyHeight(uint32(epoch.Height)) } // notifyBlockEpochs notifies all registered block epoch clients of the newly // connected block to the main chain. -func (b *BitcoindNotifier) notifyBlockEpochs(newHeight int32, newSha *chainhash.Hash, - blockHeader *wire.BlockHeader) { - +func (b *BitcoindNotifier) notifyBlockEpochs(block chainntnfs.BlockEpoch) { for _, client := range b.blockEpochClients { - b.notifyBlockEpochClient(client, newHeight, newSha, blockHeader) + b.notifyBlockEpochClient(client, block) } } // notifyBlockEpochClient sends a registered block epoch client a notification // about a specific block. -func (b *BitcoindNotifier) notifyBlockEpochClient(epochClient *blockEpochRegistration, - height int32, sha *chainhash.Hash, header *wire.BlockHeader) { - - epoch := &chainntnfs.BlockEpoch{ - Height: height, - Hash: sha, - BlockHeader: header, - } +func (b *BitcoindNotifier) notifyBlockEpochClient( + epochClient *blockEpochRegistration, epoch chainntnfs.BlockEpoch) { select { - case epochClient.epochQueue.ChanIn() <- epoch: + case epochClient.epochQueue.ChanIn() <- &epoch: case <-epochClient.cancelChan: case <-b.quit: } diff --git a/chainntnfs/btcdnotify/btcd.go b/chainntnfs/btcdnotify/btcd.go index e865426e9af..228169018c1 100644 --- a/chainntnfs/btcdnotify/btcd.go +++ b/chainntnfs/btcdnotify/btcd.go @@ -231,7 +231,7 @@ func (b *BtcdNotifier) startNotifier() error { return err } - bestBlock, err := b.chainConn.GetBlock(currentHash) + bestBlock, err := b.GetBlock(currentHash) if err != nil { b.txUpdates.Stop() b.chainUpdates.Stop() @@ -244,9 +244,9 @@ func (b *BtcdNotifier) startNotifier() error { ) b.bestBlock = chainntnfs.BlockEpoch{ - Height: currentHeight, - Hash: currentHash, - BlockHeader: &bestBlock.Header, + Height: currentHeight, + Hash: currentHash, + Block: *bestBlock, } if err := b.chainConn.NotifyBlocks(); err != nil { @@ -409,9 +409,7 @@ out: // a notification for the current tip. if msg.bestBlock == nil { b.notifyBlockEpochClient( - msg, b.bestBlock.Height, - b.bestBlock.Hash, - b.bestBlock.BlockHeader, + msg, b.bestBlock, ) msg.errorChan <- nil @@ -431,10 +429,7 @@ out: } for _, block := range missedBlocks { - b.notifyBlockEpochClient( - msg, block.Height, block.Hash, - block.BlockHeader, - ) + b.notifyBlockEpochClient(msg, block) } msg.errorChan <- nil @@ -443,16 +438,14 @@ out: case item := <-b.chainUpdates.ChanOut(): update := item.(*chainUpdate) if update.connect { - blockHeader, err := b.chainConn.GetBlockHeader( - update.blockHash, - ) + block, err := b.GetBlock(update.blockHash) if err != nil { chainntnfs.Log.Errorf("Unable to fetch "+ "block header: %v", err) continue } - if blockHeader.PrevBlock != *b.bestBlock.Hash { + if block.Header.PrevBlock != *b.bestBlock.Hash { // Handle the case where the notifier // missed some blocks from its chain // backend @@ -484,9 +477,9 @@ out: } newBlock := chainntnfs.BlockEpoch{ - Height: update.blockHeight, - Hash: update.blockHash, - BlockHeader: blockHeader, + Height: update.blockHeight, + Hash: update.blockHash, + Block: *block, } if err := b.handleBlockConnected(newBlock); err != nil { chainntnfs.Log.Error(err) @@ -730,38 +723,26 @@ func (b *BtcdNotifier) handleBlockConnected(epoch chainntnfs.BlockEpoch) error { // satisfy any client requests based upon the new block. b.bestBlock = epoch - b.notifyBlockEpochs( - epoch.Height, epoch.Hash, epoch.BlockHeader, - ) + b.notifyBlockEpochs(epoch) return b.txNotifier.NotifyHeight(uint32(epoch.Height)) } // notifyBlockEpochs notifies all registered block epoch clients of the newly // connected block to the main chain. -func (b *BtcdNotifier) notifyBlockEpochs(newHeight int32, - newSha *chainhash.Hash, blockHeader *wire.BlockHeader) { - +func (b *BtcdNotifier) notifyBlockEpochs(epoch chainntnfs.BlockEpoch) { for _, client := range b.blockEpochClients { - b.notifyBlockEpochClient( - client, newHeight, newSha, blockHeader, - ) + b.notifyBlockEpochClient(client, epoch) } } // notifyBlockEpochClient sends a registered block epoch client a notification // about a specific block. -func (b *BtcdNotifier) notifyBlockEpochClient(epochClient *blockEpochRegistration, - height int32, sha *chainhash.Hash, blockHeader *wire.BlockHeader) { - - epoch := &chainntnfs.BlockEpoch{ - Height: height, - Hash: sha, - BlockHeader: blockHeader, - } +func (b *BtcdNotifier) notifyBlockEpochClient( + epochClient *blockEpochRegistration, epoch chainntnfs.BlockEpoch) { select { - case epochClient.epochQueue.ChanIn() <- epoch: + case epochClient.epochQueue.ChanIn() <- &epoch: case <-epochClient.cancelChan: case <-b.quit: } diff --git a/chainntnfs/interface.go b/chainntnfs/interface.go index 3337f1451a6..282ac2fcbdc 100644 --- a/chainntnfs/interface.go +++ b/chainntnfs/interface.go @@ -366,8 +366,8 @@ type BlockEpoch struct { // the main chain. Height int32 - // BlockHeader is the block header of this new height. - BlockHeader *wire.BlockHeader + // Block is the full block. + Block wire.MsgBlock } // BlockEpochEvent encapsulates an on-going stream of block epoch @@ -471,6 +471,9 @@ type ChainConn interface { // GetBlockHash returns the hash from a block height. GetBlockHash(blockHeight int64) (*chainhash.Hash, error) + + // GetBlock returns a block from the hash. + GetBlock(hash *chainhash.Hash) (*wire.MsgBlock, error) } // GetCommonBlockAncestorHeight takes in: @@ -555,9 +558,9 @@ func RewindChain(chainConn ChainConn, txNotifier *TxNotifier, currBestBlock BlockEpoch, targetHeight int32) (BlockEpoch, error) { newBestBlock := BlockEpoch{ - Height: currBestBlock.Height, - Hash: currBestBlock.Hash, - BlockHeader: currBestBlock.BlockHeader, + Height: currBestBlock.Height, + Hash: currBestBlock.Hash, + Block: currBestBlock.Block, } for height := currBestBlock.Height; height > targetHeight; height-- { @@ -567,7 +570,7 @@ func RewindChain(chainConn ChainConn, txNotifier *TxNotifier, "find blockhash for disconnected height=%d: %v", height, err) } - header, err := chainConn.GetBlockHeader(hash) + block, err := chainConn.GetBlock(hash) if err != nil { return newBestBlock, fmt.Errorf("unable to get block "+ "header for height=%v", height-1) @@ -584,7 +587,7 @@ func RewindChain(chainConn ChainConn, txNotifier *TxNotifier, } newBestBlock.Height = height - 1 newBestBlock.Hash = hash - newBestBlock.BlockHeader = header + newBestBlock.Block = *block } return newBestBlock, nil @@ -665,7 +668,7 @@ func getMissedBlocks(chainConn ChainConn, startingHeight, return nil, fmt.Errorf("unable to find blockhash for "+ "height=%d: %v", height, err) } - header, err := chainConn.GetBlockHeader(hash) + block, err := chainConn.GetBlock(hash) if err != nil { return nil, fmt.Errorf("unable to find block header "+ "for height=%d: %v", height, err) @@ -674,9 +677,9 @@ func getMissedBlocks(chainConn ChainConn, startingHeight, missedBlocks = append( missedBlocks, BlockEpoch{ - Hash: hash, - Height: height, - BlockHeader: header, + Hash: hash, + Height: height, + Block: *block, }, ) } diff --git a/chainntnfs/neutrinonotify/neutrino.go b/chainntnfs/neutrinonotify/neutrino.go index 80e210c2fd8..500c9481ca5 100644 --- a/chainntnfs/neutrinonotify/neutrino.go +++ b/chainntnfs/neutrinonotify/neutrino.go @@ -177,9 +177,7 @@ func (n *NeutrinoNotifier) startNotifier() error { n.txUpdates.Stop() return err } - startingHeader, err := n.p2pNode.GetBlockHeader( - &startingPoint.Hash, - ) + startingBlock, err := n.p2pNode.GetBlock(startingPoint.Hash) if err != nil { n.txUpdates.Stop() return err @@ -187,7 +185,7 @@ func (n *NeutrinoNotifier) startNotifier() error { n.bestBlock.Hash = &startingPoint.Hash n.bestBlock.Height = startingPoint.Height - n.bestBlock.BlockHeader = startingHeader + n.bestBlock.Block = *startingBlock.MsgBlock() n.txNotifier = chainntnfs.NewTxNotifier( uint32(n.bestBlock.Height), chainntnfs.ReorgSafetyLimit, @@ -472,9 +470,7 @@ func (n *NeutrinoNotifier) notificationDispatcher() { // a notification for the current tip. if msg.bestBlock == nil { n.notifyBlockEpochClient( - msg, n.bestBlock.Height, - n.bestBlock.Hash, - n.bestBlock.BlockHeader, + msg, n.bestBlock, ) msg.errorChan <- nil @@ -498,10 +494,7 @@ func (n *NeutrinoNotifier) notificationDispatcher() { } for _, block := range missedBlocks { - n.notifyBlockEpochClient( - msg, block.Height, block.Hash, - block.BlockHeader, - ) + n.notifyBlockEpochClient(msg, block) } msg.errorChan <- nil @@ -682,11 +675,9 @@ func (n *NeutrinoNotifier) handleBlockConnected(newBlock *filteredBlock) error { // satisfy any client requests based upon the new block. n.bestBlock.Hash = &newBlock.hash n.bestBlock.Height = int32(newBlock.height) - n.bestBlock.BlockHeader = newBlock.header + n.bestBlock.Block = *rawBlock.MsgBlock() - n.notifyBlockEpochs( - int32(newBlock.height), &newBlock.hash, newBlock.header, - ) + n.notifyBlockEpochs(n.bestBlock) return n.txNotifier.NotifyHeight(newBlock.height) } @@ -711,27 +702,19 @@ func (n *NeutrinoNotifier) getFilteredBlock(epoch chainntnfs.BlockEpoch) (*filte // notifyBlockEpochs notifies all registered block epoch clients of the newly // connected block to the main chain. -func (n *NeutrinoNotifier) notifyBlockEpochs(newHeight int32, newSha *chainhash.Hash, - blockHeader *wire.BlockHeader) { - +func (n *NeutrinoNotifier) notifyBlockEpochs(epoch chainntnfs.BlockEpoch) { for _, client := range n.blockEpochClients { - n.notifyBlockEpochClient(client, newHeight, newSha, blockHeader) + n.notifyBlockEpochClient(client, epoch) } } // notifyBlockEpochClient sends a registered block epoch client a notification // about a specific block. -func (n *NeutrinoNotifier) notifyBlockEpochClient(epochClient *blockEpochRegistration, - height int32, sha *chainhash.Hash, blockHeader *wire.BlockHeader) { - - epoch := &chainntnfs.BlockEpoch{ - Height: height, - Hash: sha, - BlockHeader: blockHeader, - } +func (n *NeutrinoNotifier) notifyBlockEpochClient( + epochClient *blockEpochRegistration, epoch chainntnfs.BlockEpoch) { select { - case epochClient.epochQueue.ChanIn() <- epoch: + case epochClient.epochQueue.ChanIn() <- &epoch: case <-epochClient.cancelChan: case <-n.quit: } @@ -1124,6 +1107,18 @@ type NeutrinoChainConn struct { p2pNode *neutrino.ChainService } +// GetBlock returns the block for a hash. +func (n *NeutrinoChainConn) GetBlock( + blockHash *chainhash.Hash) (*wire.MsgBlock, error) { + + utilBlock, err := n.p2pNode.GetBlock(*blockHash) + if err != nil { + return nil, err + } + + return utilBlock.MsgBlock(), nil +} + // GetBlockHeader returns the block header for a hash. func (n *NeutrinoChainConn) GetBlockHeader(blockHash *chainhash.Hash) (*wire.BlockHeader, error) { return n.p2pNode.GetBlockHeader(blockHash) diff --git a/chainntnfs/test/test_interface.go b/chainntnfs/test/test_interface.go index 35e63a45e98..693314b75cb 100644 --- a/chainntnfs/test/test_interface.go +++ b/chainntnfs/test/test_interface.go @@ -410,17 +410,14 @@ func testBlockEpochNotification(miner *rpctest.Harness, // and that header matches the contained header // hash. blockEpoch := <-epochClient.Epochs - if blockEpoch.BlockHeader == nil { - t.Logf("%d", i) - clientErrors <- fmt.Errorf("block " + - "header is nil") - return - } - if blockEpoch.BlockHeader.BlockHash() != - *blockEpoch.Hash { - - clientErrors <- fmt.Errorf("block " + - "header hash mismatch") + t.Logf("%d", i) + + got := blockEpoch.Block.Header.BlockHash() + want := *blockEpoch.Hash + if got != want { + clientErrors <- fmt.Errorf("block "+ + "header hash mismatch: "+ + "want=%v, got=%v", want, got) return } diff --git a/config_builder.go b/config_builder.go index 49a0b6c507b..66bda50f9e9 100644 --- a/config_builder.go +++ b/config_builder.go @@ -629,8 +629,7 @@ func proxyBlockEpoch(notifier chainntnfs.ChainNotifier, go func() { for blk := range blockEpoch.Epochs { ntfn := blockntfns.NewBlockConnected( - *blk.BlockHeader, - uint32(blk.Height), + blk.Block.Header, uint32(blk.Height), ) sub.Notifications <- ntfn From 5eb0e6dafc44ebbf2ba72482a0117213a35b9a89 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 30 Apr 2024 17:13:57 +0800 Subject: [PATCH 02/19] sweep: add new state `TxError` for erroneous sweepings Also updated the loggings. This new state will be used in the following commit. --- sweep/fee_bumper.go | 24 +++++++++++++++++++++--- sweep/sweeper.go | 5 ++++- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index b4d4298943b..fac2248000a 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -71,6 +71,9 @@ const ( // TxConfirmed is sent when the tx is confirmed. TxConfirmed + // TxError is sent when there's an error creating the tx. + TxError + // sentinalEvent is used to check if an event is unknown. sentinalEvent ) @@ -86,6 +89,8 @@ func (e BumpEvent) String() string { return "Replaced" case TxConfirmed: return "Confirmed" + case TxError: + return "Error" default: return "Unknown" } @@ -204,6 +209,16 @@ type BumpResult struct { requestID uint64 } +// String returns a human-readable string for the result. +func (b *BumpResult) String() string { + desc := fmt.Sprintf("Event=%v", b.Event) + if b.Tx != nil { + desc += fmt.Sprintf(", Tx=%v", b.Tx.TxHash()) + } + + return fmt.Sprintf("[%s]", desc) +} + // Validate validates the BumpResult so it's safe to use. func (b *BumpResult) Validate() error { // Every result must have a tx. @@ -592,8 +607,7 @@ func (t *TxPublisher) notifyResult(result *BumpResult) { return } - log.Debugf("Sending result for requestID=%v, tx=%v", id, - result.Tx.TxHash()) + log.Debugf("Sending result %v for requestID=%v", result, id) select { // Send the result to the subscriber. @@ -622,10 +636,14 @@ func (t *TxPublisher) removeResult(result *BumpResult) { id, result.Tx.TxHash(), result.Err) case TxConfirmed: - // Remove the record is the tx is confirmed. + // Remove the record if the tx is confirmed. log.Debugf("Removing confirmed monitor record=%v, tx=%v", id, result.Tx.TxHash()) + case TxError: + // Remove the record if there's an error. + log.Debugf("Removing monitor record=%v due to error", id) + // Do nothing if it's neither failed or confirmed. default: log.Tracef("Skipping record removal for id=%v, event=%v", id, diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 39a03228d07..2ea4183237e 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -1718,7 +1718,7 @@ func (s *UtxoSweeper) handleBumpEventTxPublished(r *BumpResult) error { // NOTE: TxConfirmed event is not handled, since we already subscribe to the // input's spending event, we don't need to do anything here. func (s *UtxoSweeper) handleBumpEvent(r *BumpResult) error { - log.Debugf("Received bump event [%v] for tx %v", r.Event, r.Tx.TxHash()) + log.Debugf("Received bump result %v", r) switch r.Event { // The tx has been published, we update the inputs' state and create a @@ -1734,6 +1734,9 @@ func (s *UtxoSweeper) handleBumpEvent(r *BumpResult) error { // with the new one. case TxReplaced: return s.handleBumpEventTxReplaced(r) + + case TxError: + // TODO(yy): create a method to remove this input. } return nil From 8ba581046b1aa6e97e60216a548eb7b7f8a446af Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 30 Apr 2024 17:15:35 +0800 Subject: [PATCH 03/19] sweep: add `handleInitialBroadcast` to handle initial broadcast This commit adds a new method `handleInitialBroadcast` to handle the initial broadcast. Previously we'd broadcast immediately inside `Broadcast`, which soon will not work giving the blockbeat being used in the following commit. --- sweep/fee_bumper.go | 158 ++++++++++++----- sweep/fee_bumper_test.go | 365 ++++++++++++++++++++++++--------------- sweep/fee_function.go | 5 +- 3 files changed, 342 insertions(+), 186 deletions(-) diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index fac2248000a..7326af360d8 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -319,12 +319,10 @@ func (t *TxPublisher) isNeutrinoBackend() bool { return t.cfg.Wallet.BackEnd() == "neutrino" } -// Broadcast is used to publish the tx created from the given inputs. It will, -// 1. init a fee function based on the given strategy. -// 2. create an RBF-compliant tx and monitor it for confirmation. -// 3. notify the initial broadcast result back to the caller. -// The initial broadcast is guaranteed to be RBF-compliant unless the budget -// specified cannot cover the fee. +// Broadcast is used to publish the tx created from the given inputs. It will +// register the broadcast request and return a chan to the caller to subscribe +// the broadcast result. The initial broadcast is guaranteed to be +// RBF-compliant unless the budget specified cannot cover the fee. // // NOTE: part of the Bumper interface. func (t *TxPublisher) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { @@ -333,28 +331,26 @@ func (t *TxPublisher) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { return spew.Sdump(req) })()) - // Attempt an initial broadcast which is guaranteed to comply with the - // RBF rules. - result, err := t.initialBroadcast(req) - if err != nil { - log.Errorf("Initial broadcast failed: %v", err) + // Increase the request counter. + // + // NOTE: this is the only place where we increase the counter. + requestID := t.requestCounter.Add(1) - return nil, err - } + // Register the record. + t.records.Store(requestID, &monitorRecord{req: req}) // Create a chan to send the result to the caller. subscriber := make(chan *BumpResult, 1) - t.subscriberChans.Store(result.requestID, subscriber) - - // Send the initial broadcast result to the caller. - t.handleResult(result) + t.subscriberChans.Store(requestID, subscriber) return subscriber, nil } // initialBroadcast initializes a fee function, creates an RBF-compliant tx and // broadcasts it. -func (t *TxPublisher) initialBroadcast(req *BumpRequest) (*BumpResult, error) { +func (t *TxPublisher) initialBroadcast(requestID uint64, + req *BumpRequest) (*BumpResult, error) { + // Create a fee bumping algorithm to be used for future RBF. feeAlgo, err := t.initializeFeeFunction(req) if err != nil { @@ -363,7 +359,7 @@ func (t *TxPublisher) initialBroadcast(req *BumpRequest) (*BumpResult, error) { // Create the initial tx to be broadcasted. This tx is guaranteed to // comply with the RBF restrictions. - requestID, err := t.createRBFCompliantTx(req, feeAlgo) + err = t.createRBFCompliantTx(requestID, req, feeAlgo) if err != nil { return nil, fmt.Errorf("create RBF-compliant tx: %w", err) } @@ -410,8 +406,8 @@ func (t *TxPublisher) initializeFeeFunction( // so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee // and redo the process until the tx is valid, or return an error when non-RBF // related errors occur or the budget has been used up. -func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, - f FeeFunction) (uint64, error) { +func (t *TxPublisher) createRBFCompliantTx(requestID uint64, req *BumpRequest, + f FeeFunction) error { for { // Create a new tx with the given fee rate and check its @@ -420,15 +416,15 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, switch { case err == nil: - // The tx is valid, return the request ID. - requestID := t.storeRecord(tx, req, f, fee) + // The tx is valid, store it. + t.storeRecord(requestID, tx, req, f, fee) - log.Infof("Created tx %v for %v inputs: feerate=%v, "+ - "fee=%v, inputs=%v", tx.TxHash(), + log.Infof("Created initial sweep tx=%v for %v inputs: "+ + "feerate=%v, fee=%v, inputs:\n%v", tx.TxHash(), len(req.Inputs), f.FeeRate(), fee, inputTypeSummary(req.Inputs)) - return requestID, nil + return nil // If the error indicates the fees paid is not enough, we will // ask the fee function to increase the fee rate and retry. @@ -459,7 +455,7 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, // cluster these inputs differetly. increased, err = f.Increment() if err != nil { - return 0, err + return err } } @@ -469,20 +465,14 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, // mempool acceptance. default: log.Debugf("Failed to create RBF-compliant tx: %v", err) - return 0, err + return err } } } // storeRecord stores the given record in the records map. -func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest, - f FeeFunction, fee btcutil.Amount) uint64 { - - // Increase the request counter. - // - // NOTE: this is the only place where we increase the - // counter. - requestID := t.requestCounter.Add(1) +func (t *TxPublisher) storeRecord(requestID uint64, tx *wire.MsgTx, + req *BumpRequest, f FeeFunction, fee btcutil.Amount) { // Register the record. t.records.Store(requestID, &monitorRecord{ @@ -491,8 +481,6 @@ func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest, feeFunction: f, fee: fee, }) - - return requestID } // createAndCheckTx creates a tx based on the given inputs, change output @@ -756,18 +744,27 @@ func (t *TxPublisher) processRecords() { // confirmed. confirmedRecords := make(map[uint64]*monitorRecord) - // feeBumpRecords stores a map of the records which need to be bumped. + // feeBumpRecords stores a map of records which need to be bumped. feeBumpRecords := make(map[uint64]*monitorRecord) - // failedRecords stores a map of the records which has inputs being - // spent by a third party. + // failedRecords stores a map of records which has inputs being spent + // by a third party. // // NOTE: this is only used for neutrino backend. failedRecords := make(map[uint64]*monitorRecord) + // initialRecords stores a map of records which are being created and + // published for the first time. + initialRecords := make(map[uint64]*monitorRecord) + // visitor is a helper closure that visits each record and divides them // into two groups. visitor := func(requestID uint64, r *monitorRecord) error { + if r.tx == nil { + initialRecords[requestID] = r + return nil + } + log.Tracef("Checking monitor recordID=%v for tx=%v", requestID, r.tx.TxHash()) @@ -795,9 +792,18 @@ func (t *TxPublisher) processRecords() { return nil } - // Iterate through all the records and divide them into two groups. + // Iterate through all the records and divide them into four groups. t.records.ForEach(visitor) + // Handle the initial broadcast. + for requestID, r := range initialRecords { + rec := r + + log.Debugf("Initial broadcast for requestID=%v", requestID) + t.wg.Add(1) + go t.handleInitialBroadcast(rec, requestID) + } + // For records that are confirmed, we'll notify the caller about this // result. for requestID, r := range confirmedRecords { @@ -853,6 +859,74 @@ func (t *TxPublisher) handleTxConfirmed(r *monitorRecord, requestID uint64) { t.handleResult(result) } +// handleInitialBroadcast is called when a new request is received. It will +// handle the initial tx creation and broadcast. In details, +// 1. init a fee function based on the given strategy. +// 2. create an RBF-compliant tx and monitor it for confirmation. +// 3. notify the initial broadcast result back to the caller. +// +// NOTE: Must be run as a goroutine to avoid blocking on sending the result. +func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord, + requestID uint64) { + + defer t.wg.Done() + + var ( + result *BumpResult + err error + ) + + // Attempt an initial broadcast which is guaranteed to comply with the + // RBF rules. + result, err = t.initialBroadcast(requestID, r.req) + if err != nil { + log.Errorf("Initial broadcast failed: %v", err) + + // Create a tx so the caller knowns which inputs are failed. + sweepTx := wire.NewMsgTx(2) + for _, o := range r.req.Inputs { + sweepTx.AddTxIn(&wire.TxIn{ + PreviousOutPoint: o.OutPoint(), + Sequence: o.BlocksToMaturity(), + }) + } + + // We now decide what type of event to send. + var event BumpEvent + + switch { + // When the error is due to a dust output, we'll send a + // TxFailed so these inputs can be retried with a different + // group in the next block. + case errors.Is(err, ErrTxNoOutput): + event = TxFailed + + // When the error is due to budget being used up, we'll send a + // TxFailed so these inputs can be retried with a different + // group in the next block. + case errors.Is(err, ErrMaxPosition): + event = TxFailed + + // When the error is due to zero fee rate delta, we'll send a + // TxFailed so these inputs can be retried in the next block. + case errors.Is(err, ErrZeroFeeRateDelta): + event = TxFailed + + default: + event = TxError + } + + result = &BumpResult{ + Event: event, + Err: err, + requestID: requestID, + Tx: sweepTx, + } + } + + t.handleResult(result) +} + // handleFeeBumpTx checks if the tx needs to be bumped, and if so, it will // attempt to bump the fee of the tx. // diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go index 63a828654d4..0ab89a2db4f 100644 --- a/sweep/fee_bumper_test.go +++ b/sweep/fee_bumper_test.go @@ -310,13 +310,10 @@ func TestStoreRecord(t *testing.T) { initialCounter := tp.requestCounter.Load() // Call the method under test. - requestID := tp.storeRecord(tx, req, feeFunc, fee) - - // Check the request ID is as expected. - require.Equal(t, initialCounter+1, requestID) + tp.storeRecord(initialCounter, tx, req, feeFunc, fee) // Read the saved record and compare. - record, ok := tp.records.Load(requestID) + record, ok := tp.records.Load(initialCounter) require.True(t, ok) require.Equal(t, tx, record.tx) require.Equal(t, feeFunc, record.feeFunction) @@ -611,23 +608,20 @@ func TestCreateRBFCompliantTx(t *testing.T) { }, } + requestCounter := atomic.Uint64{} + for _, tc := range testCases { tc := tc + rid := requestCounter.Add(1) t.Run(tc.name, func(t *testing.T) { tc.setupMock() // Call the method under test. - id, err := tp.createRBFCompliantTx(req, m.feeFunc) + err := tp.createRBFCompliantTx(rid, req, m.feeFunc) // Check the result is as expected. require.ErrorIs(t, err, tc.expectedErr) - - // If there's an error, expect the requestID to be - // empty. - if tc.expectedErr != nil { - require.Zero(t, id) - } }) } } @@ -652,7 +646,8 @@ func TestTxPublisherBroadcast(t *testing.T) { // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + requestID := uint64(1) + tp.storeRecord(requestID, tx, req, m.feeFunc, fee) // Quickly check when the requestID cannot be found, an error is // returned. @@ -739,6 +734,9 @@ func TestRemoveResult(t *testing.T) { // Create a testing record and put it in the map. fee := btcutil.Amount(1000) + // Create a test request ID counter. + requestCounter := atomic.Uint64{} + testCases := []struct { name string setupRecord func() uint64 @@ -750,10 +748,11 @@ func TestRemoveResult(t *testing.T) { // removed. name: "remove on TxConfirmed", setupRecord: func() uint64 { - id := tp.storeRecord(tx, req, m.feeFunc, fee) - tp.subscriberChans.Store(id, nil) + rid := requestCounter.Add(1) + tp.storeRecord(rid, tx, req, m.feeFunc, fee) + tp.subscriberChans.Store(rid, nil) - return id + return rid }, result: &BumpResult{ Event: TxConfirmed, @@ -765,10 +764,11 @@ func TestRemoveResult(t *testing.T) { // When the tx is failed, the records will be removed. name: "remove on TxFailed", setupRecord: func() uint64 { - id := tp.storeRecord(tx, req, m.feeFunc, fee) - tp.subscriberChans.Store(id, nil) + rid := requestCounter.Add(1) + tp.storeRecord(rid, tx, req, m.feeFunc, fee) + tp.subscriberChans.Store(rid, nil) - return id + return rid }, result: &BumpResult{ Event: TxFailed, @@ -781,10 +781,11 @@ func TestRemoveResult(t *testing.T) { // Noop when the tx is neither confirmed or failed. name: "noop when tx is not confirmed or failed", setupRecord: func() uint64 { - id := tp.storeRecord(tx, req, m.feeFunc, fee) - tp.subscriberChans.Store(id, nil) + rid := requestCounter.Add(1) + tp.storeRecord(rid, tx, req, m.feeFunc, fee) + tp.subscriberChans.Store(rid, nil) - return id + return rid }, result: &BumpResult{ Event: TxPublished, @@ -831,7 +832,8 @@ func TestNotifyResult(t *testing.T) { // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + requestID := uint64(1) + tp.storeRecord(requestID, tx, req, m.feeFunc, fee) // Create a subscription to the event. subscriber := make(chan *BumpResult, 1) @@ -879,41 +881,17 @@ func TestNotifyResult(t *testing.T) { } } -// TestBroadcastSuccess checks the public `Broadcast` method can successfully -// broadcast a tx based on the request. -func TestBroadcastSuccess(t *testing.T) { +// TestBroadcast checks the public `Broadcast` method can successfully register +// a broadcast request. +func TestBroadcast(t *testing.T) { t.Parallel() // Create a publisher using the mocks. - tp, m := createTestPublisher(t) + tp, _ := createTestPublisher(t) // Create a test feerate. feerate := chainfee.SatPerKWeight(1000) - // Mock the fee estimator to return the testing fee rate. - // - // We are not testing `NewLinearFeeFunction` here, so the actual params - // used are irrelevant. - m.estimator.On("EstimateFeePerKW", mock.Anything).Return( - feerate, nil).Once() - m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once() - - // Mock the signer to always return a valid script. - // - // NOTE: we are not testing the utility of creating valid txes here, so - // this is fine to be mocked. This behaves essentially as skipping the - // Signer check and alaways assume the tx has a valid sig. - script := &input.Script{} - m.signer.On("ComputeInputScript", mock.Anything, - mock.Anything).Return(script, nil) - - // Mock the testmempoolaccept to pass. - m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() - - // Mock the wallet to publish successfully. - m.wallet.On("PublishTransaction", - mock.Anything, mock.Anything).Return(nil).Once() - // Create a test request. inp := createTestInput(1000, input.WitnessKeyHash) @@ -929,101 +907,17 @@ func TestBroadcastSuccess(t *testing.T) { // Send the req and expect no error. resultChan, err := tp.Broadcast(req) require.NoError(t, err) - - // Check the result is sent back. - select { - case <-time.After(time.Second): - t.Fatal("timeout waiting for subscriber to receive result") - - case result := <-resultChan: - // We expect the first result to be TxPublished. - require.Equal(t, TxPublished, result.Event) - } + require.NotNil(t, resultChan) // Validate the record was stored. require.Equal(t, 1, tp.records.Len()) require.Equal(t, 1, tp.subscriberChans.Len()) -} - -// TestBroadcastFail checks the public `Broadcast` returns the error or a -// failed result when the broadcast fails. -func TestBroadcastFail(t *testing.T) { - t.Parallel() - - // Create a publisher using the mocks. - tp, m := createTestPublisher(t) - - // Create a test feerate. - feerate := chainfee.SatPerKWeight(1000) - - // Create a test request. - inp := createTestInput(1000, input.WitnessKeyHash) - - // Create a testing bump request. - req := &BumpRequest{ - DeliveryAddress: changePkScript, - Inputs: []input.Input{&inp}, - Budget: btcutil.Amount(1000), - MaxFeeRate: feerate * 10, - DeadlineHeight: 10, - } - - // Mock the fee estimator to return the testing fee rate. - // - // We are not testing `NewLinearFeeFunction` here, so the actual params - // used are irrelevant. - m.estimator.On("EstimateFeePerKW", mock.Anything).Return( - feerate, nil).Twice() - m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice() - - // Mock the signer to always return a valid script. - // - // NOTE: we are not testing the utility of creating valid txes here, so - // this is fine to be mocked. This behaves essentially as skipping the - // Signer check and alaways assume the tx has a valid sig. - script := &input.Script{} - m.signer.On("ComputeInputScript", mock.Anything, - mock.Anything).Return(script, nil) - - // Mock the testmempoolaccept to return an error. - m.wallet.On("CheckMempoolAcceptance", - mock.Anything).Return(errDummy).Once() - - // Send the req and expect an error returned. - resultChan, err := tp.Broadcast(req) - require.ErrorIs(t, err, errDummy) - require.Nil(t, resultChan) - - // Validate the record was NOT stored. - require.Equal(t, 0, tp.records.Len()) - require.Equal(t, 0, tp.subscriberChans.Len()) - - // Mock the testmempoolaccept again, this time it passes. - m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() - - // Mock the wallet to fail on publish. - m.wallet.On("PublishTransaction", - mock.Anything, mock.Anything).Return(errDummy).Once() - - // Send the req and expect no error returned. - resultChan, err = tp.Broadcast(req) - require.NoError(t, err) - - // Check the result is sent back. - select { - case <-time.After(time.Second): - t.Fatal("timeout waiting for subscriber to receive result") - case result := <-resultChan: - // We expect the result to be TxFailed and the error is set in - // the result. - require.Equal(t, TxFailed, result.Event) - require.ErrorIs(t, result.Err, errDummy) - } - - // Validate the record was removed. - require.Equal(t, 0, tp.records.Len()) - require.Equal(t, 0, tp.subscriberChans.Len()) + // Validate the record. + rid := tp.requestCounter.Load() + record, found := tp.records.Load(rid) + require.True(t, found) + require.Equal(t, req, record.req) } // TestCreateAnPublishFail checks all the error cases are handled properly in @@ -1188,7 +1082,8 @@ func TestHandleTxConfirmed(t *testing.T) { // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + requestID := uint64(1) + tp.storeRecord(requestID, tx, req, m.feeFunc, fee) record, ok := tp.records.Load(requestID) require.True(t, ok) @@ -1260,7 +1155,8 @@ func TestHandleFeeBumpTx(t *testing.T) { // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + requestID := uint64(1) + tp.storeRecord(requestID, tx, req, m.feeFunc, fee) // Create a subscription to the event. subscriber := make(chan *BumpResult, 1) @@ -1461,3 +1357,186 @@ func TestProcessRecords(t *testing.T) { require.Equal(t, requestID2, result.requestID) } } + +// TestHandleInitialBroadcastSuccess checks `handleInitialBroadcast` method can +// successfully broadcast a tx based on the request. +func TestHandleInitialBroadcastSuccess(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test feerate. + feerate := chainfee.SatPerKWeight(1000) + + // Mock the fee estimator to return the testing fee rate. + // + // We are not testing `NewLinearFeeFunction` here, so the actual params + // used are irrelevant. + m.estimator.On("EstimateFeePerKW", mock.Anything).Return( + feerate, nil).Once() + m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once() + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Mock the testmempoolaccept to pass. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() + + // Mock the wallet to publish successfully. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(nil).Once() + + // Create a test request. + inp := createTestInput(1000, input.WitnessKeyHash) + + // Create a testing bump request. + req := &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + MaxFeeRate: feerate * 10, + DeadlineHeight: 10, + } + + // Register the testing record use `Broadcast`. + resultChan, err := tp.Broadcast(req) + require.NoError(t, err) + + // Grab the monitor record from the map. + rid := tp.requestCounter.Load() + rec, ok := tp.records.Load(rid) + require.True(t, ok) + + // Call the method under test. + tp.wg.Add(1) + tp.handleInitialBroadcast(rec, rid) + + // Check the result is sent back. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-resultChan: + // We expect the first result to be TxPublished. + require.Equal(t, TxPublished, result.Event) + } + + // Validate the record was stored. + require.Equal(t, 1, tp.records.Len()) + require.Equal(t, 1, tp.subscriberChans.Len()) +} + +// TestHandleInitialBroadcastFail checks `handleInitialBroadcast` returns the +// error or a failed result when the broadcast fails. +func TestHandleInitialBroadcastFail(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test feerate. + feerate := chainfee.SatPerKWeight(1000) + + // Create a test request. + inp := createTestInput(1000, input.WitnessKeyHash) + + // Create a testing bump request. + req := &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + MaxFeeRate: feerate * 10, + DeadlineHeight: 10, + } + + // Mock the fee estimator to return the testing fee rate. + // + // We are not testing `NewLinearFeeFunction` here, so the actual params + // used are irrelevant. + m.estimator.On("EstimateFeePerKW", mock.Anything).Return( + feerate, nil).Twice() + m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice() + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Mock the testmempoolaccept to return an error. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(errDummy).Once() + + // Register the testing record use `Broadcast`. + resultChan, err := tp.Broadcast(req) + require.NoError(t, err) + + // Grab the monitor record from the map. + rid := tp.requestCounter.Load() + rec, ok := tp.records.Load(rid) + require.True(t, ok) + + // Call the method under test and expect an error returned. + tp.wg.Add(1) + tp.handleInitialBroadcast(rec, rid) + + // Check the result is sent back. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-resultChan: + // We expect the first result to be TxError. + require.Equal(t, TxError, result.Event) + } + + // Validate the record was NOT stored. + require.Equal(t, 0, tp.records.Len()) + require.Equal(t, 0, tp.subscriberChans.Len()) + + // Mock the testmempoolaccept again, this time it passes. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() + + // Mock the wallet to fail on publish. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(errDummy).Once() + + // Register the testing record use `Broadcast`. + resultChan, err = tp.Broadcast(req) + require.NoError(t, err) + + // Grab the monitor record from the map. + rid = tp.requestCounter.Load() + rec, ok = tp.records.Load(rid) + require.True(t, ok) + + // Call the method under test. + tp.wg.Add(1) + tp.handleInitialBroadcast(rec, rid) + + // Check the result is sent back. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-resultChan: + // We expect the result to be TxFailed and the error is set in + // the result. + require.Equal(t, TxFailed, result.Event) + require.ErrorIs(t, result.Err, errDummy) + } + + // Validate the record was removed. + require.Equal(t, 0, tp.records.Len()) + require.Equal(t, 0, tp.subscriberChans.Len()) +} diff --git a/sweep/fee_function.go b/sweep/fee_function.go index cbf283e37d7..15d44ed6160 100644 --- a/sweep/fee_function.go +++ b/sweep/fee_function.go @@ -14,6 +14,9 @@ var ( // ErrMaxPosition is returned when trying to increase the position of // the fee function while it's already at its max. ErrMaxPosition = errors.New("position already at max") + + // ErrZeroFeeRateDelta is returned when the fee rate delta is zero. + ErrZeroFeeRateDelta = errors.New("fee rate delta is zero") ) // mSatPerKWeight represents a fee rate in msat/kw. @@ -169,7 +172,7 @@ func NewLinearFeeFunction(maxFeeRate chainfee.SatPerKWeight, "endingFeeRate=%v, width=%v, delta=%v", start, end, l.width, l.deltaFeeRate) - return nil, fmt.Errorf("fee rate delta is zero") + return nil, ErrZeroFeeRateDelta } // Attach the calculated values to the fee function. From df0190ecb9d29147235ac08e8eb3c4368b4585ee Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 30 Apr 2024 17:39:45 +0800 Subject: [PATCH 04/19] sweep: remove redundant error from `Broadcast` --- sweep/fee_bumper.go | 6 +++--- sweep/fee_bumper_test.go | 12 ++++-------- sweep/mock_test.go | 6 +++--- sweep/sweeper.go | 16 +--------------- sweep/sweeper_test.go | 9 ++------- 5 files changed, 13 insertions(+), 36 deletions(-) diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index 7326af360d8..bd336246286 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -52,7 +52,7 @@ type Bumper interface { // and monitors its confirmation status for potential fee bumping. It // returns a chan that the caller can use to receive updates about the // broadcast result and potential RBF attempts. - Broadcast(req *BumpRequest) (<-chan *BumpResult, error) + Broadcast(req *BumpRequest) <-chan *BumpResult } // BumpEvent represents the event of a fee bumping attempt. @@ -325,7 +325,7 @@ func (t *TxPublisher) isNeutrinoBackend() bool { // RBF-compliant unless the budget specified cannot cover the fee. // // NOTE: part of the Bumper interface. -func (t *TxPublisher) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { +func (t *TxPublisher) Broadcast(req *BumpRequest) <-chan *BumpResult { log.Tracef("Received broadcast request: %s", newLogClosure( func() string { return spew.Sdump(req) @@ -343,7 +343,7 @@ func (t *TxPublisher) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { subscriber := make(chan *BumpResult, 1) t.subscriberChans.Store(requestID, subscriber) - return subscriber, nil + return subscriber } // initialBroadcast initializes a fee function, creates an RBF-compliant tx and diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go index 0ab89a2db4f..85d883f64da 100644 --- a/sweep/fee_bumper_test.go +++ b/sweep/fee_bumper_test.go @@ -905,8 +905,7 @@ func TestBroadcast(t *testing.T) { } // Send the req and expect no error. - resultChan, err := tp.Broadcast(req) - require.NoError(t, err) + resultChan := tp.Broadcast(req) require.NotNil(t, resultChan) // Validate the record was stored. @@ -1406,8 +1405,7 @@ func TestHandleInitialBroadcastSuccess(t *testing.T) { } // Register the testing record use `Broadcast`. - resultChan, err := tp.Broadcast(req) - require.NoError(t, err) + resultChan := tp.Broadcast(req) // Grab the monitor record from the map. rid := tp.requestCounter.Load() @@ -1478,8 +1476,7 @@ func TestHandleInitialBroadcastFail(t *testing.T) { mock.Anything).Return(errDummy).Once() // Register the testing record use `Broadcast`. - resultChan, err := tp.Broadcast(req) - require.NoError(t, err) + resultChan := tp.Broadcast(req) // Grab the monitor record from the map. rid := tp.requestCounter.Load() @@ -1512,8 +1509,7 @@ func TestHandleInitialBroadcastFail(t *testing.T) { mock.Anything, mock.Anything).Return(errDummy).Once() // Register the testing record use `Broadcast`. - resultChan, err = tp.Broadcast(req) - require.NoError(t, err) + resultChan = tp.Broadcast(req) // Grab the monitor record from the map. rid = tp.requestCounter.Load() diff --git a/sweep/mock_test.go b/sweep/mock_test.go index 605b8d14ec4..f79ccea290e 100644 --- a/sweep/mock_test.go +++ b/sweep/mock_test.go @@ -276,14 +276,14 @@ type MockBumper struct { var _ Bumper = (*MockBumper)(nil) // Broadcast broadcasts the transaction to the network. -func (m *MockBumper) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { +func (m *MockBumper) Broadcast(req *BumpRequest) <-chan *BumpResult { args := m.Called(req) if args.Get(0) == nil { - return nil, args.Error(1) + return nil } - return args.Get(0).(chan *BumpResult), args.Error(1) + return args.Get(0).(chan *BumpResult) } // MockFeeFunction is a mock implementation of the FeeFunction interface. diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 2ea4183237e..3134747e343 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -823,21 +823,7 @@ func (s *UtxoSweeper) sweep(set InputSet) error { // Broadcast will return a read-only chan that we will listen to for // this publish result and future RBF attempt. - resp, err := s.cfg.Publisher.Broadcast(req) - if err != nil { - outpoints := make([]wire.OutPoint, len(set.Inputs())) - for i, inp := range set.Inputs() { - outpoints[i] = inp.OutPoint() - } - - log.Errorf("Initial broadcast failed: %v, inputs=\n%v", err, - inputTypeSummary(set.Inputs())) - - // TODO(yy): find out which input is causing the failure. - s.markInputsPublishFailed(outpoints) - - return err - } + resp := s.cfg.Publisher.Broadcast(req) // Successfully sent the broadcast attempt, we now handle the result by // subscribing to the result chan and listen for future updates about diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index c8d9fc510bf..20c1506dd62 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -715,13 +715,8 @@ func TestSweepPendingInputs(t *testing.T) { setNeedWallet, normalSet, }) - // Mock `Broadcast` to return an error. This should cause the - // `createSweepTx` inside `sweep` to fail. This is done so we can - // terminate the method early as we are only interested in testing the - // workflow in `sweepPendingInputs`. We don't need to test `sweep` here - // as it should be tested in its own unit test. - dummyErr := errors.New("dummy error") - publisher.On("Broadcast", mock.Anything).Return(nil, dummyErr).Twice() + // Mock `Broadcast` to return a result. + publisher.On("Broadcast", mock.Anything).Return(nil).Twice() // Call the method under test. s.sweepPendingInputs(pis) From f601893d140b9eb3d22b7372c82a6e3591c52987 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 30 Apr 2024 19:25:06 +0800 Subject: [PATCH 05/19] sweep: add method `handleBumpEventError` and fix `markInputFailed` Previously in `markInputFailed`, we'd remove all inputs under the same group via `removeExclusiveGroup`. This is wrong as when the current sweep fails for this input, it shouldn't affect other inputs. --- sweep/sweeper.go | 64 +++++++++++++-- sweep/sweeper_test.go | 176 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 234 insertions(+), 6 deletions(-) diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 3134747e343..a0f9828a390 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -1432,11 +1432,6 @@ func (s *UtxoSweeper) markInputFailed(pi *SweeperInput, err error) { pi.state = Failed - // Remove all other inputs in this exclusive group. - if pi.params.ExclusiveGroup != nil { - s.removeExclusiveGroup(*pi.params.ExclusiveGroup) - } - s.signalResult(pi, Result{Err: err}) } @@ -1698,6 +1693,60 @@ func (s *UtxoSweeper) handleBumpEventTxPublished(r *BumpResult) error { return nil } +// handleBumpEventError handles the case where there's an unexpected error when +// creating or publishing the sweeping tx. In this case, the tx will be removed +// from the sweeper store and the inputs will be marked as `Failed`. +func (s *UtxoSweeper) handleBumpEventError(r *BumpResult) error { + txid := r.Tx.TxHash() + log.Infof("Tx=%v failed with unexpected error: %v", txid, r.Err) + + // Remove the tx from the sweeper db if it exists. + if err := s.cfg.Store.DeleteTx(txid); err != nil { + return fmt.Errorf("delete tx record for %v: %w", txid, err) + } + + // Mark the inputs as failed. + s.markInputsFailed(r.Tx, r.Err) + + return nil +} + +// markInputsFailed marks all inputs found in the tx as failed. It will also +// notify all the subscribers of these inputs. +func (s *UtxoSweeper) markInputsFailed(tx *wire.MsgTx, err error) { + for _, txIn := range tx.TxIn { + outpoint := txIn.PreviousOutPoint + + input, ok := s.inputs[outpoint] + if !ok { + // It's very likely that a spending tx contains inputs + // that we don't know. + log.Tracef("Skipped marking input as failed: %v not "+ + "found in pending inputs", outpoint) + + continue + } + + // If the input is already in a terminal state, we don't want + // to rewrite it, which also indicates an error as we only get + // an error event during the initial broadcast. + if input.terminated() { + log.Errorf("Skipped marking input=%v as failed due to "+ + "unexpected state=%v", outpoint, input.state) + + continue + } + + input.state = Failed + + // Signal result channels. + s.signalResult(input, Result{ + Tx: tx, + Err: err, + }) + } +} + // handleBumpEvent handles the result sent from the bumper based on its event // type. // @@ -1721,8 +1770,11 @@ func (s *UtxoSweeper) handleBumpEvent(r *BumpResult) error { case TxReplaced: return s.handleBumpEventTxReplaced(r) + // There's an unexpected error in creating or publishing the tx, we + // will remove the tx from the sweeper db and mark the inputs as + // failed. case TxError: - // TODO(yy): create a method to remove this input. + return s.handleBumpEventError(r) } return nil diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 20c1506dd62..4169d88bf79 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -1102,3 +1102,179 @@ func TestMonitorFeeBumpResult(t *testing.T) { }) } } + +// TestMarkInputsFailed checks that given a list of inputs with different +// states, the method `markInputsFailed` correctly marks the inputs as failed. +func TestMarkInputsFailed(t *testing.T) { + t.Parallel() + + require := require.New(t) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{}) + + // Create a mock input. + mockInput := &input.MockInput{} + defer mockInput.AssertExpectations(t) + + // Mock the `OutPoint` to return a dummy outpoint. + mockInput.On("OutPoint").Return(wire.OutPoint{Hash: chainhash.Hash{1}}) + + // Create testing inputs for each state. + // + // inputNotExist specifies an input that's not found in the sweeper's + // `inputs` map. + inputNotExist := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 1}, + } + + // inputInit specifies a newly created input. When marking this as + // published, we should see an error log as this input hasn't been + // published yet. + inputInit := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 2}, + } + s.inputs[inputInit.PreviousOutPoint] = &SweeperInput{ + state: Init, + Input: mockInput, + } + + // inputPendingPublish specifies an input that's about to be published. + inputPendingPublish := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 3}, + } + s.inputs[inputPendingPublish.PreviousOutPoint] = &SweeperInput{ + state: PendingPublish, + Input: mockInput, + } + + // inputPublished specifies an input that's published. + inputPublished := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 4}, + } + s.inputs[inputPublished.PreviousOutPoint] = &SweeperInput{ + state: Published, + Input: mockInput, + } + + // inputPublishFailed specifies an input that's failed to be published. + inputPublishFailed := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 5}, + } + s.inputs[inputPublishFailed.PreviousOutPoint] = &SweeperInput{ + state: PublishFailed, + Input: mockInput, + } + + // inputSwept specifies an input that's swept. + inputSwept := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 6}, + } + s.inputs[inputSwept.PreviousOutPoint] = &SweeperInput{ + state: Swept, + Input: mockInput, + } + + // inputExcluded specifies an input that's excluded. + inputExcluded := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 7}, + } + s.inputs[inputExcluded.PreviousOutPoint] = &SweeperInput{ + state: Excluded, + Input: mockInput, + } + + // inputFailed specifies an input that's failed. + inputFailed := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 8}, + } + s.inputs[inputFailed.PreviousOutPoint] = &SweeperInput{ + state: Failed, + Input: mockInput, + } + + // Create a test tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + inputNotExist, + inputInit, + inputPendingPublish, + inputPublished, + inputPublishFailed, + inputSwept, + inputExcluded, + inputFailed, + }, + } + + // Mark the test inputs. We expect the non-exist input and + // inputSwept/inputExcluded/inputFailed to be skipped. + s.markInputsFailed(tx, errDummy) + + // We expect unchanged number of pending inputs. + require.Len(s.inputs, 7) + + // We expect the init input's to be marked as failed. + require.Equal(Failed, s.inputs[inputInit.PreviousOutPoint].state) + + // We expect the pending-publish input to be marked as failed. + require.Equal(Failed, + s.inputs[inputPendingPublish.PreviousOutPoint].state) + + // We expect the published input to be marked as failed. + require.Equal(Failed, s.inputs[inputPublished.PreviousOutPoint].state) + + // We expect the publish failed input to be markd as failed. + require.Equal(Failed, + s.inputs[inputPublishFailed.PreviousOutPoint].state) + + // We expect the swept input to stay unchanged. + require.Equal(Swept, s.inputs[inputSwept.PreviousOutPoint].state) + + // We expect the excluded input to stay unchanged. + require.Equal(Excluded, s.inputs[inputExcluded.PreviousOutPoint].state) + + // We expect the failed input to stay unchanged. + require.Equal(Failed, s.inputs[inputFailed.PreviousOutPoint].state) +} + +// TestHandleBumpEventError checks that `handleBumpEventError` correctly +// handles a `TxError` event. +func TestHandleBumpEventError(t *testing.T) { + t.Parallel() + + rt := require.New(t) + + // Create a mock store. + store := &MockSweeperStore{} + defer store.AssertExpectations(t) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Store: store, + }) + + // Create a testing tx. + // + // We are not testing `markInputFailed` here, so the actual tx doesn't + // matter. + tx := &wire.MsgTx{} + result := &BumpResult{ + Tx: tx, + Err: errDummy, + } + + // Mock the store to return an error. + store.On("DeleteTx", mock.Anything).Return(errDummy).Once() + + // Call the method under test and assert the error is returned. + err := s.handleBumpEventError(result) + rt.ErrorIs(err, errDummy) + + // Mock the store to return nil. + store.On("DeleteTx", mock.Anything).Return(nil).Once() + + // Call the method under test and assert no error is returned. + err = s.handleBumpEventError(result) + rt.NoError(err) +} From a7addc3aa9ece079d6db89c27a404f5f5419ae01 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 30 Apr 2024 19:28:01 +0800 Subject: [PATCH 06/19] sweep: add method `isMature` on `SweeperInput` Also updated `handlePendingSweepsReq` to skip immature inputs so the returned results are the same as those in pre-0.18.0. --- sweep/sweeper.go | 47 ++++++++++++++++++++++++++++++------------- sweep/sweeper_test.go | 2 ++ 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/sweep/sweeper.go b/sweep/sweeper.go index a0f9828a390..e8efeb37b40 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -221,6 +221,30 @@ func (p *SweeperInput) terminated() bool { } } +// isMature returns a boolean indicating whether the input has a timelock that +// has been reached or not. The locktime found is also returned. +func (p *SweeperInput) isMature(currentHeight uint32) (bool, uint32) { + locktime, _ := p.RequiredLockTime() + if currentHeight < locktime { + log.Debugf("Input %v has locktime=%v, current height is %v", + p.OutPoint(), locktime, currentHeight) + + return false, locktime + } + + // If the input has a CSV that's not yet reached, we will skip + // this input and wait for the expiry. + locktime = p.BlocksToMaturity() + p.HeightHint() + if currentHeight+1 < locktime { + log.Debugf("Input %v has CSV expiry=%v, current height is %v", + p.OutPoint(), locktime, currentHeight) + + return false, locktime + } + + return true, locktime +} + // InputsMap is a type alias for a set of pending inputs. type InputsMap = map[wire.OutPoint]*SweeperInput @@ -1026,6 +1050,12 @@ func (s *UtxoSweeper) handlePendingSweepsReq( resps := make(map[wire.OutPoint]*PendingInputResponse, len(s.inputs)) for _, inp := range s.inputs { + // Skip immature inputs for compatibility. + mature, _ := inp.isMature(uint32(s.currentHeight)) + if !mature { + continue + } + // Only the exported fields are set, as we expect the response // to only be consumed externally. op := inp.OutPoint() @@ -1476,20 +1506,9 @@ func (s *UtxoSweeper) updateSweeperInputs() InputsMap { // If the input has a locktime that's not yet reached, we will // skip this input and wait for the locktime to be reached. - locktime, _ := input.RequiredLockTime() - if uint32(s.currentHeight) < locktime { - log.Warnf("Skipping input %v due to locktime=%v not "+ - "reached, current height is %v", op, locktime, - s.currentHeight) - - continue - } - - // If the input has a CSV that's not yet reached, we will skip - // this input and wait for the expiry. - locktime = input.BlocksToMaturity() + input.HeightHint() - if s.currentHeight < int32(locktime)-1 { - log.Infof("Skipping input %v due to CSV expiry=%v not "+ + mature, locktime := input.isMature(uint32(s.currentHeight)) + if !mature { + log.Infof("Skipping input %v due to locktime=%v not "+ "reached, current height is %v", op, locktime, s.currentHeight) diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 4169d88bf79..8ea25c073ed 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -490,6 +490,7 @@ func TestUpdateSweeperInputs(t *testing.T) { // returned. inp2.On("RequiredLockTime").Return( uint32(s.currentHeight+1), true).Once() + inp2.On("OutPoint").Return(wire.OutPoint{Index: 2}).Maybe() input7 := &SweeperInput{state: Init, Input: inp2} // Mock the input to have a CSV expiry in the future so it will NOT be @@ -498,6 +499,7 @@ func TestUpdateSweeperInputs(t *testing.T) { uint32(s.currentHeight), false).Once() inp3.On("BlocksToMaturity").Return(uint32(2)).Once() inp3.On("HeightHint").Return(uint32(s.currentHeight)).Once() + inp3.On("OutPoint").Return(wire.OutPoint{Index: 3}).Maybe() input8 := &SweeperInput{state: Init, Input: inp3} // Add the inputs to the sweeper. After the update, we should see the From 01a576a909dacd5d7c88bcfbc4d617784636cc58 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 1 May 2024 02:58:27 +0800 Subject: [PATCH 07/19] sweep: make sure defaultDeadline is derived from the mature height --- sweep/sweeper.go | 55 ++++++++++++++++++++++++++++++++++--------- sweep/sweeper_test.go | 2 +- 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/sweep/sweeper.go b/sweep/sweeper.go index e8efeb37b40..e0206b83fdd 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -525,7 +525,7 @@ func (s *UtxoSweeper) SweepInput(inp input.Input, } absoluteTimeLock, _ := inp.RequiredLockTime() - log.Infof("Sweep request received: out_point=%v, witness_type=%v, "+ + log.Debugf("Sweep request received: out_point=%v, witness_type=%v, "+ "relative_time_lock=%v, absolute_time_lock=%v, amount=%v, "+ "parent=(%v), params=(%v)", inp.OutPoint(), inp.WitnessType(), inp.BlocksToMaturity(), absoluteTimeLock, @@ -725,7 +725,17 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { inputs := s.updateSweeperInputs() log.Debugf("Received new block: height=%v, attempt "+ - "sweeping %d inputs", epoch.Height, len(inputs)) + "sweeping %d inputs:\n%s", epoch.Height, + len(inputs), newLogClosure(func() string { + inps := make( + []input.Input, 0, len(inputs), + ) + for _, in := range inputs { + inps = append(inps, in) + } + + return inputTypeSummary(inps) + })()) // Attempt to sweep any pending inputs. s.sweepPendingInputs(inputs) @@ -1191,13 +1201,29 @@ func (s *UtxoSweeper) mempoolLookup(op wire.OutPoint) fn.Option[wire.MsgTx] { return s.cfg.Mempool.LookupInputMempoolSpend(op) } -// handleNewInput processes a new input by registering spend notification and -// scheduling sweeping for it. -func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) error { +// calculateDefaultDeadline calculates the default deadline height for a sweep +// request that has no deadline height specified. +func (s *UtxoSweeper) calculateDefaultDeadline(pi *SweeperInput) int32 { // Create a default deadline height, which will be used when there's no // DeadlineHeight specified for a given input. defaultDeadline := s.currentHeight + int32(s.cfg.NoDeadlineConfTarget) + // If the input is immature and has a locktime, we'll use the locktime + // height as the starting height. + matured, locktime := pi.isMature(uint32(s.currentHeight)) + if !matured { + defaultDeadline = int32(locktime + s.cfg.NoDeadlineConfTarget) + log.Debugf("Input %v is immature, using locktime=%v instead "+ + "of current height=%d", pi.OutPoint(), locktime, + s.currentHeight) + } + + return defaultDeadline +} + +// handleNewInput processes a new input by registering spend notification and +// scheduling sweeping for it. +func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) error { outpoint := input.input.OutPoint() pi, pending := s.inputs[outpoint] if pending { @@ -1222,15 +1248,22 @@ func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) error { Input: input.input, params: input.params, rbf: rbfInfo, - // Set the acutal deadline height. - DeadlineHeight: input.params.DeadlineHeight.UnwrapOr( - defaultDeadline, - ), } + // Set the acutal deadline height. + pi.DeadlineHeight = input.params.DeadlineHeight.UnwrapOr( + s.calculateDefaultDeadline(pi), + ) + s.inputs[outpoint] = pi log.Tracef("input %v, state=%v, added to inputs", outpoint, pi.state) + log.Infof("Registered sweep request at block %d: out_point=%v, "+ + "witness_type=%v, amount=%v, deadline=%d, params=(%v)", + s.currentHeight, pi.OutPoint(), pi.WitnessType(), + btcutil.Amount(pi.SignDesc().Output.Value), pi.DeadlineHeight, + pi.params) + // Start watching for spend of this input, either by us or the remote // party. cancel, err := s.monitorSpend( @@ -1626,7 +1659,7 @@ func (s *UtxoSweeper) monitorFeeBumpResult(resultChan <-chan *BumpResult) { func (s *UtxoSweeper) handleBumpEventTxFailed(r *BumpResult) error { tx, err := r.Tx, r.Err - log.Errorf("Fee bump attempt failed for tx=%v: %v", tx.TxHash(), err) + log.Warnf("Fee bump attempt failed for tx=%v: %v", tx.TxHash(), err) outpoints := make([]wire.OutPoint, 0, len(tx.TxIn)) for _, inp := range tx.TxIn { @@ -1636,7 +1669,7 @@ func (s *UtxoSweeper) handleBumpEventTxFailed(r *BumpResult) error { // TODO(yy): should we also remove the failed tx from db? s.markInputsPublishFailed(outpoints) - return err + return nil } // handleBumpEventTxReplaced handles the case where the sweeping tx has been diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 8ea25c073ed..39e3ff895c2 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -775,7 +775,7 @@ func TestHandleBumpEventTxFailed(t *testing.T) { // Call the method under test. err := s.handleBumpEvent(br) - require.ErrorIs(t, err, errDummy) + require.NoError(t, err) // Assert the states of the first two inputs are updated. require.Equal(t, PublishFailed, s.inputs[op1].state) From 845867d36bf1b7acf6cad35d2b82e16a6173794d Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Fri, 26 Apr 2024 17:43:13 +0800 Subject: [PATCH 08/19] chainio: add `blockbeat` to handle block synchronization among subsystems In this commit, a minimal version of `BlockBeat` is added to synchronize block heights, which will be used in `ChainArb`, `Sweeper`, and `TxPublisher` so blocks are processed sequentially among them. --- chainio/blockbeat.go | 262 ++++++++++++++++++++++++++++++++++++++ chainio/blockbeat_test.go | 1 + chainio/log.go | 29 +++++ 3 files changed, 292 insertions(+) create mode 100644 chainio/blockbeat.go create mode 100644 chainio/blockbeat_test.go create mode 100644 chainio/log.go diff --git a/chainio/blockbeat.go b/chainio/blockbeat.go new file mode 100644 index 00000000000..07087445f5d --- /dev/null +++ b/chainio/blockbeat.go @@ -0,0 +1,262 @@ +package chainio + +import ( + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/fn" +) + +// DefaultProcessBlockTimeout is the timeout value used when waiting for one +// consumer to finish processing the new block epoch. +var DefaultProcessBlockTimeout = 30 * time.Second + +// Consumer defines a blockbeat consumer interface. Subsystems that need block +// info should implement it. +type Consumer interface { + // Name returns a human-readable string for this subsystem. + Name() string + + // ProcessBlock takes a beat and processes it. A receive-only error + // chan must be returned. + // + // NOTE: When implementing this, it's very important to send back the + // error or nil to the channel immediately, otherwise BlockBeat will + // timeout and lnd will shutdown. + ProcessBlock(b Beat) <-chan error +} + +// Beat contains the block epoch and a buffer error chan. +// +// TODO(yy): extend this to check for confirmation status - which serves as the +// single source of truth, to avoid the potential race between receiving blocks +// and `GetTransactionDetails/RegisterSpendNtfn/RegisterConfirmationsNtfn`. +type Beat struct { + // Epoch is the current block epoch the blockbeat is aware of. + Epoch chainntnfs.BlockEpoch + + // Err is a buffered chan that receives an error or nil from + // ProcessBlock. + Err chan error +} + +// NewBeat creates a new beat with the specified block epoch and a buffered +// error chan. +func NewBeat(epoch chainntnfs.BlockEpoch) Beat { + return Beat{ + Epoch: epoch, + Err: make(chan error, 1), + } +} + +// BlockBeat is a service that handles dispatching new blocks to `lnd`'s +// subsystems. During startup, subsystems that are block-driven should +// implement the `Consumer` interface and register themselves via +// `RegisterQueue`. When two subsystems are independent of each other, they +// should be registered in differet queues so blocks are notified concurrently. +// Otherwise, when living in the same queue, the subsystems are notified of the +// new blocks sequentially, which means it's critical to understand the +// relationship of these systems to properly handle the order. +type BlockBeat struct { + wg sync.WaitGroup + + // notifier is used to receive new block epochs. + notifier chainntnfs.ChainNotifier + + // blockEpoch is the latest block epoch received . + blockEpoch chainntnfs.BlockEpoch + + // consumerQueues is a map of consumers that will receive blocks. Each + // queue is notified concurrently, and consumers in the same queue is + // notified sequentially. + consumerQueues map[uint32][]Consumer + + // counter is used to assign a unique id to each queue. + counter atomic.Uint32 + + // quit is used to signal the BlockBeat to stop. + quit chan struct{} +} + +// NewBlockBeat returns a new blockbeat instance. +func NewBlockBeat(notifier chainntnfs.ChainNotifier) *BlockBeat { + return &BlockBeat{ + notifier: notifier, + quit: make(chan struct{}), + consumerQueues: make(map[uint32][]Consumer), + } +} + +// RegisterQueue takes a list of consumers and register them in the same queue. +// +// NOTE: these consumers are notified sequentially. +func (b *BlockBeat) RegisterQueue(consumers []Consumer) { + qid := b.counter.Add(1) + + b.consumerQueues[qid] = append(b.consumerQueues[qid], consumers...) + log.Infof("Registered queue=%d with %d blockbeat consumers", qid, + len(consumers)) + + for _, c := range consumers { + log.Debugf("Consumer [%s] registered in queue %d", c.Name(), + qid) + } +} + +// Start starts the blockbeat - it registers a block notification and monitors +// and dispatches new blocks in a goroutine. It will refuse to start if there +// are no registered consumers. +func (b *BlockBeat) Start() error { + // Make sure consumers are registered. + if len(b.consumerQueues) == 0 { + return fmt.Errorf("no consumers registered") + } + + // Start listening to new block epochs. + blockEpochs, err := b.notifier.RegisterBlockEpochNtfn(nil) + if err != nil { + return fmt.Errorf("register block epoch ntfn: %w", err) + } + + log.Infof("BlockBeat is starting with %d consumer queues", + len(b.consumerQueues)) + defer log.Debug("BlockBeat started") + + b.wg.Add(1) + go b.dispatchBlocks(blockEpochs) + + return nil +} + +// Stop shuts down the blockbeat. +func (b *BlockBeat) Stop() { + log.Info("BlockBeat is stopping") + defer log.Debug("BlockBeat stopped") + + // Signal the dispatchBlocks goroutine to stop. + close(b.quit) + b.wg.Wait() +} + +// dispatchBlocks listens to new block epoch and dispatches it to all the +// consumers. Each queue in BlockBeat is notified concurrently, and the +// consumers in the same queue are notified sequentially. +func (b *BlockBeat) dispatchBlocks(blockEpochs *chainntnfs.BlockEpochEvent) { + defer b.wg.Done() + defer blockEpochs.Cancel() + + for { + select { + case blockEpoch, ok := <-blockEpochs.Epochs: + if !ok { + log.Debugf("Block epoch channel closed") + return + } + + log.Infof("Received new block %v at height %d, "+ + "notifying consumers...", blockEpoch.Hash, + blockEpoch.Height) + + // Update the current block epoch. + b.blockEpoch = *blockEpoch + + // Notify all consumers. + b.notifyQueues() + + log.Infof("Notified all consumers on block %v at "+ + "height %d", blockEpoch.Hash, blockEpoch.Height) + + case <-b.quit: + log.Debugf("BlockBeat quit signal received") + return + } + } +} + +// notifyQueues notifies each queue concurrently about the latest block epoch. +func (b *BlockBeat) notifyQueues() { + // errChans is a map of channels that will be used to receive errors + // returned from notifying the consumers. + errChans := make(map[uint32]chan error, len(b.consumerQueues)) + + // Notify each queue in goroutines. + for qid, consumers := range b.consumerQueues { + log.Debugf("Notifying queue=%d on block %d", qid, + b.blockEpoch.Height) + + // Create a signal chan. + errChan := make(chan error) + errChans[qid] = errChan + + // Notify each queue concurrently. + b.wg.Add(1) + go func(qid uint32, c []Consumer, + epoch chainntnfs.BlockEpoch) { + + defer b.wg.Done() + + // Notify each consumer in this queue sequentially. + errChan <- b.notifyQueue(c, epoch) + }(qid, consumers, b.blockEpoch) + } + + // Wait for all consumers in each queue to finish. + for qid, errChan := range errChans { + select { + case err := <-errChan: + // It's critical that the subsystems can process blocks + // correctly and timely, if an error returns, we'd + // gracefully shutdown lnd to bring attentions. + if err != nil { + log.Criticalf("Queue=%d failed to process "+ + "block: %v", qid, err) + + return + } + + log.Debugf("Notified queue=%d on block %d", qid, + b.blockEpoch.Height) + + case <-b.quit: + } + } +} + +// notifyQueue takes a list of consumers and notify them about the new epoch +// sequentially. +func (b *BlockBeat) notifyQueue(queue []Consumer, + epoch chainntnfs.BlockEpoch) error { + + for _, c := range queue { + log.Debugf("Notifying consumer [%s] on block %d", c.Name(), + epoch.Height) + + // Construct a new beat with a buffered error chan. + beat := NewBeat(epoch) + + // Record the time it takes the consumer to process this block. + start := time.Now() + + // We expect the consumer to finish processing this block under + // 30s, otherwise a timeout error is returned. + err, timeout := fn.RecvOrTimeout( + c.ProcessBlock(beat), DefaultProcessBlockTimeout, + ) + if err != nil { + return fmt.Errorf("%s: ProcessBlock got: %w", c.Name(), + err) + } + if timeout != nil { + return fmt.Errorf("%s timed out while processing block", + c.Name()) + } + + log.Debugf("Consumer [%s] processed block %d in %v", c.Name(), + epoch.Height, time.Since(start)) + } + + return nil +} diff --git a/chainio/blockbeat_test.go b/chainio/blockbeat_test.go new file mode 100644 index 00000000000..8d034eda785 --- /dev/null +++ b/chainio/blockbeat_test.go @@ -0,0 +1 @@ +package chainio diff --git a/chainio/log.go b/chainio/log.go new file mode 100644 index 00000000000..fb562abdeb8 --- /dev/null +++ b/chainio/log.go @@ -0,0 +1,29 @@ +package chainio + +import ( + "github.com/btcsuite/btclog" + "github.com/lightningnetwork/lnd/build" +) + +// log is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var log btclog.Logger + +// The default amount of logging is none. +func init() { + UseLogger(build.NewSubLogger("CHIO", nil)) +} + +// DisableLog disables all library log output. Logging output is disabled +// by default until UseLogger is called. +func DisableLog() { + UseLogger(btclog.Disabled) +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} From 7090c9e77f225cb05510ad158c6a378fbd2dfb0e Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Fri, 26 Apr 2024 18:16:49 +0800 Subject: [PATCH 09/19] sweep+contractcourt: update `UtxoSweeper`, `TxPublisher` and `ChainArb` to use blockbeat In this commit, we replace the individual block subscription with the implementation of the interface method `ProcessBlock` so they share a single block notifier. --- contractcourt/chain_arbitrator.go | 47 ++++++++++++++++++------- sweep/fee_bumper.go | 42 ++++++++++++++++++----- sweep/sweeper.go | 57 ++++++++++++++++++------------- 3 files changed, 101 insertions(+), 45 deletions(-) diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 0cc4b111a5a..2ec894f51bf 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -11,6 +11,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/walletdb" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" @@ -250,6 +251,11 @@ type ChainArbitrator struct { // active channels that it must still watch over. chanSource *channeldb.DB + // blockBeatChan is a channel to receive blocks from BlockBeat. The + // received block contains the best known height and the transactions + // confirmed in this block. + blockBeatChan chan chainio.Beat + quit chan struct{} wg sync.WaitGroup @@ -266,6 +272,7 @@ func NewChainArbitrator(cfg ChainArbitratorConfig, activeWatchers: make(map[wire.OutPoint]*chainWatcher), chanSource: db, quit: make(chan struct{}), + blockBeatChan: make(chan chainio.Beat), } } @@ -745,18 +752,11 @@ func (c *ChainArbitrator) Start() error { } } - // Subscribe to a single stream of block epoch notifications that we - // will dispatch to all active arbitrators. - blockEpoch, err := c.cfg.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return err - } - // Start our goroutine which will dispatch blocks to each arbitrator. c.wg.Add(1) go func() { defer c.wg.Done() - c.dispatchBlocks(blockEpoch) + c.dispatchBlocks() }() // TODO(roasbeef): eventually move all breach watching here @@ -781,8 +781,7 @@ type blockRecipient struct { // dispatchBlocks consumes a block epoch notification stream and dispatches // blocks to each of the chain arb's active channel arbitrators. This function // must be run in a goroutine. -func (c *ChainArbitrator) dispatchBlocks( - blockEpoch *chainntnfs.BlockEpochEvent) { +func (c *ChainArbitrator) dispatchBlocks() { // getRecipients is a helper function which acquires the chain arb // lock and returns a set of block recipients which can be used to @@ -805,8 +804,6 @@ func (c *ChainArbitrator) dispatchBlocks( // On exit, cancel our blocks subscription and close each block channel // so that the arbitrators know they will no longer be receiving blocks. defer func() { - blockEpoch.Cancel() - recipients := getRecipients() for _, recipient := range recipients { close(recipient.blocks) @@ -818,13 +815,15 @@ func (c *ChainArbitrator) dispatchBlocks( select { // Consume block epochs, exiting if our subscription is // terminated. - case block, ok := <-blockEpoch.Epochs: + case beat, ok := <-c.blockBeatChan: if !ok { log.Trace("dispatchBlocks block epoch " + "cancelled") return } + block := beat.Epoch + // Get the set of currently active channels block // subscription channels and dispatch the block to // each. @@ -853,6 +852,9 @@ func (c *ChainArbitrator) dispatchBlocks( } } + // Notify we've processed the block. + fn.SendOrQuit(beat.Err, nil, c.quit) + // Exit if the chain arbitrator is shutting down. case <-c.quit: return @@ -1320,3 +1322,22 @@ func (c *ChainArbitrator) FindOutgoingHTLCDeadline(scid lnwire.ShortChannelID, // TODO(roasbeef): arbitration reports // * types: contested, waiting for success conf, etc + +// NOTE: part of the `chainio.Consumer` interface. +func (c *ChainArbitrator) ProcessBlock(beat chainio.Beat) <-chan error { + select { + case c.blockBeatChan <- beat: + log.Debugf("Received block beat for height=%d", + beat.Epoch.Height) + + case <-c.quit: + return nil + } + + return beat.Err +} + +// NOTE: part of the `chainio.Consumer` interface. +func (c *ChainArbitrator) Name() string { + return "chain arbitrator" +} diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index bd336246286..6bce81b3b61 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/chain" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" @@ -297,6 +298,11 @@ type TxPublisher struct { // the chan that the publisher sends the fee bump result to. subscriberChans lnutils.SyncMap[uint64, chan *BumpResult] + // blockBeatChan is a channel to receive blocks from BlockBeat. The + // received block contains the best known height and the transactions + // confirmed in this block. + blockBeatChan chan chainio.Beat + // quit is used to signal the publisher to stop. quit chan struct{} } @@ -311,6 +317,7 @@ func NewTxPublisher(cfg TxPublisherConfig) *TxPublisher { records: lnutils.SyncMap[uint64, *monitorRecord]{}, subscriberChans: lnutils.SyncMap[uint64, chan *BumpResult]{}, quit: make(chan struct{}), + blockBeatChan: make(chan chainio.Beat), } } @@ -346,6 +353,25 @@ func (t *TxPublisher) Broadcast(req *BumpRequest) <-chan *BumpResult { return subscriber } +// NOTE: part of the `chainio.Consumer` interface. +func (t *TxPublisher) ProcessBlock(beat chainio.Beat) <-chan error { + select { + case t.blockBeatChan <- beat: + log.Tracef("TxPublisher received block beat for height=%d", + beat.Epoch.Height) + + case <-t.quit: + return nil + } + + return beat.Err +} + +// NOTE: part of the `chainio.Consumer` interface. +func (t *TxPublisher) Name() string { + return "tx publisher" +} + // initialBroadcast initializes a fee function, creates an RBF-compliant tx and // broadcasts it. func (t *TxPublisher) initialBroadcast(requestID uint64, @@ -677,13 +703,8 @@ func (t *TxPublisher) Start() error { log.Info("TxPublisher starting...") defer log.Debugf("TxPublisher started") - blockEvent, err := t.cfg.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return fmt.Errorf("register block epoch ntfn: %w", err) - } - t.wg.Add(1) - go t.monitor(blockEvent) + go t.monitor() return nil } @@ -703,13 +724,12 @@ func (t *TxPublisher) Stop() { // to be bumped. If so, it will attempt to bump the fee of the tx. // // NOTE: Must be run as a goroutine. -func (t *TxPublisher) monitor(blockEvent *chainntnfs.BlockEpochEvent) { - defer blockEvent.Cancel() +func (t *TxPublisher) monitor() { defer t.wg.Done() for { select { - case epoch, ok := <-blockEvent.Epochs: + case beat, ok := <-t.blockBeatChan: if !ok { // We should stop the publisher before stopping // the chain service. Otherwise it indicates an @@ -720,6 +740,7 @@ func (t *TxPublisher) monitor(blockEvent *chainntnfs.BlockEpochEvent) { return } + epoch := beat.Epoch log.Debugf("TxPublisher received new block: %v", epoch.Height) @@ -730,6 +751,9 @@ func (t *TxPublisher) monitor(blockEvent *chainntnfs.BlockEpochEvent) { // to be bumped. t.processRecords() + // Notify we've processed the block. + fn.SendOrQuit(beat.Err, nil, t.quit) + case <-t.quit: log.Debug("Fee bumper stopped, exit monitor") return diff --git a/sweep/sweeper.go b/sweep/sweeper.go index e0206b83fdd..6fd0b7d1c67 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" @@ -335,6 +336,11 @@ type UtxoSweeper struct { // bumpResultChan is a channel that receives broadcast results from the // TxPublisher. bumpResultChan chan *BumpResult + + // blockBeatChan is a channel to receive blocks from BlockBeat. The + // received block contains the best known height and the transactions + // confirmed in this block. + blockBeatChan chan chainio.Beat } // UtxoSweeperConfig contains dependencies of UtxoSweeper. @@ -419,6 +425,7 @@ func New(cfg *UtxoSweeperConfig) *UtxoSweeper { quit: make(chan struct{}), inputs: make(InputsMap), bumpResultChan: make(chan *BumpResult, 100), + blockBeatChan: make(chan chainio.Beat), } } @@ -434,21 +441,12 @@ func (s *UtxoSweeper) Start() error { // not change from here on. s.relayFeeRate = s.cfg.FeeEstimator.RelayFeePerKW() - // We need to register for block epochs and retry sweeping every block. - // We should get a notification with the current best block immediately - // if we don't provide any epoch. We'll wait for that in the collector. - blockEpochs, err := s.cfg.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return fmt.Errorf("register block epoch ntfn: %w", err) - } - // Start sweeper main loop. s.wg.Add(1) go func() { - defer blockEpochs.Cancel() defer s.wg.Done() - s.collector(blockEpochs.Epochs) + s.collector() // The collector exited and won't longer handle incoming // requests. This can happen on shutdown, when the block @@ -503,6 +501,25 @@ func (s *UtxoSweeper) Stop() error { return nil } +// NOTE: part of the `chainio.Consumer` interface. +func (s *UtxoSweeper) ProcessBlock(beat chainio.Beat) <-chan error { + select { + case s.blockBeatChan <- beat: + log.Debugf("Received block beat for height=%d", + beat.Epoch.Height) + + case <-s.quit: + return nil + } + + return beat.Err +} + +// NOTE: part of the `chainio.Consumer` interface. +func (s *UtxoSweeper) Name() string { + return "sweeper" +} + // SweepInput sweeps inputs back into the wallet. The inputs will be batched and // swept after the batch time window ends. A custom fee preference can be // provided to determine what fee rate should be used for the input. Note that @@ -634,18 +651,7 @@ func (s *UtxoSweeper) removeConflictSweepDescendants( // collector is the sweeper main loop. It processes new inputs, spend // notifications and counts down to publication of the sweep tx. -func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { - // We registered for the block epochs with a nil request. The notifier - // should send us the current best block immediately. So we need to wait - // for it here because we need to know the current best height. - select { - case bestBlock := <-blockEpochs: - s.currentHeight = bestBlock.Height - - case <-s.quit: - return - } - +func (s *UtxoSweeper) collector() { for { // Clean inputs, which will remove inputs that are swept, // failed, or excluded from the sweeper and return inputs that @@ -708,7 +714,7 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { // A new block comes in, update the bestHeight, perform a check // over all pending inputs and publish sweeping txns if needed. - case epoch, ok := <-blockEpochs: + case beat, ok := <-s.blockBeatChan: if !ok { // We should stop the sweeper before stopping // the chain service. Otherwise it indicates an @@ -718,6 +724,8 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { return } + epoch := beat.Epoch + // Update the sweeper to the best height. s.currentHeight = epoch.Height @@ -740,6 +748,9 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { // Attempt to sweep any pending inputs. s.sweepPendingInputs(inputs) + // Notify we've processed the block. + fn.SendOrQuit(beat.Err, nil, s.quit) + case <-s.quit: return } From faac0054666d421e876713c6657318bb7d3df78d Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Mon, 29 Apr 2024 20:31:14 +0800 Subject: [PATCH 10/19] contractcourt: use blockbeat in `ChannelArbitrator` This commit refactors the block dispatching logic in `ChainArbitrator` so the blocks are sent concurrently to all active channel arbitrators. It also makes sure the blockbeat is now sent to channel arbitrators. --- contractcourt/chain_arbitrator.go | 152 ++++++++++++----------- contractcourt/channel_arbitrator.go | 77 ++++++++---- contractcourt/channel_arbitrator_test.go | 36 +++++- 3 files changed, 166 insertions(+), 99 deletions(-) diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 2ec894f51bf..c4951aa6478 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -764,52 +764,10 @@ func (c *ChainArbitrator) Start() error { return nil } -// blockRecipient contains the information we need to dispatch a block to a -// channel arbitrator. -type blockRecipient struct { - // chanPoint is the funding outpoint of the channel. - chanPoint wire.OutPoint - - // blocks is the channel that new block heights are sent into. This - // channel should be sufficiently buffered as to not block the sender. - blocks chan<- int32 - - // quit is closed if the receiving entity is shutting down. - quit chan struct{} -} - // dispatchBlocks consumes a block epoch notification stream and dispatches // blocks to each of the chain arb's active channel arbitrators. This function // must be run in a goroutine. func (c *ChainArbitrator) dispatchBlocks() { - - // getRecipients is a helper function which acquires the chain arb - // lock and returns a set of block recipients which can be used to - // dispatch blocks. - getRecipients := func() []blockRecipient { - c.Lock() - blocks := make([]blockRecipient, 0, len(c.activeChannels)) - for _, channel := range c.activeChannels { - blocks = append(blocks, blockRecipient{ - chanPoint: channel.cfg.ChanPoint, - blocks: channel.blocks, - quit: channel.quit, - }) - } - c.Unlock() - - return blocks - } - - // On exit, cancel our blocks subscription and close each block channel - // so that the arbitrators know they will no longer be receiving blocks. - defer func() { - recipients := getRecipients() - for _, recipient := range recipients { - close(recipient.blocks) - } - }() - // Consume block epochs until we receive the instruction to shutdown. for { select { @@ -822,37 +780,11 @@ func (c *ChainArbitrator) dispatchBlocks() { return } - block := beat.Epoch - - // Get the set of currently active channels block - // subscription channels and dispatch the block to - // each. - for _, recipient := range getRecipients() { - select { - // Deliver the block to the arbitrator. - case recipient.blocks <- block.Height: - - // If the recipient is shutting down, exit - // without delivering the block. This may be - // the case when two blocks are mined in quick - // succession, and the arbitrator resolves - // after the first block, and does not need to - // consume the second block. - case <-recipient.quit: - log.Debugf("channel: %v exit without "+ - "receiving block: %v", - recipient.chanPoint, - block.Height) - - // If the chain arb is shutting down, we don't - // need to deliver any more blocks (everything - // will be shutting down). - case <-c.quit: - return - } - } + // Send this blockbeat to all the active channels and + // wait for them to finish processing it. + c.sendBlockAndWait(beat) - // Notify we've processed the block. + // Notify the chain arbitrator has processed the block. fn.SendOrQuit(beat.Err, nil, c.quit) // Exit if the chain arbitrator is shutting down. @@ -862,6 +794,82 @@ func (c *ChainArbitrator) dispatchBlocks() { } } +// sendBlockAndWait sends the blockbeat to all active channel arbitrator in +// parallel and wait for them to finish processing it. +func (c *ChainArbitrator) sendBlockAndWait(beat chainio.Beat) { + // Read the active channels in a lock. + c.Lock() + + // Create a map to record active channel arbitrator. + channels := make( + map[wire.OutPoint]*ChannelArbitrator, len(c.activeChannels), + ) + + // Create a map of go chans to store the done signals. + doneChans := make( + map[wire.OutPoint]chan struct{}, len(c.activeChannels), + ) + + // Copy the active channels to the map. + for op, channel := range c.activeChannels { + channels[op] = channel + doneChans[op] = make(chan struct{}) + } + + c.Unlock() + + // Iterate all the copied channels and send the blockbeat to them. + for _, channel := range channels { + beat := chainio.NewBeat(beat.Epoch) + + // Deliver the block to the channel arbitrator. + go func(ch *ChannelArbitrator, beat chainio.Beat) { + // Send the block to the arbitrator. + c.waitForChanArbProcessBlock(ch, beat) + + // Signal that the arbitrator has finished processing + // the block. + close(doneChans[ch.cfg.ChanPoint]) + }(channel, beat) + } + + // Wait for all channel arbitrators to process the block. + for op, doneChan := range doneChans { + select { + case <-doneChan: + log.Debugf("ChannelArbitrator(%v): processed block %d", + op, beat.Epoch.Height) + + case <-c.quit: + return + } + } +} + +// waitForChanArbProcessBlock waits for the channel arbitrator to process the +// block. It will log an error if there's an error returned from the channel +// arbitrator or the processing timed out. +func (c *ChainArbitrator) waitForChanArbProcessBlock(chanArb *ChannelArbitrator, + beat chainio.Beat) { + + log.Debugf("Sending block=%d to ChannelArbitrator(%v)", + beat.Epoch.Height, chanArb.cfg.ChanPoint) + + // We expect the channel arbitrator to finish processing this block + // under 30s, otherwise a timeout error is returned. + err, timeout := fn.RecvOrTimeout( + chanArb.processBlock(beat), chainio.DefaultProcessBlockTimeout, + ) + if err != nil { + log.Errorf("ChannelArbitrator(%v): process block got: %v", + chanArb.cfg.ChanPoint, err) + } + if timeout != nil { + log.Errorf("ChannelArbitrator(%v): process block timeout: %v", + chanArb.cfg.ChanPoint, err) + } +} + // republishClosingTxs will load any stored cooperative or unilateral closing // transactions and republish them. This helps ensure propagation of the // transactions in the event that prior publications failed. diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index cda0d4e1f63..dca1f428ecc 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -15,6 +15,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" @@ -346,10 +347,10 @@ type ChannelArbitrator struct { // to do its duty. cfg ChannelArbitratorConfig - // blocks is a channel that the arbitrator will receive new blocks on. - // This channel should be buffered by so that it does not block the - // sender. - blocks chan int32 + // blockBeatChan is a channel to receive blocks from BlockBeat. The + // received block contains the best known height and the transactions + // confirmed in this block. + blockBeatChan chan chainio.Beat // signalUpdates is a channel that any new live signals for the channel // we're watching over will be sent. @@ -399,8 +400,10 @@ func NewChannelArbitrator(cfg ChannelArbitratorConfig, } return &ChannelArbitrator{ - log: log, - blocks: make(chan int32, arbitratorBlockBufferSize), + log: log, + blockBeatChan: make( + chan chainio.Beat, arbitratorBlockBufferSize, + ), signalUpdates: make(chan *signalUpdateMsg), resolutionSignal: make(chan struct{}), forceCloseReqs: make(chan *forceCloseReq), @@ -2790,31 +2793,23 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { // A new block has arrived, we'll examine all the active HTLC's // to see if any of them have expired, and also update our // track of the best current height. - case blockHeight, ok := <-c.blocks: + case beat, ok := <-c.blockBeatChan: if !ok { return } - bestHeight = blockHeight - // If we're not in the default state, then we can - // ignore this signal as we're waiting for contract - // resolution. - if c.state != StateDefault { - continue - } + log.Debugf("ChannelArbitrator(%v): new block height=%v", + c.cfg.ChanPoint, beat.Epoch.Height) - // Now that a new block has arrived, we'll attempt to - // advance our state forward. - nextState, _, err := c.advanceState( - uint32(bestHeight), chainTrigger, nil, - ) + err := c.handleBlockbeat(beat) if err != nil { - log.Errorf("Unable to advance state: %v", err) + log.Errorf("Handle block=%v got err: %v", + beat.Epoch.Height, err) } // If as a result of this trigger, the contract is // fully resolved, then well exit. - if nextState == StateFullyResolved { + if c.state == StateFullyResolved { return } @@ -3140,6 +3135,46 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { } } +// handleBlockbeat processes a newly received blockbeat. +func (c *ChannelArbitrator) handleBlockbeat(beat chainio.Beat) error { + // Notify we've processed the block. + defer fn.SendOrQuit(beat.Err, nil, c.quit) + + bestHeight := beat.Epoch.Height + + // If we're not in the default state, then we can ignore this signal as + // we're waiting for contract resolution. + if c.state != StateDefault { + return nil + } + + // Now that a new block has arrived, we'll attempt to advance our state + // forward. + _, _, err := c.advanceState( + uint32(bestHeight), chainTrigger, nil, + ) + if err != nil { + return fmt.Errorf("unable to advance state: %w", err) + } + + return nil +} + +// processBlock sends the specified blockbeat to the channel arbitrator's inner +// loop for processing. +func (c *ChannelArbitrator) processBlock(beat chainio.Beat) <-chan error { + select { + case c.blockBeatChan <- beat: + log.Debugf("Received block beat for height=%d", + beat.Epoch.Height) + + case <-c.quit: + return nil + } + + return beat.Err +} + // checkLegacyBreach returns StateFullyResolved if the channel was closed with // a breach transaction before the channel arbitrator launched its own breach // resolver. StateContractClosed is returned if this is a modern breach close diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 43238494efc..9f66701e2df 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" @@ -1065,6 +1066,11 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { // Send a notification that the expiry height has been reached. oldNotifier.EpochChan <- &chainntnfs.BlockEpoch{Height: 10} + beat := chainio.NewBeat(chainntnfs.BlockEpoch{ + Height: 10, + }) + chanArbCtx.chanArb.blockBeatChan <- beat + // htlcOutgoingContestResolver is now transforming into a // htlcTimeoutResolver and should send the contract off for incubation. select { @@ -1900,7 +1906,10 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { // now mine a block (height 5), which is 5 blocks away // (our grace delta) from the expiry of that HTLC. case testCase.htlcExpired: - chanArbCtx.chanArb.blocks <- 5 + beat := chainio.NewBeat(chainntnfs.BlockEpoch{ + Height: 5, + }) + chanArbCtx.chanArb.blockBeatChan <- beat // Otherwise, we'll just trigger a regular force close // request. @@ -2004,7 +2013,10 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { // so instead, we'll mine another block which'll cause // it to re-examine its state and realize there're no // more HTLCs. - chanArbCtx.chanArb.blocks <- 6 + beat := chainio.NewBeat(chainntnfs.BlockEpoch{ + Height: 6, + }) + chanArbCtx.chanArb.blockBeatChan <- beat chanArbCtx.AssertStateTransitions(StateFullyResolved) }) } @@ -2076,13 +2088,19 @@ func TestChannelArbitratorPendingExpiredHTLC(t *testing.T) { // We will advance the uptime to 10 seconds which should be still within // the grace period and should not trigger going to chain. testClock.SetTime(startTime.Add(time.Second * 10)) - chanArbCtx.chanArb.blocks <- 5 + beat := chainio.NewBeat(chainntnfs.BlockEpoch{ + Height: 5, + }) + chanArbCtx.chanArb.blockBeatChan <- beat chanArbCtx.AssertState(StateDefault) // We will advance the uptime to 16 seconds which should trigger going // to chain. testClock.SetTime(startTime.Add(time.Second * 16)) - chanArbCtx.chanArb.blocks <- 6 + beat = chainio.NewBeat(chainntnfs.BlockEpoch{ + Height: 6, + }) + chanArbCtx.chanArb.blockBeatChan <- beat chanArbCtx.AssertStateTransitions( StateBroadcastCommit, StateCommitmentBroadcasted, @@ -2450,7 +2468,10 @@ func TestSweepAnchors(t *testing.T) { // Set current block height. heightHint := uint32(1000) - chanArbCtx.chanArb.blocks <- int32(heightHint) + beat := chainio.NewBeat(chainntnfs.BlockEpoch{ + Height: int32(heightHint), + }) + chanArbCtx.chanArb.blockBeatChan <- beat htlcIndexBase := uint64(99) deadlineDelta := uint32(10) @@ -2651,7 +2672,10 @@ func TestChannelArbitratorAnchors(t *testing.T) { // Set current block height. heightHint := uint32(1000) - chanArbCtx.chanArb.blocks <- int32(heightHint) + beat := chainio.NewBeat(chainntnfs.BlockEpoch{ + Height: int32(heightHint), + }) + chanArbCtx.chanArb.blockBeatChan <- beat htlcAmt := lnwire.MilliSatoshi(1_000_000) From 6ee2cd79cbfb311cb36a7630cf6b7b23c25ccec6 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 30 Apr 2024 12:26:24 +0800 Subject: [PATCH 11/19] contractcourt: init `activeResolvers` with a block chan This commit changes the `activeResolvers` map so each active resolver now has a block chan to receive new blocks. --- contractcourt/channel_arbitrator.go | 55 ++++++++++++++++++------ contractcourt/channel_arbitrator_test.go | 14 ++++-- 2 files changed, 52 insertions(+), 17 deletions(-) diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index dca1f428ecc..845b0b757a8 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -1,7 +1,6 @@ package contractcourt import ( - "bytes" "context" "errors" "fmt" @@ -356,9 +355,9 @@ type ChannelArbitrator struct { // we're watching over will be sent. signalUpdates chan *signalUpdateMsg - // activeResolvers is a slice of any active resolvers. This is used to - // be able to signal them for shutdown in the case that we shutdown. - activeResolvers []ContractResolver + // activeResolvers is a map of active resolvers. It uses the resolver + // as the key and a block chan as the value. + activeResolvers map[ContractResolver]chan int32 // activeResolversLock prevents simultaneous read and write to the // resolvers slice. @@ -801,7 +800,7 @@ func (c *ChannelArbitrator) Report() []*ContractReport { defer c.activeResolversLock.RUnlock() var reports []*ContractReport - for _, resolver := range c.activeResolvers { + for resolver := range c.activeResolvers { r, ok := resolver.(reportingContractResolver) if !ok { continue @@ -831,7 +830,7 @@ func (c *ChannelArbitrator) Stop() error { } c.activeResolversLock.RLock() - for _, activeResolver := range c.activeResolvers { + for activeResolver := range c.activeResolvers { activeResolver.Stop() } c.activeResolversLock.RUnlock() @@ -1585,9 +1584,15 @@ func (c *ChannelArbitrator) launchResolvers(resolvers []ContractResolver, c.activeResolversLock.Lock() defer c.activeResolversLock.Unlock() - c.activeResolvers = resolvers + c.activeResolvers = make(map[ContractResolver]chan int32) + for _, contract := range resolvers { c.wg.Add(1) + + // Create a block chan for each contract resolver. + blockChan := make(chan int32, arbitratorBlockBufferSize) + c.activeResolvers[contract] = blockChan + go c.resolveContract(contract, immediate) } } @@ -2582,15 +2587,15 @@ func (c *ChannelArbitrator) replaceResolver(oldResolver, c.activeResolversLock.Lock() defer c.activeResolversLock.Unlock() - oldKey := oldResolver.ResolverKey() - for i, r := range c.activeResolvers { - if bytes.Equal(r.ResolverKey(), oldKey) { - c.activeResolvers[i] = newResolver - return nil - } + blockChan, ok := c.activeResolvers[oldResolver] + if !ok { + return errors.New("resolver to be replaced not found") } - return errors.New("resolver to be replaced not found") + c.activeResolvers[newResolver] = blockChan + delete(c.activeResolvers, oldResolver) + + return nil } // resolveContract is a goroutine tasked with fully resolving an unresolved @@ -2768,6 +2773,25 @@ func (c *ChannelArbitrator) updateActiveHTLCs() { } } +// notifyResolvers will send the block height to all the active resolvers. +func (c *ChannelArbitrator) notifyResolvers(height int32) { + c.activeResolversLock.RLock() + defer c.activeResolversLock.RUnlock() + + log.Debugf("Notifying %v resolvers of new block height %v", + len(c.activeResolvers), height) + + for _, blockChan := range c.activeResolvers { + select { + case blockChan <- height: + case <-c.quit: + } + } + + log.Debugf("Notified %v resolvers of new block height %v", + len(c.activeResolvers), height) +} + // channelAttendant is the primary goroutine that acts at the judicial // arbitrator between our channel state, the remote channel peer, and the // blockchain (Our judge). This goroutine will ensure that we faithfully execute @@ -3142,6 +3166,9 @@ func (c *ChannelArbitrator) handleBlockbeat(beat chainio.Beat) error { bestHeight := beat.Epoch.Height + // Notify all resolvers about this new block. + c.notifyResolvers(bestHeight) + // If we're not in the default state, then we can ignore this signal as // we're waiting for contract resolution. if c.state != StateDefault { diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 9f66701e2df..84a63fa8683 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -809,7 +809,7 @@ func TestChannelArbitratorBreachClose(t *testing.T) { require.Equal(t, 2, len(chanArb.activeResolvers)) var anchorExists, breachExists bool - for _, resolver := range chanArb.activeResolvers { + for resolver := range chanArb.activeResolvers { switch resolver.(type) { case *anchorResolver: anchorExists = true @@ -1040,9 +1040,13 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { len(chanArb.activeResolvers)) } + var resolver ContractResolver + for r := range chanArb.activeResolvers { + resolver = r + } + // We'll now examine the in-memory state of the active resolvers to // ensure t hey were populated properly. - resolver := chanArb.activeResolvers[0] outgoingResolver, ok := resolver.(*htlcOutgoingContestResolver) if !ok { t.Fatalf("expected outgoing contest resolver, got %vT", @@ -2801,7 +2805,11 @@ func TestChannelArbitratorAnchors(t *testing.T) { len(chanArb.activeResolvers)) } - resolver := chanArb.activeResolvers[0] + var resolver ContractResolver + for r := range chanArb.activeResolvers { + resolver = r + } + _, ok := resolver.(*anchorResolver) if !ok { t.Fatalf("expected anchor resolver, got %T", resolver) From d96d4c5e42f3f5475077b2773c3e3c44b1180a63 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 30 Apr 2024 12:27:16 +0800 Subject: [PATCH 12/19] contractcourt: remove block subscription used in resolvers After this commit, when a new block comes, it will be passed through chainArb -> ChannelArbitrator -> resolvers. --- contractcourt/anchor_resolver.go | 4 +++- contractcourt/breach_resolver.go | 4 +++- contractcourt/chain_arbitrator_test.go | 2 -- contractcourt/channel_arbitrator.go | 8 ++++--- contractcourt/channel_arbitrator_test.go | 5 +--- contractcourt/commit_sweep_resolver.go | 24 +++++++++---------- contractcourt/commit_sweep_resolver_test.go | 17 +++++++------ contractcourt/contract_resolver.go | 3 ++- .../htlc_incoming_contest_resolver.go | 20 +++++----------- .../htlc_incoming_contest_resolver_test.go | 9 ++++--- .../htlc_outgoing_contest_resolver.go | 14 ++++------- .../htlc_outgoing_contest_resolver_test.go | 9 ++++--- contractcourt/htlc_success_resolver.go | 18 +++++++------- contractcourt/htlc_success_resolver_test.go | 14 +++++++---- contractcourt/htlc_timeout_resolver.go | 10 ++++---- contractcourt/htlc_timeout_resolver_test.go | 8 +++---- 16 files changed, 77 insertions(+), 92 deletions(-) diff --git a/contractcourt/anchor_resolver.go b/contractcourt/anchor_resolver.go index b4d6877202e..36e2970ce10 100644 --- a/contractcourt/anchor_resolver.go +++ b/contractcourt/anchor_resolver.go @@ -84,7 +84,9 @@ func (c *anchorResolver) ResolverKey() []byte { } // Resolve offers the anchor output to the sweeper and waits for it to be swept. -func (c *anchorResolver) Resolve(_ bool) (ContractResolver, error) { +func (c *anchorResolver) Resolve(_ bool, + _ <-chan int32) (ContractResolver, error) { + // Attempt to update the sweep parameters to the post-confirmation // situation. We don't want to force sweep anymore, because the anchor // lost its special purpose to get the commitment confirmed. It is just diff --git a/contractcourt/breach_resolver.go b/contractcourt/breach_resolver.go index 740b4471d5d..56ef51cc8e6 100644 --- a/contractcourt/breach_resolver.go +++ b/contractcourt/breach_resolver.go @@ -47,7 +47,9 @@ func (b *breachResolver) ResolverKey() []byte { // been broadcast. // // TODO(yy): let sweeper handle the breach inputs. -func (b *breachResolver) Resolve(_ bool) (ContractResolver, error) { +func (b *breachResolver) Resolve(_ bool, + _ <-chan int32) (ContractResolver, error) { + if !b.subscribed { complete, err := b.SubscribeBreachComplete( &b.ChanPoint, b.replyChan, diff --git a/contractcourt/chain_arbitrator_test.go b/contractcourt/chain_arbitrator_test.go index 36f6dad18ba..866052d8dc4 100644 --- a/contractcourt/chain_arbitrator_test.go +++ b/contractcourt/chain_arbitrator_test.go @@ -80,7 +80,6 @@ func TestChainArbitratorRepublishCloses(t *testing.T) { ChainIO: &mock.ChainIO{}, Notifier: &mock.ChainNotifier{ SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), ConfChan: make(chan *chainntnfs.TxConfirmation), }, PublishTx: func(tx *wire.MsgTx, _ string) error { @@ -165,7 +164,6 @@ func TestResolveContract(t *testing.T) { ChainIO: &mock.ChainIO{}, Notifier: &mock.ChainNotifier{ SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), ConfChan: make(chan *chainntnfs.TxConfirmation), }, PublishTx: func(tx *wire.MsgTx, _ string) error { diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 845b0b757a8..09ddec032d1 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -1593,7 +1593,7 @@ func (c *ChannelArbitrator) launchResolvers(resolvers []ContractResolver, blockChan := make(chan int32, arbitratorBlockBufferSize) c.activeResolvers[contract] = blockChan - go c.resolveContract(contract, immediate) + go c.resolveContract(contract, blockChan, immediate) } } @@ -2607,7 +2607,7 @@ func (c *ChannelArbitrator) replaceResolver(oldResolver, // // NOTE: This MUST be run as a goroutine. func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver, - immediate bool) { + blockChan <-chan int32, immediate bool) { defer c.wg.Done() @@ -2629,7 +2629,9 @@ func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver, default: // Otherwise, we'll attempt to resolve the current // contract. - nextContract, err := currentContract.Resolve(immediate) + nextContract, err := currentContract.Resolve( + immediate, blockChan, + ) if err != nil { if err == errResolverShuttingDown { return diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 84a63fa8683..ec6eaeedadd 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -359,7 +359,6 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog, OutgoingBroadcastDelta: 5, IncomingBroadcastDelta: 5, Notifier: &mock.ChainNotifier{ - EpochChan: make(chan *chainntnfs.BlockEpoch), SpendChan: make(chan *chainntnfs.SpendDetail), ConfChan: make(chan *chainntnfs.TxConfirmation), }, @@ -1068,12 +1067,10 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { } // Send a notification that the expiry height has been reached. - oldNotifier.EpochChan <- &chainntnfs.BlockEpoch{Height: 10} - beat := chainio.NewBeat(chainntnfs.BlockEpoch{ Height: 10, }) - chanArbCtx.chanArb.blockBeatChan <- beat + chanArb.blockBeatChan <- beat // htlcOutgoingContestResolver is now transforming into a // htlcTimeoutResolver and should send the contract off for incubation. diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go index 296ea38e554..b64dd71cbd9 100644 --- a/contractcourt/commit_sweep_resolver.go +++ b/contractcourt/commit_sweep_resolver.go @@ -92,24 +92,20 @@ func (c *commitSweepResolver) ResolverKey() []byte { // waitForHeight registers for block notifications and waits for the provided // block height to be reached. -func waitForHeight(waitHeight uint32, notifier chainntnfs.ChainNotifier, +// +// TODO(yy): There's no need to wait for height in the resolvers, instead, we +// can offer these immature inputs to the sweeper immediately since the sweeper +// will handle the waiting. +func waitForHeight(waitHeight uint32, blockChan <-chan int32, quit <-chan struct{}) error { - // Register for block epochs. After registration, the current height - // will be sent on the channel immediately. - blockEpochs, err := notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return err - } - defer blockEpochs.Cancel() - for { select { - case newBlock, ok := <-blockEpochs.Epochs: + case newBlockHeight, ok := <-blockChan: if !ok { return errResolverShuttingDown } - height := newBlock.Height + height := newBlockHeight if height >= int32(waitHeight) { return nil } @@ -186,7 +182,9 @@ func (c *commitSweepResolver) getCommitTxConfHeight() (uint32, error) { // NOTE: This function MUST be run as a goroutine. // //nolint:funlen -func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) { +func (c *commitSweepResolver) Resolve(_ bool, + blockChan <-chan int32) (ContractResolver, error) { + // If we're already resolved, then we can exit early. if c.resolved { return nil, nil @@ -241,7 +239,7 @@ func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) { waitHeight = unlockHeight - 1 } - err := waitForHeight(waitHeight, c.Notifier, c.quit) + err := waitForHeight(waitHeight, blockChan, c.quit) if err != nil { return nil, err } diff --git a/contractcourt/commit_sweep_resolver_test.go b/contractcourt/commit_sweep_resolver_test.go index f2b43b0f80a..c6ee1d8fd6c 100644 --- a/contractcourt/commit_sweep_resolver_test.go +++ b/contractcourt/commit_sweep_resolver_test.go @@ -23,13 +23,13 @@ type commitSweepResolverTestContext struct { sweeper *mockSweeper resolverResultChan chan resolveResult t *testing.T + blockChan chan int32 } func newCommitSweepResolverTestContext(t *testing.T, resolution *lnwallet.CommitOutputResolution) *commitSweepResolverTestContext { notifier := &mock.ChainNotifier{ - EpochChan: make(chan *chainntnfs.BlockEpoch), SpendChan: make(chan *chainntnfs.SpendDetail), ConfChan: make(chan *chainntnfs.TxConfirmation), } @@ -71,10 +71,11 @@ func newCommitSweepResolverTestContext(t *testing.T, ) return &commitSweepResolverTestContext{ - resolver: resolver, - notifier: notifier, - sweeper: sweeper, - t: t, + resolver: resolver, + notifier: notifier, + sweeper: sweeper, + t: t, + blockChan: make(chan int32, 1), } } @@ -82,7 +83,7 @@ func (i *commitSweepResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) go func() { - nextResolver, err := i.resolver.Resolve(false) + nextResolver, err := i.resolver.Resolve(false, i.blockChan) i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, err: err, @@ -91,9 +92,7 @@ func (i *commitSweepResolverTestContext) resolve() { } func (i *commitSweepResolverTestContext) notifyEpoch(height int32) { - i.notifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: height, - } + i.blockChan <- height } func (i *commitSweepResolverTestContext) waitForResult() { diff --git a/contractcourt/contract_resolver.go b/contractcourt/contract_resolver.go index 5acf8006496..67a694f81c5 100644 --- a/contractcourt/contract_resolver.go +++ b/contractcourt/contract_resolver.go @@ -43,7 +43,8 @@ type ContractResolver interface { // resolution, then another resolve is returned. // // NOTE: This function MUST be run as a goroutine. - Resolve(immediate bool) (ContractResolver, error) + Resolve(immediate bool, + blockChan <-chan int32) (ContractResolver, error) // SupplementState allows the user of a ContractResolver to supplement // it with state required for the proper resolution of a contract. diff --git a/contractcourt/htlc_incoming_contest_resolver.go b/contractcourt/htlc_incoming_contest_resolver.go index b104f8d70a6..10f8ec5e1bc 100644 --- a/contractcourt/htlc_incoming_contest_resolver.go +++ b/contractcourt/htlc_incoming_contest_resolver.go @@ -90,8 +90,8 @@ func (h *htlcIncomingContestResolver) processFinalHtlcFail() error { // as we have no remaining actions left at our disposal. // // NOTE: Part of the ContractResolver interface. -func (h *htlcIncomingContestResolver) Resolve( - _ bool) (ContractResolver, error) { +func (h *htlcIncomingContestResolver) Resolve(_ bool, + blockChan <-chan int32) (ContractResolver, error) { // If we're already full resolved, then we don't have anything further // to do. @@ -126,21 +126,13 @@ func (h *htlcIncomingContestResolver) Resolve( return nil, h.PutResolverReport(nil, resReport) } - // Register for block epochs. After registration, the current height - // will be sent on the channel immediately. - blockEpochs, err := h.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return nil, err - } - defer blockEpochs.Cancel() - var currentHeight int32 select { - case newBlock, ok := <-blockEpochs.Epochs: + case height, ok := <-blockChan: if !ok { return nil, errResolverShuttingDown } - currentHeight = newBlock.Height + currentHeight = height case <-h.quit: return nil, errResolverShuttingDown } @@ -403,7 +395,7 @@ func (h *htlcIncomingContestResolver) Resolve( htlcResolution := hodlItem.(invoices.HtlcResolution) return processHtlcResolution(htlcResolution) - case newBlock, ok := <-blockEpochs.Epochs: + case height, ok := <-blockChan: if !ok { return nil, errResolverShuttingDown } @@ -411,7 +403,7 @@ func (h *htlcIncomingContestResolver) Resolve( // If this new height expires the HTLC, then this means // we never found out the preimage, so we can mark // resolved and exit. - newHeight := uint32(newBlock.Height) + newHeight := uint32(height) if newHeight >= h.htlcExpiry { log.Infof("%T(%v): HTLC has timed out "+ "(expiry=%v, height=%v), abandoning", h, diff --git a/contractcourt/htlc_incoming_contest_resolver_test.go b/contractcourt/htlc_incoming_contest_resolver_test.go index 55d93a6fb37..c6c0448c1e1 100644 --- a/contractcourt/htlc_incoming_contest_resolver_test.go +++ b/contractcourt/htlc_incoming_contest_resolver_test.go @@ -309,11 +309,11 @@ type incomingResolverTestContext struct { nextResolver ContractResolver finalHtlcOutcomeStored bool t *testing.T + blockChan chan int32 } func newIncomingResolverTestContext(t *testing.T, isExit bool) *incomingResolverTestContext { notifier := &mock.ChainNotifier{ - EpochChan: make(chan *chainntnfs.BlockEpoch), SpendChan: make(chan *chainntnfs.SpendDetail), ConfChan: make(chan *chainntnfs.TxConfirmation), } @@ -332,6 +332,7 @@ func newIncomingResolverTestContext(t *testing.T, isExit bool) *incomingResolver notifier: notifier, onionProcessor: onionProcessor, t: t, + blockChan: make(chan int32, 1), } htlcNotifier := &mockHTLCNotifier{} @@ -395,7 +396,7 @@ func (i *incomingResolverTestContext) resolve() { i.resolveErr = make(chan error, 1) go func() { var err error - i.nextResolver, err = i.resolver.Resolve(false) + i.nextResolver, err = i.resolver.Resolve(false, i.blockChan) i.resolveErr <- err }() @@ -404,9 +405,7 @@ func (i *incomingResolverTestContext) resolve() { } func (i *incomingResolverTestContext) notifyEpoch(height int32) { - i.notifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: height, - } + i.blockChan <- height } func (i *incomingResolverTestContext) waitForResult(expectSuccessRes bool) { diff --git a/contractcourt/htlc_outgoing_contest_resolver.go b/contractcourt/htlc_outgoing_contest_resolver.go index 2466544c982..64b9da8d578 100644 --- a/contractcourt/htlc_outgoing_contest_resolver.go +++ b/contractcourt/htlc_outgoing_contest_resolver.go @@ -49,8 +49,8 @@ func newOutgoingContestResolver(res lnwallet.OutgoingHtlcResolution, // When either of these two things happens, we'll create a new resolver which // is able to handle the final resolution of the contract. We're only the pivot // point. -func (h *htlcOutgoingContestResolver) Resolve( - _ bool) (ContractResolver, error) { +func (h *htlcOutgoingContestResolver) Resolve(_ bool, + blockChan <-chan int32) (ContractResolver, error) { // If we're already full resolved, then we don't have anything further // to do. @@ -99,18 +99,12 @@ func (h *htlcOutgoingContestResolver) Resolve( // If we reach this point, then we can't fully act yet, so we'll await // either of our signals triggering: the HTLC expires, or we learn of // the preimage. - blockEpochs, err := h.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return nil, err - } - defer blockEpochs.Cancel() - for { select { // A new block has arrived, we'll check to see if this leads to // HTLC expiration. - case newBlock, ok := <-blockEpochs.Epochs: + case newBlockHeight, ok := <-blockChan: if !ok { return nil, errResolverShuttingDown } @@ -125,7 +119,7 @@ func (h *htlcOutgoingContestResolver) Resolve( // check doesn't pass, error `transaction is not // finalized` will be returned and the broadcast will // fail. - newHeight := uint32(newBlock.Height) + newHeight := uint32(newBlockHeight) if newHeight >= h.htlcResolution.Expiry { log.Infof("%T(%v): HTLC has expired "+ "(height=%v, expiry=%v), transforming "+ diff --git a/contractcourt/htlc_outgoing_contest_resolver_test.go b/contractcourt/htlc_outgoing_contest_resolver_test.go index f67c34ff4e1..cd59e194825 100644 --- a/contractcourt/htlc_outgoing_contest_resolver_test.go +++ b/contractcourt/htlc_outgoing_contest_resolver_test.go @@ -122,11 +122,11 @@ type outgoingResolverTestContext struct { resolverResultChan chan resolveResult resolutionChan chan ResolutionMsg t *testing.T + blockChan chan int32 } func newOutgoingResolverTestContext(t *testing.T) *outgoingResolverTestContext { notifier := &mock.ChainNotifier{ - EpochChan: make(chan *chainntnfs.BlockEpoch), SpendChan: make(chan *chainntnfs.SpendDetail), ConfChan: make(chan *chainntnfs.TxConfirmation), } @@ -202,6 +202,7 @@ func newOutgoingResolverTestContext(t *testing.T) *outgoingResolverTestContext { preimageDB: preimageDB, resolutionChan: resolutionChan, t: t, + blockChan: make(chan int32, 1), } } @@ -209,7 +210,7 @@ func (i *outgoingResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) go func() { - nextResolver, err := i.resolver.Resolve(false) + nextResolver, err := i.resolver.Resolve(false, i.blockChan) i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, err: err, @@ -221,9 +222,7 @@ func (i *outgoingResolverTestContext) resolve() { } func (i *outgoingResolverTestContext) notifyEpoch(height int32) { - i.notifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: height, - } + i.blockChan <- height } func (i *outgoingResolverTestContext) waitForResult(expectTimeoutRes bool) { diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index 6eee939eac0..6792ea6ca59 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -115,8 +115,8 @@ func (h *htlcSuccessResolver) ResolverKey() []byte { // TODO(roasbeef): create multi to batch // // NOTE: Part of the ContractResolver interface. -func (h *htlcSuccessResolver) Resolve( - immediate bool) (ContractResolver, error) { +func (h *htlcSuccessResolver) Resolve(immediate bool, + blockChan <-chan int32) (ContractResolver, error) { // If we're already resolved, then we can exit early. if h.resolved { @@ -131,7 +131,7 @@ func (h *htlcSuccessResolver) Resolve( // Otherwise this an output on our own commitment, and we must start by // broadcasting the second-level success transaction. - secondLevelOutpoint, err := h.broadcastSuccessTx(immediate) + secondLevelOutpoint, err := h.broadcastSuccessTx(immediate, blockChan) if err != nil { return nil, err } @@ -165,8 +165,8 @@ func (h *htlcSuccessResolver) Resolve( // broadcasting the second-level success transaction. It returns the ultimate // outpoint of the second-level tx, that we must wait to be spent for the // resolver to be fully resolved. -func (h *htlcSuccessResolver) broadcastSuccessTx( - immediate bool) (*wire.OutPoint, error) { +func (h *htlcSuccessResolver) broadcastSuccessTx(immediate bool, + blockChan <-chan int32) (*wire.OutPoint, error) { // If we have non-nil SignDetails, this means that have a 2nd level // HTLC transaction that is signed using sighash SINGLE|ANYONECANPAY @@ -175,7 +175,7 @@ func (h *htlcSuccessResolver) broadcastSuccessTx( // the checkpointed outputIncubating field to determine if we already // swept the HTLC output into the second level transaction. if h.htlcResolution.SignDetails != nil { - return h.broadcastReSignedSuccessTx(immediate) + return h.broadcastReSignedSuccessTx(immediate, blockChan) } // Otherwise we'll publish the second-level transaction directly and @@ -227,8 +227,8 @@ func (h *htlcSuccessResolver) broadcastSuccessTx( // will re-sign it and attach fees at will. // //nolint:funlen -func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) ( - *wire.OutPoint, error) { +func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool, + blockChan <-chan int32) (*wire.OutPoint, error) { // Keep track of the tx spending the HTLC output on the commitment, as // this will be the confirmed second-level tx we'll ultimately sweep. @@ -375,7 +375,7 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) ( waitHeight-- // TODO(yy): let sweeper handles the wait? - err := waitForHeight(waitHeight, h.Notifier, h.quit) + err := waitForHeight(waitHeight, blockChan, h.quit) if err != nil { return nil, err } diff --git a/contractcourt/htlc_success_resolver_test.go b/contractcourt/htlc_success_resolver_test.go index b9182500bb4..4a39dd2f2f5 100644 --- a/contractcourt/htlc_success_resolver_test.go +++ b/contractcourt/htlc_success_resolver_test.go @@ -37,6 +37,8 @@ type htlcResolverTestContext struct { finalHtlcOutcomeStored bool t *testing.T + + blockChan chan int32 } func newHtlcResolverTestContext(t *testing.T, @@ -44,7 +46,6 @@ func newHtlcResolverTestContext(t *testing.T, cfg ResolverConfig) ContractResolver) *htlcResolverTestContext { notifier := &mock.ChainNotifier{ - EpochChan: make(chan *chainntnfs.BlockEpoch, 1), SpendChan: make(chan *chainntnfs.SpendDetail, 1), ConfChan: make(chan *chainntnfs.TxConfirmation, 1), } @@ -54,6 +55,7 @@ func newHtlcResolverTestContext(t *testing.T, notifier: notifier, resolutionChan: make(chan ResolutionMsg, 1), t: t, + blockChan: make(chan int32, 1), } htlcNotifier := &mockHTLCNotifier{} @@ -134,7 +136,7 @@ func (i *htlcResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) go func() { - nextResolver, err := i.resolver.Resolve(false) + nextResolver, err := i.resolver.Resolve(false, i.blockChan) i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, err: err, @@ -142,6 +144,10 @@ func (i *htlcResolverTestContext) resolve() { }() } +func (i *htlcResolverTestContext) notifyEpoch(height int32) { + i.blockChan <- height +} + func (i *htlcResolverTestContext) waitForResult() { i.t.Helper() @@ -437,9 +443,7 @@ func TestHtlcSuccessSecondStageResolutionSweeper(t *testing.T) { } } - ctx.notifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: 13, - } + ctx.notifyEpoch(13) // We expect it to sweep the second-level // transaction we notfied about above. diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index 62ff832071c..1c49fc05f2b 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -418,8 +418,8 @@ func checkSizeAndIndex(witness wire.TxWitness, size, index int) bool { // see a direct sweep via the timeout clause. // // NOTE: Part of the ContractResolver interface. -func (h *htlcTimeoutResolver) Resolve( - immediate bool) (ContractResolver, error) { +func (h *htlcTimeoutResolver) Resolve(immediate bool, + blockChan <-chan int32) (ContractResolver, error) { // If we're already resolved, then we can exit early. if h.resolved { @@ -468,7 +468,7 @@ func (h *htlcTimeoutResolver) Resolve( // Depending on whether this was a local or remote commit, we must // handle the spending transaction accordingly. - return h.handleCommitSpend(commitSpend) + return h.handleCommitSpend(blockChan, commitSpend) } // sweepSecondLevelTx sends a second level timeout transaction to the sweeper. @@ -668,7 +668,7 @@ func (h *htlcTimeoutResolver) checkPointSecondLevelTx() error { // confirmed second-level timeout transaction, and we'll sweep that into our // wallet. If the was a remote commitment, the resolver will resolve // immetiately. -func (h *htlcTimeoutResolver) handleCommitSpend( +func (h *htlcTimeoutResolver) handleCommitSpend(blockChan <-chan int32, commitSpend *chainntnfs.SpendDetail) (ContractResolver, error) { var ( @@ -731,7 +731,7 @@ func (h *htlcTimeoutResolver) handleCommitSpend( waitHeight-- // TODO(yy): let sweeper handles the wait? - err := waitForHeight(waitHeight, h.Notifier, h.quit) + err := waitForHeight(waitHeight, blockChan, h.quit) if err != nil { return nil, err } diff --git a/contractcourt/htlc_timeout_resolver_test.go b/contractcourt/htlc_timeout_resolver_test.go index c551a6f1ceb..c854b75f723 100644 --- a/contractcourt/htlc_timeout_resolver_test.go +++ b/contractcourt/htlc_timeout_resolver_test.go @@ -267,7 +267,6 @@ func TestHtlcTimeoutResolver(t *testing.T) { } notifier := &mock.ChainNotifier{ - EpochChan: make(chan *chainntnfs.BlockEpoch), SpendChan: make(chan *chainntnfs.SpendDetail), ConfChan: make(chan *chainntnfs.TxConfirmation), } @@ -375,7 +374,8 @@ func TestHtlcTimeoutResolver(t *testing.T) { go func() { defer wg.Done() - _, err := resolver.Resolve(false) + // TODO(yy): fix. + _, err := resolver.Resolve(false, nil) if err != nil { resolveErr <- err } @@ -1089,9 +1089,7 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { } // Mimic CSV lock expiring. - ctx.notifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: 13, - } + ctx.notifyEpoch(13) // The timeout tx output should now be given to // the sweeper. From 06a9c13446f918780d4c43415a8f75f7c719dbfb Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Fri, 26 Apr 2024 18:28:25 +0800 Subject: [PATCH 13/19] lnd: start blockbeat service and register subsystems --- log.go | 2 ++ server.go | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/log.go b/log.go index f6da0235a92..fc3238b509e 100644 --- a/log.go +++ b/log.go @@ -7,6 +7,7 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/autopilot" "github.com/lightningnetwork/lnd/build" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/chanacceptor" @@ -164,6 +165,7 @@ func SetupLoggers(root *build.RotatingLogWriter, interceptor signal.Interceptor) AddSubLogger(root, "CHFD", interceptor, chanfunding.UseLogger) AddSubLogger(root, "PEER", interceptor, peer.UseLogger) AddSubLogger(root, "CHCL", interceptor, chancloser.UseLogger) + AddSubLogger(root, "CHIO", interceptor, chainio.UseLogger) AddSubLogger(root, routing.Subsystem, interceptor, routing.UseLogger) AddSubLogger(root, routerrpc.Subsystem, interceptor, routerrpc.UseLogger) diff --git a/server.go b/server.go index 2b54f81d6a8..49db7435072 100644 --- a/server.go +++ b/server.go @@ -27,6 +27,7 @@ import ( "github.com/lightningnetwork/lnd/aliasmgr" "github.com/lightningnetwork/lnd/autopilot" "github.com/lightningnetwork/lnd/brontide" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/chanacceptor" "github.com/lightningnetwork/lnd/chanbackup" @@ -330,6 +331,10 @@ type server struct { // txPublisher is a publisher with fee-bumping capability. txPublisher *sweep.TxPublisher + // blockbeat is a block dispatcher that notifies subscribers of new + // blocks. + blockbeat *chainio.BlockBeat + quit chan struct{} wg sync.WaitGroup @@ -568,6 +573,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s := &server{ cfg: cfg, + blockbeat: chainio.NewBlockBeat(cc.ChainNotifier), graphDB: dbs.GraphDB.ChannelGraph(), chanStateDB: dbs.ChanStateDB.ChannelStateDB(), addrSource: dbs.ChanStateDB, @@ -1677,9 +1683,31 @@ func newServer(cfg *Config, listenAddrs []net.Addr, } s.connMgr = cmgr + // Finally, register the subsystems in blockbeat. + s.registerBlockConsumers() + return s, nil } +// registerBlockConsumers registers the subsystems that consume block events. +// By calling `RegisterQueue`, a list of subsystems are registered in the +// blockbeat for block notifications. When a new block arrives, the subsystems +// in the same queue are notified sequentially, and different queues are +// notified concurrently. +// +// NOTE: To put a subsystem in a different queue, create a slice and pass it to +// a new `RegisterQueue` call. +func (s *server) registerBlockConsumers() { + // In this queue, when a new block arrives, it will be received and + // processed in this order: chainArb -> sweeper -> txPublisher. + consumers := []chainio.Consumer{ + s.chainArb, + s.sweeper, + s.txPublisher, + } + s.blockbeat.RegisterQueue(consumers) +} + // signAliasUpdate takes a ChannelUpdate and returns the signature. This is // used for option_scid_alias channels where the ChannelUpdate to be sent back // may differ from what is on disk. @@ -2110,6 +2138,17 @@ func (s *server) Start() error { return nil }) + // Start the blockbeat after all other subsystems have been + // started so they are ready to receive new blocks. + if err := s.blockbeat.Start(); err != nil { + startErr = err + return + } + cleanup = cleanup.add(func() error { + s.blockbeat.Stop() + return nil + }) + // If peers are specified as a config option, we'll add those // peers first. for _, peerAddrCfg := range s.cfg.AddPeers { @@ -2297,6 +2336,7 @@ func (s *server) Stop() error { } s.txPublisher.Stop() + s.blockbeat.Stop() if err := s.channelNotifier.Stop(); err != nil { srvrLog.Warnf("failed to stop channelNotifier: %v", err) From dd2b11027dac7e21fe5388fddaacfd78442f0a56 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 30 Apr 2024 13:21:20 +0800 Subject: [PATCH 14/19] contractcourt: remove the immediate param used in `Resolve` This `immediate` flag was added as a hack so during a restart, the pending resolvers would offer the inputs to the sweeper and ask it to sweep them immediately. This is no longer need due to `blockbeat`, as now during restart, a block is always sent to all subsystems via the flow `ChainArb` -> `ChannelArb` -> resolvers -> sweeper. Thus, when there are pending inputs offered, they will be processed by the sweeper immediately. --- contractcourt/anchor_resolver.go | 4 +--- contractcourt/breach_resolver.go | 4 +--- contractcourt/channel_arbitrator.go | 16 ++++++---------- contractcourt/commit_sweep_resolver.go | 2 +- contractcourt/commit_sweep_resolver_test.go | 2 +- contractcourt/contract_resolver.go | 3 +-- .../htlc_incoming_contest_resolver.go | 2 +- .../htlc_incoming_contest_resolver_test.go | 2 +- .../htlc_outgoing_contest_resolver.go | 2 +- .../htlc_outgoing_contest_resolver_test.go | 2 +- contractcourt/htlc_success_resolver.go | 18 ++++++++---------- contractcourt/htlc_success_resolver_test.go | 2 +- contractcourt/htlc_timeout_resolver.go | 13 ++++++------- contractcourt/htlc_timeout_resolver_test.go | 4 +--- 14 files changed, 31 insertions(+), 45 deletions(-) diff --git a/contractcourt/anchor_resolver.go b/contractcourt/anchor_resolver.go index 36e2970ce10..b80b84033d4 100644 --- a/contractcourt/anchor_resolver.go +++ b/contractcourt/anchor_resolver.go @@ -84,9 +84,7 @@ func (c *anchorResolver) ResolverKey() []byte { } // Resolve offers the anchor output to the sweeper and waits for it to be swept. -func (c *anchorResolver) Resolve(_ bool, - _ <-chan int32) (ContractResolver, error) { - +func (c *anchorResolver) Resolve(_ <-chan int32) (ContractResolver, error) { // Attempt to update the sweep parameters to the post-confirmation // situation. We don't want to force sweep anymore, because the anchor // lost its special purpose to get the commitment confirmed. It is just diff --git a/contractcourt/breach_resolver.go b/contractcourt/breach_resolver.go index 56ef51cc8e6..57004b299d0 100644 --- a/contractcourt/breach_resolver.go +++ b/contractcourt/breach_resolver.go @@ -47,9 +47,7 @@ func (b *breachResolver) ResolverKey() []byte { // been broadcast. // // TODO(yy): let sweeper handle the breach inputs. -func (b *breachResolver) Resolve(_ bool, - _ <-chan int32) (ContractResolver, error) { - +func (b *breachResolver) Resolve(_ <-chan int32) (ContractResolver, error) { if !b.subscribed { complete, err := b.SubscribeBreachComplete( &b.ChanPoint, b.replyChan, diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 09ddec032d1..2192f794300 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -789,7 +789,7 @@ func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet, // TODO(roasbeef): this isn't re-launched? } - c.launchResolvers(unresolvedContracts, true) + c.launchResolvers(unresolvedContracts) return nil } @@ -1247,7 +1247,7 @@ func (c *ChannelArbitrator) stateStep( // Finally, we'll launch all the required contract resolvers. // Once they're all resolved, we're no longer needed. - c.launchResolvers(resolvers, false) + c.launchResolvers(resolvers) nextState = StateWaitingFullResolution @@ -1578,9 +1578,7 @@ func (c *ChannelArbitrator) findCommitmentDeadlineAndValue(heightHint uint32, } // launchResolvers updates the activeResolvers list and starts the resolvers. -func (c *ChannelArbitrator) launchResolvers(resolvers []ContractResolver, - immediate bool) { - +func (c *ChannelArbitrator) launchResolvers(resolvers []ContractResolver) { c.activeResolversLock.Lock() defer c.activeResolversLock.Unlock() @@ -1593,7 +1591,7 @@ func (c *ChannelArbitrator) launchResolvers(resolvers []ContractResolver, blockChan := make(chan int32, arbitratorBlockBufferSize) c.activeResolvers[contract] = blockChan - go c.resolveContract(contract, blockChan, immediate) + go c.resolveContract(contract, blockChan) } } @@ -2607,7 +2605,7 @@ func (c *ChannelArbitrator) replaceResolver(oldResolver, // // NOTE: This MUST be run as a goroutine. func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver, - blockChan <-chan int32, immediate bool) { + blockChan <-chan int32) { defer c.wg.Done() @@ -2629,9 +2627,7 @@ func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver, default: // Otherwise, we'll attempt to resolve the current // contract. - nextContract, err := currentContract.Resolve( - immediate, blockChan, - ) + nextContract, err := currentContract.Resolve(blockChan) if err != nil { if err == errResolverShuttingDown { return diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go index b64dd71cbd9..75f3ecc6a1d 100644 --- a/contractcourt/commit_sweep_resolver.go +++ b/contractcourt/commit_sweep_resolver.go @@ -182,7 +182,7 @@ func (c *commitSweepResolver) getCommitTxConfHeight() (uint32, error) { // NOTE: This function MUST be run as a goroutine. // //nolint:funlen -func (c *commitSweepResolver) Resolve(_ bool, +func (c *commitSweepResolver) Resolve( blockChan <-chan int32) (ContractResolver, error) { // If we're already resolved, then we can exit early. diff --git a/contractcourt/commit_sweep_resolver_test.go b/contractcourt/commit_sweep_resolver_test.go index c6ee1d8fd6c..1b4f35bd275 100644 --- a/contractcourt/commit_sweep_resolver_test.go +++ b/contractcourt/commit_sweep_resolver_test.go @@ -83,7 +83,7 @@ func (i *commitSweepResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) go func() { - nextResolver, err := i.resolver.Resolve(false, i.blockChan) + nextResolver, err := i.resolver.Resolve(i.blockChan) i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, err: err, diff --git a/contractcourt/contract_resolver.go b/contractcourt/contract_resolver.go index 67a694f81c5..c574f8ae577 100644 --- a/contractcourt/contract_resolver.go +++ b/contractcourt/contract_resolver.go @@ -43,8 +43,7 @@ type ContractResolver interface { // resolution, then another resolve is returned. // // NOTE: This function MUST be run as a goroutine. - Resolve(immediate bool, - blockChan <-chan int32) (ContractResolver, error) + Resolve(blockChan <-chan int32) (ContractResolver, error) // SupplementState allows the user of a ContractResolver to supplement // it with state required for the proper resolution of a contract. diff --git a/contractcourt/htlc_incoming_contest_resolver.go b/contractcourt/htlc_incoming_contest_resolver.go index 10f8ec5e1bc..e3c4e991a2a 100644 --- a/contractcourt/htlc_incoming_contest_resolver.go +++ b/contractcourt/htlc_incoming_contest_resolver.go @@ -90,7 +90,7 @@ func (h *htlcIncomingContestResolver) processFinalHtlcFail() error { // as we have no remaining actions left at our disposal. // // NOTE: Part of the ContractResolver interface. -func (h *htlcIncomingContestResolver) Resolve(_ bool, +func (h *htlcIncomingContestResolver) Resolve( blockChan <-chan int32) (ContractResolver, error) { // If we're already full resolved, then we don't have anything further diff --git a/contractcourt/htlc_incoming_contest_resolver_test.go b/contractcourt/htlc_incoming_contest_resolver_test.go index c6c0448c1e1..e3d965e1953 100644 --- a/contractcourt/htlc_incoming_contest_resolver_test.go +++ b/contractcourt/htlc_incoming_contest_resolver_test.go @@ -396,7 +396,7 @@ func (i *incomingResolverTestContext) resolve() { i.resolveErr = make(chan error, 1) go func() { var err error - i.nextResolver, err = i.resolver.Resolve(false, i.blockChan) + i.nextResolver, err = i.resolver.Resolve(i.blockChan) i.resolveErr <- err }() diff --git a/contractcourt/htlc_outgoing_contest_resolver.go b/contractcourt/htlc_outgoing_contest_resolver.go index 64b9da8d578..c37f89b4a9a 100644 --- a/contractcourt/htlc_outgoing_contest_resolver.go +++ b/contractcourt/htlc_outgoing_contest_resolver.go @@ -49,7 +49,7 @@ func newOutgoingContestResolver(res lnwallet.OutgoingHtlcResolution, // When either of these two things happens, we'll create a new resolver which // is able to handle the final resolution of the contract. We're only the pivot // point. -func (h *htlcOutgoingContestResolver) Resolve(_ bool, +func (h *htlcOutgoingContestResolver) Resolve( blockChan <-chan int32) (ContractResolver, error) { // If we're already full resolved, then we don't have anything further diff --git a/contractcourt/htlc_outgoing_contest_resolver_test.go b/contractcourt/htlc_outgoing_contest_resolver_test.go index cd59e194825..1968a042a2b 100644 --- a/contractcourt/htlc_outgoing_contest_resolver_test.go +++ b/contractcourt/htlc_outgoing_contest_resolver_test.go @@ -210,7 +210,7 @@ func (i *outgoingResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) go func() { - nextResolver, err := i.resolver.Resolve(false, i.blockChan) + nextResolver, err := i.resolver.Resolve(i.blockChan) i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, err: err, diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index 6792ea6ca59..554acfcdc77 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -115,7 +115,7 @@ func (h *htlcSuccessResolver) ResolverKey() []byte { // TODO(roasbeef): create multi to batch // // NOTE: Part of the ContractResolver interface. -func (h *htlcSuccessResolver) Resolve(immediate bool, +func (h *htlcSuccessResolver) Resolve( blockChan <-chan int32) (ContractResolver, error) { // If we're already resolved, then we can exit early. @@ -126,12 +126,12 @@ func (h *htlcSuccessResolver) Resolve(immediate bool, // If we don't have a success transaction, then this means that this is // an output on the remote party's commitment transaction. if h.htlcResolution.SignedSuccessTx == nil { - return h.resolveRemoteCommitOutput(immediate) + return h.resolveRemoteCommitOutput() } // Otherwise this an output on our own commitment, and we must start by // broadcasting the second-level success transaction. - secondLevelOutpoint, err := h.broadcastSuccessTx(immediate, blockChan) + secondLevelOutpoint, err := h.broadcastSuccessTx(blockChan) if err != nil { return nil, err } @@ -165,8 +165,8 @@ func (h *htlcSuccessResolver) Resolve(immediate bool, // broadcasting the second-level success transaction. It returns the ultimate // outpoint of the second-level tx, that we must wait to be spent for the // resolver to be fully resolved. -func (h *htlcSuccessResolver) broadcastSuccessTx(immediate bool, - blockChan <-chan int32) (*wire.OutPoint, error) { +func (h *htlcSuccessResolver) broadcastSuccessTx(blockChan <-chan int32) ( + *wire.OutPoint, error) { // If we have non-nil SignDetails, this means that have a 2nd level // HTLC transaction that is signed using sighash SINGLE|ANYONECANPAY @@ -175,7 +175,7 @@ func (h *htlcSuccessResolver) broadcastSuccessTx(immediate bool, // the checkpointed outputIncubating field to determine if we already // swept the HTLC output into the second level transaction. if h.htlcResolution.SignDetails != nil { - return h.broadcastReSignedSuccessTx(immediate, blockChan) + return h.broadcastReSignedSuccessTx(blockChan) } // Otherwise we'll publish the second-level transaction directly and @@ -227,7 +227,7 @@ func (h *htlcSuccessResolver) broadcastSuccessTx(immediate bool, // will re-sign it and attach fees at will. // //nolint:funlen -func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool, +func (h *htlcSuccessResolver) broadcastReSignedSuccessTx( blockChan <-chan int32) (*wire.OutPoint, error) { // Keep track of the tx spending the HTLC output on the commitment, as @@ -284,7 +284,6 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool, sweep.Params{ Budget: budget, DeadlineHeight: deadline, - Immediate: immediate, }, ) if err != nil { @@ -440,7 +439,7 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool, // resolveRemoteCommitOutput handles sweeping an HTLC output on the remote // commitment with the preimage. In this case we can sweep the output directly, // and don't have to broadcast a second-level transaction. -func (h *htlcSuccessResolver) resolveRemoteCommitOutput(immediate bool) ( +func (h *htlcSuccessResolver) resolveRemoteCommitOutput() ( ContractResolver, error) { isTaproot := txscript.IsPayToTaproot( @@ -489,7 +488,6 @@ func (h *htlcSuccessResolver) resolveRemoteCommitOutput(immediate bool) ( sweep.Params{ Budget: budget, DeadlineHeight: deadline, - Immediate: immediate, }, ) if err != nil { diff --git a/contractcourt/htlc_success_resolver_test.go b/contractcourt/htlc_success_resolver_test.go index 4a39dd2f2f5..473cd4506a3 100644 --- a/contractcourt/htlc_success_resolver_test.go +++ b/contractcourt/htlc_success_resolver_test.go @@ -136,7 +136,7 @@ func (i *htlcResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) go func() { - nextResolver, err := i.resolver.Resolve(false, i.blockChan) + nextResolver, err := i.resolver.Resolve(i.blockChan) i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, err: err, diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index 1c49fc05f2b..f0a167a878a 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -418,7 +418,7 @@ func checkSizeAndIndex(witness wire.TxWitness, size, index int) bool { // see a direct sweep via the timeout clause. // // NOTE: Part of the ContractResolver interface. -func (h *htlcTimeoutResolver) Resolve(immediate bool, +func (h *htlcTimeoutResolver) Resolve( blockChan <-chan int32) (ContractResolver, error) { // If we're already resolved, then we can exit early. @@ -429,7 +429,7 @@ func (h *htlcTimeoutResolver) Resolve(immediate bool, // Start by spending the HTLC output, either by broadcasting the // second-level timeout transaction, or directly if this is the remote // commitment. - commitSpend, err := h.spendHtlcOutput(immediate) + commitSpend, err := h.spendHtlcOutput() if err != nil { return nil, err } @@ -473,7 +473,7 @@ func (h *htlcTimeoutResolver) Resolve(immediate bool, // sweepSecondLevelTx sends a second level timeout transaction to the sweeper. // This transaction uses the SINLGE|ANYONECANPAY flag. -func (h *htlcTimeoutResolver) sweepSecondLevelTx(immediate bool) error { +func (h *htlcTimeoutResolver) sweepSecondLevelTx() error { log.Infof("%T(%x): offering second-layer timeout tx to sweeper: %v", h, h.htlc.RHash[:], spew.Sdump(h.htlcResolution.SignedTimeoutTx)) @@ -531,7 +531,6 @@ func (h *htlcTimeoutResolver) sweepSecondLevelTx(immediate bool) error { sweep.Params{ Budget: budget, DeadlineHeight: h.incomingHTLCExpiryHeight, - Immediate: immediate, }, ) if err != nil { @@ -567,8 +566,8 @@ func (h *htlcTimeoutResolver) sendSecondLevelTxLegacy() error { // used to spend the output into the next stage. If this is the remote // commitment, the output will be swept directly without the timeout // transaction. -func (h *htlcTimeoutResolver) spendHtlcOutput( - immediate bool) (*chainntnfs.SpendDetail, error) { +func (h *htlcTimeoutResolver) spendHtlcOutput() ( + *chainntnfs.SpendDetail, error) { switch { // If we have non-nil SignDetails, this means that have a 2nd level @@ -576,7 +575,7 @@ func (h *htlcTimeoutResolver) spendHtlcOutput( // (the case for anchor type channels). In this case we can re-sign it // and attach fees at will. We let the sweeper handle this job. case h.htlcResolution.SignDetails != nil && !h.outputIncubating: - if err := h.sweepSecondLevelTx(immediate); err != nil { + if err := h.sweepSecondLevelTx(); err != nil { log.Errorf("Sending timeout tx to sweeper: %v", err) return nil, err diff --git a/contractcourt/htlc_timeout_resolver_test.go b/contractcourt/htlc_timeout_resolver_test.go index c854b75f723..3e38486fb4c 100644 --- a/contractcourt/htlc_timeout_resolver_test.go +++ b/contractcourt/htlc_timeout_resolver_test.go @@ -373,9 +373,7 @@ func TestHtlcTimeoutResolver(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - - // TODO(yy): fix. - _, err := resolver.Resolve(false, nil) + _, err := resolver.Resolve(nil) if err != nil { resolveErr <- err } From ce7cba6e6c2df8b73c34d502889545038944a9d3 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 30 Apr 2024 19:29:58 +0800 Subject: [PATCH 15/19] contractcourt: remove `waitForHeight` in resolvers The sweeper can handle the waiting so there's no need to wait for blocks inside the resolvers. By offering the inputs prior to their mature heights also guarantees the inputs with the same deadline are aggregated. --- contractcourt/commit_sweep_resolver.go | 62 +------------------------- contractcourt/htlc_success_resolver.go | 40 +++-------------- contractcourt/htlc_timeout_resolver.go | 32 ++----------- 3 files changed, 13 insertions(+), 121 deletions(-) diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go index 75f3ecc6a1d..b6ac7514752 100644 --- a/contractcourt/commit_sweep_resolver.go +++ b/contractcourt/commit_sweep_resolver.go @@ -90,32 +90,6 @@ func (c *commitSweepResolver) ResolverKey() []byte { return key[:] } -// waitForHeight registers for block notifications and waits for the provided -// block height to be reached. -// -// TODO(yy): There's no need to wait for height in the resolvers, instead, we -// can offer these immature inputs to the sweeper immediately since the sweeper -// will handle the waiting. -func waitForHeight(waitHeight uint32, blockChan <-chan int32, - quit <-chan struct{}) error { - - for { - select { - case newBlockHeight, ok := <-blockChan: - if !ok { - return errResolverShuttingDown - } - height := newBlockHeight - if height >= int32(waitHeight) { - return nil - } - - case <-quit: - return errResolverShuttingDown - } - } -} - // waitForSpend waits for the given outpoint to be spent, and returns the // details of the spending tx. func waitForSpend(op *wire.OutPoint, pkScript []byte, heightHint uint32, @@ -183,7 +157,7 @@ func (c *commitSweepResolver) getCommitTxConfHeight() (uint32, error) { // //nolint:funlen func (c *commitSweepResolver) Resolve( - blockChan <-chan int32) (ContractResolver, error) { + _ <-chan int32) (ContractResolver, error) { // If we're already resolved, then we can exit early. if c.resolved { @@ -212,44 +186,12 @@ func (c *commitSweepResolver) Resolve( c.currentReport.MaturityHeight = unlockHeight c.reportLock.Unlock() - // If there is a csv/cltv lock, we'll wait for that. - if c.commitResolution.MaturityDelay > 0 || c.hasCLTV() { - // Determine what height we should wait until for the locks to - // expire. - var waitHeight uint32 - switch { - // If we have both a csv and cltv lock, we'll need to look at - // both and see which expires later. - case c.commitResolution.MaturityDelay > 0 && c.hasCLTV(): - c.log.Debugf("waiting for CSV and CLTV lock to expire "+ - "at height %v", unlockHeight) - // If the CSV expires after the CLTV, or there is no - // CLTV, then we can broadcast a sweep a block before. - // Otherwise, we need to broadcast at our expected - // unlock height. - waitHeight = uint32(math.Max( - float64(unlockHeight-1), float64(c.leaseExpiry), - )) - - // If we only have a csv lock, wait for the height before the - // lock expires as the spend path should be unlocked by then. - case c.commitResolution.MaturityDelay > 0: - c.log.Debugf("waiting for CSV lock to expire at "+ - "height %v", unlockHeight) - waitHeight = unlockHeight - 1 - } - - err := waitForHeight(waitHeight, blockChan, c.quit) - if err != nil { - return nil, err - } - } - var ( isLocalCommitTx bool signDesc = c.commitResolution.SelfOutputSignDesc ) + switch { // For taproot channels, we'll know if this is the local commit based // on the witness script. For local channels, the witness script has an diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index 554acfcdc77..8b1245b4576 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -116,7 +116,7 @@ func (h *htlcSuccessResolver) ResolverKey() []byte { // // NOTE: Part of the ContractResolver interface. func (h *htlcSuccessResolver) Resolve( - blockChan <-chan int32) (ContractResolver, error) { + _ <-chan int32) (ContractResolver, error) { // If we're already resolved, then we can exit early. if h.resolved { @@ -131,7 +131,7 @@ func (h *htlcSuccessResolver) Resolve( // Otherwise this an output on our own commitment, and we must start by // broadcasting the second-level success transaction. - secondLevelOutpoint, err := h.broadcastSuccessTx(blockChan) + secondLevelOutpoint, err := h.broadcastSuccessTx() if err != nil { return nil, err } @@ -165,7 +165,7 @@ func (h *htlcSuccessResolver) Resolve( // broadcasting the second-level success transaction. It returns the ultimate // outpoint of the second-level tx, that we must wait to be spent for the // resolver to be fully resolved. -func (h *htlcSuccessResolver) broadcastSuccessTx(blockChan <-chan int32) ( +func (h *htlcSuccessResolver) broadcastSuccessTx() ( *wire.OutPoint, error) { // If we have non-nil SignDetails, this means that have a 2nd level @@ -175,7 +175,7 @@ func (h *htlcSuccessResolver) broadcastSuccessTx(blockChan <-chan int32) ( // the checkpointed outputIncubating field to determine if we already // swept the HTLC output into the second level transaction. if h.htlcResolution.SignDetails != nil { - return h.broadcastReSignedSuccessTx(blockChan) + return h.broadcastReSignedSuccessTx() } // Otherwise we'll publish the second-level transaction directly and @@ -225,10 +225,8 @@ func (h *htlcSuccessResolver) broadcastSuccessTx(blockChan <-chan int32) ( // broadcastReSignedSuccessTx handles the case where we have non-nil // SignDetails, and offers the second level transaction to the Sweeper, that // will re-sign it and attach fees at will. -// -//nolint:funlen -func (h *htlcSuccessResolver) broadcastReSignedSuccessTx( - blockChan <-chan int32) (*wire.OutPoint, error) { +func (h *htlcSuccessResolver) broadcastReSignedSuccessTx() (*wire.OutPoint, + error) { // Keep track of the tx spending the HTLC output on the commitment, as // this will be the confirmed second-level tx we'll ultimately sweep. @@ -355,30 +353,6 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx( "height %v", h, h.htlc.RHash[:], waitHeight) } - // Deduct one block so this input is offered to the sweeper one block - // earlier since the sweeper will wait for one block to trigger the - // sweeping. - // - // TODO(yy): this is done so the outputs can be aggregated - // properly. Suppose CSV locks of five 2nd-level outputs all - // expire at height 840000, there is a race in block digestion - // between contractcourt and sweeper: - // - G1: block 840000 received in contractcourt, it now offers - // the outputs to the sweeper. - // - G2: block 840000 received in sweeper, it now starts to - // sweep the received outputs - there's no guarantee all - // fives have been received. - // To solve this, we either offer the outputs earlier, or - // implement `blockbeat`, and force contractcourt and sweeper - // to consume each block sequentially. - waitHeight-- - - // TODO(yy): let sweeper handles the wait? - err := waitForHeight(waitHeight, blockChan, h.quit) - if err != nil { - return nil, err - } - // We'll use this input index to determine the second-level output // index on the transaction, as the signatures requires the indexes to // be the same. We don't look for the second-level output script @@ -417,7 +391,7 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx( h.htlc.RHash[:], budget, waitHeight) // TODO(roasbeef): need to update above for leased types - _, err = h.Sweeper.SweepInput( + _, err := h.Sweeper.SweepInput( inp, sweep.Params{ Budget: budget, diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index f0a167a878a..435c7d7d03e 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -419,7 +419,7 @@ func checkSizeAndIndex(witness wire.TxWitness, size, index int) bool { // // NOTE: Part of the ContractResolver interface. func (h *htlcTimeoutResolver) Resolve( - blockChan <-chan int32) (ContractResolver, error) { + _ <-chan int32) (ContractResolver, error) { // If we're already resolved, then we can exit early. if h.resolved { @@ -468,7 +468,7 @@ func (h *htlcTimeoutResolver) Resolve( // Depending on whether this was a local or remote commit, we must // handle the spending transaction accordingly. - return h.handleCommitSpend(blockChan, commitSpend) + return h.handleCommitSpend(commitSpend) } // sweepSecondLevelTx sends a second level timeout transaction to the sweeper. @@ -667,7 +667,7 @@ func (h *htlcTimeoutResolver) checkPointSecondLevelTx() error { // confirmed second-level timeout transaction, and we'll sweep that into our // wallet. If the was a remote commitment, the resolver will resolve // immetiately. -func (h *htlcTimeoutResolver) handleCommitSpend(blockChan <-chan int32, +func (h *htlcTimeoutResolver) handleCommitSpend( commitSpend *chainntnfs.SpendDetail) (ContractResolver, error) { var ( @@ -711,30 +711,6 @@ func (h *htlcTimeoutResolver) handleCommitSpend(blockChan <-chan int32, "height %v", h, h.htlc.RHash[:], waitHeight) } - // Deduct one block so this input is offered to the sweeper one - // block earlier since the sweeper will wait for one block to - // trigger the sweeping. - // - // TODO(yy): this is done so the outputs can be aggregated - // properly. Suppose CSV locks of five 2nd-level outputs all - // expire at height 840000, there is a race in block digestion - // between contractcourt and sweeper: - // - G1: block 840000 received in contractcourt, it now offers - // the outputs to the sweeper. - // - G2: block 840000 received in sweeper, it now starts to - // sweep the received outputs - there's no guarantee all - // fives have been received. - // To solve this, we either offer the outputs earlier, or - // implement `blockbeat`, and force contractcourt and sweeper - // to consume each block sequentially. - waitHeight-- - - // TODO(yy): let sweeper handles the wait? - err := waitForHeight(waitHeight, blockChan, h.quit) - if err != nil { - return nil, err - } - // We'll use this input index to determine the second-level // output index on the transaction, as the signatures requires // the indexes to be the same. We don't look for the @@ -773,7 +749,7 @@ func (h *htlcTimeoutResolver) handleCommitSpend(blockChan <-chan int32, "sweeper with no deadline and budget=%v at height=%v", h, h.htlc.RHash[:], budget, waitHeight) - _, err = h.Sweeper.SweepInput( + _, err := h.Sweeper.SweepInput( inp, sweep.Params{ Budget: budget, From bc1861469e916b25d2c82bd7511d1eb0ee47b651 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 30 Apr 2024 19:31:41 +0800 Subject: [PATCH 16/19] itest: fix sweep tests and remove hacks --- itest/lnd_sweep_test.go | 531 ++++++++---------------------------- lntest/harness_assertion.go | 33 +++ 2 files changed, 150 insertions(+), 414 deletions(-) diff --git a/itest/lnd_sweep_test.go b/itest/lnd_sweep_test.go index 27ad1326644..e25acf3a321 100644 --- a/itest/lnd_sweep_test.go +++ b/itest/lnd_sweep_test.go @@ -2,7 +2,6 @@ package itest import ( "fmt" - "math" "time" "github.com/btcsuite/btcd/btcutil" @@ -61,10 +60,7 @@ func testSweepCPFPAnchorOutgoingTimeout(ht *lntest.HarnessTest) { // Set up the fee estimator to return the testing fee rate when the // conf target is the deadline. - // - // TODO(yy): switch to conf when `blockbeat` is in place. - // ht.SetFeeEstimateWithConf(startFeeRateAnchor, deadlineDeltaAnchor) - ht.SetFeeEstimate(startFeeRateAnchor) + ht.SetFeeEstimateWithConf(startFeeRateAnchor, deadlineDeltaAnchor) // htlcValue is the outgoing HTLC's value. htlcValue := invoiceAmt @@ -167,52 +163,26 @@ func testSweepCPFPAnchorOutgoingTimeout(ht *lntest.HarnessTest) { )) ht.MineEmptyBlocks(int(numBlocks)) - // Assert Bob's force closing tx has been broadcast. - closeTxid := ht.Miner.AssertNumTxsInMempool(1)[0] + // Assert Bob's force closing tx has been broadcast. We should see two + // txns in the mempool: + // 1. Bob's force closing tx. + // 2. Bob's anchor sweeping tx CPFPing the force close tx. + _, sweepTx := ht.AssertForceCloseAndAnchorTxnsInMempool() // Remember the force close height so we can calculate the deadline // height. _, forceCloseHeight := ht.Miner.GetBestBlock() - // Bob should have two pending sweeps, + // Bob should have one pending sweep, // - anchor sweeping from his local commitment. - // - anchor sweeping from his remote commitment (invalid). - // - // TODO(yy): consider only sweeping the anchor from the local - // commitment. Previously we would sweep up to three versions of - // anchors because we don't know which one will be confirmed - if we - // only broadcast the local anchor sweeping, our peer can broadcast - // their commitment tx and replaces ours. With the new fee bumping, we - // should be safe to only sweep our local anchor since we RBF it on - // every new block, which destroys the remote's ability to pin us. - sweeps := ht.AssertNumPendingSweeps(bob, 2) - - // The two anchor sweeping should have the same deadline height. + anchorSweep := ht.AssertNumPendingSweeps(bob, 1)[0] + + // The anchor sweeping should have the expected deadline height. deadlineHeight := uint32(forceCloseHeight) + deadlineDeltaAnchor - require.Equal(ht, deadlineHeight, sweeps[0].DeadlineHeight) - require.Equal(ht, deadlineHeight, sweeps[1].DeadlineHeight) + require.Equal(ht, deadlineHeight, anchorSweep.DeadlineHeight) // Remember the deadline height for the CPFP anchor. - anchorDeadline := sweeps[0].DeadlineHeight - - // Mine a block so Bob's force closing tx stays in the mempool, which - // also triggers the CPFP anchor sweep. - ht.MineEmptyBlocks(1) - - // Bob should still have two pending sweeps, - // - anchor sweeping from his local commitment. - // - anchor sweeping from his remote commitment (invalid). - ht.AssertNumPendingSweeps(bob, 2) - - // We now check the expected fee and fee rate are used for Bob's anchor - // sweeping tx. - // - // We should see Bob's anchor sweeping tx triggered by the above - // block, along with his force close tx. - txns := ht.Miner.GetNumTxsFromMempool(2) - - // Find the sweeping tx. - sweepTx := ht.FindSweepingTxns(txns, 1, *closeTxid)[0] + anchorDeadline := anchorSweep.DeadlineHeight // Get the weight for Bob's anchor sweeping tx. txWeight := ht.CalculateTxWeight(sweepTx) @@ -224,11 +194,10 @@ func testSweepCPFPAnchorOutgoingTimeout(ht *lntest.HarnessTest) { fee := uint64(ht.CalculateTxFee(sweepTx)) feeRate := uint64(ht.CalculateTxFeeRate(sweepTx)) - // feeFuncWidth is the width of the fee function. By the time we got - // here, we've already mined one block, and the fee function maxes - // out one block before the deadline, so the width is the original - // deadline minus 2. - feeFuncWidth := deadlineDeltaAnchor - 2 + // feeFuncWidth is the width of the fee function. The fee function + // maxes out one block before the deadline, so the width is the + // original deadline minus 1. + feeFuncWidth := deadlineDeltaAnchor - 1 // Calculate the expected delta increased per block. feeDelta := (cpfpBudget - startFeeAnchor).MulF64( @@ -254,10 +223,10 @@ func testSweepCPFPAnchorOutgoingTimeout(ht *lntest.HarnessTest) { // Bob's fee bumper should increase its fees. ht.MineEmptyBlocks(1) - // Bob should still have two pending sweeps, - // - anchor sweeping from his local commitment. - // - anchor sweeping from his remote commitment (invalid). - ht.AssertNumPendingSweeps(bob, 2) + // Bob should still have the anchor sweeping from his local + // commitment. His anchor sweeping from his remote commitment + // is invalid and should be removed. + ht.AssertNumPendingSweeps(bob, 1) // Make sure Bob's old sweeping tx has been removed from the // mempool. @@ -266,7 +235,7 @@ func testSweepCPFPAnchorOutgoingTimeout(ht *lntest.HarnessTest) { // We expect to see two txns in the mempool, // - Bob's force close tx. // - Bob's anchor sweep tx. - ht.Miner.AssertNumTxsInMempool(2) + _, sweepTx := ht.AssertForceCloseAndAnchorTxnsInMempool() // We expect the fees to increase by i*delta. expectedFee := startFeeAnchor + feeDelta.MulF64(float64(i)) @@ -276,11 +245,7 @@ func testSweepCPFPAnchorOutgoingTimeout(ht *lntest.HarnessTest) { // We should see Bob's anchor sweeping tx being fee bumped // since it's not confirmed, along with his force close tx. - txns = ht.Miner.GetNumTxsFromMempool(2) - - // Find the sweeping tx. - sweepTx = ht.FindSweepingTxns(txns, 1, *closeTxid)[0] - + // // Calculate the fee rate of Bob's new sweeping tx. feeRate = uint64(ht.CalculateTxFeeRate(sweepTx)) @@ -315,17 +280,13 @@ func testSweepCPFPAnchorOutgoingTimeout(ht *lntest.HarnessTest) { // Get the last sweeping tx - we should see two txns here, Bob's anchor // sweeping tx and his force close tx. - txns = ht.Miner.GetNumTxsFromMempool(2) - - // Find the sweeping tx. - sweepTx = ht.FindSweepingTxns(txns, 1, *closeTxid)[0] + _, sweepTx = ht.AssertForceCloseAndAnchorTxnsInMempool() // Calculate the fee of Bob's new sweeping tx. fee = uint64(ht.CalculateTxFee(sweepTx)) - // Assert the budget is now used up. - require.InEpsilonf(ht, uint64(cpfpBudget), fee, 0.01, "want %d, got %d", - cpfpBudget, fee) + // Bob should have the anchor sweeping from his local commitment. + ht.AssertNumPendingSweeps(bob, 1) // Mine one more block. Since Bob's budget has been used up, there // won't be any more sweeping attempts. We now assert this by checking @@ -336,10 +297,7 @@ func testSweepCPFPAnchorOutgoingTimeout(ht *lntest.HarnessTest) { // // We expect two txns here, one for the anchor sweeping, the other for // the force close tx. - txns = ht.Miner.GetNumTxsFromMempool(2) - - // Find the sweeping tx. - currentSweepTx := ht.FindSweepingTxns(txns, 1, *closeTxid)[0] + _, currentSweepTx := ht.AssertForceCloseAndAnchorTxnsInMempool() // Assert the anchor sweep tx stays unchanged. require.Equal(ht, sweepTx.TxHash(), currentSweepTx.TxHash()) @@ -400,10 +358,7 @@ func testSweepCPFPAnchorIncomingTimeout(ht *lntest.HarnessTest) { // Set up the fee estimator to return the testing fee rate when the // conf target is the deadline. - // - // TODO(yy): switch to conf when `blockbeat` is in place. - // ht.SetFeeEstimateWithConf(startFeeRateAnchor, deadlineDeltaAnchor) - ht.SetFeeEstimate(startFeeRateAnchor) + ht.SetFeeEstimateWithConf(startFeeRateAnchor, deadlineDeltaAnchor) // Create a preimage, that will be held by Carol. var preimage lntypes.Preimage @@ -516,40 +471,22 @@ func testSweepCPFPAnchorIncomingTimeout(ht *lntest.HarnessTest) { numBlocks := forceCloseHeight - uint32(currentHeight) ht.MineEmptyBlocks(int(numBlocks)) - // Assert Bob's force closing tx has been broadcast. - closeTxid := ht.Miner.AssertNumTxsInMempool(1)[0] + // Assert Bob's force closing tx has been broadcast. We should see two + // txns in the mempool: + // 1. Bob's force closing tx. + // 2. Bob's anchor sweeping tx CPFPing the force close tx. + _, sweepTx := ht.AssertForceCloseAndAnchorTxnsInMempool() - // Bob should have two pending sweeps, + // Bob should have one pending sweep, // - anchor sweeping from his local commitment. - // - anchor sweeping from his remote commitment (invalid). - sweeps := ht.AssertNumPendingSweeps(bob, 2) + anchorSweep := ht.AssertNumPendingSweeps(bob, 1)[0] - // The two anchor sweeping should have the same deadline height. + // The anchor sweeping should have the expected deadline height. deadlineHeight := forceCloseHeight + deadlineDeltaAnchor - require.Equal(ht, deadlineHeight, sweeps[0].DeadlineHeight) - require.Equal(ht, deadlineHeight, sweeps[1].DeadlineHeight) + require.Equal(ht, deadlineHeight, anchorSweep.DeadlineHeight) // Remember the deadline height for the CPFP anchor. - anchorDeadline := sweeps[0].DeadlineHeight - - // Mine a block so Bob's force closing tx stays in the mempool, which - // also triggers the CPFP anchor sweep. - ht.MineEmptyBlocks(1) - - // Bob should still have two pending sweeps, - // - anchor sweeping from his local commitment. - // - anchor sweeping from his remote commitment (invalid). - ht.AssertNumPendingSweeps(bob, 2) - - // We now check the expected fee and fee rate are used for Bob's anchor - // sweeping tx. - // - // We should see Bob's anchor sweeping tx triggered by the above - // block, along with his force close tx. - txns := ht.Miner.GetNumTxsFromMempool(2) - - // Find the sweeping tx. - sweepTx := ht.FindSweepingTxns(txns, 1, *closeTxid)[0] + anchorDeadline := anchorSweep.DeadlineHeight // Get the weight for Bob's anchor sweeping tx. txWeight := ht.CalculateTxWeight(sweepTx) @@ -561,11 +498,10 @@ func testSweepCPFPAnchorIncomingTimeout(ht *lntest.HarnessTest) { fee := uint64(ht.CalculateTxFee(sweepTx)) feeRate := uint64(ht.CalculateTxFeeRate(sweepTx)) - // feeFuncWidth is the width of the fee function. By the time we got - // here, we've already mined one block, and the fee function maxes - // out one block before the deadline, so the width is the original - // deadline minus 2. - feeFuncWidth := deadlineDeltaAnchor - 2 + // feeFuncWidth is the width of the fee function. The fee function + // maxes out one block before the deadline, so the width is the + // original deadline minus 1. + feeFuncWidth := deadlineDeltaAnchor - 1 // Calculate the expected delta increased per block. feeDelta := (cpfpBudget - startFeeAnchor).MulF64( @@ -591,10 +527,10 @@ func testSweepCPFPAnchorIncomingTimeout(ht *lntest.HarnessTest) { // Bob's fee bumper should increase its fees. ht.MineEmptyBlocks(1) - // Bob should still have two pending sweeps, - // - anchor sweeping from his local commitment. - // - anchor sweeping from his remote commitment (invalid). - ht.AssertNumPendingSweeps(bob, 2) + // Bob should still have the anchor sweeping from his local + // commitment. His anchor sweeping from his remote commitment + // is invalid and should be removed. + ht.AssertNumPendingSweeps(bob, 1) // Make sure Bob's old sweeping tx has been removed from the // mempool. @@ -603,7 +539,7 @@ func testSweepCPFPAnchorIncomingTimeout(ht *lntest.HarnessTest) { // We expect to see two txns in the mempool, // - Bob's force close tx. // - Bob's anchor sweep tx. - ht.Miner.AssertNumTxsInMempool(2) + _, sweepTx := ht.AssertForceCloseAndAnchorTxnsInMempool() // We expect the fees to increase by i*delta. expectedFee := startFeeAnchor + feeDelta.MulF64(float64(i)) @@ -611,13 +547,6 @@ func testSweepCPFPAnchorIncomingTimeout(ht *lntest.HarnessTest) { expectedFee, txWeight, ) - // We should see Bob's anchor sweeping tx being fee bumped - // since it's not confirmed, along with his force close tx. - txns = ht.Miner.GetNumTxsFromMempool(2) - - // Find the sweeping tx. - sweepTx = ht.FindSweepingTxns(txns, 1, *closeTxid)[0] - // Calculate the fee rate of Bob's new sweeping tx. feeRate = uint64(ht.CalculateTxFeeRate(sweepTx)) @@ -652,10 +581,7 @@ func testSweepCPFPAnchorIncomingTimeout(ht *lntest.HarnessTest) { // Get the last sweeping tx - we should see two txns here, Bob's anchor // sweeping tx and his force close tx. - txns = ht.Miner.GetNumTxsFromMempool(2) - - // Find the sweeping tx. - sweepTx = ht.FindSweepingTxns(txns, 1, *closeTxid)[0] + _, sweepTx = ht.AssertForceCloseAndAnchorTxnsInMempool() // Calculate the fee of Bob's new sweeping tx. fee = uint64(ht.CalculateTxFee(sweepTx)) @@ -673,10 +599,7 @@ func testSweepCPFPAnchorIncomingTimeout(ht *lntest.HarnessTest) { // // We expect two txns here, one for the anchor sweeping, the other for // the force close tx. - txns = ht.Miner.GetNumTxsFromMempool(2) - - // Find the sweeping tx. - currentSweepTx := ht.FindSweepingTxns(txns, 1, *closeTxid)[0] + _, currentSweepTx := ht.AssertForceCloseAndAnchorTxnsInMempool() // Assert the anchor sweep tx stays unchanged. require.Equal(ht, sweepTx.TxHash(), currentSweepTx.TxHash()) @@ -729,7 +652,7 @@ func testSweepHTLCs(ht *lntest.HarnessTest) { // Start tracking the deadline delta of Bob's HTLCs. We need one block // for the CSV lock, and another block to trigger the sweeper to sweep. outgoingHTLCDeadline := int32(cltvDelta - 2) - incomingHTLCDeadline := int32(lncfg.DefaultIncomingBroadcastDelta - 2) + incomingHTLCDeadline := int32(lncfg.DefaultIncomingBroadcastDelta - 3) // startFeeRate1 and startFeeRate2 are returned by the fee estimator in // sat/kw. They will be used as the starting fee rate for the linear @@ -884,34 +807,33 @@ func testSweepHTLCs(ht *lntest.HarnessTest) { // Bob should now have two pending sweeps, one for the anchor on the // local commitment, the other on the remote commitment. - ht.AssertNumPendingSweeps(bob, 2) + ht.AssertNumPendingSweeps(bob, 1) - // Assert Bob's force closing tx has been broadcast. - ht.Miner.AssertNumTxsInMempool(1) + // We expect to see two txns in the mempool: + // 1. Bob's force closing tx. + // 1. Bob's anchor CPFP sweeping tx. + ht.Miner.AssertNumTxsInMempool(2) - // Mine the force close tx, which triggers Bob's contractcourt to offer - // his outgoing HTLC to his sweeper. + // Mine the force close tx and CPFP sweeping tx, which triggers Bob's + // contractcourt to offer his outgoing HTLC to his sweeper. // // NOTE: HTLC outputs are only offered to sweeper when the force close // tx is confirmed and the CSV has reached. - ht.MineBlocksAndAssertNumTxes(1, 1) + ht.MineBlocksAndAssertNumTxes(1, 2) // Update the blocks left till Bob force closes Alice->Bob. blocksTillIncomingSweep-- - // Bob should have two pending sweeps, one for the anchor sweeping, the - // other for the outgoing HTLC. - ht.AssertNumPendingSweeps(bob, 2) + // Bob should have one pending sweep for the outgoing HTLC. + ht.AssertNumPendingSweeps(bob, 1) - // Mine one block to confirm Bob's anchor sweeping tx, which will - // trigger his sweeper to publish the HTLC sweeping tx. - ht.MineBlocksAndAssertNumTxes(1, 1) + // Mine a block to trigger Bob's sweeper to sweep his outgoing HTLC. + ht.MineEmptyBlocks(1) // Update the blocks left till Bob force closes Alice->Bob. blocksTillIncomingSweep-- - // Bob should now have one sweep and one sweeping tx in the mempool. - ht.AssertNumPendingSweeps(bob, 1) + // Bob should have one sweeping tx in the mempool. outgoingSweep := ht.Miner.GetNumTxsFromMempool(1)[0] // Check the shape of the sweeping tx - we expect it to be @@ -1009,22 +931,25 @@ func testSweepHTLCs(ht *lntest.HarnessTest) { // Update Bob's fee function position. outgoingFuncPosition++ - // Bob should now have three pending sweeps: + // Bob should now have two pending sweeps: // 1. the outgoing HTLC output. // 2. the anchor output from his local commitment. - // 3. the anchor output from his remote commitment. - ht.AssertNumPendingSweeps(bob, 3) + ht.AssertNumPendingSweeps(bob, 2) - // We should see two txns in the mempool: + // We should see three txns in the mempool: // 1. Bob's outgoing HTLC sweeping tx. // 2. Bob's force close tx for Alice->Bob. - txns := ht.Miner.GetNumTxsFromMempool(2) + // 2. Bob's anchor CPFP sweeping tx for Alice->Bob. + txns := ht.Miner.GetNumTxsFromMempool(3) // Find the force close tx - we expect it to have a single input. closeTx := txns[0] if len(closeTx.TxIn) != 1 { closeTx = txns[1] } + if len(closeTx.TxIn) != 1 { + closeTx = txns[2] + } // We don't care the behavior of the anchor sweep in this test, so we // mine the force close tx to trigger Bob's contractcourt to offer his @@ -1034,6 +959,10 @@ func testSweepHTLCs(ht *lntest.HarnessTest) { // Update Bob's fee function position. outgoingFuncPosition++ + // Mine a block to trigger the sweep. + ht.MineEmptyBlocks(1) + outgoingFuncPosition++ + // Bob should now have three pending sweeps: // 1. the outgoing HTLC output on Bob->Carol. // 2. the incoming HTLC output on Alice->Bob. @@ -1228,10 +1157,15 @@ func testSweepCommitOutputAndAnchor(ht *lntest.HarnessTest) { // config. deadline := uint32(1000) - // The actual deadline used by the fee function will be one block off - // from the deadline configured as we require one block to be mined to - // trigger the sweep. - deadlineA, deadlineB := deadline-1, deadline-1 + // For Alice, since her commit output is offered to the sweeper at + // CSV-1. With a deadline of 1000, her actual width of her fee func is + // CSV+1000-1. + deadlineA := deadline + 1 + + // For Bob, the actual deadline used by the fee function will be one + // block off from the deadline configured as we require one block to be + // mined to trigger the sweep. + deadlineB := deadline - 1 // startFeeRate is returned by the fee estimator in sat/kw. This // will be used as the starting fee rate for the linear fee func used @@ -1242,7 +1176,7 @@ func testSweepCommitOutputAndAnchor(ht *lntest.HarnessTest) { // Set up the fee estimator to return the testing fee rate when the // conf target is the deadline. - ht.SetFeeEstimateWithConf(startFeeRate, deadlineA) + ht.SetFeeEstimateWithConf(startFeeRate, deadlineB) // toLocalCSV is the CSV delay for Alice's to_local output. We use a // small value to save us from mining blocks. @@ -1250,25 +1184,7 @@ func testSweepCommitOutputAndAnchor(ht *lntest.HarnessTest) { // NOTE: once the force close tx is confirmed, we expect anchor // sweeping starts. Then two more block later the commit output // sweeping starts. - // - // NOTE: The CSV value is chosen to be 3 instead of 2, to reduce the - // possibility of flakes as there is a race between the two goroutines: - // G1 - Alice's sweeper receives the commit output. - // G2 - Alice's sweeper receives the new block mined. - // G1 is triggered by the same block being received by Alice's - // contractcourt, deciding the commit output is mature and offering it - // to her sweeper. Normally, we'd expect G2 to be finished before G1 - // because it's the same block processed by both contractcourt and - // sweeper. However, if G2 is delayed (maybe the sweeper is slow in - // finishing its previous round), G1 may finish before G2. This will - // cause the sweeper to add the commit output to its pending inputs, - // and once G2 fires, it will then start sweeping this output, - // resulting a valid sweep tx being created using her commit and anchor - // outputs. - // - // TODO(yy): fix the above issue by making sure subsystems share the - // same view on current block height. - toLocalCSV := 3 + toLocalCSV := 2 // htlcAmt is the amount of the HTLC in sats, this should be Alice's // to_remote amount that goes to Bob. @@ -1362,140 +1278,18 @@ func testSweepCommitOutputAndAnchor(ht *lntest.HarnessTest) { ht.AssertNumPendingSweeps(bob, 2) // Mine one more empty block should trigger Bob's sweeping. Since we - // use a CSV of 3, this means Alice's to_local output is one block away - // from being mature. + // use a CSV of 2, this means Alice's to_local output is now mature. ht.MineEmptyBlocks(1) - // We expect to see one sweeping tx in the mempool: - // - Alice's anchor sweeping tx must have been failed due to the fee - // rate chosen in this test - the anchor sweep tx has no output. - // - Bob's sweeping tx, which sweeps both his anchor and commit outputs. - bobSweepTx := ht.Miner.GetNumTxsFromMempool(1)[0] - // We expect two pending sweeps for Bob - anchor and commit outputs. - pendingSweepBob := ht.AssertNumPendingSweeps(bob, 2)[0] - - // The sweeper may be one block behind contractcourt, so we double - // check the actual deadline. - // - // TODO(yy): assert they are equal once blocks are synced via - // `blockbeat`. - _, currentHeight := ht.Miner.GetBestBlock() - actualDeadline := int32(pendingSweepBob.DeadlineHeight) - currentHeight - if actualDeadline != int32(deadlineB) { - ht.Logf("!!! Found unsynced block between sweeper and "+ - "contractcourt, expected deadline=%v, got=%v", - deadlineB, actualDeadline) - - deadlineB = uint32(actualDeadline) - } - - // Alice should still have one pending sweep - the anchor output. - ht.AssertNumPendingSweeps(alice, 1) - - // We now check Bob's sweeping tx. - // - // Bob's sweeping tx should have 2 inputs, one from his commit output, - // the other from his anchor output. - require.Len(ht, bobSweepTx.TxIn, 2) - - // Because Bob is sweeping without deadline pressure, the starting fee - // rate should be the min relay fee rate. - bobStartFeeRate := ht.CalculateTxFeeRate(bobSweepTx) - require.InEpsilonf(ht, uint64(chainfee.FeePerKwFloor), - uint64(bobStartFeeRate), 0.01, "want %v, got %v", - chainfee.FeePerKwFloor, bobStartFeeRate) - - // With Bob's starting fee rate being validated, we now calculate his - // ending fee rate and fee rate delta. - // - // Bob sweeps two inputs - anchor and commit, so the starting budget - // should come from the sum of these two. - bobValue := btcutil.Amount(bobToLocal + 330) - bobBudget := bobValue.MulF64(contractcourt.DefaultBudgetRatio) - - // Calculate the ending fee rate and fee rate delta used in his fee - // function. - bobTxWeight := ht.CalculateTxWeight(bobSweepTx) - bobEndingFeeRate := chainfee.NewSatPerKWeight(bobBudget, bobTxWeight) - bobFeeRateDelta := (bobEndingFeeRate - bobStartFeeRate) / - chainfee.SatPerKWeight(deadlineB-1) - - // Mine an empty block, which should trigger Alice's contractcourt to - // offer her commit output to the sweeper. - ht.MineEmptyBlocks(1) - - // Alice should have both anchor and commit as the pending sweep - // requests. - aliceSweeps := ht.AssertNumPendingSweeps(alice, 2) - aliceAnchor, aliceCommit := aliceSweeps[0], aliceSweeps[1] - if aliceAnchor.AmountSat > aliceCommit.AmountSat { - aliceAnchor, aliceCommit = aliceCommit, aliceAnchor - } - - // The sweeper may be one block behind contractcourt, so we double - // check the actual deadline. - // - // TODO(yy): assert they are equal once blocks are synced via - // `blockbeat`. - _, currentHeight = ht.Miner.GetBestBlock() - actualDeadline = int32(aliceCommit.DeadlineHeight) - currentHeight - if actualDeadline != int32(deadlineA) { - ht.Logf("!!! Found unsynced block between Alice's sweeper and "+ - "contractcourt, expected deadline=%v, got=%v", - deadlineA, actualDeadline) - - deadlineA = uint32(actualDeadline) - } - - // We now wait for 30 seconds to overcome the flake - there's a block - // race between contractcourt and sweeper, causing the sweep to be - // broadcast earlier. - // - // TODO(yy): remove this once `blockbeat` is in place. - aliceStartPosition := 0 - var aliceFirstSweepTx *wire.MsgTx - err := wait.NoError(func() error { - mem := ht.Miner.GetRawMempool() - if len(mem) != 2 { - return fmt.Errorf("want 2, got %v in mempool: %v", - len(mem), mem) - } - - // If there are two txns, it means Alice's sweep tx has been - // created and published. - aliceStartPosition = 1 - - txns := ht.Miner.GetNumTxsFromMempool(2) - aliceFirstSweepTx = txns[0] - - // Reassign if the second tx is larger. - if txns[1].TxOut[0].Value > aliceFirstSweepTx.TxOut[0].Value { - aliceFirstSweepTx = txns[1] - } - - return nil - }, wait.DefaultTimeout) - ht.Logf("Checking mempool got: %v", err) - - // Mine an empty block, which should trigger Alice's sweeper to publish - // her commit sweep along with her anchor output. - ht.MineEmptyBlocks(1) + ht.AssertNumPendingSweeps(bob, 2) - // If Alice has already published her initial sweep tx, the above mined - // block would trigger an RBF. We now need to assert the mempool has - // removed the replaced tx. - if aliceFirstSweepTx != nil { - ht.Miner.AssertTxNotInMempool(aliceFirstSweepTx.TxHash()) - } + // We expect two pending sweeps for Alice - anchor and commit outputs. + ht.AssertNumPendingSweeps(alice, 2) // We also remember the positions of fee functions used by Alice and // Bob. They will be used to calculate the expected fee rates later. - // - // Alice's sweeping tx has just been created, so she is at the starting - // position. For Bob, due to the above mined blocks, his fee function - // is now at position 2. - alicePosition, bobPosition := uint32(aliceStartPosition), uint32(2) + alicePosition, bobPosition := uint32(0), uint32(0) // We should see two txns in the mempool: // - Alice's sweeping tx, which sweeps her commit output at the @@ -1508,8 +1302,7 @@ func testSweepCommitOutputAndAnchor(ht *lntest.HarnessTest) { // Assume the first tx is Alice's sweeping tx, if the second tx has a // larger output value, then that's Alice's as her to_local value is // much gearter. - aliceSweepTx := txns[0] - bobSweepTx = txns[1] + aliceSweepTx, bobSweepTx := txns[0], txns[1] // Swap them if bobSweepTx is smaller. if bobSweepTx.TxOut[0].Value > aliceSweepTx.TxOut[0].Value { @@ -1523,20 +1316,6 @@ func testSweepCommitOutputAndAnchor(ht *lntest.HarnessTest) { require.Len(ht, aliceSweepTx.TxIn, 1) require.Len(ht, aliceSweepTx.TxOut, 1) - // We now check Alice's sweeping tx to see if it's already published. - // - // TODO(yy): remove this check once we have better block control. - aliceSweeps = ht.AssertNumPendingSweeps(alice, 2) - aliceCommit = aliceSweeps[0] - if aliceCommit.AmountSat < aliceSweeps[1].AmountSat { - aliceCommit = aliceSweeps[1] - } - if aliceCommit.BroadcastAttempts > 1 { - ht.Logf("!!! Alice's commit sweep has already been broadcast, "+ - "broadcast_attempts=%v", aliceCommit.BroadcastAttempts) - alicePosition = aliceCommit.BroadcastAttempts - } - // Alice's sweeping tx should use the min relay fee rate as there's no // deadline pressure. aliceStartingFeeRate := chainfee.FeePerKwFloor @@ -1551,7 +1330,7 @@ func testSweepCommitOutputAndAnchor(ht *lntest.HarnessTest) { aliceTxWeight := uint64(ht.CalculateTxWeight(aliceSweepTx)) aliceEndingFeeRate := sweep.DefaultMaxFeeRate.FeePerKWeight() aliceFeeRateDelta := (aliceEndingFeeRate - aliceStartingFeeRate) / - chainfee.SatPerKWeight(deadlineA-1) + chainfee.SatPerKWeight(deadlineA) aliceFeeRate := ht.CalculateTxFeeRate(aliceSweepTx) expectedFeeRateAlice := aliceStartingFeeRate + @@ -1560,111 +1339,35 @@ func testSweepCommitOutputAndAnchor(ht *lntest.HarnessTest) { uint64(aliceFeeRate), 0.02, "want %v, got %v", expectedFeeRateAlice, aliceFeeRate) - // We now check Bob' sweeping tx. - // - // The above mined block will trigger Bob's sweeper to RBF his previous - // sweeping tx, which will fail due to RBF rule#4 - the additional fees - // paid are not sufficient. This happens as our default incremental - // relay fee rate is 1 sat/vb, with the tx size of 771 weight units, or - // 192 vbytes, we need to pay at least 192 sats more to be able to RBF. - // However, since Bob's budget delta is (100_000 + 330) * 0.5 / 1008 = - // 49.77 sats, it means Bob can only perform a successful RBF every 4 - // blocks. - // - // Assert Bob's sweeping tx is not RBFed. - bobFeeRate := ht.CalculateTxFeeRate(bobSweepTx) - expectedFeeRateBob := bobStartFeeRate - require.InEpsilonf(ht, uint64(expectedFeeRateBob), uint64(bobFeeRate), - 0.01, "want %d, got %d", expectedFeeRateBob, bobFeeRate) - - // reloclateAlicePosition is a temp hack to find the actual fee - // function position used for Alice. Due to block sync issue among the - // subsystems, we can end up having this situation: - // - sweeper is at block 2, starts sweeping an input with deadline 100. - // - fee bumper is at block 1, and thinks the conf target is 99. - // - new block 3 arrives, the func now is at position 2. + // We now check Bob's sweeping tx. // - // TODO(yy): fix it using `blockbeat`. - reloclateAlicePosition := func() { - // Mine an empty block to trigger the possible RBF attempts. - ht.MineEmptyBlocks(1) - - // Increase the positions for both fee functions. - alicePosition++ - bobPosition++ - - // We expect two pending sweeps for both nodes as we are mining - // empty blocks. - ht.AssertNumPendingSweeps(alice, 2) - ht.AssertNumPendingSweeps(bob, 2) - - // We expect to see both Alice's and Bob's sweeping txns in the - // mempool. - ht.Miner.AssertNumTxsInMempool(2) - - // Make sure Alice's old sweeping tx has been removed from the - // mempool. - ht.Miner.AssertTxNotInMempool(aliceSweepTx.TxHash()) - - // We should see two txns in the mempool: - // - Alice's sweeping tx, which sweeps both her anchor and - // commit outputs, using the increased fee rate. - // - Bob's previous sweeping tx, which sweeps both his anchor - // and commit outputs, at the possible increased fee rate. - txns = ht.Miner.GetNumTxsFromMempool(2) - - // Assume the first tx is Alice's sweeping tx, if the second tx - // has a larger output value, then that's Alice's as her - // to_local value is much gearter. - aliceSweepTx = txns[0] - bobSweepTx = txns[1] - - // Swap them if bobSweepTx is smaller. - if bobSweepTx.TxOut[0].Value > aliceSweepTx.TxOut[0].Value { - aliceSweepTx, bobSweepTx = bobSweepTx, aliceSweepTx - } - - // Alice's sweeping tx should be increased. - aliceFeeRate := ht.CalculateTxFeeRate(aliceSweepTx) - expectedFeeRate := aliceStartingFeeRate + - aliceFeeRateDelta*chainfee.SatPerKWeight(alicePosition) - - ht.Logf("Alice(deadline=%v): txWeight=%v, want feerate=%v, "+ - "got feerate=%v, delta=%v", deadlineA-alicePosition, - aliceTxWeight, expectedFeeRate, aliceFeeRate, - aliceFeeRateDelta) - - nextPosition := alicePosition + 1 - nextFeeRate := aliceStartingFeeRate + - aliceFeeRateDelta*chainfee.SatPerKWeight(nextPosition) - - // Calculate the distances. - delta := math.Abs(float64(aliceFeeRate - expectedFeeRate)) - deltaNext := math.Abs(float64(aliceFeeRate - nextFeeRate)) - - // Exit early if the first distance is smaller - it means we - // are at the right fee func position. - if delta < deltaNext { - require.InEpsilonf(ht, uint64(expectedFeeRate), - uint64(aliceFeeRate), 0.02, "want %v, got %v "+ - "in tx=%v", expectedFeeRate, - aliceFeeRate, aliceSweepTx.TxHash()) + // Bob's sweeping tx should have 2 inputs, one from his commit output, + // the other from his anchor output. + require.Len(ht, bobSweepTx.TxIn, 2) - return - } + // Because Bob is sweeping without deadline pressure, the starting fee + // rate should be the min relay fee rate. + bobStartFeeRate := ht.CalculateTxFeeRate(bobSweepTx) + require.InEpsilonf(ht, uint64(chainfee.FeePerKwFloor), + uint64(bobStartFeeRate), 0.01, "want %v, got %v", + chainfee.FeePerKwFloor, bobStartFeeRate) - alicePosition++ - ht.Logf("Jump position for Alice(deadline=%v): txWeight=%v, "+ - "want feerate=%v, got feerate=%v, delta=%v", - deadlineA-alicePosition, aliceTxWeight, nextFeeRate, - aliceFeeRate, aliceFeeRateDelta) + // With Bob's starting fee rate being validated, we now calculate his + // ending fee rate and fee rate delta. + // + // Bob sweeps two inputs - anchor and commit, so the starting budget + // should come from the sum of these two. + bobValue := btcutil.Amount(bobToLocal + 330) + bobBudget := bobValue.MulF64(contractcourt.DefaultBudgetRatio) - require.InEpsilonf(ht, uint64(nextFeeRate), - uint64(aliceFeeRate), 0.02, "want %v, got %v in tx=%v", - nextFeeRate, aliceFeeRate, aliceSweepTx.TxHash()) - } + // Calculate the ending fee rate and fee rate delta used in his fee + // function. + bobTxWeight := ht.CalculateTxWeight(bobSweepTx) + bobEndingFeeRate := chainfee.NewSatPerKWeight(bobBudget, bobTxWeight) + bobFeeRateDelta := (bobEndingFeeRate - bobStartFeeRate) / + chainfee.SatPerKWeight(deadlineB-1) - reloclateAlicePosition() + expectedFeeRateBob := bobStartFeeRate // We now mine 7 empty blocks. For each block mined, we'd see Alice's // sweeping tx being RBFed. For Bob, he performs a fee bump every @@ -1672,7 +1375,7 @@ func testSweepCommitOutputAndAnchor(ht *lntest.HarnessTest) { // the fee bumps is not sufficient to meet the fee requirements // enforced by RBF. Since his fee function is already at position 1, // mining 7 more blocks means he will RBF his sweeping tx twice. - for i := 1; i < 7; i++ { + for i := 1; i < 8; i++ { // Mine an empty block to trigger the possible RBF attempts. ht.MineEmptyBlocks(1) diff --git a/lntest/harness_assertion.go b/lntest/harness_assertion.go index f8c1c9716cb..81dc26e4196 100644 --- a/lntest/harness_assertion.go +++ b/lntest/harness_assertion.go @@ -2633,3 +2633,36 @@ func (h *HarnessTest) FindSweepingTxns(txns []*wire.MsgTx, return sweepTxns } + +// AssertForceCloseAndAnchorTxnsInMempool asserts that the force close and +// anchor sweep txns are found in the mempool and returns the force close tx +// and the anchor sweep tx. +func (h *HarnessTest) AssertForceCloseAndAnchorTxnsInMempool() (*wire.MsgTx, + *wire.MsgTx) { + + // Assert there are two txns in the mempool. + txns := h.Miner.GetNumTxsFromMempool(2) + + // Assume the first is the force close tx. + forceCloseTx, anchorSweepTx := txns[0], txns[1] + + // Get the txid. + closeTxid := forceCloseTx.TxHash() + + // We now check whether there is an anchor input used in the assumed + // anchorSweepTx by checking every input's previous outpoint against + // the assumed closingTxid. If we fail to find one, it means the first + // item from the above txns is the anchor sweeping tx. + for _, inp := range anchorSweepTx.TxIn { + if inp.PreviousOutPoint.Hash == closeTxid { + // Found a match, this is indeed the anchor sweeping tx + // so we return it here. + return forceCloseTx, anchorSweepTx + } + } + + // The assumed order is incorrect so we swap and return. + forceCloseTx, anchorSweepTx = anchorSweepTx, forceCloseTx + + return forceCloseTx, anchorSweepTx +} From 1378022093c640175333f2aa4a26677ae12a217e Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 22 May 2024 16:51:36 +0800 Subject: [PATCH 17/19] contractcourt+sweep: improve loggings --- contractcourt/channel_arbitrator.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 2192f794300..3a6519da10b 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -1615,8 +1615,8 @@ func (c *ChannelArbitrator) advanceState( for { priorState = c.state log.Debugf("ChannelArbitrator(%v): attempting state step with "+ - "trigger=%v from state=%v", c.cfg.ChanPoint, trigger, - priorState) + "trigger=%v from state=%v at height=%v", + c.cfg.ChanPoint, trigger, priorState, triggerHeight) nextState, closeTx, err := c.stateStep( triggerHeight, trigger, confCommitSet, @@ -2880,16 +2880,18 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { // We have broadcasted our commitment, and it is now confirmed // on-chain. case closeInfo := <-c.cfg.ChainEvents.LocalUnilateralClosure: - log.Infof("ChannelArbitrator(%v): local on-chain "+ - "channel close", c.cfg.ChanPoint) - if c.state != StateCommitmentBroadcasted { log.Errorf("ChannelArbitrator(%v): unexpected "+ "local on-chain channel close", c.cfg.ChanPoint) } + closeTx := closeInfo.CloseTx + log.Infof("ChannelArbitrator(%v): local force close "+ + "tx=%v confirmed", c.cfg.ChanPoint, + closeTx.TxHash()) + contractRes := &ContractResolutions{ CommitHash: closeTx.TxHash(), CommitResolution: closeInfo.CommitResolution, From 3b5f718a281792cd94e560bca9c1d1fee89bcd87 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 22 May 2024 17:21:58 +0800 Subject: [PATCH 18/19] x --- contractcourt/channel_arbitrator.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 3a6519da10b..1889699b926 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -2779,13 +2779,20 @@ func (c *ChannelArbitrator) notifyResolvers(height int32) { log.Debugf("Notifying %v resolvers of new block height %v", len(c.activeResolvers), height) - for _, blockChan := range c.activeResolvers { + // notifyHeight is a helper closure that sends the block height to a + // single resolver. + notifyHeight := func(height int32, blockChan chan int32) { select { case blockChan <- height: case <-c.quit: } } + // Notify all resolvers in parallel. + for _, blockChan := range c.activeResolvers { + go notifyHeight(height, blockChan) + } + log.Debugf("Notified %v resolvers of new block height %v", len(c.activeResolvers), height) } From 6ec05ab4c94544cb9bdb497763b6c058ff442e60 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 23 May 2024 22:44:58 +0800 Subject: [PATCH 19/19] x --- chainio/blockbeat.go | 157 +++++++++++++++++++++------- contractcourt/chain_arbitrator.go | 34 +----- contractcourt/channel_arbitrator.go | 16 ++- 3 files changed, 140 insertions(+), 67 deletions(-) diff --git a/chainio/blockbeat.go b/chainio/blockbeat.go index 07087445f5d..618b7849184 100644 --- a/chainio/blockbeat.go +++ b/chainio/blockbeat.go @@ -52,6 +52,88 @@ func NewBeat(epoch chainntnfs.BlockEpoch) Beat { } } +// NotifySequential takes a list of consumers and notify them about the new +// epoch sequentially. +func (b *Beat) NotifySequential(consumers []Consumer) error { + for _, c := range consumers { + // Construct a new beat with a buffered error chan. + beat := NewBeat(b.Epoch) + + // Record the time it takes the consumer to process this block. + start := time.Now() + + log.Tracef("Sending block %v to consumer: %v", b.Epoch.Height, + c.Name()) + + // We expect the consumer to finish processing this block under + // 30s, otherwise a timeout error is returned. + err, timeout := fn.RecvOrTimeout( + c.ProcessBlock(beat), DefaultProcessBlockTimeout, + ) + if err != nil { + return fmt.Errorf("%s: ProcessBlock got: %w", c.Name(), + err) + } + if timeout != nil { + return fmt.Errorf("%s timed out while processing block", + c.Name()) + } + + log.Debugf("Consumer [%s] processed block %d in %v", c.Name(), + b.Epoch.Height, time.Since(start)) + } + + return nil +} + +// NotifyConcurrent notifies each queue concurrently about the latest block +// epoch. +func (b *Beat) NotifyConcurrent(consumers []Consumer, quit chan struct{}) { + // errChans is a map of channels that will be used to receive errors + // returned from notifying the consumers. + errChans := make(map[string]chan error, len(consumers)) + + // Notify each queue in goroutines. + for _, c := range consumers { + log.Tracef("Sending block %v to consumer: %v", b.Epoch.Height, + c.Name()) + + // Create a signal chan. + errChan := make(chan error) + errChans[c.Name()] = errChan + + // Notify each consumer concurrently. + go func(c Consumer, epoch chainntnfs.BlockEpoch) { + // Construct a new beat with a buffered error chan. + beat := NewBeat(epoch) + + // Notify each consumer in this queue sequentially. + errChan <- beat.NotifySequential([]Consumer{c}) + }(c, b.Epoch) + } + + // Wait for all consumers in each queue to finish. + for name, errChan := range errChans { + select { + case err := <-errChan: + // It's critical that the subsystems can process blocks + // correctly and timely, if an error returns, we'd + // gracefully shutdown lnd to bring attentions. + if err != nil { + log.Criticalf("Consumer=%v failed to process "+ + "block: %v", name, err) + + return + } + + log.Debugf("Notified consumer=%v on block %d", name, + b.Epoch.Height) + + case <-quit: + } + } +} + // BlockBeat is a service that handles dispatching new blocks to `lnd`'s // subsystems. During startup, subsystems that are block-driven should // implement the `Consumer` interface and register themselves via @@ -198,8 +280,11 @@ func (b *BlockBeat) notifyQueues() { defer b.wg.Done() + // Construct a new beat with a buffered error chan. + beat := NewBeat(epoch) + // Notify each consumer in this queue sequentially. - errChan <- b.notifyQueue(c, epoch) + errChan <- beat.NotifySequential(c) }(qid, consumers, b.blockEpoch) } @@ -225,38 +310,38 @@ func (b *BlockBeat) notifyQueues() { } } -// notifyQueue takes a list of consumers and notify them about the new epoch -// sequentially. -func (b *BlockBeat) notifyQueue(queue []Consumer, - epoch chainntnfs.BlockEpoch) error { - - for _, c := range queue { - log.Debugf("Notifying consumer [%s] on block %d", c.Name(), - epoch.Height) - - // Construct a new beat with a buffered error chan. - beat := NewBeat(epoch) - - // Record the time it takes the consumer to process this block. - start := time.Now() - - // We expect the consumer to finish processing this block under - // 30s, otherwise a timeout error is returned. - err, timeout := fn.RecvOrTimeout( - c.ProcessBlock(beat), DefaultProcessBlockTimeout, - ) - if err != nil { - return fmt.Errorf("%s: ProcessBlock got: %w", c.Name(), - err) - } - if timeout != nil { - return fmt.Errorf("%s timed out while processing block", - c.Name()) - } - - log.Debugf("Consumer [%s] processed block %d in %v", c.Name(), - epoch.Height, time.Since(start)) - } - - return nil -} +// // notifyQueue takes a list of consumers and notify them about the new epoch +// // sequentially. +// func (b *BlockBeat) notifyQueue(queue []Consumer, +// epoch chainntnfs.BlockEpoch) error { + +// for _, c := range queue { +// log.Debugf("Notifying consumer [%s] on block %d", c.Name(), +// epoch.Height) + +// // Construct a new beat with a buffered error chan. +// beat := NewBeat(epoch) + +// // Record the time it takes the consumer to process this block. +// start := time.Now() + +// // We expect the consumer to finish processing this block under +// // 30s, otherwise a timeout error is returned. +// err, timeout := fn.RecvOrTimeout( +// c.ProcessBlock(beat), DefaultProcessBlockTimeout, +// ) +// if err != nil { +// return fmt.Errorf("%s: ProcessBlock got: %w", c.Name(), +// err) +// } +// if timeout != nil { +// return fmt.Errorf("%s timed out while processing block", +// c.Name()) +// } + +// log.Debugf("Consumer [%s] processed block %d in %v", c.Name(), +// epoch.Height, time.Since(start)) +// } + +// return nil +// } diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index c4951aa6478..03cde88b8fc 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -801,9 +801,7 @@ func (c *ChainArbitrator) sendBlockAndWait(beat chainio.Beat) { c.Lock() // Create a map to record active channel arbitrator. - channels := make( - map[wire.OutPoint]*ChannelArbitrator, len(c.activeChannels), - ) + channels := make([]chainio.Consumer, 0, len(c.activeChannels)) // Create a map of go chans to store the done signals. doneChans := make( @@ -812,15 +810,17 @@ func (c *ChainArbitrator) sendBlockAndWait(beat chainio.Beat) { // Copy the active channels to the map. for op, channel := range c.activeChannels { - channels[op] = channel + channels = append(channels, channel) doneChans[op] = make(chan struct{}) } c.Unlock() + beat.NotifyConsumers(channels) + // Iterate all the copied channels and send the blockbeat to them. + for _, channel := range channels { - beat := chainio.NewBeat(beat.Epoch) // Deliver the block to the channel arbitrator. go func(ch *ChannelArbitrator, beat chainio.Beat) { @@ -846,30 +846,6 @@ func (c *ChainArbitrator) sendBlockAndWait(beat chainio.Beat) { } } -// waitForChanArbProcessBlock waits for the channel arbitrator to process the -// block. It will log an error if there's an error returned from the channel -// arbitrator or the processing timed out. -func (c *ChainArbitrator) waitForChanArbProcessBlock(chanArb *ChannelArbitrator, - beat chainio.Beat) { - - log.Debugf("Sending block=%d to ChannelArbitrator(%v)", - beat.Epoch.Height, chanArb.cfg.ChanPoint) - - // We expect the channel arbitrator to finish processing this block - // under 30s, otherwise a timeout error is returned. - err, timeout := fn.RecvOrTimeout( - chanArb.processBlock(beat), chainio.DefaultProcessBlockTimeout, - ) - if err != nil { - log.Errorf("ChannelArbitrator(%v): process block got: %v", - chanArb.cfg.ChanPoint, err) - } - if timeout != nil { - log.Errorf("ChannelArbitrator(%v): process block timeout: %v", - chanArb.cfg.ChanPoint, err) - } -} - // republishClosingTxs will load any stored cooperative or unilateral closing // transactions and republish them. This helps ensure propagation of the // transactions in the event that prior publications failed. diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 1889699b926..e3547f8f344 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -413,6 +413,9 @@ func NewChannelArbitrator(cfg ChannelArbitratorConfig, } } +// Compile-time check for the WebFeeService interface. +var _ chainio.Consumer = (*ChannelArbitrator)(nil) + // chanArbStartState contains the information from disk that we need to start // up a channel arbitrator. type chanArbStartState struct { @@ -3194,9 +3197,11 @@ func (c *ChannelArbitrator) handleBlockbeat(beat chainio.Beat) error { return nil } -// processBlock sends the specified blockbeat to the channel arbitrator's inner +// ProcessBlock sends the specified blockbeat to the channel arbitrator's inner // loop for processing. -func (c *ChannelArbitrator) processBlock(beat chainio.Beat) <-chan error { +// +// NOTE: Part of chainio.Consumer interface. +func (c *ChannelArbitrator) ProcessBlock(beat chainio.Beat) <-chan error { select { case c.blockBeatChan <- beat: log.Debugf("Received block beat for height=%d", @@ -3209,6 +3214,13 @@ func (c *ChannelArbitrator) processBlock(beat chainio.Beat) <-chan error { return beat.Err } +// Name returns a human-readable string for this subsystem. +// +// NOTE: Part of chainio.Consumer interface. +func (c *ChannelArbitrator) Name() string { + return fmt.Sprint("ChannelArbitrator(%v)", c.cfg.ChanPoint) +} + // checkLegacyBreach returns StateFullyResolved if the channel was closed with // a breach transaction before the channel arbitrator launched its own breach // resolver. StateContractClosed is returned if this is a modern breach close