diff --git a/chainio/blockbeat.go b/chainio/blockbeat.go new file mode 100644 index 00000000000..618b7849184 --- /dev/null +++ b/chainio/blockbeat.go @@ -0,0 +1,347 @@ +package chainio + +import ( + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/fn" +) + +// DefaultProcessBlockTimeout is the timeout value used when waiting for one +// consumer to finish processing the new block epoch. +var DefaultProcessBlockTimeout = 30 * time.Second + +// Consumer defines a blockbeat consumer interface. Subsystems that need block +// info should implement it. +type Consumer interface { + // Name returns a human-readable string for this subsystem. + Name() string + + // ProcessBlock takes a beat and processes it. A receive-only error + // chan must be returned. + // + // NOTE: When implementing this, it's very important to send back the + // error or nil to the channel immediately, otherwise BlockBeat will + // timeout and lnd will shutdown. + ProcessBlock(b Beat) <-chan error +} + +// Beat contains the block epoch and a buffer error chan. +// +// TODO(yy): extend this to check for confirmation status - which serves as the +// single source of truth, to avoid the potential race between receiving blocks +// and `GetTransactionDetails/RegisterSpendNtfn/RegisterConfirmationsNtfn`. +type Beat struct { + // Epoch is the current block epoch the blockbeat is aware of. + Epoch chainntnfs.BlockEpoch + + // Err is a buffered chan that receives an error or nil from + // ProcessBlock. + Err chan error +} + +// NewBeat creates a new beat with the specified block epoch and a buffered +// error chan. +func NewBeat(epoch chainntnfs.BlockEpoch) Beat { + return Beat{ + Epoch: epoch, + Err: make(chan error, 1), + } +} + +// NotifySequential takes a list of consumers and notify them about the new +// epoch sequentially. +func (b *Beat) NotifySequential(consumers []Consumer) error { + for _, c := range consumers { + // Construct a new beat with a buffered error chan. + beat := NewBeat(b.Epoch) + + // Record the time it takes the consumer to process this block. + start := time.Now() + + log.Tracef("Sending block %v to consumer: %v", b.Epoch.Height, + c.Name()) + + // We expect the consumer to finish processing this block under + // 30s, otherwise a timeout error is returned. + err, timeout := fn.RecvOrTimeout( + c.ProcessBlock(beat), DefaultProcessBlockTimeout, + ) + if err != nil { + return fmt.Errorf("%s: ProcessBlock got: %w", c.Name(), + err) + } + if timeout != nil { + return fmt.Errorf("%s timed out while processing block", + c.Name()) + } + + log.Debugf("Consumer [%s] processed block %d in %v", c.Name(), + b.Epoch.Height, time.Since(start)) + } + + return nil +} + +// NotifyConcurrent notifies each queue concurrently about the latest block +// epoch. +func (b *Beat) NotifyConcurrent(consumers []Consumer, quit chan struct{}) { + // errChans is a map of channels that will be used to receive errors + // returned from notifying the consumers. + errChans := make(map[string]chan error, len(consumers)) + + // Notify each queue in goroutines. + for _, c := range consumers { + log.Tracef("Sending block %v to consumer: %v", b.Epoch.Height, + c.Name()) + + // Create a signal chan. + errChan := make(chan error) + errChans[c.Name()] = errChan + + // Notify each consumer concurrently. + go func(c Consumer, epoch chainntnfs.BlockEpoch) { + // Construct a new beat with a buffered error chan. + beat := NewBeat(epoch) + + // Notify each consumer in this queue sequentially. + errChan <- beat.NotifySequential([]Consumer{c}) + }(c, b.Epoch) + } + + // Wait for all consumers in each queue to finish. + for name, errChan := range errChans { + select { + case err := <-errChan: + // It's critical that the subsystems can process blocks + // correctly and timely, if an error returns, we'd + // gracefully shutdown lnd to bring attentions. + if err != nil { + log.Criticalf("Consumer=%v failed to process "+ + "block: %v", name, err) + + return + } + + log.Debugf("Notified consumer=%v on block %d", name, + b.Epoch.Height) + + case <-quit: + } + } +} + +// BlockBeat is a service that handles dispatching new blocks to `lnd`'s +// subsystems. During startup, subsystems that are block-driven should +// implement the `Consumer` interface and register themselves via +// `RegisterQueue`. When two subsystems are independent of each other, they +// should be registered in differet queues so blocks are notified concurrently. +// Otherwise, when living in the same queue, the subsystems are notified of the +// new blocks sequentially, which means it's critical to understand the +// relationship of these systems to properly handle the order. +type BlockBeat struct { + wg sync.WaitGroup + + // notifier is used to receive new block epochs. + notifier chainntnfs.ChainNotifier + + // blockEpoch is the latest block epoch received . + blockEpoch chainntnfs.BlockEpoch + + // consumerQueues is a map of consumers that will receive blocks. Each + // queue is notified concurrently, and consumers in the same queue is + // notified sequentially. + consumerQueues map[uint32][]Consumer + + // counter is used to assign a unique id to each queue. + counter atomic.Uint32 + + // quit is used to signal the BlockBeat to stop. + quit chan struct{} +} + +// NewBlockBeat returns a new blockbeat instance. +func NewBlockBeat(notifier chainntnfs.ChainNotifier) *BlockBeat { + return &BlockBeat{ + notifier: notifier, + quit: make(chan struct{}), + consumerQueues: make(map[uint32][]Consumer), + } +} + +// RegisterQueue takes a list of consumers and register them in the same queue. +// +// NOTE: these consumers are notified sequentially. +func (b *BlockBeat) RegisterQueue(consumers []Consumer) { + qid := b.counter.Add(1) + + b.consumerQueues[qid] = append(b.consumerQueues[qid], consumers...) + log.Infof("Registered queue=%d with %d blockbeat consumers", qid, + len(consumers)) + + for _, c := range consumers { + log.Debugf("Consumer [%s] registered in queue %d", c.Name(), + qid) + } +} + +// Start starts the blockbeat - it registers a block notification and monitors +// and dispatches new blocks in a goroutine. It will refuse to start if there +// are no registered consumers. +func (b *BlockBeat) Start() error { + // Make sure consumers are registered. + if len(b.consumerQueues) == 0 { + return fmt.Errorf("no consumers registered") + } + + // Start listening to new block epochs. + blockEpochs, err := b.notifier.RegisterBlockEpochNtfn(nil) + if err != nil { + return fmt.Errorf("register block epoch ntfn: %w", err) + } + + log.Infof("BlockBeat is starting with %d consumer queues", + len(b.consumerQueues)) + defer log.Debug("BlockBeat started") + + b.wg.Add(1) + go b.dispatchBlocks(blockEpochs) + + return nil +} + +// Stop shuts down the blockbeat. +func (b *BlockBeat) Stop() { + log.Info("BlockBeat is stopping") + defer log.Debug("BlockBeat stopped") + + // Signal the dispatchBlocks goroutine to stop. + close(b.quit) + b.wg.Wait() +} + +// dispatchBlocks listens to new block epoch and dispatches it to all the +// consumers. Each queue in BlockBeat is notified concurrently, and the +// consumers in the same queue are notified sequentially. +func (b *BlockBeat) dispatchBlocks(blockEpochs *chainntnfs.BlockEpochEvent) { + defer b.wg.Done() + defer blockEpochs.Cancel() + + for { + select { + case blockEpoch, ok := <-blockEpochs.Epochs: + if !ok { + log.Debugf("Block epoch channel closed") + return + } + + log.Infof("Received new block %v at height %d, "+ + "notifying consumers...", blockEpoch.Hash, + blockEpoch.Height) + + // Update the current block epoch. + b.blockEpoch = *blockEpoch + + // Notify all consumers. + b.notifyQueues() + + log.Infof("Notified all consumers on block %v at "+ + "height %d", blockEpoch.Hash, blockEpoch.Height) + + case <-b.quit: + log.Debugf("BlockBeat quit signal received") + return + } + } +} + +// notifyQueues notifies each queue concurrently about the latest block epoch. +func (b *BlockBeat) notifyQueues() { + // errChans is a map of channels that will be used to receive errors + // returned from notifying the consumers. + errChans := make(map[uint32]chan error, len(b.consumerQueues)) + + // Notify each queue in goroutines. + for qid, consumers := range b.consumerQueues { + log.Debugf("Notifying queue=%d on block %d", qid, + b.blockEpoch.Height) + + // Create a signal chan. + errChan := make(chan error) + errChans[qid] = errChan + + // Notify each queue concurrently. + b.wg.Add(1) + go func(qid uint32, c []Consumer, + epoch chainntnfs.BlockEpoch) { + + defer b.wg.Done() + + // Construct a new beat with a buffered error chan. + beat := NewBeat(epoch) + + // Notify each consumer in this queue sequentially. + errChan <- beat.NotifySequential(c) + }(qid, consumers, b.blockEpoch) + } + + // Wait for all consumers in each queue to finish. + for qid, errChan := range errChans { + select { + case err := <-errChan: + // It's critical that the subsystems can process blocks + // correctly and timely, if an error returns, we'd + // gracefully shutdown lnd to bring attentions. + if err != nil { + log.Criticalf("Queue=%d failed to process "+ + "block: %v", qid, err) + + return + } + + log.Debugf("Notified queue=%d on block %d", qid, + b.blockEpoch.Height) + + case <-b.quit: + } + } +} + +// // notifyQueue takes a list of consumers and notify them about the new epoch +// // sequentially. +// func (b *BlockBeat) notifyQueue(queue []Consumer, +// epoch chainntnfs.BlockEpoch) error { + +// for _, c := range queue { +// log.Debugf("Notifying consumer [%s] on block %d", c.Name(), +// epoch.Height) + +// // Construct a new beat with a buffered error chan. +// beat := NewBeat(epoch) + +// // Record the time it takes the consumer to process this block. +// start := time.Now() + +// // We expect the consumer to finish processing this block under +// // 30s, otherwise a timeout error is returned. +// err, timeout := fn.RecvOrTimeout( +// c.ProcessBlock(beat), DefaultProcessBlockTimeout, +// ) +// if err != nil { +// return fmt.Errorf("%s: ProcessBlock got: %w", c.Name(), +// err) +// } +// if timeout != nil { +// return fmt.Errorf("%s timed out while processing block", +// c.Name()) +// } + +// log.Debugf("Consumer [%s] processed block %d in %v", c.Name(), +// epoch.Height, time.Since(start)) +// } + +// return nil +// } diff --git a/chainio/blockbeat_test.go b/chainio/blockbeat_test.go new file mode 100644 index 00000000000..8d034eda785 --- /dev/null +++ b/chainio/blockbeat_test.go @@ -0,0 +1 @@ +package chainio diff --git a/chainio/log.go b/chainio/log.go new file mode 100644 index 00000000000..fb562abdeb8 --- /dev/null +++ b/chainio/log.go @@ -0,0 +1,29 @@ +package chainio + +import ( + "github.com/btcsuite/btclog" + "github.com/lightningnetwork/lnd/build" +) + +// log is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var log btclog.Logger + +// The default amount of logging is none. +func init() { + UseLogger(build.NewSubLogger("CHIO", nil)) +} + +// DisableLog disables all library log output. Logging output is disabled +// by default until UseLogger is called. +func DisableLog() { + UseLogger(btclog.Disabled) +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/chainntnfs/best_block_view.go b/chainntnfs/best_block_view.go index c043e68e7e0..ba9295e0530 100644 --- a/chainntnfs/best_block_view.go +++ b/chainntnfs/best_block_view.go @@ -68,7 +68,7 @@ func (t *BestBlockTracker) BestBlockHeader() (*wire.BlockHeader, error) { return nil, errors.New("best block header not yet known") } - return epoch.BlockHeader, nil + return &epoch.Block.Header, nil } // updateLoop is a helper that subscribes to the underlying BlockEpochEvent diff --git a/chainntnfs/best_block_view_test.go b/chainntnfs/best_block_view_test.go index 2a55bc8b032..338a8630aad 100644 --- a/chainntnfs/best_block_view_test.go +++ b/chainntnfs/best_block_view_test.go @@ -26,14 +26,15 @@ func (blockEpoch) Generate(r *rand.Rand, size int) reflect.Value { return reflect.ValueOf(blockEpoch(chainntnfs.BlockEpoch{ Hash: &chainHash, Height: r.Int31n(1000000), - BlockHeader: &wire.BlockHeader{ - Version: 2, - PrevBlock: prevBlockHash, - MerkleRoot: merkleRootHash, - Timestamp: time.Now(), - Bits: r.Uint32(), - Nonce: r.Uint32(), - }, + Block: wire.MsgBlock{ + Header: wire.BlockHeader{ + Version: 2, + PrevBlock: prevBlockHash, + MerkleRoot: merkleRootHash, + Timestamp: time.Now(), + Bits: r.Uint32(), + Nonce: r.Uint32(), + }}, })) } @@ -77,7 +78,7 @@ func TestBestBlockTracker(t *testing.T) { header, _ := tracker.BestBlockHeader() return height == uint32(epoch.Height) && - header == epoch.BlockHeader + *header == epoch.Block.Header } idempotence := func(epochRand blockEpoch) bool { epoch := chainntnfs.BlockEpoch(epochRand) diff --git a/chainntnfs/bitcoindnotify/bitcoind.go b/chainntnfs/bitcoindnotify/bitcoind.go index 2bffefdbefd..8addd799367 100644 --- a/chainntnfs/bitcoindnotify/bitcoind.go +++ b/chainntnfs/bitcoindnotify/bitcoind.go @@ -178,7 +178,8 @@ func (b *BitcoindNotifier) startNotifier() error { if err != nil { return err } - blockHeader, err := b.chainConn.GetBlockHeader(currentHash) + + block, err := b.GetBlock(currentHash) if err != nil { return err } @@ -189,9 +190,9 @@ func (b *BitcoindNotifier) startNotifier() error { ) b.bestBlock = chainntnfs.BlockEpoch{ - Height: currentHeight, - Hash: currentHash, - BlockHeader: blockHeader, + Height: currentHeight, + Hash: currentHash, + Block: *block, } b.wg.Add(1) @@ -350,9 +351,7 @@ out: // a notification for the current tip. if msg.bestBlock == nil { b.notifyBlockEpochClient( - msg, b.bestBlock.Height, - b.bestBlock.Hash, - b.bestBlock.BlockHeader, + msg, b.bestBlock, ) msg.errorChan <- nil @@ -372,10 +371,7 @@ out: } for _, block := range missedBlocks { - b.notifyBlockEpochClient( - msg, block.Height, block.Hash, - block.BlockHeader, - ) + b.notifyBlockEpochClient(msg, block) } msg.errorChan <- nil @@ -384,15 +380,14 @@ out: case ntfn := <-b.chainConn.Notifications(): switch item := ntfn.(type) { case chain.BlockConnected: - blockHeader, err := - b.chainConn.GetBlockHeader(&item.Hash) + block, err := b.GetBlock(&item.Hash) if err != nil { chainntnfs.Log.Errorf("Unable to fetch "+ "block header: %v", err) continue } - if blockHeader.PrevBlock != *b.bestBlock.Hash { + if block.Header.PrevBlock != *b.bestBlock.Hash { // Handle the case where the notifier // missed some blocks from its chain // backend. @@ -424,9 +419,9 @@ out: } newBlock := chainntnfs.BlockEpoch{ - Height: item.Height, - Hash: &item.Hash, - BlockHeader: blockHeader, + Height: item.Height, + Hash: &item.Hash, + Block: *block, } if err := b.handleBlockConnected(newBlock); err != nil { chainntnfs.Log.Error(err) @@ -631,61 +626,49 @@ func (b *BitcoindNotifier) confDetailsManually(confRequest chainntnfs.ConfReques // handleBlockConnected applies a chain update for a new block. Any watched // transactions included this block will processed to either send notifications // now or after numConfirmations confs. -func (b *BitcoindNotifier) handleBlockConnected(block chainntnfs.BlockEpoch) error { - // First, we'll fetch the raw block as we'll need to gather all the - // transactions to determine whether any are relevant to our registered - // clients. - rawBlock, err := b.GetBlock(block.Hash) - if err != nil { - return fmt.Errorf("unable to get block: %w", err) - } - utilBlock := btcutil.NewBlock(rawBlock) +func (b *BitcoindNotifier) handleBlockConnected( + epoch chainntnfs.BlockEpoch) error { + + utilBlock := btcutil.NewBlock(&epoch.Block) // We'll then extend the txNotifier's height with the information of // this new block, which will handle all of the notification logic for // us. - err = b.txNotifier.ConnectTip(utilBlock, uint32(block.Height)) + err := b.txNotifier.ConnectTip(utilBlock, uint32(epoch.Height)) if err != nil { return fmt.Errorf("unable to connect tip: %w", err) } - chainntnfs.Log.Infof("New block: height=%v, sha=%v", block.Height, - block.Hash) + chainntnfs.Log.Infof("New block: height=%v, sha=%v", epoch.Height, + epoch.Hash) // Now that we've guaranteed the new block extends the txNotifier's // current tip, we'll proceed to dispatch notifications to all of our // registered clients whom have had notifications fulfilled. Before // doing so, we'll make sure update our in memory state in order to // satisfy any client requests based upon the new block. - b.bestBlock = block + b.bestBlock = epoch + + b.notifyBlockEpochs(epoch) - b.notifyBlockEpochs(block.Height, block.Hash, block.BlockHeader) - return b.txNotifier.NotifyHeight(uint32(block.Height)) + return b.txNotifier.NotifyHeight(uint32(epoch.Height)) } // notifyBlockEpochs notifies all registered block epoch clients of the newly // connected block to the main chain. -func (b *BitcoindNotifier) notifyBlockEpochs(newHeight int32, newSha *chainhash.Hash, - blockHeader *wire.BlockHeader) { - +func (b *BitcoindNotifier) notifyBlockEpochs(block chainntnfs.BlockEpoch) { for _, client := range b.blockEpochClients { - b.notifyBlockEpochClient(client, newHeight, newSha, blockHeader) + b.notifyBlockEpochClient(client, block) } } // notifyBlockEpochClient sends a registered block epoch client a notification // about a specific block. -func (b *BitcoindNotifier) notifyBlockEpochClient(epochClient *blockEpochRegistration, - height int32, sha *chainhash.Hash, header *wire.BlockHeader) { - - epoch := &chainntnfs.BlockEpoch{ - Height: height, - Hash: sha, - BlockHeader: header, - } +func (b *BitcoindNotifier) notifyBlockEpochClient( + epochClient *blockEpochRegistration, epoch chainntnfs.BlockEpoch) { select { - case epochClient.epochQueue.ChanIn() <- epoch: + case epochClient.epochQueue.ChanIn() <- &epoch: case <-epochClient.cancelChan: case <-b.quit: } diff --git a/chainntnfs/btcdnotify/btcd.go b/chainntnfs/btcdnotify/btcd.go index e865426e9af..228169018c1 100644 --- a/chainntnfs/btcdnotify/btcd.go +++ b/chainntnfs/btcdnotify/btcd.go @@ -231,7 +231,7 @@ func (b *BtcdNotifier) startNotifier() error { return err } - bestBlock, err := b.chainConn.GetBlock(currentHash) + bestBlock, err := b.GetBlock(currentHash) if err != nil { b.txUpdates.Stop() b.chainUpdates.Stop() @@ -244,9 +244,9 @@ func (b *BtcdNotifier) startNotifier() error { ) b.bestBlock = chainntnfs.BlockEpoch{ - Height: currentHeight, - Hash: currentHash, - BlockHeader: &bestBlock.Header, + Height: currentHeight, + Hash: currentHash, + Block: *bestBlock, } if err := b.chainConn.NotifyBlocks(); err != nil { @@ -409,9 +409,7 @@ out: // a notification for the current tip. if msg.bestBlock == nil { b.notifyBlockEpochClient( - msg, b.bestBlock.Height, - b.bestBlock.Hash, - b.bestBlock.BlockHeader, + msg, b.bestBlock, ) msg.errorChan <- nil @@ -431,10 +429,7 @@ out: } for _, block := range missedBlocks { - b.notifyBlockEpochClient( - msg, block.Height, block.Hash, - block.BlockHeader, - ) + b.notifyBlockEpochClient(msg, block) } msg.errorChan <- nil @@ -443,16 +438,14 @@ out: case item := <-b.chainUpdates.ChanOut(): update := item.(*chainUpdate) if update.connect { - blockHeader, err := b.chainConn.GetBlockHeader( - update.blockHash, - ) + block, err := b.GetBlock(update.blockHash) if err != nil { chainntnfs.Log.Errorf("Unable to fetch "+ "block header: %v", err) continue } - if blockHeader.PrevBlock != *b.bestBlock.Hash { + if block.Header.PrevBlock != *b.bestBlock.Hash { // Handle the case where the notifier // missed some blocks from its chain // backend @@ -484,9 +477,9 @@ out: } newBlock := chainntnfs.BlockEpoch{ - Height: update.blockHeight, - Hash: update.blockHash, - BlockHeader: blockHeader, + Height: update.blockHeight, + Hash: update.blockHash, + Block: *block, } if err := b.handleBlockConnected(newBlock); err != nil { chainntnfs.Log.Error(err) @@ -730,38 +723,26 @@ func (b *BtcdNotifier) handleBlockConnected(epoch chainntnfs.BlockEpoch) error { // satisfy any client requests based upon the new block. b.bestBlock = epoch - b.notifyBlockEpochs( - epoch.Height, epoch.Hash, epoch.BlockHeader, - ) + b.notifyBlockEpochs(epoch) return b.txNotifier.NotifyHeight(uint32(epoch.Height)) } // notifyBlockEpochs notifies all registered block epoch clients of the newly // connected block to the main chain. -func (b *BtcdNotifier) notifyBlockEpochs(newHeight int32, - newSha *chainhash.Hash, blockHeader *wire.BlockHeader) { - +func (b *BtcdNotifier) notifyBlockEpochs(epoch chainntnfs.BlockEpoch) { for _, client := range b.blockEpochClients { - b.notifyBlockEpochClient( - client, newHeight, newSha, blockHeader, - ) + b.notifyBlockEpochClient(client, epoch) } } // notifyBlockEpochClient sends a registered block epoch client a notification // about a specific block. -func (b *BtcdNotifier) notifyBlockEpochClient(epochClient *blockEpochRegistration, - height int32, sha *chainhash.Hash, blockHeader *wire.BlockHeader) { - - epoch := &chainntnfs.BlockEpoch{ - Height: height, - Hash: sha, - BlockHeader: blockHeader, - } +func (b *BtcdNotifier) notifyBlockEpochClient( + epochClient *blockEpochRegistration, epoch chainntnfs.BlockEpoch) { select { - case epochClient.epochQueue.ChanIn() <- epoch: + case epochClient.epochQueue.ChanIn() <- &epoch: case <-epochClient.cancelChan: case <-b.quit: } diff --git a/chainntnfs/interface.go b/chainntnfs/interface.go index 3337f1451a6..282ac2fcbdc 100644 --- a/chainntnfs/interface.go +++ b/chainntnfs/interface.go @@ -366,8 +366,8 @@ type BlockEpoch struct { // the main chain. Height int32 - // BlockHeader is the block header of this new height. - BlockHeader *wire.BlockHeader + // Block is the full block. + Block wire.MsgBlock } // BlockEpochEvent encapsulates an on-going stream of block epoch @@ -471,6 +471,9 @@ type ChainConn interface { // GetBlockHash returns the hash from a block height. GetBlockHash(blockHeight int64) (*chainhash.Hash, error) + + // GetBlock returns a block from the hash. + GetBlock(hash *chainhash.Hash) (*wire.MsgBlock, error) } // GetCommonBlockAncestorHeight takes in: @@ -555,9 +558,9 @@ func RewindChain(chainConn ChainConn, txNotifier *TxNotifier, currBestBlock BlockEpoch, targetHeight int32) (BlockEpoch, error) { newBestBlock := BlockEpoch{ - Height: currBestBlock.Height, - Hash: currBestBlock.Hash, - BlockHeader: currBestBlock.BlockHeader, + Height: currBestBlock.Height, + Hash: currBestBlock.Hash, + Block: currBestBlock.Block, } for height := currBestBlock.Height; height > targetHeight; height-- { @@ -567,7 +570,7 @@ func RewindChain(chainConn ChainConn, txNotifier *TxNotifier, "find blockhash for disconnected height=%d: %v", height, err) } - header, err := chainConn.GetBlockHeader(hash) + block, err := chainConn.GetBlock(hash) if err != nil { return newBestBlock, fmt.Errorf("unable to get block "+ "header for height=%v", height-1) @@ -584,7 +587,7 @@ func RewindChain(chainConn ChainConn, txNotifier *TxNotifier, } newBestBlock.Height = height - 1 newBestBlock.Hash = hash - newBestBlock.BlockHeader = header + newBestBlock.Block = *block } return newBestBlock, nil @@ -665,7 +668,7 @@ func getMissedBlocks(chainConn ChainConn, startingHeight, return nil, fmt.Errorf("unable to find blockhash for "+ "height=%d: %v", height, err) } - header, err := chainConn.GetBlockHeader(hash) + block, err := chainConn.GetBlock(hash) if err != nil { return nil, fmt.Errorf("unable to find block header "+ "for height=%d: %v", height, err) @@ -674,9 +677,9 @@ func getMissedBlocks(chainConn ChainConn, startingHeight, missedBlocks = append( missedBlocks, BlockEpoch{ - Hash: hash, - Height: height, - BlockHeader: header, + Hash: hash, + Height: height, + Block: *block, }, ) } diff --git a/chainntnfs/neutrinonotify/neutrino.go b/chainntnfs/neutrinonotify/neutrino.go index 80e210c2fd8..500c9481ca5 100644 --- a/chainntnfs/neutrinonotify/neutrino.go +++ b/chainntnfs/neutrinonotify/neutrino.go @@ -177,9 +177,7 @@ func (n *NeutrinoNotifier) startNotifier() error { n.txUpdates.Stop() return err } - startingHeader, err := n.p2pNode.GetBlockHeader( - &startingPoint.Hash, - ) + startingBlock, err := n.p2pNode.GetBlock(startingPoint.Hash) if err != nil { n.txUpdates.Stop() return err @@ -187,7 +185,7 @@ func (n *NeutrinoNotifier) startNotifier() error { n.bestBlock.Hash = &startingPoint.Hash n.bestBlock.Height = startingPoint.Height - n.bestBlock.BlockHeader = startingHeader + n.bestBlock.Block = *startingBlock.MsgBlock() n.txNotifier = chainntnfs.NewTxNotifier( uint32(n.bestBlock.Height), chainntnfs.ReorgSafetyLimit, @@ -472,9 +470,7 @@ func (n *NeutrinoNotifier) notificationDispatcher() { // a notification for the current tip. if msg.bestBlock == nil { n.notifyBlockEpochClient( - msg, n.bestBlock.Height, - n.bestBlock.Hash, - n.bestBlock.BlockHeader, + msg, n.bestBlock, ) msg.errorChan <- nil @@ -498,10 +494,7 @@ func (n *NeutrinoNotifier) notificationDispatcher() { } for _, block := range missedBlocks { - n.notifyBlockEpochClient( - msg, block.Height, block.Hash, - block.BlockHeader, - ) + n.notifyBlockEpochClient(msg, block) } msg.errorChan <- nil @@ -682,11 +675,9 @@ func (n *NeutrinoNotifier) handleBlockConnected(newBlock *filteredBlock) error { // satisfy any client requests based upon the new block. n.bestBlock.Hash = &newBlock.hash n.bestBlock.Height = int32(newBlock.height) - n.bestBlock.BlockHeader = newBlock.header + n.bestBlock.Block = *rawBlock.MsgBlock() - n.notifyBlockEpochs( - int32(newBlock.height), &newBlock.hash, newBlock.header, - ) + n.notifyBlockEpochs(n.bestBlock) return n.txNotifier.NotifyHeight(newBlock.height) } @@ -711,27 +702,19 @@ func (n *NeutrinoNotifier) getFilteredBlock(epoch chainntnfs.BlockEpoch) (*filte // notifyBlockEpochs notifies all registered block epoch clients of the newly // connected block to the main chain. -func (n *NeutrinoNotifier) notifyBlockEpochs(newHeight int32, newSha *chainhash.Hash, - blockHeader *wire.BlockHeader) { - +func (n *NeutrinoNotifier) notifyBlockEpochs(epoch chainntnfs.BlockEpoch) { for _, client := range n.blockEpochClients { - n.notifyBlockEpochClient(client, newHeight, newSha, blockHeader) + n.notifyBlockEpochClient(client, epoch) } } // notifyBlockEpochClient sends a registered block epoch client a notification // about a specific block. -func (n *NeutrinoNotifier) notifyBlockEpochClient(epochClient *blockEpochRegistration, - height int32, sha *chainhash.Hash, blockHeader *wire.BlockHeader) { - - epoch := &chainntnfs.BlockEpoch{ - Height: height, - Hash: sha, - BlockHeader: blockHeader, - } +func (n *NeutrinoNotifier) notifyBlockEpochClient( + epochClient *blockEpochRegistration, epoch chainntnfs.BlockEpoch) { select { - case epochClient.epochQueue.ChanIn() <- epoch: + case epochClient.epochQueue.ChanIn() <- &epoch: case <-epochClient.cancelChan: case <-n.quit: } @@ -1124,6 +1107,18 @@ type NeutrinoChainConn struct { p2pNode *neutrino.ChainService } +// GetBlock returns the block for a hash. +func (n *NeutrinoChainConn) GetBlock( + blockHash *chainhash.Hash) (*wire.MsgBlock, error) { + + utilBlock, err := n.p2pNode.GetBlock(*blockHash) + if err != nil { + return nil, err + } + + return utilBlock.MsgBlock(), nil +} + // GetBlockHeader returns the block header for a hash. func (n *NeutrinoChainConn) GetBlockHeader(blockHash *chainhash.Hash) (*wire.BlockHeader, error) { return n.p2pNode.GetBlockHeader(blockHash) diff --git a/chainntnfs/test/test_interface.go b/chainntnfs/test/test_interface.go index 35e63a45e98..693314b75cb 100644 --- a/chainntnfs/test/test_interface.go +++ b/chainntnfs/test/test_interface.go @@ -410,17 +410,14 @@ func testBlockEpochNotification(miner *rpctest.Harness, // and that header matches the contained header // hash. blockEpoch := <-epochClient.Epochs - if blockEpoch.BlockHeader == nil { - t.Logf("%d", i) - clientErrors <- fmt.Errorf("block " + - "header is nil") - return - } - if blockEpoch.BlockHeader.BlockHash() != - *blockEpoch.Hash { - - clientErrors <- fmt.Errorf("block " + - "header hash mismatch") + t.Logf("%d", i) + + got := blockEpoch.Block.Header.BlockHash() + want := *blockEpoch.Hash + if got != want { + clientErrors <- fmt.Errorf("block "+ + "header hash mismatch: "+ + "want=%v, got=%v", want, got) return } diff --git a/config_builder.go b/config_builder.go index 49a0b6c507b..66bda50f9e9 100644 --- a/config_builder.go +++ b/config_builder.go @@ -629,8 +629,7 @@ func proxyBlockEpoch(notifier chainntnfs.ChainNotifier, go func() { for blk := range blockEpoch.Epochs { ntfn := blockntfns.NewBlockConnected( - *blk.BlockHeader, - uint32(blk.Height), + blk.Block.Header, uint32(blk.Height), ) sub.Notifications <- ntfn diff --git a/contractcourt/anchor_resolver.go b/contractcourt/anchor_resolver.go index b4d6877202e..b80b84033d4 100644 --- a/contractcourt/anchor_resolver.go +++ b/contractcourt/anchor_resolver.go @@ -84,7 +84,7 @@ func (c *anchorResolver) ResolverKey() []byte { } // Resolve offers the anchor output to the sweeper and waits for it to be swept. -func (c *anchorResolver) Resolve(_ bool) (ContractResolver, error) { +func (c *anchorResolver) Resolve(_ <-chan int32) (ContractResolver, error) { // Attempt to update the sweep parameters to the post-confirmation // situation. We don't want to force sweep anymore, because the anchor // lost its special purpose to get the commitment confirmed. It is just diff --git a/contractcourt/breach_resolver.go b/contractcourt/breach_resolver.go index 740b4471d5d..57004b299d0 100644 --- a/contractcourt/breach_resolver.go +++ b/contractcourt/breach_resolver.go @@ -47,7 +47,7 @@ func (b *breachResolver) ResolverKey() []byte { // been broadcast. // // TODO(yy): let sweeper handle the breach inputs. -func (b *breachResolver) Resolve(_ bool) (ContractResolver, error) { +func (b *breachResolver) Resolve(_ <-chan int32) (ContractResolver, error) { if !b.subscribed { complete, err := b.SubscribeBreachComplete( &b.ChanPoint, b.replyChan, diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 0cc4b111a5a..03cde88b8fc 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -11,6 +11,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/walletdb" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" @@ -250,6 +251,11 @@ type ChainArbitrator struct { // active channels that it must still watch over. chanSource *channeldb.DB + // blockBeatChan is a channel to receive blocks from BlockBeat. The + // received block contains the best known height and the transactions + // confirmed in this block. + blockBeatChan chan chainio.Beat + quit chan struct{} wg sync.WaitGroup @@ -266,6 +272,7 @@ func NewChainArbitrator(cfg ChainArbitratorConfig, activeWatchers: make(map[wire.OutPoint]*chainWatcher), chanSource: db, quit: make(chan struct{}), + blockBeatChan: make(chan chainio.Beat), } } @@ -745,18 +752,11 @@ func (c *ChainArbitrator) Start() error { } } - // Subscribe to a single stream of block epoch notifications that we - // will dispatch to all active arbitrators. - blockEpoch, err := c.cfg.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return err - } - // Start our goroutine which will dispatch blocks to each arbitrator. c.wg.Add(1) go func() { defer c.wg.Done() - c.dispatchBlocks(blockEpoch) + c.dispatchBlocks() }() // TODO(roasbeef): eventually move all breach watching here @@ -764,94 +764,28 @@ func (c *ChainArbitrator) Start() error { return nil } -// blockRecipient contains the information we need to dispatch a block to a -// channel arbitrator. -type blockRecipient struct { - // chanPoint is the funding outpoint of the channel. - chanPoint wire.OutPoint - - // blocks is the channel that new block heights are sent into. This - // channel should be sufficiently buffered as to not block the sender. - blocks chan<- int32 - - // quit is closed if the receiving entity is shutting down. - quit chan struct{} -} - // dispatchBlocks consumes a block epoch notification stream and dispatches // blocks to each of the chain arb's active channel arbitrators. This function // must be run in a goroutine. -func (c *ChainArbitrator) dispatchBlocks( - blockEpoch *chainntnfs.BlockEpochEvent) { - - // getRecipients is a helper function which acquires the chain arb - // lock and returns a set of block recipients which can be used to - // dispatch blocks. - getRecipients := func() []blockRecipient { - c.Lock() - blocks := make([]blockRecipient, 0, len(c.activeChannels)) - for _, channel := range c.activeChannels { - blocks = append(blocks, blockRecipient{ - chanPoint: channel.cfg.ChanPoint, - blocks: channel.blocks, - quit: channel.quit, - }) - } - c.Unlock() - - return blocks - } - - // On exit, cancel our blocks subscription and close each block channel - // so that the arbitrators know they will no longer be receiving blocks. - defer func() { - blockEpoch.Cancel() - - recipients := getRecipients() - for _, recipient := range recipients { - close(recipient.blocks) - } - }() - +func (c *ChainArbitrator) dispatchBlocks() { // Consume block epochs until we receive the instruction to shutdown. for { select { // Consume block epochs, exiting if our subscription is // terminated. - case block, ok := <-blockEpoch.Epochs: + case beat, ok := <-c.blockBeatChan: if !ok { log.Trace("dispatchBlocks block epoch " + "cancelled") return } - // Get the set of currently active channels block - // subscription channels and dispatch the block to - // each. - for _, recipient := range getRecipients() { - select { - // Deliver the block to the arbitrator. - case recipient.blocks <- block.Height: - - // If the recipient is shutting down, exit - // without delivering the block. This may be - // the case when two blocks are mined in quick - // succession, and the arbitrator resolves - // after the first block, and does not need to - // consume the second block. - case <-recipient.quit: - log.Debugf("channel: %v exit without "+ - "receiving block: %v", - recipient.chanPoint, - block.Height) - - // If the chain arb is shutting down, we don't - // need to deliver any more blocks (everything - // will be shutting down). - case <-c.quit: - return - } - } + // Send this blockbeat to all the active channels and + // wait for them to finish processing it. + c.sendBlockAndWait(beat) + + // Notify the chain arbitrator has processed the block. + fn.SendOrQuit(beat.Err, nil, c.quit) // Exit if the chain arbitrator is shutting down. case <-c.quit: @@ -860,6 +794,58 @@ func (c *ChainArbitrator) dispatchBlocks( } } +// sendBlockAndWait sends the blockbeat to all active channel arbitrator in +// parallel and wait for them to finish processing it. +func (c *ChainArbitrator) sendBlockAndWait(beat chainio.Beat) { + // Read the active channels in a lock. + c.Lock() + + // Create a map to record active channel arbitrator. + channels := make([]chainio.Consumer, 0, len(c.activeChannels)) + + // Create a map of go chans to store the done signals. + doneChans := make( + map[wire.OutPoint]chan struct{}, len(c.activeChannels), + ) + + // Copy the active channels to the map. + for op, channel := range c.activeChannels { + channels = append(channels, channel) + doneChans[op] = make(chan struct{}) + } + + c.Unlock() + + beat.NotifyConsumers(channels) + + // Iterate all the copied channels and send the blockbeat to them. + + for _, channel := range channels { + + // Deliver the block to the channel arbitrator. + go func(ch *ChannelArbitrator, beat chainio.Beat) { + // Send the block to the arbitrator. + c.waitForChanArbProcessBlock(ch, beat) + + // Signal that the arbitrator has finished processing + // the block. + close(doneChans[ch.cfg.ChanPoint]) + }(channel, beat) + } + + // Wait for all channel arbitrators to process the block. + for op, doneChan := range doneChans { + select { + case <-doneChan: + log.Debugf("ChannelArbitrator(%v): processed block %d", + op, beat.Epoch.Height) + + case <-c.quit: + return + } + } +} + // republishClosingTxs will load any stored cooperative or unilateral closing // transactions and republish them. This helps ensure propagation of the // transactions in the event that prior publications failed. @@ -1320,3 +1306,22 @@ func (c *ChainArbitrator) FindOutgoingHTLCDeadline(scid lnwire.ShortChannelID, // TODO(roasbeef): arbitration reports // * types: contested, waiting for success conf, etc + +// NOTE: part of the `chainio.Consumer` interface. +func (c *ChainArbitrator) ProcessBlock(beat chainio.Beat) <-chan error { + select { + case c.blockBeatChan <- beat: + log.Debugf("Received block beat for height=%d", + beat.Epoch.Height) + + case <-c.quit: + return nil + } + + return beat.Err +} + +// NOTE: part of the `chainio.Consumer` interface. +func (c *ChainArbitrator) Name() string { + return "chain arbitrator" +} diff --git a/contractcourt/chain_arbitrator_test.go b/contractcourt/chain_arbitrator_test.go index 36f6dad18ba..866052d8dc4 100644 --- a/contractcourt/chain_arbitrator_test.go +++ b/contractcourt/chain_arbitrator_test.go @@ -80,7 +80,6 @@ func TestChainArbitratorRepublishCloses(t *testing.T) { ChainIO: &mock.ChainIO{}, Notifier: &mock.ChainNotifier{ SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), ConfChan: make(chan *chainntnfs.TxConfirmation), }, PublishTx: func(tx *wire.MsgTx, _ string) error { @@ -165,7 +164,6 @@ func TestResolveContract(t *testing.T) { ChainIO: &mock.ChainIO{}, Notifier: &mock.ChainNotifier{ SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), ConfChan: make(chan *chainntnfs.TxConfirmation), }, PublishTx: func(tx *wire.MsgTx, _ string) error { diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index cda0d4e1f63..e3547f8f344 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -1,7 +1,6 @@ package contractcourt import ( - "bytes" "context" "errors" "fmt" @@ -15,6 +14,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" @@ -346,18 +346,18 @@ type ChannelArbitrator struct { // to do its duty. cfg ChannelArbitratorConfig - // blocks is a channel that the arbitrator will receive new blocks on. - // This channel should be buffered by so that it does not block the - // sender. - blocks chan int32 + // blockBeatChan is a channel to receive blocks from BlockBeat. The + // received block contains the best known height and the transactions + // confirmed in this block. + blockBeatChan chan chainio.Beat // signalUpdates is a channel that any new live signals for the channel // we're watching over will be sent. signalUpdates chan *signalUpdateMsg - // activeResolvers is a slice of any active resolvers. This is used to - // be able to signal them for shutdown in the case that we shutdown. - activeResolvers []ContractResolver + // activeResolvers is a map of active resolvers. It uses the resolver + // as the key and a block chan as the value. + activeResolvers map[ContractResolver]chan int32 // activeResolversLock prevents simultaneous read and write to the // resolvers slice. @@ -399,8 +399,10 @@ func NewChannelArbitrator(cfg ChannelArbitratorConfig, } return &ChannelArbitrator{ - log: log, - blocks: make(chan int32, arbitratorBlockBufferSize), + log: log, + blockBeatChan: make( + chan chainio.Beat, arbitratorBlockBufferSize, + ), signalUpdates: make(chan *signalUpdateMsg), resolutionSignal: make(chan struct{}), forceCloseReqs: make(chan *forceCloseReq), @@ -411,6 +413,9 @@ func NewChannelArbitrator(cfg ChannelArbitratorConfig, } } +// Compile-time check for the WebFeeService interface. +var _ chainio.Consumer = (*ChannelArbitrator)(nil) + // chanArbStartState contains the information from disk that we need to start // up a channel arbitrator. type chanArbStartState struct { @@ -787,7 +792,7 @@ func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet, // TODO(roasbeef): this isn't re-launched? } - c.launchResolvers(unresolvedContracts, true) + c.launchResolvers(unresolvedContracts) return nil } @@ -798,7 +803,7 @@ func (c *ChannelArbitrator) Report() []*ContractReport { defer c.activeResolversLock.RUnlock() var reports []*ContractReport - for _, resolver := range c.activeResolvers { + for resolver := range c.activeResolvers { r, ok := resolver.(reportingContractResolver) if !ok { continue @@ -828,7 +833,7 @@ func (c *ChannelArbitrator) Stop() error { } c.activeResolversLock.RLock() - for _, activeResolver := range c.activeResolvers { + for activeResolver := range c.activeResolvers { activeResolver.Stop() } c.activeResolversLock.RUnlock() @@ -1245,7 +1250,7 @@ func (c *ChannelArbitrator) stateStep( // Finally, we'll launch all the required contract resolvers. // Once they're all resolved, we're no longer needed. - c.launchResolvers(resolvers, false) + c.launchResolvers(resolvers) nextState = StateWaitingFullResolution @@ -1576,16 +1581,20 @@ func (c *ChannelArbitrator) findCommitmentDeadlineAndValue(heightHint uint32, } // launchResolvers updates the activeResolvers list and starts the resolvers. -func (c *ChannelArbitrator) launchResolvers(resolvers []ContractResolver, - immediate bool) { - +func (c *ChannelArbitrator) launchResolvers(resolvers []ContractResolver) { c.activeResolversLock.Lock() defer c.activeResolversLock.Unlock() - c.activeResolvers = resolvers + c.activeResolvers = make(map[ContractResolver]chan int32) + for _, contract := range resolvers { c.wg.Add(1) - go c.resolveContract(contract, immediate) + + // Create a block chan for each contract resolver. + blockChan := make(chan int32, arbitratorBlockBufferSize) + c.activeResolvers[contract] = blockChan + + go c.resolveContract(contract, blockChan) } } @@ -1609,8 +1618,8 @@ func (c *ChannelArbitrator) advanceState( for { priorState = c.state log.Debugf("ChannelArbitrator(%v): attempting state step with "+ - "trigger=%v from state=%v", c.cfg.ChanPoint, trigger, - priorState) + "trigger=%v from state=%v at height=%v", + c.cfg.ChanPoint, trigger, priorState, triggerHeight) nextState, closeTx, err := c.stateStep( triggerHeight, trigger, confCommitSet, @@ -2579,15 +2588,15 @@ func (c *ChannelArbitrator) replaceResolver(oldResolver, c.activeResolversLock.Lock() defer c.activeResolversLock.Unlock() - oldKey := oldResolver.ResolverKey() - for i, r := range c.activeResolvers { - if bytes.Equal(r.ResolverKey(), oldKey) { - c.activeResolvers[i] = newResolver - return nil - } + blockChan, ok := c.activeResolvers[oldResolver] + if !ok { + return errors.New("resolver to be replaced not found") } - return errors.New("resolver to be replaced not found") + c.activeResolvers[newResolver] = blockChan + delete(c.activeResolvers, oldResolver) + + return nil } // resolveContract is a goroutine tasked with fully resolving an unresolved @@ -2599,7 +2608,7 @@ func (c *ChannelArbitrator) replaceResolver(oldResolver, // // NOTE: This MUST be run as a goroutine. func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver, - immediate bool) { + blockChan <-chan int32) { defer c.wg.Done() @@ -2621,7 +2630,7 @@ func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver, default: // Otherwise, we'll attempt to resolve the current // contract. - nextContract, err := currentContract.Resolve(immediate) + nextContract, err := currentContract.Resolve(blockChan) if err != nil { if err == errResolverShuttingDown { return @@ -2765,6 +2774,32 @@ func (c *ChannelArbitrator) updateActiveHTLCs() { } } +// notifyResolvers will send the block height to all the active resolvers. +func (c *ChannelArbitrator) notifyResolvers(height int32) { + c.activeResolversLock.RLock() + defer c.activeResolversLock.RUnlock() + + log.Debugf("Notifying %v resolvers of new block height %v", + len(c.activeResolvers), height) + + // notifyHeight is a helper closure that sends the block height to a + // single resolver. + notifyHeight := func(height int32, blockChan chan int32) { + select { + case blockChan <- height: + case <-c.quit: + } + } + + // Notify all resolvers in parallel. + for _, blockChan := range c.activeResolvers { + go notifyHeight(height, blockChan) + } + + log.Debugf("Notified %v resolvers of new block height %v", + len(c.activeResolvers), height) +} + // channelAttendant is the primary goroutine that acts at the judicial // arbitrator between our channel state, the remote channel peer, and the // blockchain (Our judge). This goroutine will ensure that we faithfully execute @@ -2790,31 +2825,23 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { // A new block has arrived, we'll examine all the active HTLC's // to see if any of them have expired, and also update our // track of the best current height. - case blockHeight, ok := <-c.blocks: + case beat, ok := <-c.blockBeatChan: if !ok { return } - bestHeight = blockHeight - // If we're not in the default state, then we can - // ignore this signal as we're waiting for contract - // resolution. - if c.state != StateDefault { - continue - } + log.Debugf("ChannelArbitrator(%v): new block height=%v", + c.cfg.ChanPoint, beat.Epoch.Height) - // Now that a new block has arrived, we'll attempt to - // advance our state forward. - nextState, _, err := c.advanceState( - uint32(bestHeight), chainTrigger, nil, - ) + err := c.handleBlockbeat(beat) if err != nil { - log.Errorf("Unable to advance state: %v", err) + log.Errorf("Handle block=%v got err: %v", + beat.Epoch.Height, err) } // If as a result of this trigger, the contract is // fully resolved, then well exit. - if nextState == StateFullyResolved { + if c.state == StateFullyResolved { return } @@ -2863,16 +2890,18 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { // We have broadcasted our commitment, and it is now confirmed // on-chain. case closeInfo := <-c.cfg.ChainEvents.LocalUnilateralClosure: - log.Infof("ChannelArbitrator(%v): local on-chain "+ - "channel close", c.cfg.ChanPoint) - if c.state != StateCommitmentBroadcasted { log.Errorf("ChannelArbitrator(%v): unexpected "+ "local on-chain channel close", c.cfg.ChanPoint) } + closeTx := closeInfo.CloseTx + log.Infof("ChannelArbitrator(%v): local force close "+ + "tx=%v confirmed", c.cfg.ChanPoint, + closeTx.TxHash()) + contractRes := &ContractResolutions{ CommitHash: closeTx.TxHash(), CommitResolution: closeInfo.CommitResolution, @@ -3140,6 +3169,58 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { } } +// handleBlockbeat processes a newly received blockbeat. +func (c *ChannelArbitrator) handleBlockbeat(beat chainio.Beat) error { + // Notify we've processed the block. + defer fn.SendOrQuit(beat.Err, nil, c.quit) + + bestHeight := beat.Epoch.Height + + // Notify all resolvers about this new block. + c.notifyResolvers(bestHeight) + + // If we're not in the default state, then we can ignore this signal as + // we're waiting for contract resolution. + if c.state != StateDefault { + return nil + } + + // Now that a new block has arrived, we'll attempt to advance our state + // forward. + _, _, err := c.advanceState( + uint32(bestHeight), chainTrigger, nil, + ) + if err != nil { + return fmt.Errorf("unable to advance state: %w", err) + } + + return nil +} + +// ProcessBlock sends the specified blockbeat to the channel arbitrator's inner +// loop for processing. +// +// NOTE: Part of chainio.Consumer interface. +func (c *ChannelArbitrator) ProcessBlock(beat chainio.Beat) <-chan error { + select { + case c.blockBeatChan <- beat: + log.Debugf("Received block beat for height=%d", + beat.Epoch.Height) + + case <-c.quit: + return nil + } + + return beat.Err +} + +// Name returns a human-readable string for this subsystem. +// +// NOTE: Part of chainio.Consumer interface. +func (c *ChannelArbitrator) Name() string { + return fmt.Sprint("ChannelArbitrator(%v)", c.cfg.ChanPoint) +} + // checkLegacyBreach returns StateFullyResolved if the channel was closed with // a breach transaction before the channel arbitrator launched its own breach // resolver. StateContractClosed is returned if this is a modern breach close diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 43238494efc..ec6eaeedadd 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" @@ -358,7 +359,6 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog, OutgoingBroadcastDelta: 5, IncomingBroadcastDelta: 5, Notifier: &mock.ChainNotifier{ - EpochChan: make(chan *chainntnfs.BlockEpoch), SpendChan: make(chan *chainntnfs.SpendDetail), ConfChan: make(chan *chainntnfs.TxConfirmation), }, @@ -808,7 +808,7 @@ func TestChannelArbitratorBreachClose(t *testing.T) { require.Equal(t, 2, len(chanArb.activeResolvers)) var anchorExists, breachExists bool - for _, resolver := range chanArb.activeResolvers { + for resolver := range chanArb.activeResolvers { switch resolver.(type) { case *anchorResolver: anchorExists = true @@ -1039,9 +1039,13 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { len(chanArb.activeResolvers)) } + var resolver ContractResolver + for r := range chanArb.activeResolvers { + resolver = r + } + // We'll now examine the in-memory state of the active resolvers to // ensure t hey were populated properly. - resolver := chanArb.activeResolvers[0] outgoingResolver, ok := resolver.(*htlcOutgoingContestResolver) if !ok { t.Fatalf("expected outgoing contest resolver, got %vT", @@ -1063,7 +1067,10 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { } // Send a notification that the expiry height has been reached. - oldNotifier.EpochChan <- &chainntnfs.BlockEpoch{Height: 10} + beat := chainio.NewBeat(chainntnfs.BlockEpoch{ + Height: 10, + }) + chanArb.blockBeatChan <- beat // htlcOutgoingContestResolver is now transforming into a // htlcTimeoutResolver and should send the contract off for incubation. @@ -1900,7 +1907,10 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { // now mine a block (height 5), which is 5 blocks away // (our grace delta) from the expiry of that HTLC. case testCase.htlcExpired: - chanArbCtx.chanArb.blocks <- 5 + beat := chainio.NewBeat(chainntnfs.BlockEpoch{ + Height: 5, + }) + chanArbCtx.chanArb.blockBeatChan <- beat // Otherwise, we'll just trigger a regular force close // request. @@ -2004,7 +2014,10 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { // so instead, we'll mine another block which'll cause // it to re-examine its state and realize there're no // more HTLCs. - chanArbCtx.chanArb.blocks <- 6 + beat := chainio.NewBeat(chainntnfs.BlockEpoch{ + Height: 6, + }) + chanArbCtx.chanArb.blockBeatChan <- beat chanArbCtx.AssertStateTransitions(StateFullyResolved) }) } @@ -2076,13 +2089,19 @@ func TestChannelArbitratorPendingExpiredHTLC(t *testing.T) { // We will advance the uptime to 10 seconds which should be still within // the grace period and should not trigger going to chain. testClock.SetTime(startTime.Add(time.Second * 10)) - chanArbCtx.chanArb.blocks <- 5 + beat := chainio.NewBeat(chainntnfs.BlockEpoch{ + Height: 5, + }) + chanArbCtx.chanArb.blockBeatChan <- beat chanArbCtx.AssertState(StateDefault) // We will advance the uptime to 16 seconds which should trigger going // to chain. testClock.SetTime(startTime.Add(time.Second * 16)) - chanArbCtx.chanArb.blocks <- 6 + beat = chainio.NewBeat(chainntnfs.BlockEpoch{ + Height: 6, + }) + chanArbCtx.chanArb.blockBeatChan <- beat chanArbCtx.AssertStateTransitions( StateBroadcastCommit, StateCommitmentBroadcasted, @@ -2450,7 +2469,10 @@ func TestSweepAnchors(t *testing.T) { // Set current block height. heightHint := uint32(1000) - chanArbCtx.chanArb.blocks <- int32(heightHint) + beat := chainio.NewBeat(chainntnfs.BlockEpoch{ + Height: int32(heightHint), + }) + chanArbCtx.chanArb.blockBeatChan <- beat htlcIndexBase := uint64(99) deadlineDelta := uint32(10) @@ -2651,7 +2673,10 @@ func TestChannelArbitratorAnchors(t *testing.T) { // Set current block height. heightHint := uint32(1000) - chanArbCtx.chanArb.blocks <- int32(heightHint) + beat := chainio.NewBeat(chainntnfs.BlockEpoch{ + Height: int32(heightHint), + }) + chanArbCtx.chanArb.blockBeatChan <- beat htlcAmt := lnwire.MilliSatoshi(1_000_000) @@ -2777,7 +2802,11 @@ func TestChannelArbitratorAnchors(t *testing.T) { len(chanArb.activeResolvers)) } - resolver := chanArb.activeResolvers[0] + var resolver ContractResolver + for r := range chanArb.activeResolvers { + resolver = r + } + _, ok := resolver.(*anchorResolver) if !ok { t.Fatalf("expected anchor resolver, got %T", resolver) diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go index 296ea38e554..b6ac7514752 100644 --- a/contractcourt/commit_sweep_resolver.go +++ b/contractcourt/commit_sweep_resolver.go @@ -90,36 +90,6 @@ func (c *commitSweepResolver) ResolverKey() []byte { return key[:] } -// waitForHeight registers for block notifications and waits for the provided -// block height to be reached. -func waitForHeight(waitHeight uint32, notifier chainntnfs.ChainNotifier, - quit <-chan struct{}) error { - - // Register for block epochs. After registration, the current height - // will be sent on the channel immediately. - blockEpochs, err := notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return err - } - defer blockEpochs.Cancel() - - for { - select { - case newBlock, ok := <-blockEpochs.Epochs: - if !ok { - return errResolverShuttingDown - } - height := newBlock.Height - if height >= int32(waitHeight) { - return nil - } - - case <-quit: - return errResolverShuttingDown - } - } -} - // waitForSpend waits for the given outpoint to be spent, and returns the // details of the spending tx. func waitForSpend(op *wire.OutPoint, pkScript []byte, heightHint uint32, @@ -186,7 +156,9 @@ func (c *commitSweepResolver) getCommitTxConfHeight() (uint32, error) { // NOTE: This function MUST be run as a goroutine. // //nolint:funlen -func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) { +func (c *commitSweepResolver) Resolve( + _ <-chan int32) (ContractResolver, error) { + // If we're already resolved, then we can exit early. if c.resolved { return nil, nil @@ -214,44 +186,12 @@ func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) { c.currentReport.MaturityHeight = unlockHeight c.reportLock.Unlock() - // If there is a csv/cltv lock, we'll wait for that. - if c.commitResolution.MaturityDelay > 0 || c.hasCLTV() { - // Determine what height we should wait until for the locks to - // expire. - var waitHeight uint32 - switch { - // If we have both a csv and cltv lock, we'll need to look at - // both and see which expires later. - case c.commitResolution.MaturityDelay > 0 && c.hasCLTV(): - c.log.Debugf("waiting for CSV and CLTV lock to expire "+ - "at height %v", unlockHeight) - // If the CSV expires after the CLTV, or there is no - // CLTV, then we can broadcast a sweep a block before. - // Otherwise, we need to broadcast at our expected - // unlock height. - waitHeight = uint32(math.Max( - float64(unlockHeight-1), float64(c.leaseExpiry), - )) - - // If we only have a csv lock, wait for the height before the - // lock expires as the spend path should be unlocked by then. - case c.commitResolution.MaturityDelay > 0: - c.log.Debugf("waiting for CSV lock to expire at "+ - "height %v", unlockHeight) - waitHeight = unlockHeight - 1 - } - - err := waitForHeight(waitHeight, c.Notifier, c.quit) - if err != nil { - return nil, err - } - } - var ( isLocalCommitTx bool signDesc = c.commitResolution.SelfOutputSignDesc ) + switch { // For taproot channels, we'll know if this is the local commit based // on the witness script. For local channels, the witness script has an diff --git a/contractcourt/commit_sweep_resolver_test.go b/contractcourt/commit_sweep_resolver_test.go index f2b43b0f80a..1b4f35bd275 100644 --- a/contractcourt/commit_sweep_resolver_test.go +++ b/contractcourt/commit_sweep_resolver_test.go @@ -23,13 +23,13 @@ type commitSweepResolverTestContext struct { sweeper *mockSweeper resolverResultChan chan resolveResult t *testing.T + blockChan chan int32 } func newCommitSweepResolverTestContext(t *testing.T, resolution *lnwallet.CommitOutputResolution) *commitSweepResolverTestContext { notifier := &mock.ChainNotifier{ - EpochChan: make(chan *chainntnfs.BlockEpoch), SpendChan: make(chan *chainntnfs.SpendDetail), ConfChan: make(chan *chainntnfs.TxConfirmation), } @@ -71,10 +71,11 @@ func newCommitSweepResolverTestContext(t *testing.T, ) return &commitSweepResolverTestContext{ - resolver: resolver, - notifier: notifier, - sweeper: sweeper, - t: t, + resolver: resolver, + notifier: notifier, + sweeper: sweeper, + t: t, + blockChan: make(chan int32, 1), } } @@ -82,7 +83,7 @@ func (i *commitSweepResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) go func() { - nextResolver, err := i.resolver.Resolve(false) + nextResolver, err := i.resolver.Resolve(i.blockChan) i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, err: err, @@ -91,9 +92,7 @@ func (i *commitSweepResolverTestContext) resolve() { } func (i *commitSweepResolverTestContext) notifyEpoch(height int32) { - i.notifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: height, - } + i.blockChan <- height } func (i *commitSweepResolverTestContext) waitForResult() { diff --git a/contractcourt/contract_resolver.go b/contractcourt/contract_resolver.go index 5acf8006496..c574f8ae577 100644 --- a/contractcourt/contract_resolver.go +++ b/contractcourt/contract_resolver.go @@ -43,7 +43,7 @@ type ContractResolver interface { // resolution, then another resolve is returned. // // NOTE: This function MUST be run as a goroutine. - Resolve(immediate bool) (ContractResolver, error) + Resolve(blockChan <-chan int32) (ContractResolver, error) // SupplementState allows the user of a ContractResolver to supplement // it with state required for the proper resolution of a contract. diff --git a/contractcourt/htlc_incoming_contest_resolver.go b/contractcourt/htlc_incoming_contest_resolver.go index b104f8d70a6..e3c4e991a2a 100644 --- a/contractcourt/htlc_incoming_contest_resolver.go +++ b/contractcourt/htlc_incoming_contest_resolver.go @@ -91,7 +91,7 @@ func (h *htlcIncomingContestResolver) processFinalHtlcFail() error { // // NOTE: Part of the ContractResolver interface. func (h *htlcIncomingContestResolver) Resolve( - _ bool) (ContractResolver, error) { + blockChan <-chan int32) (ContractResolver, error) { // If we're already full resolved, then we don't have anything further // to do. @@ -126,21 +126,13 @@ func (h *htlcIncomingContestResolver) Resolve( return nil, h.PutResolverReport(nil, resReport) } - // Register for block epochs. After registration, the current height - // will be sent on the channel immediately. - blockEpochs, err := h.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return nil, err - } - defer blockEpochs.Cancel() - var currentHeight int32 select { - case newBlock, ok := <-blockEpochs.Epochs: + case height, ok := <-blockChan: if !ok { return nil, errResolverShuttingDown } - currentHeight = newBlock.Height + currentHeight = height case <-h.quit: return nil, errResolverShuttingDown } @@ -403,7 +395,7 @@ func (h *htlcIncomingContestResolver) Resolve( htlcResolution := hodlItem.(invoices.HtlcResolution) return processHtlcResolution(htlcResolution) - case newBlock, ok := <-blockEpochs.Epochs: + case height, ok := <-blockChan: if !ok { return nil, errResolverShuttingDown } @@ -411,7 +403,7 @@ func (h *htlcIncomingContestResolver) Resolve( // If this new height expires the HTLC, then this means // we never found out the preimage, so we can mark // resolved and exit. - newHeight := uint32(newBlock.Height) + newHeight := uint32(height) if newHeight >= h.htlcExpiry { log.Infof("%T(%v): HTLC has timed out "+ "(expiry=%v, height=%v), abandoning", h, diff --git a/contractcourt/htlc_incoming_contest_resolver_test.go b/contractcourt/htlc_incoming_contest_resolver_test.go index 55d93a6fb37..e3d965e1953 100644 --- a/contractcourt/htlc_incoming_contest_resolver_test.go +++ b/contractcourt/htlc_incoming_contest_resolver_test.go @@ -309,11 +309,11 @@ type incomingResolverTestContext struct { nextResolver ContractResolver finalHtlcOutcomeStored bool t *testing.T + blockChan chan int32 } func newIncomingResolverTestContext(t *testing.T, isExit bool) *incomingResolverTestContext { notifier := &mock.ChainNotifier{ - EpochChan: make(chan *chainntnfs.BlockEpoch), SpendChan: make(chan *chainntnfs.SpendDetail), ConfChan: make(chan *chainntnfs.TxConfirmation), } @@ -332,6 +332,7 @@ func newIncomingResolverTestContext(t *testing.T, isExit bool) *incomingResolver notifier: notifier, onionProcessor: onionProcessor, t: t, + blockChan: make(chan int32, 1), } htlcNotifier := &mockHTLCNotifier{} @@ -395,7 +396,7 @@ func (i *incomingResolverTestContext) resolve() { i.resolveErr = make(chan error, 1) go func() { var err error - i.nextResolver, err = i.resolver.Resolve(false) + i.nextResolver, err = i.resolver.Resolve(i.blockChan) i.resolveErr <- err }() @@ -404,9 +405,7 @@ func (i *incomingResolverTestContext) resolve() { } func (i *incomingResolverTestContext) notifyEpoch(height int32) { - i.notifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: height, - } + i.blockChan <- height } func (i *incomingResolverTestContext) waitForResult(expectSuccessRes bool) { diff --git a/contractcourt/htlc_outgoing_contest_resolver.go b/contractcourt/htlc_outgoing_contest_resolver.go index 2466544c982..c37f89b4a9a 100644 --- a/contractcourt/htlc_outgoing_contest_resolver.go +++ b/contractcourt/htlc_outgoing_contest_resolver.go @@ -50,7 +50,7 @@ func newOutgoingContestResolver(res lnwallet.OutgoingHtlcResolution, // is able to handle the final resolution of the contract. We're only the pivot // point. func (h *htlcOutgoingContestResolver) Resolve( - _ bool) (ContractResolver, error) { + blockChan <-chan int32) (ContractResolver, error) { // If we're already full resolved, then we don't have anything further // to do. @@ -99,18 +99,12 @@ func (h *htlcOutgoingContestResolver) Resolve( // If we reach this point, then we can't fully act yet, so we'll await // either of our signals triggering: the HTLC expires, or we learn of // the preimage. - blockEpochs, err := h.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return nil, err - } - defer blockEpochs.Cancel() - for { select { // A new block has arrived, we'll check to see if this leads to // HTLC expiration. - case newBlock, ok := <-blockEpochs.Epochs: + case newBlockHeight, ok := <-blockChan: if !ok { return nil, errResolverShuttingDown } @@ -125,7 +119,7 @@ func (h *htlcOutgoingContestResolver) Resolve( // check doesn't pass, error `transaction is not // finalized` will be returned and the broadcast will // fail. - newHeight := uint32(newBlock.Height) + newHeight := uint32(newBlockHeight) if newHeight >= h.htlcResolution.Expiry { log.Infof("%T(%v): HTLC has expired "+ "(height=%v, expiry=%v), transforming "+ diff --git a/contractcourt/htlc_outgoing_contest_resolver_test.go b/contractcourt/htlc_outgoing_contest_resolver_test.go index f67c34ff4e1..1968a042a2b 100644 --- a/contractcourt/htlc_outgoing_contest_resolver_test.go +++ b/contractcourt/htlc_outgoing_contest_resolver_test.go @@ -122,11 +122,11 @@ type outgoingResolverTestContext struct { resolverResultChan chan resolveResult resolutionChan chan ResolutionMsg t *testing.T + blockChan chan int32 } func newOutgoingResolverTestContext(t *testing.T) *outgoingResolverTestContext { notifier := &mock.ChainNotifier{ - EpochChan: make(chan *chainntnfs.BlockEpoch), SpendChan: make(chan *chainntnfs.SpendDetail), ConfChan: make(chan *chainntnfs.TxConfirmation), } @@ -202,6 +202,7 @@ func newOutgoingResolverTestContext(t *testing.T) *outgoingResolverTestContext { preimageDB: preimageDB, resolutionChan: resolutionChan, t: t, + blockChan: make(chan int32, 1), } } @@ -209,7 +210,7 @@ func (i *outgoingResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) go func() { - nextResolver, err := i.resolver.Resolve(false) + nextResolver, err := i.resolver.Resolve(i.blockChan) i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, err: err, @@ -221,9 +222,7 @@ func (i *outgoingResolverTestContext) resolve() { } func (i *outgoingResolverTestContext) notifyEpoch(height int32) { - i.notifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: height, - } + i.blockChan <- height } func (i *outgoingResolverTestContext) waitForResult(expectTimeoutRes bool) { diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index 6eee939eac0..8b1245b4576 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -116,7 +116,7 @@ func (h *htlcSuccessResolver) ResolverKey() []byte { // // NOTE: Part of the ContractResolver interface. func (h *htlcSuccessResolver) Resolve( - immediate bool) (ContractResolver, error) { + _ <-chan int32) (ContractResolver, error) { // If we're already resolved, then we can exit early. if h.resolved { @@ -126,12 +126,12 @@ func (h *htlcSuccessResolver) Resolve( // If we don't have a success transaction, then this means that this is // an output on the remote party's commitment transaction. if h.htlcResolution.SignedSuccessTx == nil { - return h.resolveRemoteCommitOutput(immediate) + return h.resolveRemoteCommitOutput() } // Otherwise this an output on our own commitment, and we must start by // broadcasting the second-level success transaction. - secondLevelOutpoint, err := h.broadcastSuccessTx(immediate) + secondLevelOutpoint, err := h.broadcastSuccessTx() if err != nil { return nil, err } @@ -165,8 +165,8 @@ func (h *htlcSuccessResolver) Resolve( // broadcasting the second-level success transaction. It returns the ultimate // outpoint of the second-level tx, that we must wait to be spent for the // resolver to be fully resolved. -func (h *htlcSuccessResolver) broadcastSuccessTx( - immediate bool) (*wire.OutPoint, error) { +func (h *htlcSuccessResolver) broadcastSuccessTx() ( + *wire.OutPoint, error) { // If we have non-nil SignDetails, this means that have a 2nd level // HTLC transaction that is signed using sighash SINGLE|ANYONECANPAY @@ -175,7 +175,7 @@ func (h *htlcSuccessResolver) broadcastSuccessTx( // the checkpointed outputIncubating field to determine if we already // swept the HTLC output into the second level transaction. if h.htlcResolution.SignDetails != nil { - return h.broadcastReSignedSuccessTx(immediate) + return h.broadcastReSignedSuccessTx() } // Otherwise we'll publish the second-level transaction directly and @@ -225,10 +225,8 @@ func (h *htlcSuccessResolver) broadcastSuccessTx( // broadcastReSignedSuccessTx handles the case where we have non-nil // SignDetails, and offers the second level transaction to the Sweeper, that // will re-sign it and attach fees at will. -// -//nolint:funlen -func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) ( - *wire.OutPoint, error) { +func (h *htlcSuccessResolver) broadcastReSignedSuccessTx() (*wire.OutPoint, + error) { // Keep track of the tx spending the HTLC output on the commitment, as // this will be the confirmed second-level tx we'll ultimately sweep. @@ -284,7 +282,6 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) ( sweep.Params{ Budget: budget, DeadlineHeight: deadline, - Immediate: immediate, }, ) if err != nil { @@ -356,30 +353,6 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) ( "height %v", h, h.htlc.RHash[:], waitHeight) } - // Deduct one block so this input is offered to the sweeper one block - // earlier since the sweeper will wait for one block to trigger the - // sweeping. - // - // TODO(yy): this is done so the outputs can be aggregated - // properly. Suppose CSV locks of five 2nd-level outputs all - // expire at height 840000, there is a race in block digestion - // between contractcourt and sweeper: - // - G1: block 840000 received in contractcourt, it now offers - // the outputs to the sweeper. - // - G2: block 840000 received in sweeper, it now starts to - // sweep the received outputs - there's no guarantee all - // fives have been received. - // To solve this, we either offer the outputs earlier, or - // implement `blockbeat`, and force contractcourt and sweeper - // to consume each block sequentially. - waitHeight-- - - // TODO(yy): let sweeper handles the wait? - err := waitForHeight(waitHeight, h.Notifier, h.quit) - if err != nil { - return nil, err - } - // We'll use this input index to determine the second-level output // index on the transaction, as the signatures requires the indexes to // be the same. We don't look for the second-level output script @@ -418,7 +391,7 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) ( h.htlc.RHash[:], budget, waitHeight) // TODO(roasbeef): need to update above for leased types - _, err = h.Sweeper.SweepInput( + _, err := h.Sweeper.SweepInput( inp, sweep.Params{ Budget: budget, @@ -440,7 +413,7 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) ( // resolveRemoteCommitOutput handles sweeping an HTLC output on the remote // commitment with the preimage. In this case we can sweep the output directly, // and don't have to broadcast a second-level transaction. -func (h *htlcSuccessResolver) resolveRemoteCommitOutput(immediate bool) ( +func (h *htlcSuccessResolver) resolveRemoteCommitOutput() ( ContractResolver, error) { isTaproot := txscript.IsPayToTaproot( @@ -489,7 +462,6 @@ func (h *htlcSuccessResolver) resolveRemoteCommitOutput(immediate bool) ( sweep.Params{ Budget: budget, DeadlineHeight: deadline, - Immediate: immediate, }, ) if err != nil { diff --git a/contractcourt/htlc_success_resolver_test.go b/contractcourt/htlc_success_resolver_test.go index b9182500bb4..473cd4506a3 100644 --- a/contractcourt/htlc_success_resolver_test.go +++ b/contractcourt/htlc_success_resolver_test.go @@ -37,6 +37,8 @@ type htlcResolverTestContext struct { finalHtlcOutcomeStored bool t *testing.T + + blockChan chan int32 } func newHtlcResolverTestContext(t *testing.T, @@ -44,7 +46,6 @@ func newHtlcResolverTestContext(t *testing.T, cfg ResolverConfig) ContractResolver) *htlcResolverTestContext { notifier := &mock.ChainNotifier{ - EpochChan: make(chan *chainntnfs.BlockEpoch, 1), SpendChan: make(chan *chainntnfs.SpendDetail, 1), ConfChan: make(chan *chainntnfs.TxConfirmation, 1), } @@ -54,6 +55,7 @@ func newHtlcResolverTestContext(t *testing.T, notifier: notifier, resolutionChan: make(chan ResolutionMsg, 1), t: t, + blockChan: make(chan int32, 1), } htlcNotifier := &mockHTLCNotifier{} @@ -134,7 +136,7 @@ func (i *htlcResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) go func() { - nextResolver, err := i.resolver.Resolve(false) + nextResolver, err := i.resolver.Resolve(i.blockChan) i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, err: err, @@ -142,6 +144,10 @@ func (i *htlcResolverTestContext) resolve() { }() } +func (i *htlcResolverTestContext) notifyEpoch(height int32) { + i.blockChan <- height +} + func (i *htlcResolverTestContext) waitForResult() { i.t.Helper() @@ -437,9 +443,7 @@ func TestHtlcSuccessSecondStageResolutionSweeper(t *testing.T) { } } - ctx.notifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: 13, - } + ctx.notifyEpoch(13) // We expect it to sweep the second-level // transaction we notfied about above. diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index 62ff832071c..435c7d7d03e 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -419,7 +419,7 @@ func checkSizeAndIndex(witness wire.TxWitness, size, index int) bool { // // NOTE: Part of the ContractResolver interface. func (h *htlcTimeoutResolver) Resolve( - immediate bool) (ContractResolver, error) { + _ <-chan int32) (ContractResolver, error) { // If we're already resolved, then we can exit early. if h.resolved { @@ -429,7 +429,7 @@ func (h *htlcTimeoutResolver) Resolve( // Start by spending the HTLC output, either by broadcasting the // second-level timeout transaction, or directly if this is the remote // commitment. - commitSpend, err := h.spendHtlcOutput(immediate) + commitSpend, err := h.spendHtlcOutput() if err != nil { return nil, err } @@ -473,7 +473,7 @@ func (h *htlcTimeoutResolver) Resolve( // sweepSecondLevelTx sends a second level timeout transaction to the sweeper. // This transaction uses the SINLGE|ANYONECANPAY flag. -func (h *htlcTimeoutResolver) sweepSecondLevelTx(immediate bool) error { +func (h *htlcTimeoutResolver) sweepSecondLevelTx() error { log.Infof("%T(%x): offering second-layer timeout tx to sweeper: %v", h, h.htlc.RHash[:], spew.Sdump(h.htlcResolution.SignedTimeoutTx)) @@ -531,7 +531,6 @@ func (h *htlcTimeoutResolver) sweepSecondLevelTx(immediate bool) error { sweep.Params{ Budget: budget, DeadlineHeight: h.incomingHTLCExpiryHeight, - Immediate: immediate, }, ) if err != nil { @@ -567,8 +566,8 @@ func (h *htlcTimeoutResolver) sendSecondLevelTxLegacy() error { // used to spend the output into the next stage. If this is the remote // commitment, the output will be swept directly without the timeout // transaction. -func (h *htlcTimeoutResolver) spendHtlcOutput( - immediate bool) (*chainntnfs.SpendDetail, error) { +func (h *htlcTimeoutResolver) spendHtlcOutput() ( + *chainntnfs.SpendDetail, error) { switch { // If we have non-nil SignDetails, this means that have a 2nd level @@ -576,7 +575,7 @@ func (h *htlcTimeoutResolver) spendHtlcOutput( // (the case for anchor type channels). In this case we can re-sign it // and attach fees at will. We let the sweeper handle this job. case h.htlcResolution.SignDetails != nil && !h.outputIncubating: - if err := h.sweepSecondLevelTx(immediate); err != nil { + if err := h.sweepSecondLevelTx(); err != nil { log.Errorf("Sending timeout tx to sweeper: %v", err) return nil, err @@ -712,30 +711,6 @@ func (h *htlcTimeoutResolver) handleCommitSpend( "height %v", h, h.htlc.RHash[:], waitHeight) } - // Deduct one block so this input is offered to the sweeper one - // block earlier since the sweeper will wait for one block to - // trigger the sweeping. - // - // TODO(yy): this is done so the outputs can be aggregated - // properly. Suppose CSV locks of five 2nd-level outputs all - // expire at height 840000, there is a race in block digestion - // between contractcourt and sweeper: - // - G1: block 840000 received in contractcourt, it now offers - // the outputs to the sweeper. - // - G2: block 840000 received in sweeper, it now starts to - // sweep the received outputs - there's no guarantee all - // fives have been received. - // To solve this, we either offer the outputs earlier, or - // implement `blockbeat`, and force contractcourt and sweeper - // to consume each block sequentially. - waitHeight-- - - // TODO(yy): let sweeper handles the wait? - err := waitForHeight(waitHeight, h.Notifier, h.quit) - if err != nil { - return nil, err - } - // We'll use this input index to determine the second-level // output index on the transaction, as the signatures requires // the indexes to be the same. We don't look for the @@ -774,7 +749,7 @@ func (h *htlcTimeoutResolver) handleCommitSpend( "sweeper with no deadline and budget=%v at height=%v", h, h.htlc.RHash[:], budget, waitHeight) - _, err = h.Sweeper.SweepInput( + _, err := h.Sweeper.SweepInput( inp, sweep.Params{ Budget: budget, diff --git a/contractcourt/htlc_timeout_resolver_test.go b/contractcourt/htlc_timeout_resolver_test.go index c551a6f1ceb..3e38486fb4c 100644 --- a/contractcourt/htlc_timeout_resolver_test.go +++ b/contractcourt/htlc_timeout_resolver_test.go @@ -267,7 +267,6 @@ func TestHtlcTimeoutResolver(t *testing.T) { } notifier := &mock.ChainNotifier{ - EpochChan: make(chan *chainntnfs.BlockEpoch), SpendChan: make(chan *chainntnfs.SpendDetail), ConfChan: make(chan *chainntnfs.TxConfirmation), } @@ -374,8 +373,7 @@ func TestHtlcTimeoutResolver(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - - _, err := resolver.Resolve(false) + _, err := resolver.Resolve(nil) if err != nil { resolveErr <- err } @@ -1089,9 +1087,7 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { } // Mimic CSV lock expiring. - ctx.notifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: 13, - } + ctx.notifyEpoch(13) // The timeout tx output should now be given to // the sweeper. diff --git a/itest/lnd_sweep_test.go b/itest/lnd_sweep_test.go index 27ad1326644..e25acf3a321 100644 --- a/itest/lnd_sweep_test.go +++ b/itest/lnd_sweep_test.go @@ -2,7 +2,6 @@ package itest import ( "fmt" - "math" "time" "github.com/btcsuite/btcd/btcutil" @@ -61,10 +60,7 @@ func testSweepCPFPAnchorOutgoingTimeout(ht *lntest.HarnessTest) { // Set up the fee estimator to return the testing fee rate when the // conf target is the deadline. - // - // TODO(yy): switch to conf when `blockbeat` is in place. - // ht.SetFeeEstimateWithConf(startFeeRateAnchor, deadlineDeltaAnchor) - ht.SetFeeEstimate(startFeeRateAnchor) + ht.SetFeeEstimateWithConf(startFeeRateAnchor, deadlineDeltaAnchor) // htlcValue is the outgoing HTLC's value. htlcValue := invoiceAmt @@ -167,52 +163,26 @@ func testSweepCPFPAnchorOutgoingTimeout(ht *lntest.HarnessTest) { )) ht.MineEmptyBlocks(int(numBlocks)) - // Assert Bob's force closing tx has been broadcast. - closeTxid := ht.Miner.AssertNumTxsInMempool(1)[0] + // Assert Bob's force closing tx has been broadcast. We should see two + // txns in the mempool: + // 1. Bob's force closing tx. + // 2. Bob's anchor sweeping tx CPFPing the force close tx. + _, sweepTx := ht.AssertForceCloseAndAnchorTxnsInMempool() // Remember the force close height so we can calculate the deadline // height. _, forceCloseHeight := ht.Miner.GetBestBlock() - // Bob should have two pending sweeps, + // Bob should have one pending sweep, // - anchor sweeping from his local commitment. - // - anchor sweeping from his remote commitment (invalid). - // - // TODO(yy): consider only sweeping the anchor from the local - // commitment. Previously we would sweep up to three versions of - // anchors because we don't know which one will be confirmed - if we - // only broadcast the local anchor sweeping, our peer can broadcast - // their commitment tx and replaces ours. With the new fee bumping, we - // should be safe to only sweep our local anchor since we RBF it on - // every new block, which destroys the remote's ability to pin us. - sweeps := ht.AssertNumPendingSweeps(bob, 2) - - // The two anchor sweeping should have the same deadline height. + anchorSweep := ht.AssertNumPendingSweeps(bob, 1)[0] + + // The anchor sweeping should have the expected deadline height. deadlineHeight := uint32(forceCloseHeight) + deadlineDeltaAnchor - require.Equal(ht, deadlineHeight, sweeps[0].DeadlineHeight) - require.Equal(ht, deadlineHeight, sweeps[1].DeadlineHeight) + require.Equal(ht, deadlineHeight, anchorSweep.DeadlineHeight) // Remember the deadline height for the CPFP anchor. - anchorDeadline := sweeps[0].DeadlineHeight - - // Mine a block so Bob's force closing tx stays in the mempool, which - // also triggers the CPFP anchor sweep. - ht.MineEmptyBlocks(1) - - // Bob should still have two pending sweeps, - // - anchor sweeping from his local commitment. - // - anchor sweeping from his remote commitment (invalid). - ht.AssertNumPendingSweeps(bob, 2) - - // We now check the expected fee and fee rate are used for Bob's anchor - // sweeping tx. - // - // We should see Bob's anchor sweeping tx triggered by the above - // block, along with his force close tx. - txns := ht.Miner.GetNumTxsFromMempool(2) - - // Find the sweeping tx. - sweepTx := ht.FindSweepingTxns(txns, 1, *closeTxid)[0] + anchorDeadline := anchorSweep.DeadlineHeight // Get the weight for Bob's anchor sweeping tx. txWeight := ht.CalculateTxWeight(sweepTx) @@ -224,11 +194,10 @@ func testSweepCPFPAnchorOutgoingTimeout(ht *lntest.HarnessTest) { fee := uint64(ht.CalculateTxFee(sweepTx)) feeRate := uint64(ht.CalculateTxFeeRate(sweepTx)) - // feeFuncWidth is the width of the fee function. By the time we got - // here, we've already mined one block, and the fee function maxes - // out one block before the deadline, so the width is the original - // deadline minus 2. - feeFuncWidth := deadlineDeltaAnchor - 2 + // feeFuncWidth is the width of the fee function. The fee function + // maxes out one block before the deadline, so the width is the + // original deadline minus 1. + feeFuncWidth := deadlineDeltaAnchor - 1 // Calculate the expected delta increased per block. feeDelta := (cpfpBudget - startFeeAnchor).MulF64( @@ -254,10 +223,10 @@ func testSweepCPFPAnchorOutgoingTimeout(ht *lntest.HarnessTest) { // Bob's fee bumper should increase its fees. ht.MineEmptyBlocks(1) - // Bob should still have two pending sweeps, - // - anchor sweeping from his local commitment. - // - anchor sweeping from his remote commitment (invalid). - ht.AssertNumPendingSweeps(bob, 2) + // Bob should still have the anchor sweeping from his local + // commitment. His anchor sweeping from his remote commitment + // is invalid and should be removed. + ht.AssertNumPendingSweeps(bob, 1) // Make sure Bob's old sweeping tx has been removed from the // mempool. @@ -266,7 +235,7 @@ func testSweepCPFPAnchorOutgoingTimeout(ht *lntest.HarnessTest) { // We expect to see two txns in the mempool, // - Bob's force close tx. // - Bob's anchor sweep tx. - ht.Miner.AssertNumTxsInMempool(2) + _, sweepTx := ht.AssertForceCloseAndAnchorTxnsInMempool() // We expect the fees to increase by i*delta. expectedFee := startFeeAnchor + feeDelta.MulF64(float64(i)) @@ -276,11 +245,7 @@ func testSweepCPFPAnchorOutgoingTimeout(ht *lntest.HarnessTest) { // We should see Bob's anchor sweeping tx being fee bumped // since it's not confirmed, along with his force close tx. - txns = ht.Miner.GetNumTxsFromMempool(2) - - // Find the sweeping tx. - sweepTx = ht.FindSweepingTxns(txns, 1, *closeTxid)[0] - + // // Calculate the fee rate of Bob's new sweeping tx. feeRate = uint64(ht.CalculateTxFeeRate(sweepTx)) @@ -315,17 +280,13 @@ func testSweepCPFPAnchorOutgoingTimeout(ht *lntest.HarnessTest) { // Get the last sweeping tx - we should see two txns here, Bob's anchor // sweeping tx and his force close tx. - txns = ht.Miner.GetNumTxsFromMempool(2) - - // Find the sweeping tx. - sweepTx = ht.FindSweepingTxns(txns, 1, *closeTxid)[0] + _, sweepTx = ht.AssertForceCloseAndAnchorTxnsInMempool() // Calculate the fee of Bob's new sweeping tx. fee = uint64(ht.CalculateTxFee(sweepTx)) - // Assert the budget is now used up. - require.InEpsilonf(ht, uint64(cpfpBudget), fee, 0.01, "want %d, got %d", - cpfpBudget, fee) + // Bob should have the anchor sweeping from his local commitment. + ht.AssertNumPendingSweeps(bob, 1) // Mine one more block. Since Bob's budget has been used up, there // won't be any more sweeping attempts. We now assert this by checking @@ -336,10 +297,7 @@ func testSweepCPFPAnchorOutgoingTimeout(ht *lntest.HarnessTest) { // // We expect two txns here, one for the anchor sweeping, the other for // the force close tx. - txns = ht.Miner.GetNumTxsFromMempool(2) - - // Find the sweeping tx. - currentSweepTx := ht.FindSweepingTxns(txns, 1, *closeTxid)[0] + _, currentSweepTx := ht.AssertForceCloseAndAnchorTxnsInMempool() // Assert the anchor sweep tx stays unchanged. require.Equal(ht, sweepTx.TxHash(), currentSweepTx.TxHash()) @@ -400,10 +358,7 @@ func testSweepCPFPAnchorIncomingTimeout(ht *lntest.HarnessTest) { // Set up the fee estimator to return the testing fee rate when the // conf target is the deadline. - // - // TODO(yy): switch to conf when `blockbeat` is in place. - // ht.SetFeeEstimateWithConf(startFeeRateAnchor, deadlineDeltaAnchor) - ht.SetFeeEstimate(startFeeRateAnchor) + ht.SetFeeEstimateWithConf(startFeeRateAnchor, deadlineDeltaAnchor) // Create a preimage, that will be held by Carol. var preimage lntypes.Preimage @@ -516,40 +471,22 @@ func testSweepCPFPAnchorIncomingTimeout(ht *lntest.HarnessTest) { numBlocks := forceCloseHeight - uint32(currentHeight) ht.MineEmptyBlocks(int(numBlocks)) - // Assert Bob's force closing tx has been broadcast. - closeTxid := ht.Miner.AssertNumTxsInMempool(1)[0] + // Assert Bob's force closing tx has been broadcast. We should see two + // txns in the mempool: + // 1. Bob's force closing tx. + // 2. Bob's anchor sweeping tx CPFPing the force close tx. + _, sweepTx := ht.AssertForceCloseAndAnchorTxnsInMempool() - // Bob should have two pending sweeps, + // Bob should have one pending sweep, // - anchor sweeping from his local commitment. - // - anchor sweeping from his remote commitment (invalid). - sweeps := ht.AssertNumPendingSweeps(bob, 2) + anchorSweep := ht.AssertNumPendingSweeps(bob, 1)[0] - // The two anchor sweeping should have the same deadline height. + // The anchor sweeping should have the expected deadline height. deadlineHeight := forceCloseHeight + deadlineDeltaAnchor - require.Equal(ht, deadlineHeight, sweeps[0].DeadlineHeight) - require.Equal(ht, deadlineHeight, sweeps[1].DeadlineHeight) + require.Equal(ht, deadlineHeight, anchorSweep.DeadlineHeight) // Remember the deadline height for the CPFP anchor. - anchorDeadline := sweeps[0].DeadlineHeight - - // Mine a block so Bob's force closing tx stays in the mempool, which - // also triggers the CPFP anchor sweep. - ht.MineEmptyBlocks(1) - - // Bob should still have two pending sweeps, - // - anchor sweeping from his local commitment. - // - anchor sweeping from his remote commitment (invalid). - ht.AssertNumPendingSweeps(bob, 2) - - // We now check the expected fee and fee rate are used for Bob's anchor - // sweeping tx. - // - // We should see Bob's anchor sweeping tx triggered by the above - // block, along with his force close tx. - txns := ht.Miner.GetNumTxsFromMempool(2) - - // Find the sweeping tx. - sweepTx := ht.FindSweepingTxns(txns, 1, *closeTxid)[0] + anchorDeadline := anchorSweep.DeadlineHeight // Get the weight for Bob's anchor sweeping tx. txWeight := ht.CalculateTxWeight(sweepTx) @@ -561,11 +498,10 @@ func testSweepCPFPAnchorIncomingTimeout(ht *lntest.HarnessTest) { fee := uint64(ht.CalculateTxFee(sweepTx)) feeRate := uint64(ht.CalculateTxFeeRate(sweepTx)) - // feeFuncWidth is the width of the fee function. By the time we got - // here, we've already mined one block, and the fee function maxes - // out one block before the deadline, so the width is the original - // deadline minus 2. - feeFuncWidth := deadlineDeltaAnchor - 2 + // feeFuncWidth is the width of the fee function. The fee function + // maxes out one block before the deadline, so the width is the + // original deadline minus 1. + feeFuncWidth := deadlineDeltaAnchor - 1 // Calculate the expected delta increased per block. feeDelta := (cpfpBudget - startFeeAnchor).MulF64( @@ -591,10 +527,10 @@ func testSweepCPFPAnchorIncomingTimeout(ht *lntest.HarnessTest) { // Bob's fee bumper should increase its fees. ht.MineEmptyBlocks(1) - // Bob should still have two pending sweeps, - // - anchor sweeping from his local commitment. - // - anchor sweeping from his remote commitment (invalid). - ht.AssertNumPendingSweeps(bob, 2) + // Bob should still have the anchor sweeping from his local + // commitment. His anchor sweeping from his remote commitment + // is invalid and should be removed. + ht.AssertNumPendingSweeps(bob, 1) // Make sure Bob's old sweeping tx has been removed from the // mempool. @@ -603,7 +539,7 @@ func testSweepCPFPAnchorIncomingTimeout(ht *lntest.HarnessTest) { // We expect to see two txns in the mempool, // - Bob's force close tx. // - Bob's anchor sweep tx. - ht.Miner.AssertNumTxsInMempool(2) + _, sweepTx := ht.AssertForceCloseAndAnchorTxnsInMempool() // We expect the fees to increase by i*delta. expectedFee := startFeeAnchor + feeDelta.MulF64(float64(i)) @@ -611,13 +547,6 @@ func testSweepCPFPAnchorIncomingTimeout(ht *lntest.HarnessTest) { expectedFee, txWeight, ) - // We should see Bob's anchor sweeping tx being fee bumped - // since it's not confirmed, along with his force close tx. - txns = ht.Miner.GetNumTxsFromMempool(2) - - // Find the sweeping tx. - sweepTx = ht.FindSweepingTxns(txns, 1, *closeTxid)[0] - // Calculate the fee rate of Bob's new sweeping tx. feeRate = uint64(ht.CalculateTxFeeRate(sweepTx)) @@ -652,10 +581,7 @@ func testSweepCPFPAnchorIncomingTimeout(ht *lntest.HarnessTest) { // Get the last sweeping tx - we should see two txns here, Bob's anchor // sweeping tx and his force close tx. - txns = ht.Miner.GetNumTxsFromMempool(2) - - // Find the sweeping tx. - sweepTx = ht.FindSweepingTxns(txns, 1, *closeTxid)[0] + _, sweepTx = ht.AssertForceCloseAndAnchorTxnsInMempool() // Calculate the fee of Bob's new sweeping tx. fee = uint64(ht.CalculateTxFee(sweepTx)) @@ -673,10 +599,7 @@ func testSweepCPFPAnchorIncomingTimeout(ht *lntest.HarnessTest) { // // We expect two txns here, one for the anchor sweeping, the other for // the force close tx. - txns = ht.Miner.GetNumTxsFromMempool(2) - - // Find the sweeping tx. - currentSweepTx := ht.FindSweepingTxns(txns, 1, *closeTxid)[0] + _, currentSweepTx := ht.AssertForceCloseAndAnchorTxnsInMempool() // Assert the anchor sweep tx stays unchanged. require.Equal(ht, sweepTx.TxHash(), currentSweepTx.TxHash()) @@ -729,7 +652,7 @@ func testSweepHTLCs(ht *lntest.HarnessTest) { // Start tracking the deadline delta of Bob's HTLCs. We need one block // for the CSV lock, and another block to trigger the sweeper to sweep. outgoingHTLCDeadline := int32(cltvDelta - 2) - incomingHTLCDeadline := int32(lncfg.DefaultIncomingBroadcastDelta - 2) + incomingHTLCDeadline := int32(lncfg.DefaultIncomingBroadcastDelta - 3) // startFeeRate1 and startFeeRate2 are returned by the fee estimator in // sat/kw. They will be used as the starting fee rate for the linear @@ -884,34 +807,33 @@ func testSweepHTLCs(ht *lntest.HarnessTest) { // Bob should now have two pending sweeps, one for the anchor on the // local commitment, the other on the remote commitment. - ht.AssertNumPendingSweeps(bob, 2) + ht.AssertNumPendingSweeps(bob, 1) - // Assert Bob's force closing tx has been broadcast. - ht.Miner.AssertNumTxsInMempool(1) + // We expect to see two txns in the mempool: + // 1. Bob's force closing tx. + // 1. Bob's anchor CPFP sweeping tx. + ht.Miner.AssertNumTxsInMempool(2) - // Mine the force close tx, which triggers Bob's contractcourt to offer - // his outgoing HTLC to his sweeper. + // Mine the force close tx and CPFP sweeping tx, which triggers Bob's + // contractcourt to offer his outgoing HTLC to his sweeper. // // NOTE: HTLC outputs are only offered to sweeper when the force close // tx is confirmed and the CSV has reached. - ht.MineBlocksAndAssertNumTxes(1, 1) + ht.MineBlocksAndAssertNumTxes(1, 2) // Update the blocks left till Bob force closes Alice->Bob. blocksTillIncomingSweep-- - // Bob should have two pending sweeps, one for the anchor sweeping, the - // other for the outgoing HTLC. - ht.AssertNumPendingSweeps(bob, 2) + // Bob should have one pending sweep for the outgoing HTLC. + ht.AssertNumPendingSweeps(bob, 1) - // Mine one block to confirm Bob's anchor sweeping tx, which will - // trigger his sweeper to publish the HTLC sweeping tx. - ht.MineBlocksAndAssertNumTxes(1, 1) + // Mine a block to trigger Bob's sweeper to sweep his outgoing HTLC. + ht.MineEmptyBlocks(1) // Update the blocks left till Bob force closes Alice->Bob. blocksTillIncomingSweep-- - // Bob should now have one sweep and one sweeping tx in the mempool. - ht.AssertNumPendingSweeps(bob, 1) + // Bob should have one sweeping tx in the mempool. outgoingSweep := ht.Miner.GetNumTxsFromMempool(1)[0] // Check the shape of the sweeping tx - we expect it to be @@ -1009,22 +931,25 @@ func testSweepHTLCs(ht *lntest.HarnessTest) { // Update Bob's fee function position. outgoingFuncPosition++ - // Bob should now have three pending sweeps: + // Bob should now have two pending sweeps: // 1. the outgoing HTLC output. // 2. the anchor output from his local commitment. - // 3. the anchor output from his remote commitment. - ht.AssertNumPendingSweeps(bob, 3) + ht.AssertNumPendingSweeps(bob, 2) - // We should see two txns in the mempool: + // We should see three txns in the mempool: // 1. Bob's outgoing HTLC sweeping tx. // 2. Bob's force close tx for Alice->Bob. - txns := ht.Miner.GetNumTxsFromMempool(2) + // 2. Bob's anchor CPFP sweeping tx for Alice->Bob. + txns := ht.Miner.GetNumTxsFromMempool(3) // Find the force close tx - we expect it to have a single input. closeTx := txns[0] if len(closeTx.TxIn) != 1 { closeTx = txns[1] } + if len(closeTx.TxIn) != 1 { + closeTx = txns[2] + } // We don't care the behavior of the anchor sweep in this test, so we // mine the force close tx to trigger Bob's contractcourt to offer his @@ -1034,6 +959,10 @@ func testSweepHTLCs(ht *lntest.HarnessTest) { // Update Bob's fee function position. outgoingFuncPosition++ + // Mine a block to trigger the sweep. + ht.MineEmptyBlocks(1) + outgoingFuncPosition++ + // Bob should now have three pending sweeps: // 1. the outgoing HTLC output on Bob->Carol. // 2. the incoming HTLC output on Alice->Bob. @@ -1228,10 +1157,15 @@ func testSweepCommitOutputAndAnchor(ht *lntest.HarnessTest) { // config. deadline := uint32(1000) - // The actual deadline used by the fee function will be one block off - // from the deadline configured as we require one block to be mined to - // trigger the sweep. - deadlineA, deadlineB := deadline-1, deadline-1 + // For Alice, since her commit output is offered to the sweeper at + // CSV-1. With a deadline of 1000, her actual width of her fee func is + // CSV+1000-1. + deadlineA := deadline + 1 + + // For Bob, the actual deadline used by the fee function will be one + // block off from the deadline configured as we require one block to be + // mined to trigger the sweep. + deadlineB := deadline - 1 // startFeeRate is returned by the fee estimator in sat/kw. This // will be used as the starting fee rate for the linear fee func used @@ -1242,7 +1176,7 @@ func testSweepCommitOutputAndAnchor(ht *lntest.HarnessTest) { // Set up the fee estimator to return the testing fee rate when the // conf target is the deadline. - ht.SetFeeEstimateWithConf(startFeeRate, deadlineA) + ht.SetFeeEstimateWithConf(startFeeRate, deadlineB) // toLocalCSV is the CSV delay for Alice's to_local output. We use a // small value to save us from mining blocks. @@ -1250,25 +1184,7 @@ func testSweepCommitOutputAndAnchor(ht *lntest.HarnessTest) { // NOTE: once the force close tx is confirmed, we expect anchor // sweeping starts. Then two more block later the commit output // sweeping starts. - // - // NOTE: The CSV value is chosen to be 3 instead of 2, to reduce the - // possibility of flakes as there is a race between the two goroutines: - // G1 - Alice's sweeper receives the commit output. - // G2 - Alice's sweeper receives the new block mined. - // G1 is triggered by the same block being received by Alice's - // contractcourt, deciding the commit output is mature and offering it - // to her sweeper. Normally, we'd expect G2 to be finished before G1 - // because it's the same block processed by both contractcourt and - // sweeper. However, if G2 is delayed (maybe the sweeper is slow in - // finishing its previous round), G1 may finish before G2. This will - // cause the sweeper to add the commit output to its pending inputs, - // and once G2 fires, it will then start sweeping this output, - // resulting a valid sweep tx being created using her commit and anchor - // outputs. - // - // TODO(yy): fix the above issue by making sure subsystems share the - // same view on current block height. - toLocalCSV := 3 + toLocalCSV := 2 // htlcAmt is the amount of the HTLC in sats, this should be Alice's // to_remote amount that goes to Bob. @@ -1362,140 +1278,18 @@ func testSweepCommitOutputAndAnchor(ht *lntest.HarnessTest) { ht.AssertNumPendingSweeps(bob, 2) // Mine one more empty block should trigger Bob's sweeping. Since we - // use a CSV of 3, this means Alice's to_local output is one block away - // from being mature. + // use a CSV of 2, this means Alice's to_local output is now mature. ht.MineEmptyBlocks(1) - // We expect to see one sweeping tx in the mempool: - // - Alice's anchor sweeping tx must have been failed due to the fee - // rate chosen in this test - the anchor sweep tx has no output. - // - Bob's sweeping tx, which sweeps both his anchor and commit outputs. - bobSweepTx := ht.Miner.GetNumTxsFromMempool(1)[0] - // We expect two pending sweeps for Bob - anchor and commit outputs. - pendingSweepBob := ht.AssertNumPendingSweeps(bob, 2)[0] - - // The sweeper may be one block behind contractcourt, so we double - // check the actual deadline. - // - // TODO(yy): assert they are equal once blocks are synced via - // `blockbeat`. - _, currentHeight := ht.Miner.GetBestBlock() - actualDeadline := int32(pendingSweepBob.DeadlineHeight) - currentHeight - if actualDeadline != int32(deadlineB) { - ht.Logf("!!! Found unsynced block between sweeper and "+ - "contractcourt, expected deadline=%v, got=%v", - deadlineB, actualDeadline) - - deadlineB = uint32(actualDeadline) - } - - // Alice should still have one pending sweep - the anchor output. - ht.AssertNumPendingSweeps(alice, 1) - - // We now check Bob's sweeping tx. - // - // Bob's sweeping tx should have 2 inputs, one from his commit output, - // the other from his anchor output. - require.Len(ht, bobSweepTx.TxIn, 2) - - // Because Bob is sweeping without deadline pressure, the starting fee - // rate should be the min relay fee rate. - bobStartFeeRate := ht.CalculateTxFeeRate(bobSweepTx) - require.InEpsilonf(ht, uint64(chainfee.FeePerKwFloor), - uint64(bobStartFeeRate), 0.01, "want %v, got %v", - chainfee.FeePerKwFloor, bobStartFeeRate) - - // With Bob's starting fee rate being validated, we now calculate his - // ending fee rate and fee rate delta. - // - // Bob sweeps two inputs - anchor and commit, so the starting budget - // should come from the sum of these two. - bobValue := btcutil.Amount(bobToLocal + 330) - bobBudget := bobValue.MulF64(contractcourt.DefaultBudgetRatio) - - // Calculate the ending fee rate and fee rate delta used in his fee - // function. - bobTxWeight := ht.CalculateTxWeight(bobSweepTx) - bobEndingFeeRate := chainfee.NewSatPerKWeight(bobBudget, bobTxWeight) - bobFeeRateDelta := (bobEndingFeeRate - bobStartFeeRate) / - chainfee.SatPerKWeight(deadlineB-1) - - // Mine an empty block, which should trigger Alice's contractcourt to - // offer her commit output to the sweeper. - ht.MineEmptyBlocks(1) - - // Alice should have both anchor and commit as the pending sweep - // requests. - aliceSweeps := ht.AssertNumPendingSweeps(alice, 2) - aliceAnchor, aliceCommit := aliceSweeps[0], aliceSweeps[1] - if aliceAnchor.AmountSat > aliceCommit.AmountSat { - aliceAnchor, aliceCommit = aliceCommit, aliceAnchor - } - - // The sweeper may be one block behind contractcourt, so we double - // check the actual deadline. - // - // TODO(yy): assert they are equal once blocks are synced via - // `blockbeat`. - _, currentHeight = ht.Miner.GetBestBlock() - actualDeadline = int32(aliceCommit.DeadlineHeight) - currentHeight - if actualDeadline != int32(deadlineA) { - ht.Logf("!!! Found unsynced block between Alice's sweeper and "+ - "contractcourt, expected deadline=%v, got=%v", - deadlineA, actualDeadline) - - deadlineA = uint32(actualDeadline) - } - - // We now wait for 30 seconds to overcome the flake - there's a block - // race between contractcourt and sweeper, causing the sweep to be - // broadcast earlier. - // - // TODO(yy): remove this once `blockbeat` is in place. - aliceStartPosition := 0 - var aliceFirstSweepTx *wire.MsgTx - err := wait.NoError(func() error { - mem := ht.Miner.GetRawMempool() - if len(mem) != 2 { - return fmt.Errorf("want 2, got %v in mempool: %v", - len(mem), mem) - } - - // If there are two txns, it means Alice's sweep tx has been - // created and published. - aliceStartPosition = 1 - - txns := ht.Miner.GetNumTxsFromMempool(2) - aliceFirstSweepTx = txns[0] - - // Reassign if the second tx is larger. - if txns[1].TxOut[0].Value > aliceFirstSweepTx.TxOut[0].Value { - aliceFirstSweepTx = txns[1] - } - - return nil - }, wait.DefaultTimeout) - ht.Logf("Checking mempool got: %v", err) - - // Mine an empty block, which should trigger Alice's sweeper to publish - // her commit sweep along with her anchor output. - ht.MineEmptyBlocks(1) + ht.AssertNumPendingSweeps(bob, 2) - // If Alice has already published her initial sweep tx, the above mined - // block would trigger an RBF. We now need to assert the mempool has - // removed the replaced tx. - if aliceFirstSweepTx != nil { - ht.Miner.AssertTxNotInMempool(aliceFirstSweepTx.TxHash()) - } + // We expect two pending sweeps for Alice - anchor and commit outputs. + ht.AssertNumPendingSweeps(alice, 2) // We also remember the positions of fee functions used by Alice and // Bob. They will be used to calculate the expected fee rates later. - // - // Alice's sweeping tx has just been created, so she is at the starting - // position. For Bob, due to the above mined blocks, his fee function - // is now at position 2. - alicePosition, bobPosition := uint32(aliceStartPosition), uint32(2) + alicePosition, bobPosition := uint32(0), uint32(0) // We should see two txns in the mempool: // - Alice's sweeping tx, which sweeps her commit output at the @@ -1508,8 +1302,7 @@ func testSweepCommitOutputAndAnchor(ht *lntest.HarnessTest) { // Assume the first tx is Alice's sweeping tx, if the second tx has a // larger output value, then that's Alice's as her to_local value is // much gearter. - aliceSweepTx := txns[0] - bobSweepTx = txns[1] + aliceSweepTx, bobSweepTx := txns[0], txns[1] // Swap them if bobSweepTx is smaller. if bobSweepTx.TxOut[0].Value > aliceSweepTx.TxOut[0].Value { @@ -1523,20 +1316,6 @@ func testSweepCommitOutputAndAnchor(ht *lntest.HarnessTest) { require.Len(ht, aliceSweepTx.TxIn, 1) require.Len(ht, aliceSweepTx.TxOut, 1) - // We now check Alice's sweeping tx to see if it's already published. - // - // TODO(yy): remove this check once we have better block control. - aliceSweeps = ht.AssertNumPendingSweeps(alice, 2) - aliceCommit = aliceSweeps[0] - if aliceCommit.AmountSat < aliceSweeps[1].AmountSat { - aliceCommit = aliceSweeps[1] - } - if aliceCommit.BroadcastAttempts > 1 { - ht.Logf("!!! Alice's commit sweep has already been broadcast, "+ - "broadcast_attempts=%v", aliceCommit.BroadcastAttempts) - alicePosition = aliceCommit.BroadcastAttempts - } - // Alice's sweeping tx should use the min relay fee rate as there's no // deadline pressure. aliceStartingFeeRate := chainfee.FeePerKwFloor @@ -1551,7 +1330,7 @@ func testSweepCommitOutputAndAnchor(ht *lntest.HarnessTest) { aliceTxWeight := uint64(ht.CalculateTxWeight(aliceSweepTx)) aliceEndingFeeRate := sweep.DefaultMaxFeeRate.FeePerKWeight() aliceFeeRateDelta := (aliceEndingFeeRate - aliceStartingFeeRate) / - chainfee.SatPerKWeight(deadlineA-1) + chainfee.SatPerKWeight(deadlineA) aliceFeeRate := ht.CalculateTxFeeRate(aliceSweepTx) expectedFeeRateAlice := aliceStartingFeeRate + @@ -1560,111 +1339,35 @@ func testSweepCommitOutputAndAnchor(ht *lntest.HarnessTest) { uint64(aliceFeeRate), 0.02, "want %v, got %v", expectedFeeRateAlice, aliceFeeRate) - // We now check Bob' sweeping tx. - // - // The above mined block will trigger Bob's sweeper to RBF his previous - // sweeping tx, which will fail due to RBF rule#4 - the additional fees - // paid are not sufficient. This happens as our default incremental - // relay fee rate is 1 sat/vb, with the tx size of 771 weight units, or - // 192 vbytes, we need to pay at least 192 sats more to be able to RBF. - // However, since Bob's budget delta is (100_000 + 330) * 0.5 / 1008 = - // 49.77 sats, it means Bob can only perform a successful RBF every 4 - // blocks. - // - // Assert Bob's sweeping tx is not RBFed. - bobFeeRate := ht.CalculateTxFeeRate(bobSweepTx) - expectedFeeRateBob := bobStartFeeRate - require.InEpsilonf(ht, uint64(expectedFeeRateBob), uint64(bobFeeRate), - 0.01, "want %d, got %d", expectedFeeRateBob, bobFeeRate) - - // reloclateAlicePosition is a temp hack to find the actual fee - // function position used for Alice. Due to block sync issue among the - // subsystems, we can end up having this situation: - // - sweeper is at block 2, starts sweeping an input with deadline 100. - // - fee bumper is at block 1, and thinks the conf target is 99. - // - new block 3 arrives, the func now is at position 2. + // We now check Bob's sweeping tx. // - // TODO(yy): fix it using `blockbeat`. - reloclateAlicePosition := func() { - // Mine an empty block to trigger the possible RBF attempts. - ht.MineEmptyBlocks(1) - - // Increase the positions for both fee functions. - alicePosition++ - bobPosition++ - - // We expect two pending sweeps for both nodes as we are mining - // empty blocks. - ht.AssertNumPendingSweeps(alice, 2) - ht.AssertNumPendingSweeps(bob, 2) - - // We expect to see both Alice's and Bob's sweeping txns in the - // mempool. - ht.Miner.AssertNumTxsInMempool(2) - - // Make sure Alice's old sweeping tx has been removed from the - // mempool. - ht.Miner.AssertTxNotInMempool(aliceSweepTx.TxHash()) - - // We should see two txns in the mempool: - // - Alice's sweeping tx, which sweeps both her anchor and - // commit outputs, using the increased fee rate. - // - Bob's previous sweeping tx, which sweeps both his anchor - // and commit outputs, at the possible increased fee rate. - txns = ht.Miner.GetNumTxsFromMempool(2) - - // Assume the first tx is Alice's sweeping tx, if the second tx - // has a larger output value, then that's Alice's as her - // to_local value is much gearter. - aliceSweepTx = txns[0] - bobSweepTx = txns[1] - - // Swap them if bobSweepTx is smaller. - if bobSweepTx.TxOut[0].Value > aliceSweepTx.TxOut[0].Value { - aliceSweepTx, bobSweepTx = bobSweepTx, aliceSweepTx - } - - // Alice's sweeping tx should be increased. - aliceFeeRate := ht.CalculateTxFeeRate(aliceSweepTx) - expectedFeeRate := aliceStartingFeeRate + - aliceFeeRateDelta*chainfee.SatPerKWeight(alicePosition) - - ht.Logf("Alice(deadline=%v): txWeight=%v, want feerate=%v, "+ - "got feerate=%v, delta=%v", deadlineA-alicePosition, - aliceTxWeight, expectedFeeRate, aliceFeeRate, - aliceFeeRateDelta) - - nextPosition := alicePosition + 1 - nextFeeRate := aliceStartingFeeRate + - aliceFeeRateDelta*chainfee.SatPerKWeight(nextPosition) - - // Calculate the distances. - delta := math.Abs(float64(aliceFeeRate - expectedFeeRate)) - deltaNext := math.Abs(float64(aliceFeeRate - nextFeeRate)) - - // Exit early if the first distance is smaller - it means we - // are at the right fee func position. - if delta < deltaNext { - require.InEpsilonf(ht, uint64(expectedFeeRate), - uint64(aliceFeeRate), 0.02, "want %v, got %v "+ - "in tx=%v", expectedFeeRate, - aliceFeeRate, aliceSweepTx.TxHash()) + // Bob's sweeping tx should have 2 inputs, one from his commit output, + // the other from his anchor output. + require.Len(ht, bobSweepTx.TxIn, 2) - return - } + // Because Bob is sweeping without deadline pressure, the starting fee + // rate should be the min relay fee rate. + bobStartFeeRate := ht.CalculateTxFeeRate(bobSweepTx) + require.InEpsilonf(ht, uint64(chainfee.FeePerKwFloor), + uint64(bobStartFeeRate), 0.01, "want %v, got %v", + chainfee.FeePerKwFloor, bobStartFeeRate) - alicePosition++ - ht.Logf("Jump position for Alice(deadline=%v): txWeight=%v, "+ - "want feerate=%v, got feerate=%v, delta=%v", - deadlineA-alicePosition, aliceTxWeight, nextFeeRate, - aliceFeeRate, aliceFeeRateDelta) + // With Bob's starting fee rate being validated, we now calculate his + // ending fee rate and fee rate delta. + // + // Bob sweeps two inputs - anchor and commit, so the starting budget + // should come from the sum of these two. + bobValue := btcutil.Amount(bobToLocal + 330) + bobBudget := bobValue.MulF64(contractcourt.DefaultBudgetRatio) - require.InEpsilonf(ht, uint64(nextFeeRate), - uint64(aliceFeeRate), 0.02, "want %v, got %v in tx=%v", - nextFeeRate, aliceFeeRate, aliceSweepTx.TxHash()) - } + // Calculate the ending fee rate and fee rate delta used in his fee + // function. + bobTxWeight := ht.CalculateTxWeight(bobSweepTx) + bobEndingFeeRate := chainfee.NewSatPerKWeight(bobBudget, bobTxWeight) + bobFeeRateDelta := (bobEndingFeeRate - bobStartFeeRate) / + chainfee.SatPerKWeight(deadlineB-1) - reloclateAlicePosition() + expectedFeeRateBob := bobStartFeeRate // We now mine 7 empty blocks. For each block mined, we'd see Alice's // sweeping tx being RBFed. For Bob, he performs a fee bump every @@ -1672,7 +1375,7 @@ func testSweepCommitOutputAndAnchor(ht *lntest.HarnessTest) { // the fee bumps is not sufficient to meet the fee requirements // enforced by RBF. Since his fee function is already at position 1, // mining 7 more blocks means he will RBF his sweeping tx twice. - for i := 1; i < 7; i++ { + for i := 1; i < 8; i++ { // Mine an empty block to trigger the possible RBF attempts. ht.MineEmptyBlocks(1) diff --git a/lntest/harness_assertion.go b/lntest/harness_assertion.go index f8c1c9716cb..81dc26e4196 100644 --- a/lntest/harness_assertion.go +++ b/lntest/harness_assertion.go @@ -2633,3 +2633,36 @@ func (h *HarnessTest) FindSweepingTxns(txns []*wire.MsgTx, return sweepTxns } + +// AssertForceCloseAndAnchorTxnsInMempool asserts that the force close and +// anchor sweep txns are found in the mempool and returns the force close tx +// and the anchor sweep tx. +func (h *HarnessTest) AssertForceCloseAndAnchorTxnsInMempool() (*wire.MsgTx, + *wire.MsgTx) { + + // Assert there are two txns in the mempool. + txns := h.Miner.GetNumTxsFromMempool(2) + + // Assume the first is the force close tx. + forceCloseTx, anchorSweepTx := txns[0], txns[1] + + // Get the txid. + closeTxid := forceCloseTx.TxHash() + + // We now check whether there is an anchor input used in the assumed + // anchorSweepTx by checking every input's previous outpoint against + // the assumed closingTxid. If we fail to find one, it means the first + // item from the above txns is the anchor sweeping tx. + for _, inp := range anchorSweepTx.TxIn { + if inp.PreviousOutPoint.Hash == closeTxid { + // Found a match, this is indeed the anchor sweeping tx + // so we return it here. + return forceCloseTx, anchorSweepTx + } + } + + // The assumed order is incorrect so we swap and return. + forceCloseTx, anchorSweepTx = anchorSweepTx, forceCloseTx + + return forceCloseTx, anchorSweepTx +} diff --git a/log.go b/log.go index f6da0235a92..fc3238b509e 100644 --- a/log.go +++ b/log.go @@ -7,6 +7,7 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/autopilot" "github.com/lightningnetwork/lnd/build" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/chanacceptor" @@ -164,6 +165,7 @@ func SetupLoggers(root *build.RotatingLogWriter, interceptor signal.Interceptor) AddSubLogger(root, "CHFD", interceptor, chanfunding.UseLogger) AddSubLogger(root, "PEER", interceptor, peer.UseLogger) AddSubLogger(root, "CHCL", interceptor, chancloser.UseLogger) + AddSubLogger(root, "CHIO", interceptor, chainio.UseLogger) AddSubLogger(root, routing.Subsystem, interceptor, routing.UseLogger) AddSubLogger(root, routerrpc.Subsystem, interceptor, routerrpc.UseLogger) diff --git a/server.go b/server.go index 2b54f81d6a8..49db7435072 100644 --- a/server.go +++ b/server.go @@ -27,6 +27,7 @@ import ( "github.com/lightningnetwork/lnd/aliasmgr" "github.com/lightningnetwork/lnd/autopilot" "github.com/lightningnetwork/lnd/brontide" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/chanacceptor" "github.com/lightningnetwork/lnd/chanbackup" @@ -330,6 +331,10 @@ type server struct { // txPublisher is a publisher with fee-bumping capability. txPublisher *sweep.TxPublisher + // blockbeat is a block dispatcher that notifies subscribers of new + // blocks. + blockbeat *chainio.BlockBeat + quit chan struct{} wg sync.WaitGroup @@ -568,6 +573,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s := &server{ cfg: cfg, + blockbeat: chainio.NewBlockBeat(cc.ChainNotifier), graphDB: dbs.GraphDB.ChannelGraph(), chanStateDB: dbs.ChanStateDB.ChannelStateDB(), addrSource: dbs.ChanStateDB, @@ -1677,9 +1683,31 @@ func newServer(cfg *Config, listenAddrs []net.Addr, } s.connMgr = cmgr + // Finally, register the subsystems in blockbeat. + s.registerBlockConsumers() + return s, nil } +// registerBlockConsumers registers the subsystems that consume block events. +// By calling `RegisterQueue`, a list of subsystems are registered in the +// blockbeat for block notifications. When a new block arrives, the subsystems +// in the same queue are notified sequentially, and different queues are +// notified concurrently. +// +// NOTE: To put a subsystem in a different queue, create a slice and pass it to +// a new `RegisterQueue` call. +func (s *server) registerBlockConsumers() { + // In this queue, when a new block arrives, it will be received and + // processed in this order: chainArb -> sweeper -> txPublisher. + consumers := []chainio.Consumer{ + s.chainArb, + s.sweeper, + s.txPublisher, + } + s.blockbeat.RegisterQueue(consumers) +} + // signAliasUpdate takes a ChannelUpdate and returns the signature. This is // used for option_scid_alias channels where the ChannelUpdate to be sent back // may differ from what is on disk. @@ -2110,6 +2138,17 @@ func (s *server) Start() error { return nil }) + // Start the blockbeat after all other subsystems have been + // started so they are ready to receive new blocks. + if err := s.blockbeat.Start(); err != nil { + startErr = err + return + } + cleanup = cleanup.add(func() error { + s.blockbeat.Stop() + return nil + }) + // If peers are specified as a config option, we'll add those // peers first. for _, peerAddrCfg := range s.cfg.AddPeers { @@ -2297,6 +2336,7 @@ func (s *server) Stop() error { } s.txPublisher.Stop() + s.blockbeat.Stop() if err := s.channelNotifier.Stop(); err != nil { srvrLog.Warnf("failed to stop channelNotifier: %v", err) diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index b4d4298943b..6bce81b3b61 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/chain" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" @@ -52,7 +53,7 @@ type Bumper interface { // and monitors its confirmation status for potential fee bumping. It // returns a chan that the caller can use to receive updates about the // broadcast result and potential RBF attempts. - Broadcast(req *BumpRequest) (<-chan *BumpResult, error) + Broadcast(req *BumpRequest) <-chan *BumpResult } // BumpEvent represents the event of a fee bumping attempt. @@ -71,6 +72,9 @@ const ( // TxConfirmed is sent when the tx is confirmed. TxConfirmed + // TxError is sent when there's an error creating the tx. + TxError + // sentinalEvent is used to check if an event is unknown. sentinalEvent ) @@ -86,6 +90,8 @@ func (e BumpEvent) String() string { return "Replaced" case TxConfirmed: return "Confirmed" + case TxError: + return "Error" default: return "Unknown" } @@ -204,6 +210,16 @@ type BumpResult struct { requestID uint64 } +// String returns a human-readable string for the result. +func (b *BumpResult) String() string { + desc := fmt.Sprintf("Event=%v", b.Event) + if b.Tx != nil { + desc += fmt.Sprintf(", Tx=%v", b.Tx.TxHash()) + } + + return fmt.Sprintf("[%s]", desc) +} + // Validate validates the BumpResult so it's safe to use. func (b *BumpResult) Validate() error { // Every result must have a tx. @@ -282,6 +298,11 @@ type TxPublisher struct { // the chan that the publisher sends the fee bump result to. subscriberChans lnutils.SyncMap[uint64, chan *BumpResult] + // blockBeatChan is a channel to receive blocks from BlockBeat. The + // received block contains the best known height and the transactions + // confirmed in this block. + blockBeatChan chan chainio.Beat + // quit is used to signal the publisher to stop. quit chan struct{} } @@ -296,6 +317,7 @@ func NewTxPublisher(cfg TxPublisherConfig) *TxPublisher { records: lnutils.SyncMap[uint64, *monitorRecord]{}, subscriberChans: lnutils.SyncMap[uint64, chan *BumpResult]{}, quit: make(chan struct{}), + blockBeatChan: make(chan chainio.Beat), } } @@ -304,42 +326,57 @@ func (t *TxPublisher) isNeutrinoBackend() bool { return t.cfg.Wallet.BackEnd() == "neutrino" } -// Broadcast is used to publish the tx created from the given inputs. It will, -// 1. init a fee function based on the given strategy. -// 2. create an RBF-compliant tx and monitor it for confirmation. -// 3. notify the initial broadcast result back to the caller. -// The initial broadcast is guaranteed to be RBF-compliant unless the budget -// specified cannot cover the fee. +// Broadcast is used to publish the tx created from the given inputs. It will +// register the broadcast request and return a chan to the caller to subscribe +// the broadcast result. The initial broadcast is guaranteed to be +// RBF-compliant unless the budget specified cannot cover the fee. // // NOTE: part of the Bumper interface. -func (t *TxPublisher) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { +func (t *TxPublisher) Broadcast(req *BumpRequest) <-chan *BumpResult { log.Tracef("Received broadcast request: %s", newLogClosure( func() string { return spew.Sdump(req) })()) - // Attempt an initial broadcast which is guaranteed to comply with the - // RBF rules. - result, err := t.initialBroadcast(req) - if err != nil { - log.Errorf("Initial broadcast failed: %v", err) + // Increase the request counter. + // + // NOTE: this is the only place where we increase the counter. + requestID := t.requestCounter.Add(1) - return nil, err - } + // Register the record. + t.records.Store(requestID, &monitorRecord{req: req}) // Create a chan to send the result to the caller. subscriber := make(chan *BumpResult, 1) - t.subscriberChans.Store(result.requestID, subscriber) + t.subscriberChans.Store(requestID, subscriber) - // Send the initial broadcast result to the caller. - t.handleResult(result) + return subscriber +} - return subscriber, nil +// NOTE: part of the `chainio.Consumer` interface. +func (t *TxPublisher) ProcessBlock(beat chainio.Beat) <-chan error { + select { + case t.blockBeatChan <- beat: + log.Tracef("TxPublisher received block beat for height=%d", + beat.Epoch.Height) + + case <-t.quit: + return nil + } + + return beat.Err +} + +// NOTE: part of the `chainio.Consumer` interface. +func (t *TxPublisher) Name() string { + return "tx publisher" } // initialBroadcast initializes a fee function, creates an RBF-compliant tx and // broadcasts it. -func (t *TxPublisher) initialBroadcast(req *BumpRequest) (*BumpResult, error) { +func (t *TxPublisher) initialBroadcast(requestID uint64, + req *BumpRequest) (*BumpResult, error) { + // Create a fee bumping algorithm to be used for future RBF. feeAlgo, err := t.initializeFeeFunction(req) if err != nil { @@ -348,7 +385,7 @@ func (t *TxPublisher) initialBroadcast(req *BumpRequest) (*BumpResult, error) { // Create the initial tx to be broadcasted. This tx is guaranteed to // comply with the RBF restrictions. - requestID, err := t.createRBFCompliantTx(req, feeAlgo) + err = t.createRBFCompliantTx(requestID, req, feeAlgo) if err != nil { return nil, fmt.Errorf("create RBF-compliant tx: %w", err) } @@ -395,8 +432,8 @@ func (t *TxPublisher) initializeFeeFunction( // so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee // and redo the process until the tx is valid, or return an error when non-RBF // related errors occur or the budget has been used up. -func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, - f FeeFunction) (uint64, error) { +func (t *TxPublisher) createRBFCompliantTx(requestID uint64, req *BumpRequest, + f FeeFunction) error { for { // Create a new tx with the given fee rate and check its @@ -405,15 +442,15 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, switch { case err == nil: - // The tx is valid, return the request ID. - requestID := t.storeRecord(tx, req, f, fee) + // The tx is valid, store it. + t.storeRecord(requestID, tx, req, f, fee) - log.Infof("Created tx %v for %v inputs: feerate=%v, "+ - "fee=%v, inputs=%v", tx.TxHash(), + log.Infof("Created initial sweep tx=%v for %v inputs: "+ + "feerate=%v, fee=%v, inputs:\n%v", tx.TxHash(), len(req.Inputs), f.FeeRate(), fee, inputTypeSummary(req.Inputs)) - return requestID, nil + return nil // If the error indicates the fees paid is not enough, we will // ask the fee function to increase the fee rate and retry. @@ -444,7 +481,7 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, // cluster these inputs differetly. increased, err = f.Increment() if err != nil { - return 0, err + return err } } @@ -454,20 +491,14 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, // mempool acceptance. default: log.Debugf("Failed to create RBF-compliant tx: %v", err) - return 0, err + return err } } } // storeRecord stores the given record in the records map. -func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest, - f FeeFunction, fee btcutil.Amount) uint64 { - - // Increase the request counter. - // - // NOTE: this is the only place where we increase the - // counter. - requestID := t.requestCounter.Add(1) +func (t *TxPublisher) storeRecord(requestID uint64, tx *wire.MsgTx, + req *BumpRequest, f FeeFunction, fee btcutil.Amount) { // Register the record. t.records.Store(requestID, &monitorRecord{ @@ -476,8 +507,6 @@ func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest, feeFunction: f, fee: fee, }) - - return requestID } // createAndCheckTx creates a tx based on the given inputs, change output @@ -592,8 +621,7 @@ func (t *TxPublisher) notifyResult(result *BumpResult) { return } - log.Debugf("Sending result for requestID=%v, tx=%v", id, - result.Tx.TxHash()) + log.Debugf("Sending result %v for requestID=%v", result, id) select { // Send the result to the subscriber. @@ -622,10 +650,14 @@ func (t *TxPublisher) removeResult(result *BumpResult) { id, result.Tx.TxHash(), result.Err) case TxConfirmed: - // Remove the record is the tx is confirmed. + // Remove the record if the tx is confirmed. log.Debugf("Removing confirmed monitor record=%v, tx=%v", id, result.Tx.TxHash()) + case TxError: + // Remove the record if there's an error. + log.Debugf("Removing monitor record=%v due to error", id) + // Do nothing if it's neither failed or confirmed. default: log.Tracef("Skipping record removal for id=%v, event=%v", id, @@ -671,13 +703,8 @@ func (t *TxPublisher) Start() error { log.Info("TxPublisher starting...") defer log.Debugf("TxPublisher started") - blockEvent, err := t.cfg.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return fmt.Errorf("register block epoch ntfn: %w", err) - } - t.wg.Add(1) - go t.monitor(blockEvent) + go t.monitor() return nil } @@ -697,13 +724,12 @@ func (t *TxPublisher) Stop() { // to be bumped. If so, it will attempt to bump the fee of the tx. // // NOTE: Must be run as a goroutine. -func (t *TxPublisher) monitor(blockEvent *chainntnfs.BlockEpochEvent) { - defer blockEvent.Cancel() +func (t *TxPublisher) monitor() { defer t.wg.Done() for { select { - case epoch, ok := <-blockEvent.Epochs: + case beat, ok := <-t.blockBeatChan: if !ok { // We should stop the publisher before stopping // the chain service. Otherwise it indicates an @@ -714,6 +740,7 @@ func (t *TxPublisher) monitor(blockEvent *chainntnfs.BlockEpochEvent) { return } + epoch := beat.Epoch log.Debugf("TxPublisher received new block: %v", epoch.Height) @@ -724,6 +751,9 @@ func (t *TxPublisher) monitor(blockEvent *chainntnfs.BlockEpochEvent) { // to be bumped. t.processRecords() + // Notify we've processed the block. + fn.SendOrQuit(beat.Err, nil, t.quit) + case <-t.quit: log.Debug("Fee bumper stopped, exit monitor") return @@ -738,18 +768,27 @@ func (t *TxPublisher) processRecords() { // confirmed. confirmedRecords := make(map[uint64]*monitorRecord) - // feeBumpRecords stores a map of the records which need to be bumped. + // feeBumpRecords stores a map of records which need to be bumped. feeBumpRecords := make(map[uint64]*monitorRecord) - // failedRecords stores a map of the records which has inputs being - // spent by a third party. + // failedRecords stores a map of records which has inputs being spent + // by a third party. // // NOTE: this is only used for neutrino backend. failedRecords := make(map[uint64]*monitorRecord) + // initialRecords stores a map of records which are being created and + // published for the first time. + initialRecords := make(map[uint64]*monitorRecord) + // visitor is a helper closure that visits each record and divides them // into two groups. visitor := func(requestID uint64, r *monitorRecord) error { + if r.tx == nil { + initialRecords[requestID] = r + return nil + } + log.Tracef("Checking monitor recordID=%v for tx=%v", requestID, r.tx.TxHash()) @@ -777,9 +816,18 @@ func (t *TxPublisher) processRecords() { return nil } - // Iterate through all the records and divide them into two groups. + // Iterate through all the records and divide them into four groups. t.records.ForEach(visitor) + // Handle the initial broadcast. + for requestID, r := range initialRecords { + rec := r + + log.Debugf("Initial broadcast for requestID=%v", requestID) + t.wg.Add(1) + go t.handleInitialBroadcast(rec, requestID) + } + // For records that are confirmed, we'll notify the caller about this // result. for requestID, r := range confirmedRecords { @@ -835,6 +883,74 @@ func (t *TxPublisher) handleTxConfirmed(r *monitorRecord, requestID uint64) { t.handleResult(result) } +// handleInitialBroadcast is called when a new request is received. It will +// handle the initial tx creation and broadcast. In details, +// 1. init a fee function based on the given strategy. +// 2. create an RBF-compliant tx and monitor it for confirmation. +// 3. notify the initial broadcast result back to the caller. +// +// NOTE: Must be run as a goroutine to avoid blocking on sending the result. +func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord, + requestID uint64) { + + defer t.wg.Done() + + var ( + result *BumpResult + err error + ) + + // Attempt an initial broadcast which is guaranteed to comply with the + // RBF rules. + result, err = t.initialBroadcast(requestID, r.req) + if err != nil { + log.Errorf("Initial broadcast failed: %v", err) + + // Create a tx so the caller knowns which inputs are failed. + sweepTx := wire.NewMsgTx(2) + for _, o := range r.req.Inputs { + sweepTx.AddTxIn(&wire.TxIn{ + PreviousOutPoint: o.OutPoint(), + Sequence: o.BlocksToMaturity(), + }) + } + + // We now decide what type of event to send. + var event BumpEvent + + switch { + // When the error is due to a dust output, we'll send a + // TxFailed so these inputs can be retried with a different + // group in the next block. + case errors.Is(err, ErrTxNoOutput): + event = TxFailed + + // When the error is due to budget being used up, we'll send a + // TxFailed so these inputs can be retried with a different + // group in the next block. + case errors.Is(err, ErrMaxPosition): + event = TxFailed + + // When the error is due to zero fee rate delta, we'll send a + // TxFailed so these inputs can be retried in the next block. + case errors.Is(err, ErrZeroFeeRateDelta): + event = TxFailed + + default: + event = TxError + } + + result = &BumpResult{ + Event: event, + Err: err, + requestID: requestID, + Tx: sweepTx, + } + } + + t.handleResult(result) +} + // handleFeeBumpTx checks if the tx needs to be bumped, and if so, it will // attempt to bump the fee of the tx. // diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go index 63a828654d4..85d883f64da 100644 --- a/sweep/fee_bumper_test.go +++ b/sweep/fee_bumper_test.go @@ -310,13 +310,10 @@ func TestStoreRecord(t *testing.T) { initialCounter := tp.requestCounter.Load() // Call the method under test. - requestID := tp.storeRecord(tx, req, feeFunc, fee) - - // Check the request ID is as expected. - require.Equal(t, initialCounter+1, requestID) + tp.storeRecord(initialCounter, tx, req, feeFunc, fee) // Read the saved record and compare. - record, ok := tp.records.Load(requestID) + record, ok := tp.records.Load(initialCounter) require.True(t, ok) require.Equal(t, tx, record.tx) require.Equal(t, feeFunc, record.feeFunction) @@ -611,23 +608,20 @@ func TestCreateRBFCompliantTx(t *testing.T) { }, } + requestCounter := atomic.Uint64{} + for _, tc := range testCases { tc := tc + rid := requestCounter.Add(1) t.Run(tc.name, func(t *testing.T) { tc.setupMock() // Call the method under test. - id, err := tp.createRBFCompliantTx(req, m.feeFunc) + err := tp.createRBFCompliantTx(rid, req, m.feeFunc) // Check the result is as expected. require.ErrorIs(t, err, tc.expectedErr) - - // If there's an error, expect the requestID to be - // empty. - if tc.expectedErr != nil { - require.Zero(t, id) - } }) } } @@ -652,7 +646,8 @@ func TestTxPublisherBroadcast(t *testing.T) { // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + requestID := uint64(1) + tp.storeRecord(requestID, tx, req, m.feeFunc, fee) // Quickly check when the requestID cannot be found, an error is // returned. @@ -739,6 +734,9 @@ func TestRemoveResult(t *testing.T) { // Create a testing record and put it in the map. fee := btcutil.Amount(1000) + // Create a test request ID counter. + requestCounter := atomic.Uint64{} + testCases := []struct { name string setupRecord func() uint64 @@ -750,10 +748,11 @@ func TestRemoveResult(t *testing.T) { // removed. name: "remove on TxConfirmed", setupRecord: func() uint64 { - id := tp.storeRecord(tx, req, m.feeFunc, fee) - tp.subscriberChans.Store(id, nil) + rid := requestCounter.Add(1) + tp.storeRecord(rid, tx, req, m.feeFunc, fee) + tp.subscriberChans.Store(rid, nil) - return id + return rid }, result: &BumpResult{ Event: TxConfirmed, @@ -765,10 +764,11 @@ func TestRemoveResult(t *testing.T) { // When the tx is failed, the records will be removed. name: "remove on TxFailed", setupRecord: func() uint64 { - id := tp.storeRecord(tx, req, m.feeFunc, fee) - tp.subscriberChans.Store(id, nil) + rid := requestCounter.Add(1) + tp.storeRecord(rid, tx, req, m.feeFunc, fee) + tp.subscriberChans.Store(rid, nil) - return id + return rid }, result: &BumpResult{ Event: TxFailed, @@ -781,10 +781,11 @@ func TestRemoveResult(t *testing.T) { // Noop when the tx is neither confirmed or failed. name: "noop when tx is not confirmed or failed", setupRecord: func() uint64 { - id := tp.storeRecord(tx, req, m.feeFunc, fee) - tp.subscriberChans.Store(id, nil) + rid := requestCounter.Add(1) + tp.storeRecord(rid, tx, req, m.feeFunc, fee) + tp.subscriberChans.Store(rid, nil) - return id + return rid }, result: &BumpResult{ Event: TxPublished, @@ -831,7 +832,8 @@ func TestNotifyResult(t *testing.T) { // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + requestID := uint64(1) + tp.storeRecord(requestID, tx, req, m.feeFunc, fee) // Create a subscription to the event. subscriber := make(chan *BumpResult, 1) @@ -879,41 +881,17 @@ func TestNotifyResult(t *testing.T) { } } -// TestBroadcastSuccess checks the public `Broadcast` method can successfully -// broadcast a tx based on the request. -func TestBroadcastSuccess(t *testing.T) { +// TestBroadcast checks the public `Broadcast` method can successfully register +// a broadcast request. +func TestBroadcast(t *testing.T) { t.Parallel() // Create a publisher using the mocks. - tp, m := createTestPublisher(t) + tp, _ := createTestPublisher(t) // Create a test feerate. feerate := chainfee.SatPerKWeight(1000) - // Mock the fee estimator to return the testing fee rate. - // - // We are not testing `NewLinearFeeFunction` here, so the actual params - // used are irrelevant. - m.estimator.On("EstimateFeePerKW", mock.Anything).Return( - feerate, nil).Once() - m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once() - - // Mock the signer to always return a valid script. - // - // NOTE: we are not testing the utility of creating valid txes here, so - // this is fine to be mocked. This behaves essentially as skipping the - // Signer check and alaways assume the tx has a valid sig. - script := &input.Script{} - m.signer.On("ComputeInputScript", mock.Anything, - mock.Anything).Return(script, nil) - - // Mock the testmempoolaccept to pass. - m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() - - // Mock the wallet to publish successfully. - m.wallet.On("PublishTransaction", - mock.Anything, mock.Anything).Return(nil).Once() - // Create a test request. inp := createTestInput(1000, input.WitnessKeyHash) @@ -927,103 +905,18 @@ func TestBroadcastSuccess(t *testing.T) { } // Send the req and expect no error. - resultChan, err := tp.Broadcast(req) - require.NoError(t, err) - - // Check the result is sent back. - select { - case <-time.After(time.Second): - t.Fatal("timeout waiting for subscriber to receive result") - - case result := <-resultChan: - // We expect the first result to be TxPublished. - require.Equal(t, TxPublished, result.Event) - } + resultChan := tp.Broadcast(req) + require.NotNil(t, resultChan) // Validate the record was stored. require.Equal(t, 1, tp.records.Len()) require.Equal(t, 1, tp.subscriberChans.Len()) -} - -// TestBroadcastFail checks the public `Broadcast` returns the error or a -// failed result when the broadcast fails. -func TestBroadcastFail(t *testing.T) { - t.Parallel() - - // Create a publisher using the mocks. - tp, m := createTestPublisher(t) - - // Create a test feerate. - feerate := chainfee.SatPerKWeight(1000) - - // Create a test request. - inp := createTestInput(1000, input.WitnessKeyHash) - - // Create a testing bump request. - req := &BumpRequest{ - DeliveryAddress: changePkScript, - Inputs: []input.Input{&inp}, - Budget: btcutil.Amount(1000), - MaxFeeRate: feerate * 10, - DeadlineHeight: 10, - } - - // Mock the fee estimator to return the testing fee rate. - // - // We are not testing `NewLinearFeeFunction` here, so the actual params - // used are irrelevant. - m.estimator.On("EstimateFeePerKW", mock.Anything).Return( - feerate, nil).Twice() - m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice() - - // Mock the signer to always return a valid script. - // - // NOTE: we are not testing the utility of creating valid txes here, so - // this is fine to be mocked. This behaves essentially as skipping the - // Signer check and alaways assume the tx has a valid sig. - script := &input.Script{} - m.signer.On("ComputeInputScript", mock.Anything, - mock.Anything).Return(script, nil) - - // Mock the testmempoolaccept to return an error. - m.wallet.On("CheckMempoolAcceptance", - mock.Anything).Return(errDummy).Once() - - // Send the req and expect an error returned. - resultChan, err := tp.Broadcast(req) - require.ErrorIs(t, err, errDummy) - require.Nil(t, resultChan) - - // Validate the record was NOT stored. - require.Equal(t, 0, tp.records.Len()) - require.Equal(t, 0, tp.subscriberChans.Len()) - - // Mock the testmempoolaccept again, this time it passes. - m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() - - // Mock the wallet to fail on publish. - m.wallet.On("PublishTransaction", - mock.Anything, mock.Anything).Return(errDummy).Once() - - // Send the req and expect no error returned. - resultChan, err = tp.Broadcast(req) - require.NoError(t, err) - - // Check the result is sent back. - select { - case <-time.After(time.Second): - t.Fatal("timeout waiting for subscriber to receive result") - case result := <-resultChan: - // We expect the result to be TxFailed and the error is set in - // the result. - require.Equal(t, TxFailed, result.Event) - require.ErrorIs(t, result.Err, errDummy) - } - - // Validate the record was removed. - require.Equal(t, 0, tp.records.Len()) - require.Equal(t, 0, tp.subscriberChans.Len()) + // Validate the record. + rid := tp.requestCounter.Load() + record, found := tp.records.Load(rid) + require.True(t, found) + require.Equal(t, req, record.req) } // TestCreateAnPublishFail checks all the error cases are handled properly in @@ -1188,7 +1081,8 @@ func TestHandleTxConfirmed(t *testing.T) { // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + requestID := uint64(1) + tp.storeRecord(requestID, tx, req, m.feeFunc, fee) record, ok := tp.records.Load(requestID) require.True(t, ok) @@ -1260,7 +1154,8 @@ func TestHandleFeeBumpTx(t *testing.T) { // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + requestID := uint64(1) + tp.storeRecord(requestID, tx, req, m.feeFunc, fee) // Create a subscription to the event. subscriber := make(chan *BumpResult, 1) @@ -1461,3 +1356,183 @@ func TestProcessRecords(t *testing.T) { require.Equal(t, requestID2, result.requestID) } } + +// TestHandleInitialBroadcastSuccess checks `handleInitialBroadcast` method can +// successfully broadcast a tx based on the request. +func TestHandleInitialBroadcastSuccess(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test feerate. + feerate := chainfee.SatPerKWeight(1000) + + // Mock the fee estimator to return the testing fee rate. + // + // We are not testing `NewLinearFeeFunction` here, so the actual params + // used are irrelevant. + m.estimator.On("EstimateFeePerKW", mock.Anything).Return( + feerate, nil).Once() + m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once() + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Mock the testmempoolaccept to pass. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() + + // Mock the wallet to publish successfully. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(nil).Once() + + // Create a test request. + inp := createTestInput(1000, input.WitnessKeyHash) + + // Create a testing bump request. + req := &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + MaxFeeRate: feerate * 10, + DeadlineHeight: 10, + } + + // Register the testing record use `Broadcast`. + resultChan := tp.Broadcast(req) + + // Grab the monitor record from the map. + rid := tp.requestCounter.Load() + rec, ok := tp.records.Load(rid) + require.True(t, ok) + + // Call the method under test. + tp.wg.Add(1) + tp.handleInitialBroadcast(rec, rid) + + // Check the result is sent back. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-resultChan: + // We expect the first result to be TxPublished. + require.Equal(t, TxPublished, result.Event) + } + + // Validate the record was stored. + require.Equal(t, 1, tp.records.Len()) + require.Equal(t, 1, tp.subscriberChans.Len()) +} + +// TestHandleInitialBroadcastFail checks `handleInitialBroadcast` returns the +// error or a failed result when the broadcast fails. +func TestHandleInitialBroadcastFail(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test feerate. + feerate := chainfee.SatPerKWeight(1000) + + // Create a test request. + inp := createTestInput(1000, input.WitnessKeyHash) + + // Create a testing bump request. + req := &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + MaxFeeRate: feerate * 10, + DeadlineHeight: 10, + } + + // Mock the fee estimator to return the testing fee rate. + // + // We are not testing `NewLinearFeeFunction` here, so the actual params + // used are irrelevant. + m.estimator.On("EstimateFeePerKW", mock.Anything).Return( + feerate, nil).Twice() + m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice() + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Mock the testmempoolaccept to return an error. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(errDummy).Once() + + // Register the testing record use `Broadcast`. + resultChan := tp.Broadcast(req) + + // Grab the monitor record from the map. + rid := tp.requestCounter.Load() + rec, ok := tp.records.Load(rid) + require.True(t, ok) + + // Call the method under test and expect an error returned. + tp.wg.Add(1) + tp.handleInitialBroadcast(rec, rid) + + // Check the result is sent back. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-resultChan: + // We expect the first result to be TxError. + require.Equal(t, TxError, result.Event) + } + + // Validate the record was NOT stored. + require.Equal(t, 0, tp.records.Len()) + require.Equal(t, 0, tp.subscriberChans.Len()) + + // Mock the testmempoolaccept again, this time it passes. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() + + // Mock the wallet to fail on publish. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(errDummy).Once() + + // Register the testing record use `Broadcast`. + resultChan = tp.Broadcast(req) + + // Grab the monitor record from the map. + rid = tp.requestCounter.Load() + rec, ok = tp.records.Load(rid) + require.True(t, ok) + + // Call the method under test. + tp.wg.Add(1) + tp.handleInitialBroadcast(rec, rid) + + // Check the result is sent back. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-resultChan: + // We expect the result to be TxFailed and the error is set in + // the result. + require.Equal(t, TxFailed, result.Event) + require.ErrorIs(t, result.Err, errDummy) + } + + // Validate the record was removed. + require.Equal(t, 0, tp.records.Len()) + require.Equal(t, 0, tp.subscriberChans.Len()) +} diff --git a/sweep/fee_function.go b/sweep/fee_function.go index cbf283e37d7..15d44ed6160 100644 --- a/sweep/fee_function.go +++ b/sweep/fee_function.go @@ -14,6 +14,9 @@ var ( // ErrMaxPosition is returned when trying to increase the position of // the fee function while it's already at its max. ErrMaxPosition = errors.New("position already at max") + + // ErrZeroFeeRateDelta is returned when the fee rate delta is zero. + ErrZeroFeeRateDelta = errors.New("fee rate delta is zero") ) // mSatPerKWeight represents a fee rate in msat/kw. @@ -169,7 +172,7 @@ func NewLinearFeeFunction(maxFeeRate chainfee.SatPerKWeight, "endingFeeRate=%v, width=%v, delta=%v", start, end, l.width, l.deltaFeeRate) - return nil, fmt.Errorf("fee rate delta is zero") + return nil, ErrZeroFeeRateDelta } // Attach the calculated values to the fee function. diff --git a/sweep/mock_test.go b/sweep/mock_test.go index 605b8d14ec4..f79ccea290e 100644 --- a/sweep/mock_test.go +++ b/sweep/mock_test.go @@ -276,14 +276,14 @@ type MockBumper struct { var _ Bumper = (*MockBumper)(nil) // Broadcast broadcasts the transaction to the network. -func (m *MockBumper) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { +func (m *MockBumper) Broadcast(req *BumpRequest) <-chan *BumpResult { args := m.Called(req) if args.Get(0) == nil { - return nil, args.Error(1) + return nil } - return args.Get(0).(chan *BumpResult), args.Error(1) + return args.Get(0).(chan *BumpResult) } // MockFeeFunction is a mock implementation of the FeeFunction interface. diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 39a03228d07..6fd0b7d1c67 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" @@ -221,6 +222,30 @@ func (p *SweeperInput) terminated() bool { } } +// isMature returns a boolean indicating whether the input has a timelock that +// has been reached or not. The locktime found is also returned. +func (p *SweeperInput) isMature(currentHeight uint32) (bool, uint32) { + locktime, _ := p.RequiredLockTime() + if currentHeight < locktime { + log.Debugf("Input %v has locktime=%v, current height is %v", + p.OutPoint(), locktime, currentHeight) + + return false, locktime + } + + // If the input has a CSV that's not yet reached, we will skip + // this input and wait for the expiry. + locktime = p.BlocksToMaturity() + p.HeightHint() + if currentHeight+1 < locktime { + log.Debugf("Input %v has CSV expiry=%v, current height is %v", + p.OutPoint(), locktime, currentHeight) + + return false, locktime + } + + return true, locktime +} + // InputsMap is a type alias for a set of pending inputs. type InputsMap = map[wire.OutPoint]*SweeperInput @@ -311,6 +336,11 @@ type UtxoSweeper struct { // bumpResultChan is a channel that receives broadcast results from the // TxPublisher. bumpResultChan chan *BumpResult + + // blockBeatChan is a channel to receive blocks from BlockBeat. The + // received block contains the best known height and the transactions + // confirmed in this block. + blockBeatChan chan chainio.Beat } // UtxoSweeperConfig contains dependencies of UtxoSweeper. @@ -395,6 +425,7 @@ func New(cfg *UtxoSweeperConfig) *UtxoSweeper { quit: make(chan struct{}), inputs: make(InputsMap), bumpResultChan: make(chan *BumpResult, 100), + blockBeatChan: make(chan chainio.Beat), } } @@ -410,21 +441,12 @@ func (s *UtxoSweeper) Start() error { // not change from here on. s.relayFeeRate = s.cfg.FeeEstimator.RelayFeePerKW() - // We need to register for block epochs and retry sweeping every block. - // We should get a notification with the current best block immediately - // if we don't provide any epoch. We'll wait for that in the collector. - blockEpochs, err := s.cfg.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return fmt.Errorf("register block epoch ntfn: %w", err) - } - // Start sweeper main loop. s.wg.Add(1) go func() { - defer blockEpochs.Cancel() defer s.wg.Done() - s.collector(blockEpochs.Epochs) + s.collector() // The collector exited and won't longer handle incoming // requests. This can happen on shutdown, when the block @@ -479,6 +501,25 @@ func (s *UtxoSweeper) Stop() error { return nil } +// NOTE: part of the `chainio.Consumer` interface. +func (s *UtxoSweeper) ProcessBlock(beat chainio.Beat) <-chan error { + select { + case s.blockBeatChan <- beat: + log.Debugf("Received block beat for height=%d", + beat.Epoch.Height) + + case <-s.quit: + return nil + } + + return beat.Err +} + +// NOTE: part of the `chainio.Consumer` interface. +func (s *UtxoSweeper) Name() string { + return "sweeper" +} + // SweepInput sweeps inputs back into the wallet. The inputs will be batched and // swept after the batch time window ends. A custom fee preference can be // provided to determine what fee rate should be used for the input. Note that @@ -501,7 +542,7 @@ func (s *UtxoSweeper) SweepInput(inp input.Input, } absoluteTimeLock, _ := inp.RequiredLockTime() - log.Infof("Sweep request received: out_point=%v, witness_type=%v, "+ + log.Debugf("Sweep request received: out_point=%v, witness_type=%v, "+ "relative_time_lock=%v, absolute_time_lock=%v, amount=%v, "+ "parent=(%v), params=(%v)", inp.OutPoint(), inp.WitnessType(), inp.BlocksToMaturity(), absoluteTimeLock, @@ -610,18 +651,7 @@ func (s *UtxoSweeper) removeConflictSweepDescendants( // collector is the sweeper main loop. It processes new inputs, spend // notifications and counts down to publication of the sweep tx. -func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { - // We registered for the block epochs with a nil request. The notifier - // should send us the current best block immediately. So we need to wait - // for it here because we need to know the current best height. - select { - case bestBlock := <-blockEpochs: - s.currentHeight = bestBlock.Height - - case <-s.quit: - return - } - +func (s *UtxoSweeper) collector() { for { // Clean inputs, which will remove inputs that are swept, // failed, or excluded from the sweeper and return inputs that @@ -684,7 +714,7 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { // A new block comes in, update the bestHeight, perform a check // over all pending inputs and publish sweeping txns if needed. - case epoch, ok := <-blockEpochs: + case beat, ok := <-s.blockBeatChan: if !ok { // We should stop the sweeper before stopping // the chain service. Otherwise it indicates an @@ -694,6 +724,8 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { return } + epoch := beat.Epoch + // Update the sweeper to the best height. s.currentHeight = epoch.Height @@ -701,11 +733,24 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { inputs := s.updateSweeperInputs() log.Debugf("Received new block: height=%v, attempt "+ - "sweeping %d inputs", epoch.Height, len(inputs)) + "sweeping %d inputs:\n%s", epoch.Height, + len(inputs), newLogClosure(func() string { + inps := make( + []input.Input, 0, len(inputs), + ) + for _, in := range inputs { + inps = append(inps, in) + } + + return inputTypeSummary(inps) + })()) // Attempt to sweep any pending inputs. s.sweepPendingInputs(inputs) + // Notify we've processed the block. + fn.SendOrQuit(beat.Err, nil, s.quit) + case <-s.quit: return } @@ -823,21 +868,7 @@ func (s *UtxoSweeper) sweep(set InputSet) error { // Broadcast will return a read-only chan that we will listen to for // this publish result and future RBF attempt. - resp, err := s.cfg.Publisher.Broadcast(req) - if err != nil { - outpoints := make([]wire.OutPoint, len(set.Inputs())) - for i, inp := range set.Inputs() { - outpoints[i] = inp.OutPoint() - } - - log.Errorf("Initial broadcast failed: %v, inputs=\n%v", err, - inputTypeSummary(set.Inputs())) - - // TODO(yy): find out which input is causing the failure. - s.markInputsPublishFailed(outpoints) - - return err - } + resp := s.cfg.Publisher.Broadcast(req) // Successfully sent the broadcast attempt, we now handle the result by // subscribing to the result chan and listen for future updates about @@ -1040,6 +1071,12 @@ func (s *UtxoSweeper) handlePendingSweepsReq( resps := make(map[wire.OutPoint]*PendingInputResponse, len(s.inputs)) for _, inp := range s.inputs { + // Skip immature inputs for compatibility. + mature, _ := inp.isMature(uint32(s.currentHeight)) + if !mature { + continue + } + // Only the exported fields are set, as we expect the response // to only be consumed externally. op := inp.OutPoint() @@ -1175,13 +1212,29 @@ func (s *UtxoSweeper) mempoolLookup(op wire.OutPoint) fn.Option[wire.MsgTx] { return s.cfg.Mempool.LookupInputMempoolSpend(op) } -// handleNewInput processes a new input by registering spend notification and -// scheduling sweeping for it. -func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) error { +// calculateDefaultDeadline calculates the default deadline height for a sweep +// request that has no deadline height specified. +func (s *UtxoSweeper) calculateDefaultDeadline(pi *SweeperInput) int32 { // Create a default deadline height, which will be used when there's no // DeadlineHeight specified for a given input. defaultDeadline := s.currentHeight + int32(s.cfg.NoDeadlineConfTarget) + // If the input is immature and has a locktime, we'll use the locktime + // height as the starting height. + matured, locktime := pi.isMature(uint32(s.currentHeight)) + if !matured { + defaultDeadline = int32(locktime + s.cfg.NoDeadlineConfTarget) + log.Debugf("Input %v is immature, using locktime=%v instead "+ + "of current height=%d", pi.OutPoint(), locktime, + s.currentHeight) + } + + return defaultDeadline +} + +// handleNewInput processes a new input by registering spend notification and +// scheduling sweeping for it. +func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) error { outpoint := input.input.OutPoint() pi, pending := s.inputs[outpoint] if pending { @@ -1206,15 +1259,22 @@ func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) error { Input: input.input, params: input.params, rbf: rbfInfo, - // Set the acutal deadline height. - DeadlineHeight: input.params.DeadlineHeight.UnwrapOr( - defaultDeadline, - ), } + // Set the acutal deadline height. + pi.DeadlineHeight = input.params.DeadlineHeight.UnwrapOr( + s.calculateDefaultDeadline(pi), + ) + s.inputs[outpoint] = pi log.Tracef("input %v, state=%v, added to inputs", outpoint, pi.state) + log.Infof("Registered sweep request at block %d: out_point=%v, "+ + "witness_type=%v, amount=%v, deadline=%d, params=(%v)", + s.currentHeight, pi.OutPoint(), pi.WitnessType(), + btcutil.Amount(pi.SignDesc().Output.Value), pi.DeadlineHeight, + pi.params) + // Start watching for spend of this input, either by us or the remote // party. cancel, err := s.monitorSpend( @@ -1446,11 +1506,6 @@ func (s *UtxoSweeper) markInputFailed(pi *SweeperInput, err error) { pi.state = Failed - // Remove all other inputs in this exclusive group. - if pi.params.ExclusiveGroup != nil { - s.removeExclusiveGroup(*pi.params.ExclusiveGroup) - } - s.signalResult(pi, Result{Err: err}) } @@ -1495,20 +1550,9 @@ func (s *UtxoSweeper) updateSweeperInputs() InputsMap { // If the input has a locktime that's not yet reached, we will // skip this input and wait for the locktime to be reached. - locktime, _ := input.RequiredLockTime() - if uint32(s.currentHeight) < locktime { - log.Warnf("Skipping input %v due to locktime=%v not "+ - "reached, current height is %v", op, locktime, - s.currentHeight) - - continue - } - - // If the input has a CSV that's not yet reached, we will skip - // this input and wait for the expiry. - locktime = input.BlocksToMaturity() + input.HeightHint() - if s.currentHeight < int32(locktime)-1 { - log.Infof("Skipping input %v due to CSV expiry=%v not "+ + mature, locktime := input.isMature(uint32(s.currentHeight)) + if !mature { + log.Infof("Skipping input %v due to locktime=%v not "+ "reached, current height is %v", op, locktime, s.currentHeight) @@ -1626,7 +1670,7 @@ func (s *UtxoSweeper) monitorFeeBumpResult(resultChan <-chan *BumpResult) { func (s *UtxoSweeper) handleBumpEventTxFailed(r *BumpResult) error { tx, err := r.Tx, r.Err - log.Errorf("Fee bump attempt failed for tx=%v: %v", tx.TxHash(), err) + log.Warnf("Fee bump attempt failed for tx=%v: %v", tx.TxHash(), err) outpoints := make([]wire.OutPoint, 0, len(tx.TxIn)) for _, inp := range tx.TxIn { @@ -1636,7 +1680,7 @@ func (s *UtxoSweeper) handleBumpEventTxFailed(r *BumpResult) error { // TODO(yy): should we also remove the failed tx from db? s.markInputsPublishFailed(outpoints) - return err + return nil } // handleBumpEventTxReplaced handles the case where the sweeping tx has been @@ -1712,13 +1756,67 @@ func (s *UtxoSweeper) handleBumpEventTxPublished(r *BumpResult) error { return nil } +// handleBumpEventError handles the case where there's an unexpected error when +// creating or publishing the sweeping tx. In this case, the tx will be removed +// from the sweeper store and the inputs will be marked as `Failed`. +func (s *UtxoSweeper) handleBumpEventError(r *BumpResult) error { + txid := r.Tx.TxHash() + log.Infof("Tx=%v failed with unexpected error: %v", txid, r.Err) + + // Remove the tx from the sweeper db if it exists. + if err := s.cfg.Store.DeleteTx(txid); err != nil { + return fmt.Errorf("delete tx record for %v: %w", txid, err) + } + + // Mark the inputs as failed. + s.markInputsFailed(r.Tx, r.Err) + + return nil +} + +// markInputsFailed marks all inputs found in the tx as failed. It will also +// notify all the subscribers of these inputs. +func (s *UtxoSweeper) markInputsFailed(tx *wire.MsgTx, err error) { + for _, txIn := range tx.TxIn { + outpoint := txIn.PreviousOutPoint + + input, ok := s.inputs[outpoint] + if !ok { + // It's very likely that a spending tx contains inputs + // that we don't know. + log.Tracef("Skipped marking input as failed: %v not "+ + "found in pending inputs", outpoint) + + continue + } + + // If the input is already in a terminal state, we don't want + // to rewrite it, which also indicates an error as we only get + // an error event during the initial broadcast. + if input.terminated() { + log.Errorf("Skipped marking input=%v as failed due to "+ + "unexpected state=%v", outpoint, input.state) + + continue + } + + input.state = Failed + + // Signal result channels. + s.signalResult(input, Result{ + Tx: tx, + Err: err, + }) + } +} + // handleBumpEvent handles the result sent from the bumper based on its event // type. // // NOTE: TxConfirmed event is not handled, since we already subscribe to the // input's spending event, we don't need to do anything here. func (s *UtxoSweeper) handleBumpEvent(r *BumpResult) error { - log.Debugf("Received bump event [%v] for tx %v", r.Event, r.Tx.TxHash()) + log.Debugf("Received bump result %v", r) switch r.Event { // The tx has been published, we update the inputs' state and create a @@ -1734,6 +1832,12 @@ func (s *UtxoSweeper) handleBumpEvent(r *BumpResult) error { // with the new one. case TxReplaced: return s.handleBumpEventTxReplaced(r) + + // There's an unexpected error in creating or publishing the tx, we + // will remove the tx from the sweeper db and mark the inputs as + // failed. + case TxError: + return s.handleBumpEventError(r) } return nil diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index c8d9fc510bf..39e3ff895c2 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -490,6 +490,7 @@ func TestUpdateSweeperInputs(t *testing.T) { // returned. inp2.On("RequiredLockTime").Return( uint32(s.currentHeight+1), true).Once() + inp2.On("OutPoint").Return(wire.OutPoint{Index: 2}).Maybe() input7 := &SweeperInput{state: Init, Input: inp2} // Mock the input to have a CSV expiry in the future so it will NOT be @@ -498,6 +499,7 @@ func TestUpdateSweeperInputs(t *testing.T) { uint32(s.currentHeight), false).Once() inp3.On("BlocksToMaturity").Return(uint32(2)).Once() inp3.On("HeightHint").Return(uint32(s.currentHeight)).Once() + inp3.On("OutPoint").Return(wire.OutPoint{Index: 3}).Maybe() input8 := &SweeperInput{state: Init, Input: inp3} // Add the inputs to the sweeper. After the update, we should see the @@ -715,13 +717,8 @@ func TestSweepPendingInputs(t *testing.T) { setNeedWallet, normalSet, }) - // Mock `Broadcast` to return an error. This should cause the - // `createSweepTx` inside `sweep` to fail. This is done so we can - // terminate the method early as we are only interested in testing the - // workflow in `sweepPendingInputs`. We don't need to test `sweep` here - // as it should be tested in its own unit test. - dummyErr := errors.New("dummy error") - publisher.On("Broadcast", mock.Anything).Return(nil, dummyErr).Twice() + // Mock `Broadcast` to return a result. + publisher.On("Broadcast", mock.Anything).Return(nil).Twice() // Call the method under test. s.sweepPendingInputs(pis) @@ -778,7 +775,7 @@ func TestHandleBumpEventTxFailed(t *testing.T) { // Call the method under test. err := s.handleBumpEvent(br) - require.ErrorIs(t, err, errDummy) + require.NoError(t, err) // Assert the states of the first two inputs are updated. require.Equal(t, PublishFailed, s.inputs[op1].state) @@ -1107,3 +1104,179 @@ func TestMonitorFeeBumpResult(t *testing.T) { }) } } + +// TestMarkInputsFailed checks that given a list of inputs with different +// states, the method `markInputsFailed` correctly marks the inputs as failed. +func TestMarkInputsFailed(t *testing.T) { + t.Parallel() + + require := require.New(t) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{}) + + // Create a mock input. + mockInput := &input.MockInput{} + defer mockInput.AssertExpectations(t) + + // Mock the `OutPoint` to return a dummy outpoint. + mockInput.On("OutPoint").Return(wire.OutPoint{Hash: chainhash.Hash{1}}) + + // Create testing inputs for each state. + // + // inputNotExist specifies an input that's not found in the sweeper's + // `inputs` map. + inputNotExist := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 1}, + } + + // inputInit specifies a newly created input. When marking this as + // published, we should see an error log as this input hasn't been + // published yet. + inputInit := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 2}, + } + s.inputs[inputInit.PreviousOutPoint] = &SweeperInput{ + state: Init, + Input: mockInput, + } + + // inputPendingPublish specifies an input that's about to be published. + inputPendingPublish := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 3}, + } + s.inputs[inputPendingPublish.PreviousOutPoint] = &SweeperInput{ + state: PendingPublish, + Input: mockInput, + } + + // inputPublished specifies an input that's published. + inputPublished := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 4}, + } + s.inputs[inputPublished.PreviousOutPoint] = &SweeperInput{ + state: Published, + Input: mockInput, + } + + // inputPublishFailed specifies an input that's failed to be published. + inputPublishFailed := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 5}, + } + s.inputs[inputPublishFailed.PreviousOutPoint] = &SweeperInput{ + state: PublishFailed, + Input: mockInput, + } + + // inputSwept specifies an input that's swept. + inputSwept := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 6}, + } + s.inputs[inputSwept.PreviousOutPoint] = &SweeperInput{ + state: Swept, + Input: mockInput, + } + + // inputExcluded specifies an input that's excluded. + inputExcluded := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 7}, + } + s.inputs[inputExcluded.PreviousOutPoint] = &SweeperInput{ + state: Excluded, + Input: mockInput, + } + + // inputFailed specifies an input that's failed. + inputFailed := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 8}, + } + s.inputs[inputFailed.PreviousOutPoint] = &SweeperInput{ + state: Failed, + Input: mockInput, + } + + // Create a test tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + inputNotExist, + inputInit, + inputPendingPublish, + inputPublished, + inputPublishFailed, + inputSwept, + inputExcluded, + inputFailed, + }, + } + + // Mark the test inputs. We expect the non-exist input and + // inputSwept/inputExcluded/inputFailed to be skipped. + s.markInputsFailed(tx, errDummy) + + // We expect unchanged number of pending inputs. + require.Len(s.inputs, 7) + + // We expect the init input's to be marked as failed. + require.Equal(Failed, s.inputs[inputInit.PreviousOutPoint].state) + + // We expect the pending-publish input to be marked as failed. + require.Equal(Failed, + s.inputs[inputPendingPublish.PreviousOutPoint].state) + + // We expect the published input to be marked as failed. + require.Equal(Failed, s.inputs[inputPublished.PreviousOutPoint].state) + + // We expect the publish failed input to be markd as failed. + require.Equal(Failed, + s.inputs[inputPublishFailed.PreviousOutPoint].state) + + // We expect the swept input to stay unchanged. + require.Equal(Swept, s.inputs[inputSwept.PreviousOutPoint].state) + + // We expect the excluded input to stay unchanged. + require.Equal(Excluded, s.inputs[inputExcluded.PreviousOutPoint].state) + + // We expect the failed input to stay unchanged. + require.Equal(Failed, s.inputs[inputFailed.PreviousOutPoint].state) +} + +// TestHandleBumpEventError checks that `handleBumpEventError` correctly +// handles a `TxError` event. +func TestHandleBumpEventError(t *testing.T) { + t.Parallel() + + rt := require.New(t) + + // Create a mock store. + store := &MockSweeperStore{} + defer store.AssertExpectations(t) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Store: store, + }) + + // Create a testing tx. + // + // We are not testing `markInputFailed` here, so the actual tx doesn't + // matter. + tx := &wire.MsgTx{} + result := &BumpResult{ + Tx: tx, + Err: errDummy, + } + + // Mock the store to return an error. + store.On("DeleteTx", mock.Anything).Return(errDummy).Once() + + // Call the method under test and assert the error is returned. + err := s.handleBumpEventError(result) + rt.ErrorIs(err, errDummy) + + // Mock the store to return nil. + store.On("DeleteTx", mock.Anything).Return(nil).Once() + + // Call the method under test and assert no error is returned. + err = s.handleBumpEventError(result) + rt.NoError(err) +}