diff --git a/chainio/README.md b/chainio/README.md new file mode 100644 index 00000000000..b11e38157c2 --- /dev/null +++ b/chainio/README.md @@ -0,0 +1,152 @@ +# Chainio + +`chainio` is a package designed to provide blockchain data access to various +subsystems within `lnd`. When a new block is received, it is encapsulated in a +`Blockbeat` object and disseminated to all registered consumers. Consumers may +receive these updates either concurrently or sequentially, based on their +registration configuration, ensuring that each subsystem maintains a +synchronized view of the current block state. + +The main components include: + +- `Blockbeat`: An interface that provides information about the block. + +- `Consumer`: An interface that specifies how subsystems handle the blockbeat. + +- `BlockbeatDispatcher`: The core service responsible for receiving each block + and distributing it to all consumers. + +Additionally, the `BeatConsumer` struct provides a partial implementation of +the `Consumer` interface. This struct helps reduce code duplication, allowing +subsystems to avoid re-implementing the `ProcessBlock` method and provides a +commonly used `NotifyBlockProcessed` method. + + +### Register a Consumer + +Consumers within the same queue are notified **sequentially**, while all queues +are notified **concurrently**. A queue consists of a slice of consumers, which +are notified in left-to-right order. Developers are responsible for determining +dependencies in block consumption across subsystems: independent subsystems +should be notified concurrently, whereas dependent subsystems should be +notified sequentially. + +To notify the consumers concurrently, put them in different queues, +```go +// consumer1 and consumer2 will be notified concurrently. +queue1 := []chainio.Consumer{consumer1} +blockbeatDispatcher.RegisterQueue(consumer1) + +queue2 := []chainio.Consumer{consumer2} +blockbeatDispatcher.RegisterQueue(consumer2) +``` + +To notify the consumers sequentially, put them in the same queue, +```go +// consumers will be notified sequentially via, +// consumer1 -> consumer2 -> consumer3 +queue := []chainio.Consumer{ + consumer1, + consumer2, + consumer3, +} +blockbeatDispatcher.RegisterQueue(queue) +``` + +### Implement the `Consumer` Interface + +Implementing the `Consumer` interface is straightforward. Below is an example +of how +[`sweep.TxPublisher`](https://github.com/lightningnetwork/lnd/blob/5cec466fad44c582a64cfaeb91f6d5fd302fcf85/sweep/fee_bumper.go#L310) +implements this interface. + +To start, embed the partial implementation `chainio.BeatConsumer`, which +already provides the `ProcessBlock` implementation and commonly used +`NotifyBlockProcessed` method, and exposes `BlockbeatChan` for the consumer to +receive blockbeats. + +```go +type TxPublisher struct { + started atomic.Bool + stopped atomic.Bool + + chainio.BeatConsumer + + ... +``` + +We should also remember to initialize this `BeatConsumer`, + +```go +... +// Mount the block consumer. +tp.BeatConsumer = chainio.NewBeatConsumer(tp.quit, tp.Name()) +``` + +Finally, in the main event loop, read from `BlockbeatChan`, process the +received blockbeat, and, crucially, call `tp.NotifyBlockProcessed` to inform +the blockbeat dispatcher that processing is complete. + +```go +for { + select { + case beat := <-tp.BlockbeatChan: + // Consume this blockbeat, usually it means updating the subsystem + // using the new block data. + + // Notify we've processed the block. + tp.NotifyBlockProcessed(beat, nil) + + ... +``` + +### Existing Queues + +Currently, we have a single queue of consumers dedicated to handling force +closures. This queue includes `ChainArbitrator`, `UtxoSweeper`, and +`TxPublisher`, with `ChainArbitrator` managing two internal consumers: +`chainWatcher` and `ChannelArbitrator`. The blockbeat flows sequentially +through the chain as follows: `ChainArbitrator => chainWatcher => +ChannelArbitrator => UtxoSweeper => TxPublisher`. The following diagram +illustrates the flow within the public subsystems. + +```mermaid +sequenceDiagram + autonumber + participant bb as BlockBeat + participant cc as ChainArb + participant us as UtxoSweeper + participant tp as TxPublisher + + note left of bb: 0. received block x,
dispatching... + + note over bb,cc: 1. send block x to ChainArb,
wait for its done signal + bb->>cc: block x + rect rgba(165, 0, 85, 0.8) + critical signal processed + cc->>bb: processed block + option Process error or timeout + bb->>bb: error and exit + end + end + + note over bb,us: 2. send block x to UtxoSweeper, wait for its done signal + bb->>us: block x + rect rgba(165, 0, 85, 0.8) + critical signal processed + us->>bb: processed block + option Process error or timeout + bb->>bb: error and exit + end + end + + note over bb,tp: 3. send block x to TxPublisher, wait for its done signal + bb->>tp: block x + rect rgba(165, 0, 85, 0.8) + critical signal processed + tp->>bb: processed block + option Process error or timeout + bb->>bb: error and exit + end + end +``` diff --git a/chainio/blockbeat.go b/chainio/blockbeat.go new file mode 100644 index 00000000000..5df1cad7772 --- /dev/null +++ b/chainio/blockbeat.go @@ -0,0 +1,55 @@ +package chainio + +import ( + "fmt" + + "github.com/btcsuite/btclog/v2" + "github.com/lightningnetwork/lnd/build" + "github.com/lightningnetwork/lnd/chainntnfs" +) + +// Beat implements the Blockbeat interface. It contains the block epoch and a +// customized logger. +// +// TODO(yy): extend this to check for confirmation status - which serves as the +// single source of truth, to avoid the potential race between receiving blocks +// and `GetTransactionDetails/RegisterSpendNtfn/RegisterConfirmationsNtfn`. +type Beat struct { + // epoch is the current block epoch the blockbeat is aware of. + epoch chainntnfs.BlockEpoch + + // log is the customized logger for the blockbeat which prints the + // block height. + log btclog.Logger +} + +// Compile-time check to ensure Beat satisfies the Blockbeat interface. +var _ Blockbeat = (*Beat)(nil) + +// NewBeat creates a new beat with the specified block epoch and a customized +// logger. +func NewBeat(epoch chainntnfs.BlockEpoch) *Beat { + b := &Beat{ + epoch: epoch, + } + + // Create a customized logger for the blockbeat. + logPrefix := fmt.Sprintf("Height[%6d]:", b.Height()) + b.log = build.NewPrefixLog(logPrefix, clog) + + return b +} + +// Height returns the height of the block epoch. +// +// NOTE: Part of the Blockbeat interface. +func (b *Beat) Height() int32 { + return b.epoch.Height +} + +// logger returns the logger for the blockbeat. +// +// NOTE: Part of the private blockbeat interface. +func (b *Beat) logger() btclog.Logger { + return b.log +} diff --git a/chainio/blockbeat_test.go b/chainio/blockbeat_test.go new file mode 100644 index 00000000000..9326651b38f --- /dev/null +++ b/chainio/blockbeat_test.go @@ -0,0 +1,28 @@ +package chainio + +import ( + "errors" + "testing" + + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/stretchr/testify/require" +) + +var errDummy = errors.New("dummy error") + +// TestNewBeat tests the NewBeat and Height functions. +func TestNewBeat(t *testing.T) { + t.Parallel() + + // Create a testing epoch. + epoch := chainntnfs.BlockEpoch{ + Height: 1, + } + + // Create the beat and check the internal state. + beat := NewBeat(epoch) + require.Equal(t, epoch, beat.epoch) + + // Check the height function. + require.Equal(t, epoch.Height, beat.Height()) +} diff --git a/chainio/consumer.go b/chainio/consumer.go new file mode 100644 index 00000000000..a9ec25745ba --- /dev/null +++ b/chainio/consumer.go @@ -0,0 +1,113 @@ +package chainio + +// BeatConsumer defines a supplementary component that should be used by +// subsystems which implement the `Consumer` interface. It partially implements +// the `Consumer` interface by providing the method `ProcessBlock` such that +// subsystems don't need to re-implement it. +// +// While inheritance is not commonly used in Go, subsystems embedding this +// struct cannot pass the interface check for `Consumer` because the `Name` +// method is not implemented, which gives us a "mortise and tenon" structure. +// In addition to reducing code duplication, this design allows `ProcessBlock` +// to work on the concrete type `Beat` to access its internal states. +type BeatConsumer struct { + // BlockbeatChan is a channel to receive blocks from Blockbeat. The + // received block contains the best known height and the txns confirmed + // in this block. + BlockbeatChan chan Blockbeat + + // name is the name of the consumer which embeds the BlockConsumer. + name string + + // quit is a channel that closes when the BlockConsumer is shutting + // down. + // + // NOTE: this quit channel should be mounted to the same quit channel + // used by the subsystem. + quit chan struct{} + + // errChan is a buffered chan that receives an error returned from + // processing this block. + errChan chan error +} + +// NewBeatConsumer creates a new BlockConsumer. +func NewBeatConsumer(quit chan struct{}, name string) BeatConsumer { + // Refuse to start `lnd` if the quit channel is not initialized. We + // treat this case as if we are facing a nil pointer dereference, as + // there's no point to return an error here, which will cause the node + // to fail to be started anyway. + if quit == nil { + panic("quit channel is nil") + } + + b := BeatConsumer{ + BlockbeatChan: make(chan Blockbeat), + name: name, + errChan: make(chan error, 1), + quit: quit, + } + + return b +} + +// ProcessBlock takes a blockbeat and sends it to the consumer's blockbeat +// channel. It will send it to the subsystem's BlockbeatChan, and block until +// the processed result is received from the subsystem. The subsystem must call +// `NotifyBlockProcessed` after it has finished processing the block. +// +// NOTE: part of the `chainio.Consumer` interface. +func (b *BeatConsumer) ProcessBlock(beat Blockbeat) error { + // Update the current height. + beat.logger().Tracef("set current height for [%s]", b.name) + + select { + // Send the beat to the blockbeat channel. It's expected that the + // consumer will read from this channel and process the block. Once + // processed, it should return the error or nil to the beat.Err chan. + case b.BlockbeatChan <- beat: + beat.logger().Tracef("Sent blockbeat to [%s]", b.name) + + case <-b.quit: + beat.logger().Debugf("[%s] received shutdown before sending "+ + "beat", b.name) + + return nil + } + + // Check the consumer's err chan. We expect the consumer to call + // `beat.NotifyBlockProcessed` to send the error back here. + select { + case err := <-b.errChan: + beat.logger().Debugf("[%s] processed beat: err=%v", b.name, err) + + return err + + case <-b.quit: + beat.logger().Debugf("[%s] received shutdown", b.name) + } + + return nil +} + +// NotifyBlockProcessed signals that the block has been processed. It takes the +// blockbeat being processed and an error resulted from processing it. This +// error is then sent back to the consumer's err chan to unblock +// `ProcessBlock`. +// +// NOTE: This method must be called by the subsystem after it has finished +// processing the block. +func (b *BeatConsumer) NotifyBlockProcessed(beat Blockbeat, err error) { + // Update the current height. + beat.logger().Debugf("[%s]: notifying beat processed", b.name) + + select { + case b.errChan <- err: + beat.logger().Debugf("[%s]: notified beat processed, err=%v", + b.name, err) + + case <-b.quit: + beat.logger().Debugf("[%s] received shutdown before notifying "+ + "beat processed", b.name) + } +} diff --git a/chainio/consumer_test.go b/chainio/consumer_test.go new file mode 100644 index 00000000000..3ef79b61b4b --- /dev/null +++ b/chainio/consumer_test.go @@ -0,0 +1,202 @@ +package chainio + +import ( + "testing" + "time" + + "github.com/lightningnetwork/lnd/fn" + "github.com/stretchr/testify/require" +) + +// TestNewBeatConsumer tests the NewBeatConsumer function. +func TestNewBeatConsumer(t *testing.T) { + t.Parallel() + + quitChan := make(chan struct{}) + name := "test" + + // Test the NewBeatConsumer function. + b := NewBeatConsumer(quitChan, name) + + // Assert the state. + require.Equal(t, quitChan, b.quit) + require.Equal(t, name, b.name) + require.NotNil(t, b.BlockbeatChan) +} + +// TestProcessBlockSuccess tests when the block is processed successfully, no +// error is returned. +func TestProcessBlockSuccess(t *testing.T) { + t.Parallel() + + // Create a test consumer. + quitChan := make(chan struct{}) + b := NewBeatConsumer(quitChan, "test") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock the consumer's err chan. + consumerErrChan := make(chan error, 1) + b.errChan = consumerErrChan + + // Call the method under test. + resultChan := make(chan error, 1) + go func() { + resultChan <- b.ProcessBlock(mockBeat) + }() + + // Assert the beat is sent to the blockbeat channel. + beat, err := fn.RecvOrTimeout(b.BlockbeatChan, time.Second) + require.NoError(t, err) + require.Equal(t, mockBeat, beat) + + // Send nil to the consumer's error channel. + consumerErrChan <- nil + + // Assert the result of ProcessBlock is nil. + result, err := fn.RecvOrTimeout(resultChan, time.Second) + require.NoError(t, err) + require.Nil(t, result) +} + +// TestProcessBlockConsumerQuitBeforeSend tests when the consumer is quit +// before sending the beat, the method returns immediately. +func TestProcessBlockConsumerQuitBeforeSend(t *testing.T) { + t.Parallel() + + // Create a test consumer. + quitChan := make(chan struct{}) + b := NewBeatConsumer(quitChan, "test") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Call the method under test. + resultChan := make(chan error, 1) + go func() { + resultChan <- b.ProcessBlock(mockBeat) + }() + + // Instead of reading the BlockbeatChan, close the quit channel. + close(quitChan) + + // Assert ProcessBlock returned nil. + result, err := fn.RecvOrTimeout(resultChan, time.Second) + require.NoError(t, err) + require.Nil(t, result) +} + +// TestProcessBlockConsumerQuitAfterSend tests when the consumer is quit after +// sending the beat, the method returns immediately. +func TestProcessBlockConsumerQuitAfterSend(t *testing.T) { + t.Parallel() + + // Create a test consumer. + quitChan := make(chan struct{}) + b := NewBeatConsumer(quitChan, "test") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock the consumer's err chan. + consumerErrChan := make(chan error, 1) + b.errChan = consumerErrChan + + // Call the method under test. + resultChan := make(chan error, 1) + go func() { + resultChan <- b.ProcessBlock(mockBeat) + }() + + // Assert the beat is sent to the blockbeat channel. + beat, err := fn.RecvOrTimeout(b.BlockbeatChan, time.Second) + require.NoError(t, err) + require.Equal(t, mockBeat, beat) + + // Instead of sending nil to the consumer's error channel, close the + // quit chanel. + close(quitChan) + + // Assert ProcessBlock returned nil. + result, err := fn.RecvOrTimeout(resultChan, time.Second) + require.NoError(t, err) + require.Nil(t, result) +} + +// TestNotifyBlockProcessedSendErr asserts the error can be sent and read by +// the beat via NotifyBlockProcessed. +func TestNotifyBlockProcessedSendErr(t *testing.T) { + t.Parallel() + + // Create a test consumer. + quitChan := make(chan struct{}) + b := NewBeatConsumer(quitChan, "test") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock the consumer's err chan. + consumerErrChan := make(chan error, 1) + b.errChan = consumerErrChan + + // Call the method under test. + done := make(chan error) + go func() { + defer close(done) + b.NotifyBlockProcessed(mockBeat, errDummy) + }() + + // Assert the error is sent to the beat's err chan. + result, err := fn.RecvOrTimeout(consumerErrChan, time.Second) + require.NoError(t, err) + require.ErrorIs(t, result, errDummy) + + // Assert the done channel is closed. + result, err = fn.RecvOrTimeout(done, time.Second) + require.NoError(t, err) + require.Nil(t, result) +} + +// TestNotifyBlockProcessedOnQuit asserts NotifyBlockProcessed exits +// immediately when the quit channel is closed. +func TestNotifyBlockProcessedOnQuit(t *testing.T) { + t.Parallel() + + // Create a test consumer. + quitChan := make(chan struct{}) + b := NewBeatConsumer(quitChan, "test") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock the consumer's err chan - we don't buffer it so it will block + // on sending the error. + consumerErrChan := make(chan error) + b.errChan = consumerErrChan + + // Call the method under test. + done := make(chan error) + go func() { + defer close(done) + b.NotifyBlockProcessed(mockBeat, errDummy) + }() + + // Close the quit channel so the method will return. + close(b.quit) + + // Assert the done channel is closed. + result, err := fn.RecvOrTimeout(done, time.Second) + require.NoError(t, err) + require.Nil(t, result) +} diff --git a/chainio/dispatcher.go b/chainio/dispatcher.go new file mode 100644 index 00000000000..87bc21fbaac --- /dev/null +++ b/chainio/dispatcher.go @@ -0,0 +1,296 @@ +package chainio + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/btcsuite/btclog/v2" + "github.com/lightningnetwork/lnd/chainntnfs" + "golang.org/x/sync/errgroup" +) + +// DefaultProcessBlockTimeout is the timeout value used when waiting for one +// consumer to finish processing the new block epoch. +var DefaultProcessBlockTimeout = 60 * time.Second + +// ErrProcessBlockTimeout is the error returned when a consumer takes too long +// to process the block. +var ErrProcessBlockTimeout = errors.New("process block timeout") + +// BlockbeatDispatcher is a service that handles dispatching new blocks to +// `lnd`'s subsystems. During startup, subsystems that are block-driven should +// implement the `Consumer` interface and register themselves via +// `RegisterQueue`. When two subsystems are independent of each other, they +// should be registered in different queues so blocks are notified concurrently. +// Otherwise, when living in the same queue, the subsystems are notified of the +// new blocks sequentially, which means it's critical to understand the +// relationship of these systems to properly handle the order. +type BlockbeatDispatcher struct { + wg sync.WaitGroup + + // notifier is used to receive new block epochs. + notifier chainntnfs.ChainNotifier + + // beat is the latest blockbeat received. + beat Blockbeat + + // consumerQueues is a map of consumers that will receive blocks. Its + // key is a unique counter and its value is a queue of consumers. Each + // queue is notified concurrently, and consumers in the same queue is + // notified sequentially. + consumerQueues map[uint32][]Consumer + + // counter is used to assign a unique id to each queue. + counter atomic.Uint32 + + // quit is used to signal the BlockbeatDispatcher to stop. + quit chan struct{} +} + +// NewBlockbeatDispatcher returns a new blockbeat dispatcher instance. +func NewBlockbeatDispatcher(n chainntnfs.ChainNotifier) *BlockbeatDispatcher { + return &BlockbeatDispatcher{ + notifier: n, + quit: make(chan struct{}), + consumerQueues: make(map[uint32][]Consumer), + } +} + +// RegisterQueue takes a list of consumers and registers them in the same +// queue. +// +// NOTE: these consumers are notified sequentially. +func (b *BlockbeatDispatcher) RegisterQueue(consumers []Consumer) { + qid := b.counter.Add(1) + + b.consumerQueues[qid] = append(b.consumerQueues[qid], consumers...) + clog.Infof("Registered queue=%d with %d blockbeat consumers", qid, + len(consumers)) + + for _, c := range consumers { + clog.Debugf("Consumer [%s] registered in queue %d", c.Name(), + qid) + } +} + +// Start starts the blockbeat dispatcher - it registers a block notification +// and monitors and dispatches new blocks in a goroutine. It will refuse to +// start if there are no registered consumers. +func (b *BlockbeatDispatcher) Start() error { + // Make sure consumers are registered. + if len(b.consumerQueues) == 0 { + return fmt.Errorf("no consumers registered") + } + + // Start listening to new block epochs. We should get a notification + // with the current best block immediately. + blockEpochs, err := b.notifier.RegisterBlockEpochNtfn(nil) + if err != nil { + return fmt.Errorf("register block epoch ntfn: %w", err) + } + + clog.Infof("BlockbeatDispatcher is starting with %d consumer queues", + len(b.consumerQueues)) + defer clog.Debug("BlockbeatDispatcher started") + + b.wg.Add(1) + go b.dispatchBlocks(blockEpochs) + + return nil +} + +// Stop shuts down the blockbeat dispatcher. +func (b *BlockbeatDispatcher) Stop() { + clog.Info("BlockbeatDispatcher is stopping") + defer clog.Debug("BlockbeatDispatcher stopped") + + // Signal the dispatchBlocks goroutine to stop. + close(b.quit) + b.wg.Wait() +} + +func (b *BlockbeatDispatcher) log() btclog.Logger { + return b.beat.logger() +} + +// dispatchBlocks listens to new block epoch and dispatches it to all the +// consumers. Each queue is notified concurrently, and the consumers in the +// same queue are notified sequentially. +// +// NOTE: Must be run as a goroutine. +func (b *BlockbeatDispatcher) dispatchBlocks( + blockEpochs *chainntnfs.BlockEpochEvent) { + + defer b.wg.Done() + defer blockEpochs.Cancel() + + for { + select { + case blockEpoch, ok := <-blockEpochs.Epochs: + if !ok { + clog.Debugf("Block epoch channel closed") + + return + } + + clog.Infof("Received new block %v at height %d, "+ + "notifying consumers...", blockEpoch.Hash, + blockEpoch.Height) + + // Record the time it takes the consumer to process + // this block. + start := time.Now() + + // Update the current block epoch. + b.beat = NewBeat(*blockEpoch) + + // Notify all consumers. + err := b.notifyQueues() + if err != nil { + b.log().Errorf("Notify block failed: %v", err) + } + + b.log().Infof("Notified all consumers on new block "+ + "in %v", time.Since(start)) + + case <-b.quit: + b.log().Debugf("BlockbeatDispatcher quit signal " + + "received") + + return + } + } +} + +// notifyQueues notifies each queue concurrently about the latest block epoch. +func (b *BlockbeatDispatcher) notifyQueues() error { + // errChans is a map of channels that will be used to receive errors + // returned from notifying the consumers. + errChans := make(map[uint32]chan error, len(b.consumerQueues)) + + // Notify each queue in goroutines. + for qid, consumers := range b.consumerQueues { + b.log().Debugf("Notifying queue=%d with %d consumers", qid, + len(consumers)) + + // Create a signal chan. + errChan := make(chan error, 1) + errChans[qid] = errChan + + // Notify each queue concurrently. + go func(qid uint32, c []Consumer, beat Blockbeat) { + // Notify each consumer in this queue sequentially. + errChan <- DispatchSequential(beat, c) + }(qid, consumers, b.beat) + } + + // Wait for all consumers in each queue to finish. + for qid, errChan := range errChans { + select { + case err := <-errChan: + if err != nil { + return fmt.Errorf("queue=%d got err: %w", qid, + err) + } + + b.log().Debugf("Notified queue=%d", qid) + + case <-b.quit: + b.log().Debugf("BlockbeatDispatcher quit signal " + + "received, exit notifyQueues") + + return nil + } + } + + return nil +} + +// DispatchSequential takes a list of consumers and notify them about the new +// epoch sequentially. It requires the consumer to finish processing the block +// within the specified time, otherwise a timeout error is returned. +func DispatchSequential(b Blockbeat, consumers []Consumer) error { + for _, c := range consumers { + // Send the beat to the consumer. + err := notifyAndWait(b, c, DefaultProcessBlockTimeout) + if err != nil { + b.logger().Errorf("Failed to process block: %v", err) + + return err + } + } + + return nil +} + +// DispatchConcurrent notifies each consumer concurrently about the blockbeat. +// It requires the consumer to finish processing the block within the specified +// time, otherwise a timeout error is returned. +func DispatchConcurrent(b Blockbeat, consumers []Consumer) error { + eg := &errgroup.Group{} + + // Notify each queue in goroutines. + for _, c := range consumers { + // Notify each consumer concurrently. + eg.Go(func() error { + // Send the beat to the consumer. + err := notifyAndWait(b, c, DefaultProcessBlockTimeout) + + // Exit early if there's no error. + if err == nil { + return nil + } + + b.logger().Errorf("Consumer=%v failed to process "+ + "block: %v", c.Name(), err) + + return err + }) + } + + // Wait for all consumers in each queue to finish. + if err := eg.Wait(); err != nil { + return err + } + + return nil +} + +// notifyAndWait sends the blockbeat to the specified consumer. It requires the +// consumer to finish processing the block within the specified time, otherwise +// a timeout error is returned. +func notifyAndWait(b Blockbeat, c Consumer, timeout time.Duration) error { + b.logger().Debugf("Waiting for consumer[%s] to process it", c.Name()) + + // Record the time it takes the consumer to process this block. + start := time.Now() + + errChan := make(chan error, 1) + go func() { + errChan <- c.ProcessBlock(b) + }() + + // We expect the consumer to finish processing this block under 30s, + // otherwise a timeout error is returned. + select { + case err := <-errChan: + if err == nil { + break + } + + return fmt.Errorf("%s got err in ProcessBlock: %w", c.Name(), + err) + + case <-time.After(timeout): + return fmt.Errorf("consumer %s: %w", c.Name(), + ErrProcessBlockTimeout) + } + + b.logger().Debugf("Consumer[%s] processed block in %v", c.Name(), + time.Since(start)) + + return nil +} diff --git a/chainio/dispatcher_test.go b/chainio/dispatcher_test.go new file mode 100644 index 00000000000..88044c0201d --- /dev/null +++ b/chainio/dispatcher_test.go @@ -0,0 +1,383 @@ +package chainio + +import ( + "testing" + "time" + + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/fn" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// TestNotifyAndWaitOnConsumerErr asserts when the consumer returns an error, +// it's returned by notifyAndWait. +func TestNotifyAndWaitOnConsumerErr(t *testing.T) { + t.Parallel() + + // Create a mock consumer. + consumer := &MockConsumer{} + defer consumer.AssertExpectations(t) + consumer.On("Name").Return("mocker") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock ProcessBlock to return an error. + consumer.On("ProcessBlock", mockBeat).Return(errDummy).Once() + + // Call the method under test. + err := notifyAndWait(mockBeat, consumer, DefaultProcessBlockTimeout) + + // We expect the error to be returned. + require.ErrorIs(t, err, errDummy) +} + +// TestNotifyAndWaitOnConsumerErr asserts when the consumer successfully +// processed the beat, no error is returned. +func TestNotifyAndWaitOnConsumerSuccess(t *testing.T) { + t.Parallel() + + // Create a mock consumer. + consumer := &MockConsumer{} + defer consumer.AssertExpectations(t) + consumer.On("Name").Return("mocker") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock ProcessBlock to return nil. + consumer.On("ProcessBlock", mockBeat).Return(nil).Once() + + // Call the method under test. + err := notifyAndWait(mockBeat, consumer, DefaultProcessBlockTimeout) + + // We expect a nil error to be returned. + require.NoError(t, err) +} + +// TestNotifyAndWaitOnConsumerTimeout asserts when the consumer times out +// processing the block, the timeout error is returned. +func TestNotifyAndWaitOnConsumerTimeout(t *testing.T) { + t.Parallel() + + // Set timeout to be 10ms. + processBlockTimeout := 10 * time.Millisecond + + // Create a mock consumer. + consumer := &MockConsumer{} + defer consumer.AssertExpectations(t) + consumer.On("Name").Return("mocker") + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Mock ProcessBlock to return nil but blocks on returning. + consumer.On("ProcessBlock", mockBeat).Return(nil).Run( + func(args mock.Arguments) { + // Sleep one second to block on the method. + time.Sleep(processBlockTimeout * 100) + }).Once() + + // Call the method under test. + err := notifyAndWait(mockBeat, consumer, processBlockTimeout) + + // We expect a timeout error to be returned. + require.ErrorIs(t, err, ErrProcessBlockTimeout) +} + +// TestDispatchSequential checks that the beat is sent to the consumers +// sequentially. +func TestDispatchSequential(t *testing.T) { + t.Parallel() + + // Create three mock consumers. + consumer1 := &MockConsumer{} + defer consumer1.AssertExpectations(t) + consumer1.On("Name").Return("mocker1") + + consumer2 := &MockConsumer{} + defer consumer2.AssertExpectations(t) + consumer2.On("Name").Return("mocker2") + + consumer3 := &MockConsumer{} + defer consumer3.AssertExpectations(t) + consumer3.On("Name").Return("mocker3") + + consumers := []Consumer{consumer1, consumer2, consumer3} + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // prevConsumer specifies the previous consumer that was called. + var prevConsumer string + + // Mock the ProcessBlock on consumers to reutrn immediately. + consumer1.On("ProcessBlock", mockBeat).Return(nil).Run( + func(args mock.Arguments) { + // Check the order of the consumers. + // + // The first consumer should have no previous consumer. + require.Empty(t, prevConsumer) + + // Set the consumer as the previous consumer. + prevConsumer = consumer1.Name() + }).Once() + + consumer2.On("ProcessBlock", mockBeat).Return(nil).Run( + func(args mock.Arguments) { + // Check the order of the consumers. + // + // The second consumer should see consumer1. + require.Equal(t, consumer1.Name(), prevConsumer) + + // Set the consumer as the previous consumer. + prevConsumer = consumer2.Name() + }).Once() + + consumer3.On("ProcessBlock", mockBeat).Return(nil).Run( + func(args mock.Arguments) { + // Check the order of the consumers. + // + // The third consumer should see consumer2. + require.Equal(t, consumer2.Name(), prevConsumer) + + // Set the consumer as the previous consumer. + prevConsumer = consumer3.Name() + }).Once() + + // Call the method under test. + err := DispatchSequential(mockBeat, consumers) + require.NoError(t, err) + + // Check the previous consumer is the last consumer. + require.Equal(t, consumer3.Name(), prevConsumer) +} + +// TestRegisterQueue tests the RegisterQueue function. +func TestRegisterQueue(t *testing.T) { + t.Parallel() + + // Create two mock consumers. + consumer1 := &MockConsumer{} + defer consumer1.AssertExpectations(t) + consumer1.On("Name").Return("mocker1") + + consumer2 := &MockConsumer{} + defer consumer2.AssertExpectations(t) + consumer2.On("Name").Return("mocker2") + + consumers := []Consumer{consumer1, consumer2} + + // Create a mock chain notifier. + mockNotifier := &chainntnfs.MockChainNotifier{} + defer mockNotifier.AssertExpectations(t) + + // Create a new dispatcher. + b := NewBlockbeatDispatcher(mockNotifier) + + // Register the consumers. + b.RegisterQueue(consumers) + + // Assert that the consumers have been registered. + // + // We should have one queue. + require.Len(t, b.consumerQueues, 1) + + // The queue should have two consumers. + queue, ok := b.consumerQueues[1] + require.True(t, ok) + require.Len(t, queue, 2) +} + +// TestStartDispatcher tests the Start method. +func TestStartDispatcher(t *testing.T) { + t.Parallel() + + // Create a mock chain notifier. + mockNotifier := &chainntnfs.MockChainNotifier{} + defer mockNotifier.AssertExpectations(t) + + // Create a new dispatcher. + b := NewBlockbeatDispatcher(mockNotifier) + + // Start the dispatcher without consumers should return an error. + err := b.Start() + require.Error(t, err) + + // Create a consumer and register it. + consumer := &MockConsumer{} + defer consumer.AssertExpectations(t) + consumer.On("Name").Return("mocker1") + b.RegisterQueue([]Consumer{consumer}) + + // Mock the chain notifier to return an error. + mockNotifier.On("RegisterBlockEpochNtfn", + mock.Anything).Return(nil, errDummy).Once() + + // Start the dispatcher now should return the error. + err = b.Start() + require.ErrorIs(t, err, errDummy) + + // Mock the chain notifier to return a valid notifier. + blockEpochs := &chainntnfs.BlockEpochEvent{} + mockNotifier.On("RegisterBlockEpochNtfn", + mock.Anything).Return(blockEpochs, nil).Once() + + // Start the dispatcher now should not return an error. + err = b.Start() + require.NoError(t, err) +} + +// TestDispatchBlocks asserts the blocks are properly dispatched to the queues. +func TestDispatchBlocks(t *testing.T) { + t.Parallel() + + // Create a mock chain notifier. + mockNotifier := &chainntnfs.MockChainNotifier{} + defer mockNotifier.AssertExpectations(t) + + // Create a new dispatcher. + b := NewBlockbeatDispatcher(mockNotifier) + + // Create the beat and attach it to the dispatcher. + epoch := chainntnfs.BlockEpoch{Height: 1} + beat := NewBeat(epoch) + b.beat = beat + + // Create a consumer and register it. + consumer := &MockConsumer{} + defer consumer.AssertExpectations(t) + consumer.On("Name").Return("mocker1") + b.RegisterQueue([]Consumer{consumer}) + + // Mock the consumer to return nil error on ProcessBlock. This + // implictly asserts that the step `notifyQueues` is successfully + // reached in the `dispatchBlocks` method. + consumer.On("ProcessBlock", mock.Anything).Return(nil).Once() + + // Create a test epoch chan. + epochChan := make(chan *chainntnfs.BlockEpoch, 1) + blockEpochs := &chainntnfs.BlockEpochEvent{ + Epochs: epochChan, + Cancel: func() {}, + } + + // Call the method in a goroutine. + done := make(chan struct{}) + b.wg.Add(1) + go func() { + defer close(done) + b.dispatchBlocks(blockEpochs) + }() + + // Send an epoch. + epoch = chainntnfs.BlockEpoch{Height: 2} + epochChan <- &epoch + + // Wait for the dispatcher to process the epoch. + time.Sleep(100 * time.Millisecond) + + // Stop the dispatcher. + b.Stop() + + // We expect the dispatcher to stop immediately. + _, err := fn.RecvOrTimeout(done, time.Second) + require.NoError(t, err) +} + +// TestNotifyQueuesSuccess checks when the dispatcher successfully notifies all +// the queues, no error is returned. +func TestNotifyQueuesSuccess(t *testing.T) { + t.Parallel() + + // Create two mock consumers. + consumer1 := &MockConsumer{} + defer consumer1.AssertExpectations(t) + consumer1.On("Name").Return("mocker1") + + consumer2 := &MockConsumer{} + defer consumer2.AssertExpectations(t) + consumer2.On("Name").Return("mocker2") + + // Create two queues. + queue1 := []Consumer{consumer1} + queue2 := []Consumer{consumer2} + + // Create a mock chain notifier. + mockNotifier := &chainntnfs.MockChainNotifier{} + defer mockNotifier.AssertExpectations(t) + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Create a new dispatcher. + b := NewBlockbeatDispatcher(mockNotifier) + + // Register the queues. + b.RegisterQueue(queue1) + b.RegisterQueue(queue2) + + // Attach the blockbeat. + b.beat = mockBeat + + // Mock the consumers to return nil error on ProcessBlock for + // both calls. + consumer1.On("ProcessBlock", mockBeat).Return(nil).Once() + consumer2.On("ProcessBlock", mockBeat).Return(nil).Once() + + // Notify the queues. The mockers will be asserted in the end to + // validate the calls. + err := b.notifyQueues() + require.NoError(t, err) +} + +// TestNotifyQueuesError checks when one of the queue returns an error, this +// error is returned by the method. +func TestNotifyQueuesError(t *testing.T) { + t.Parallel() + + // Create a mock consumer. + consumer := &MockConsumer{} + defer consumer.AssertExpectations(t) + consumer.On("Name").Return("mocker1") + + // Create one queue. + queue := []Consumer{consumer} + + // Create a mock chain notifier. + mockNotifier := &chainntnfs.MockChainNotifier{} + defer mockNotifier.AssertExpectations(t) + + // Create a mock beat. + mockBeat := &MockBlockbeat{} + defer mockBeat.AssertExpectations(t) + mockBeat.On("logger").Return(clog) + + // Create a new dispatcher. + b := NewBlockbeatDispatcher(mockNotifier) + + // Register the queues. + b.RegisterQueue(queue) + + // Attach the blockbeat. + b.beat = mockBeat + + // Mock the consumer to return an error on ProcessBlock. + consumer.On("ProcessBlock", mockBeat).Return(errDummy).Once() + + // Notify the queues. The mockers will be asserted in the end to + // validate the calls. + err := b.notifyQueues() + require.ErrorIs(t, err, errDummy) +} diff --git a/chainio/interface.go b/chainio/interface.go new file mode 100644 index 00000000000..03c09faf7c0 --- /dev/null +++ b/chainio/interface.go @@ -0,0 +1,53 @@ +package chainio + +import "github.com/btcsuite/btclog/v2" + +// Blockbeat defines an interface that can be used by subsystems to retrieve +// block data. It is sent by the BlockbeatDispatcher to all the registered +// consumers whenever a new block is received. Once the consumer finishes +// processing the block, it must signal it by calling `NotifyBlockProcessed`. +// +// The blockchain is a state machine - whenever there's a state change, it's +// manifested in a block. The blockbeat is a way to notify subsystems of this +// state change, and to provide them with the data they need to process it. In +// other words, subsystems must react to this state change and should consider +// being driven by the blockbeat in their own state machines. +type Blockbeat interface { + // blockbeat is a private interface that's only used in this package. + blockbeat + + // Height returns the current block height. + Height() int32 +} + +// blockbeat defines a set of private methods used in this package to make +// interaction with the blockbeat easier. +type blockbeat interface { + // logger returns the internal logger used by the blockbeat which has a + // block height prefix. + logger() btclog.Logger +} + +// Consumer defines a blockbeat consumer interface. Subsystems that need block +// info must implement it. +type Consumer interface { + // TODO(yy): We should also define the start methods used by the + // consumers such that when implementing the interface, the consumer + // will always be started with a blockbeat. This cannot be enforced at + // the moment as we need refactor all the start methods to only take a + // beat. + // + // Start(beat Blockbeat) error + + // Name returns a human-readable string for this subsystem. + Name() string + + // ProcessBlock takes a blockbeat and processes it. It should not + // return until the subsystem has updated its state based on the block + // data. + // + // NOTE: The consumer must try its best to NOT return an error. If an + // error is returned from processing the block, it means the subsystem + // cannot react to onchain state changes and lnd will shutdown. + ProcessBlock(b Blockbeat) error +} diff --git a/chainio/log.go b/chainio/log.go new file mode 100644 index 00000000000..2d8c26f7a59 --- /dev/null +++ b/chainio/log.go @@ -0,0 +1,32 @@ +package chainio + +import ( + "github.com/btcsuite/btclog/v2" + "github.com/lightningnetwork/lnd/build" +) + +// Subsystem defines the logging code for this subsystem. +const Subsystem = "CHIO" + +// clog is a logger that is initialized with no output filters. This means the +// package will not perform any logging by default until the caller requests +// it. +var clog btclog.Logger + +// The default amount of logging is none. +func init() { + UseLogger(build.NewSubLogger(Subsystem, nil)) +} + +// DisableLog disables all library log output. Logging output is disabled by +// default until UseLogger is called. +func DisableLog() { + UseLogger(btclog.Disabled) +} + +// UseLogger uses a specified Logger to output package logging info. This +// should be used in preference to SetLogWriter if the caller is also using +// btclog. +func UseLogger(logger btclog.Logger) { + clog = logger +} diff --git a/chainio/mocks.go b/chainio/mocks.go new file mode 100644 index 00000000000..5677734e1dd --- /dev/null +++ b/chainio/mocks.go @@ -0,0 +1,50 @@ +package chainio + +import ( + "github.com/btcsuite/btclog/v2" + "github.com/stretchr/testify/mock" +) + +// MockConsumer is a mock implementation of the Consumer interface. +type MockConsumer struct { + mock.Mock +} + +// Compile-time constraint to ensure MockConsumer implements Consumer. +var _ Consumer = (*MockConsumer)(nil) + +// Name returns a human-readable string for this subsystem. +func (m *MockConsumer) Name() string { + args := m.Called() + return args.String(0) +} + +// ProcessBlock takes a blockbeat and processes it. A receive-only error chan +// must be returned. +func (m *MockConsumer) ProcessBlock(b Blockbeat) error { + args := m.Called(b) + + return args.Error(0) +} + +// MockBlockbeat is a mock implementation of the Blockbeat interface. +type MockBlockbeat struct { + mock.Mock +} + +// Compile-time constraint to ensure MockBlockbeat implements Blockbeat. +var _ Blockbeat = (*MockBlockbeat)(nil) + +// Height returns the current block height. +func (m *MockBlockbeat) Height() int32 { + args := m.Called() + + return args.Get(0).(int32) +} + +// logger returns the logger for the blockbeat. +func (m *MockBlockbeat) logger() btclog.Logger { + args := m.Called() + + return args.Get(0).(btclog.Logger) +} diff --git a/contractcourt/anchor_resolver.go b/contractcourt/anchor_resolver.go index b4d6877202e..09c325f1bf9 100644 --- a/contractcourt/anchor_resolver.go +++ b/contractcourt/anchor_resolver.go @@ -84,7 +84,7 @@ func (c *anchorResolver) ResolverKey() []byte { } // Resolve offers the anchor output to the sweeper and waits for it to be swept. -func (c *anchorResolver) Resolve(_ bool) (ContractResolver, error) { +func (c *anchorResolver) Resolve() (ContractResolver, error) { // Attempt to update the sweep parameters to the post-confirmation // situation. We don't want to force sweep anymore, because the anchor // lost its special purpose to get the commitment confirmed. It is just diff --git a/contractcourt/breach_arbitrator_test.go b/contractcourt/breach_arbitrator_test.go index bd4ad856831..2001431c79c 100644 --- a/contractcourt/breach_arbitrator_test.go +++ b/contractcourt/breach_arbitrator_test.go @@ -36,7 +36,7 @@ import ( ) var ( - defaultTimeout = 30 * time.Second + defaultTimeout = 10 * time.Second breachOutPoints = []wire.OutPoint{ { diff --git a/contractcourt/breach_resolver.go b/contractcourt/breach_resolver.go index 740b4471d5d..63395651cc3 100644 --- a/contractcourt/breach_resolver.go +++ b/contractcourt/breach_resolver.go @@ -47,7 +47,7 @@ func (b *breachResolver) ResolverKey() []byte { // been broadcast. // // TODO(yy): let sweeper handle the breach inputs. -func (b *breachResolver) Resolve(_ bool) (ContractResolver, error) { +func (b *breachResolver) Resolve() (ContractResolver, error) { if !b.subscribed { complete, err := b.SubscribeBreachComplete( &b.ChanPoint, b.replyChan, diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index d7d10ba252c..71a45457079 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -11,6 +11,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/walletdb" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" @@ -244,6 +245,10 @@ type ChainArbitrator struct { started int32 // To be used atomically. stopped int32 // To be used atomically. + // Embed the blockbeat consumer struct to get access to the method + // `NotifyBlockProcessed` and the `BlockbeatChan`. + chainio.BeatConsumer + sync.Mutex // activeChannels is a map of all the active contracts that are still @@ -262,6 +267,9 @@ type ChainArbitrator struct { // active channels that it must still watch over. chanSource *channeldb.DB + // beat is the current best known blockbeat. + beat chainio.Blockbeat + quit chan struct{} wg sync.WaitGroup @@ -272,15 +280,23 @@ type ChainArbitrator struct { func NewChainArbitrator(cfg ChainArbitratorConfig, db *channeldb.DB) *ChainArbitrator { - return &ChainArbitrator{ + c := &ChainArbitrator{ cfg: cfg, activeChannels: make(map[wire.OutPoint]*ChannelArbitrator), activeWatchers: make(map[wire.OutPoint]*chainWatcher), chanSource: db, quit: make(chan struct{}), } + + // Mount the block consumer. + c.BeatConsumer = chainio.NewBeatConsumer(c.quit, c.Name()) + + return c } +// Compile-time check for the chainio.Consumer interface. +var _ chainio.Consumer = (*ChainArbitrator)(nil) + // arbChannel is a wrapper around an open channel that channel arbitrators // interact with. type arbChannel struct { @@ -558,147 +574,30 @@ func (c *ChainArbitrator) ResolveContract(chanPoint wire.OutPoint) error { } // Start launches all goroutines that the ChainArbitrator needs to operate. -func (c *ChainArbitrator) Start() error { +func (c *ChainArbitrator) Start(beat chainio.Blockbeat) error { if !atomic.CompareAndSwapInt32(&c.started, 0, 1) { return nil } - log.Infof("ChainArbitrator starting with config: budget=[%v]", - &c.cfg.Budget) + // Set the current beat. + c.beat = beat + + log.Infof("ChainArbitrator starting at height %d with budget=[%v]", + &c.cfg.Budget, c.beat.Height()) // First, we'll fetch all the channels that are still open, in order to // collect them within our set of active contracts. - openChannels, err := c.chanSource.ChannelStateDB().FetchAllChannels() - if err != nil { + if err := c.loadOpenChannels(); err != nil { return err } - if len(openChannels) > 0 { - log.Infof("Creating ChannelArbitrators for %v active channels", - len(openChannels)) - } - - // For each open channel, we'll configure then launch a corresponding - // ChannelArbitrator. - for _, channel := range openChannels { - chanPoint := channel.FundingOutpoint - channel := channel - - // First, we'll create an active chainWatcher for this channel - // to ensure that we detect any relevant on chain events. - breachClosure := func(ret *lnwallet.BreachRetribution) error { - return c.cfg.ContractBreach(chanPoint, ret) - } - - chainWatcher, err := newChainWatcher( - chainWatcherConfig{ - chanState: channel, - notifier: c.cfg.Notifier, - signer: c.cfg.Signer, - isOurAddr: c.cfg.IsOurAddress, - contractBreach: breachClosure, - extractStateNumHint: lnwallet.GetStateNumHint, - auxLeafStore: c.cfg.AuxLeafStore, - auxResolver: c.cfg.AuxResolver, - }, - ) - if err != nil { - return err - } - - c.activeWatchers[chanPoint] = chainWatcher - channelArb, err := newActiveChannelArbitrator( - channel, c, chainWatcher.SubscribeChannelEvents(), - ) - if err != nil { - return err - } - - c.activeChannels[chanPoint] = channelArb - - // Republish any closing transactions for this channel. - err = c.republishClosingTxs(channel) - if err != nil { - log.Errorf("Failed to republish closing txs for "+ - "channel %v", chanPoint) - } - } - // In addition to the channels that we know to be open, we'll also // launch arbitrators to finishing resolving any channels that are in // the pending close state. - closingChannels, err := c.chanSource.ChannelStateDB().FetchClosedChannels( - true, - ) - if err != nil { + if err := c.loadPendingCloseChannels(); err != nil { return err } - if len(closingChannels) > 0 { - log.Infof("Creating ChannelArbitrators for %v closing channels", - len(closingChannels)) - } - - // Next, for each channel is the closing state, we'll launch a - // corresponding more restricted resolver, as we don't have to watch - // the chain any longer, only resolve the contracts on the confirmed - // commitment. - //nolint:lll - for _, closeChanInfo := range closingChannels { - // We can leave off the CloseContract and ForceCloseChan - // methods as the channel is already closed at this point. - chanPoint := closeChanInfo.ChanPoint - arbCfg := ChannelArbitratorConfig{ - ChanPoint: chanPoint, - ShortChanID: closeChanInfo.ShortChanID, - ChainArbitratorConfig: c.cfg, - ChainEvents: &ChainEventSubscription{}, - IsPendingClose: true, - ClosingHeight: closeChanInfo.CloseHeight, - CloseType: closeChanInfo.CloseType, - PutResolverReport: func(tx kvdb.RwTx, - report *channeldb.ResolverReport) error { - - return c.chanSource.PutResolverReport( - tx, c.cfg.ChainHash, &chanPoint, report, - ) - }, - FetchHistoricalChannel: func() (*channeldb.OpenChannel, error) { - chanStateDB := c.chanSource.ChannelStateDB() - return chanStateDB.FetchHistoricalChannel(&chanPoint) - }, - FindOutgoingHTLCDeadline: func( - htlc channeldb.HTLC) fn.Option[int32] { - - return c.FindOutgoingHTLCDeadline( - closeChanInfo.ShortChanID, htlc, - ) - }, - } - chanLog, err := newBoltArbitratorLog( - c.chanSource.Backend, arbCfg, c.cfg.ChainHash, chanPoint, - ) - if err != nil { - return err - } - arbCfg.MarkChannelResolved = func() error { - if c.cfg.NotifyFullyResolvedChannel != nil { - c.cfg.NotifyFullyResolvedChannel(chanPoint) - } - - return c.ResolveContract(chanPoint) - } - - // We create an empty map of HTLC's here since it's possible - // that the channel is in StateDefault and updateActiveHTLCs is - // called. We want to avoid writing to an empty map. Since the - // channel is already in the process of being resolved, no new - // HTLCs will be added. - c.activeChannels[chanPoint] = NewChannelArbitrator( - arbCfg, make(map[HtlcSetKey]htlcSet), chanLog, - ) - } - // Now, we'll start all chain watchers in parallel to shorten start up // duration. In neutrino mode, this allows spend registrations to take // advantage of batch spend reporting, instead of doing a single rescan @@ -750,7 +649,7 @@ func (c *ChainArbitrator) Start() error { // transaction. var startStates map[wire.OutPoint]*chanArbStartState - err = kvdb.View(c.chanSource, func(tx walletdb.ReadTx) error { + err := kvdb.View(c.chanSource, func(tx walletdb.ReadTx) error { for _, arbitrator := range c.activeChannels { startState, err := arbitrator.getStartState(tx) if err != nil { @@ -782,24 +681,17 @@ func (c *ChainArbitrator) Start() error { arbitrator.cfg.ChanPoint) } - if err := arbitrator.Start(startState); err != nil { + if err := arbitrator.Start(startState, c.beat); err != nil { stopAndLog() return err } } - // Subscribe to a single stream of block epoch notifications that we - // will dispatch to all active arbitrators. - blockEpoch, err := c.cfg.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return err - } - // Start our goroutine which will dispatch blocks to each arbitrator. c.wg.Add(1) go func() { defer c.wg.Done() - c.dispatchBlocks(blockEpoch) + c.dispatchBlocks() }() // TODO(roasbeef): eventually move all breach watching here @@ -807,94 +699,22 @@ func (c *ChainArbitrator) Start() error { return nil } -// blockRecipient contains the information we need to dispatch a block to a -// channel arbitrator. -type blockRecipient struct { - // chanPoint is the funding outpoint of the channel. - chanPoint wire.OutPoint - - // blocks is the channel that new block heights are sent into. This - // channel should be sufficiently buffered as to not block the sender. - blocks chan<- int32 - - // quit is closed if the receiving entity is shutting down. - quit chan struct{} -} - // dispatchBlocks consumes a block epoch notification stream and dispatches // blocks to each of the chain arb's active channel arbitrators. This function // must be run in a goroutine. -func (c *ChainArbitrator) dispatchBlocks( - blockEpoch *chainntnfs.BlockEpochEvent) { - - // getRecipients is a helper function which acquires the chain arb - // lock and returns a set of block recipients which can be used to - // dispatch blocks. - getRecipients := func() []blockRecipient { - c.Lock() - blocks := make([]blockRecipient, 0, len(c.activeChannels)) - for _, channel := range c.activeChannels { - blocks = append(blocks, blockRecipient{ - chanPoint: channel.cfg.ChanPoint, - blocks: channel.blocks, - quit: channel.quit, - }) - } - c.Unlock() - - return blocks - } - - // On exit, cancel our blocks subscription and close each block channel - // so that the arbitrators know they will no longer be receiving blocks. - defer func() { - blockEpoch.Cancel() - - recipients := getRecipients() - for _, recipient := range recipients { - close(recipient.blocks) - } - }() - +func (c *ChainArbitrator) dispatchBlocks() { // Consume block epochs until we receive the instruction to shutdown. for { select { // Consume block epochs, exiting if our subscription is // terminated. - case block, ok := <-blockEpoch.Epochs: - if !ok { - log.Trace("dispatchBlocks block epoch " + - "cancelled") - return - } + case beat := <-c.BlockbeatChan: + // Set the current blockbeat. + c.beat = beat - // Get the set of currently active channels block - // subscription channels and dispatch the block to - // each. - for _, recipient := range getRecipients() { - select { - // Deliver the block to the arbitrator. - case recipient.blocks <- block.Height: - - // If the recipient is shutting down, exit - // without delivering the block. This may be - // the case when two blocks are mined in quick - // succession, and the arbitrator resolves - // after the first block, and does not need to - // consume the second block. - case <-recipient.quit: - log.Debugf("channel: %v exit without "+ - "receiving block: %v", - recipient.chanPoint, - block.Height) - - // If the chain arb is shutting down, we don't - // need to deliver any more blocks (everything - // will be shutting down). - case <-c.quit: - return - } - } + // Send this blockbeat to all the active channels and + // wait for them to finish processing it. + c.handleBlockbeat(beat) // Exit if the chain arbitrator is shutting down. case <-c.quit: @@ -903,6 +723,32 @@ func (c *ChainArbitrator) dispatchBlocks( } } +// handleBlockbeat sends the blockbeat to all active channel arbitrator in +// parallel and wait for them to finish processing it. +func (c *ChainArbitrator) handleBlockbeat(beat chainio.Blockbeat) { + // Read the active channels in a lock. + c.Lock() + + // Create a slice to record active channel arbitrator. + channels := make([]chainio.Consumer, 0, len(c.activeChannels)) + + // Copy the active channels to the slice. + for _, channel := range c.activeChannels { + channels = append(channels, channel) + } + + c.Unlock() + + // Iterate all the copied channels and send the blockbeat to them. + // + // NOTE: This method will timeout if the processing of blocks of the + // subsystems is too long (60s). + err := chainio.DispatchConcurrent(beat, channels) + + // Notify the chain arbitrator has processed the block. + c.NotifyBlockProcessed(beat, err) +} + // republishClosingTxs will load any stored cooperative or unilateral closing // transactions and republish them. This helps ensure propagation of the // transactions in the event that prior publications failed. @@ -1252,7 +1098,7 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error // arbitrators, then launch it. c.activeChannels[chanPoint] = channelArb - if err := channelArb.Start(nil); err != nil { + if err := channelArb.Start(nil, c.beat); err != nil { return err } @@ -1365,3 +1211,152 @@ func (c *ChainArbitrator) FindOutgoingHTLCDeadline(scid lnwire.ShortChannelID, // TODO(roasbeef): arbitration reports // * types: contested, waiting for success conf, etc + +// NOTE: part of the `chainio.Consumer` interface. +func (c *ChainArbitrator) Name() string { + return "ChainArbitrator" +} + +// loadOpenChannels loads all channels that are currently open in the database +// and registers them with the chainWatcher for future notification. +func (c *ChainArbitrator) loadOpenChannels() error { + openChannels, err := c.chanSource.ChannelStateDB().FetchAllChannels() + if err != nil { + return err + } + + if len(openChannels) == 0 { + return nil + } + + log.Infof("Creating ChannelArbitrators for %v active channels", + len(openChannels)) + + // For each open channel, we'll configure then launch a corresponding + // ChannelArbitrator. + for _, channel := range openChannels { + chanPoint := channel.FundingOutpoint + channel := channel + + // First, we'll create an active chainWatcher for this channel + // to ensure that we detect any relevant on chain events. + breachClosure := func(ret *lnwallet.BreachRetribution) error { + return c.cfg.ContractBreach(chanPoint, ret) + } + + chainWatcher, err := newChainWatcher( + chainWatcherConfig{ + chanState: channel, + notifier: c.cfg.Notifier, + signer: c.cfg.Signer, + isOurAddr: c.cfg.IsOurAddress, + contractBreach: breachClosure, + extractStateNumHint: lnwallet.GetStateNumHint, + auxLeafStore: c.cfg.AuxLeafStore, + auxResolver: c.cfg.AuxResolver, + }, + ) + if err != nil { + return err + } + + c.activeWatchers[chanPoint] = chainWatcher + channelArb, err := newActiveChannelArbitrator( + channel, c, chainWatcher.SubscribeChannelEvents(), + ) + if err != nil { + return err + } + + c.activeChannels[chanPoint] = channelArb + + // Republish any closing transactions for this channel. + err = c.republishClosingTxs(channel) + if err != nil { + log.Errorf("Failed to republish closing txs for "+ + "channel %v", chanPoint) + } + } + + return nil +} + +// loadPendingCloseChannels loads all channels that are currently pending +// closure in the database and registers them with the ChannelArbitrator to +// continue the resolution process. +func (c *ChainArbitrator) loadPendingCloseChannels() error { + chanStateDB := c.chanSource.ChannelStateDB() + + closingChannels, err := chanStateDB.FetchClosedChannels(true) + if err != nil { + return err + } + + if len(closingChannels) == 0 { + return nil + } + + log.Infof("Creating ChannelArbitrators for %v closing channels", + len(closingChannels)) + + // Next, for each channel is the closing state, we'll launch a + // corresponding more restricted resolver, as we don't have to watch + // the chain any longer, only resolve the contracts on the confirmed + // commitment. + //nolint:lll + for _, closeChanInfo := range closingChannels { + // We can leave off the CloseContract and ForceCloseChan + // methods as the channel is already closed at this point. + chanPoint := closeChanInfo.ChanPoint + arbCfg := ChannelArbitratorConfig{ + ChanPoint: chanPoint, + ShortChanID: closeChanInfo.ShortChanID, + ChainArbitratorConfig: c.cfg, + ChainEvents: &ChainEventSubscription{}, + IsPendingClose: true, + ClosingHeight: closeChanInfo.CloseHeight, + CloseType: closeChanInfo.CloseType, + PutResolverReport: func(tx kvdb.RwTx, + report *channeldb.ResolverReport) error { + + return c.chanSource.PutResolverReport( + tx, c.cfg.ChainHash, &chanPoint, report, + ) + }, + FetchHistoricalChannel: func() (*channeldb.OpenChannel, error) { + return chanStateDB.FetchHistoricalChannel(&chanPoint) + }, + FindOutgoingHTLCDeadline: func( + htlc channeldb.HTLC) fn.Option[int32] { + + return c.FindOutgoingHTLCDeadline( + closeChanInfo.ShortChanID, htlc, + ) + }, + } + chanLog, err := newBoltArbitratorLog( + c.chanSource.Backend, arbCfg, c.cfg.ChainHash, chanPoint, + ) + if err != nil { + return err + } + arbCfg.MarkChannelResolved = func() error { + if c.cfg.NotifyFullyResolvedChannel != nil { + c.cfg.NotifyFullyResolvedChannel(chanPoint) + } + + return c.ResolveContract(chanPoint) + } + + // We create an empty map of HTLC's here since it's possible + // that the channel is in StateDefault and updateActiveHTLCs is + // called. We want to avoid writing to an empty map. Since the + // channel is already in the process of being resolved, no new + // HTLCs will be added. + c.activeChannels[chanPoint] = NewChannelArbitrator( + arbCfg, make(map[HtlcSetKey]htlcSet), chanLog, + ) + } + + return nil +} diff --git a/contractcourt/chain_arbitrator_test.go b/contractcourt/chain_arbitrator_test.go index abaca5c2bab..a6b60a9a21a 100644 --- a/contractcourt/chain_arbitrator_test.go +++ b/contractcourt/chain_arbitrator_test.go @@ -83,7 +83,6 @@ func TestChainArbitratorRepublishCloses(t *testing.T) { ChainIO: &mock.ChainIO{}, Notifier: &mock.ChainNotifier{ SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), ConfChan: make(chan *chainntnfs.TxConfirmation), }, PublishTx: func(tx *wire.MsgTx, _ string) error { @@ -97,7 +96,8 @@ func TestChainArbitratorRepublishCloses(t *testing.T) { chainArbCfg, db, ) - if err := chainArb.Start(); err != nil { + beat := newBeatFromHeight(0) + if err := chainArb.Start(beat); err != nil { t.Fatal(err) } t.Cleanup(func() { @@ -168,7 +168,6 @@ func TestResolveContract(t *testing.T) { ChainIO: &mock.ChainIO{}, Notifier: &mock.ChainNotifier{ SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), ConfChan: make(chan *chainntnfs.TxConfirmation), }, PublishTx: func(tx *wire.MsgTx, _ string) error { @@ -185,7 +184,8 @@ func TestResolveContract(t *testing.T) { chainArb := NewChainArbitrator( chainArbCfg, db, ) - if err := chainArb.Start(); err != nil { + beat := newBeatFromHeight(0) + if err := chainArb.Start(beat); err != nil { t.Fatal(err) } t.Cleanup(func() { diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index ffa4a5d6e29..9afb8c149ba 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -14,6 +14,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" @@ -330,6 +331,10 @@ type ChannelArbitrator struct { started int32 // To be used atomically. stopped int32 // To be used atomically. + // Embed the blockbeat consumer struct to get access to the method + // `NotifyBlockProcessed` and the `BlockbeatChan`. + chainio.BeatConsumer + // startTimestamp is the time when this ChannelArbitrator was started. startTimestamp time.Time @@ -352,11 +357,6 @@ type ChannelArbitrator struct { // to do its duty. cfg ChannelArbitratorConfig - // blocks is a channel that the arbitrator will receive new blocks on. - // This channel should be buffered by so that it does not block the - // sender. - blocks chan int32 - // signalUpdates is a channel that any new live signals for the channel // we're watching over will be sent. signalUpdates chan *signalUpdateMsg @@ -404,9 +404,8 @@ func NewChannelArbitrator(cfg ChannelArbitratorConfig, unmerged[RemotePendingHtlcSet] = htlcSets[RemotePendingHtlcSet] } - return &ChannelArbitrator{ + c := &ChannelArbitrator{ log: log, - blocks: make(chan int32, arbitratorBlockBufferSize), signalUpdates: make(chan *signalUpdateMsg), resolutionSignal: make(chan struct{}), forceCloseReqs: make(chan *forceCloseReq), @@ -415,8 +414,16 @@ func NewChannelArbitrator(cfg ChannelArbitratorConfig, cfg: cfg, quit: make(chan struct{}), } + + // Mount the block consumer. + c.BeatConsumer = chainio.NewBeatConsumer(c.quit, c.Name()) + + return c } +// Compile-time check for the chainio.Consumer interface. +var _ chainio.Consumer = (*ChannelArbitrator)(nil) + // chanArbStartState contains the information from disk that we need to start // up a channel arbitrator. type chanArbStartState struct { @@ -455,7 +462,9 @@ func (c *ChannelArbitrator) getStartState(tx kvdb.RTx) (*chanArbStartState, // Start starts all the goroutines that the ChannelArbitrator needs to operate. // If takes a start state, which will be looked up on disk if it is not // provided. -func (c *ChannelArbitrator) Start(state *chanArbStartState) error { +func (c *ChannelArbitrator) Start(state *chanArbStartState, + beat chainio.Blockbeat) error { + if !atomic.CompareAndSwapInt32(&c.started, 0, 1) { return nil } @@ -477,10 +486,8 @@ func (c *ChannelArbitrator) Start(state *chanArbStartState) error { // Set our state from our starting state. c.state = state.currentState - _, bestHeight, err := c.cfg.ChainIO.GetBestBlock() - if err != nil { - return err - } + // Get the starting height. + bestHeight := beat.Height() // If the channel has been marked pending close in the database, and we // haven't transitioned the state machine to StateContractClosed (or a @@ -797,7 +804,7 @@ func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet, // TODO(roasbeef): this isn't re-launched? } - c.launchResolvers(unresolvedContracts, true) + c.launchResolvers(unresolvedContracts) return nil } @@ -1336,7 +1343,7 @@ func (c *ChannelArbitrator) stateStep( // Finally, we'll launch all the required contract resolvers. // Once they're all resolved, we're no longer needed. - c.launchResolvers(resolvers, false) + c.launchResolvers(resolvers) nextState = StateWaitingFullResolution @@ -1560,16 +1567,14 @@ func (c *ChannelArbitrator) findCommitmentDeadlineAndValue(heightHint uint32, } // launchResolvers updates the activeResolvers list and starts the resolvers. -func (c *ChannelArbitrator) launchResolvers(resolvers []ContractResolver, - immediate bool) { - +func (c *ChannelArbitrator) launchResolvers(resolvers []ContractResolver) { c.activeResolversLock.Lock() - defer c.activeResolversLock.Unlock() - c.activeResolvers = resolvers + c.activeResolversLock.Unlock() + for _, contract := range resolvers { c.wg.Add(1) - go c.resolveContract(contract, immediate) + go c.resolveContract(contract) } } @@ -1593,8 +1598,8 @@ func (c *ChannelArbitrator) advanceState( for { priorState = c.state log.Debugf("ChannelArbitrator(%v): attempting state step with "+ - "trigger=%v from state=%v", c.cfg.ChanPoint, trigger, - priorState) + "trigger=%v from state=%v at height=%v", + c.cfg.ChanPoint, trigger, priorState, triggerHeight) nextState, closeTx, err := c.stateStep( triggerHeight, trigger, confCommitSet, @@ -2541,9 +2546,7 @@ func (c *ChannelArbitrator) replaceResolver(oldResolver, // contracts. // // NOTE: This MUST be run as a goroutine. -func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver, - immediate bool) { - +func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver) { defer c.wg.Done() log.Debugf("ChannelArbitrator(%v): attempting to resolve %T", @@ -2564,7 +2567,7 @@ func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver, default: // Otherwise, we'll attempt to resolve the current // contract. - nextContract, err := currentContract.Resolve(immediate) + nextContract, err := currentContract.Resolve() if err != nil { if err == errResolverShuttingDown { return @@ -2729,31 +2732,21 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { // A new block has arrived, we'll examine all the active HTLC's // to see if any of them have expired, and also update our // track of the best current height. - case blockHeight, ok := <-c.blocks: - if !ok { - return - } - bestHeight = blockHeight + case beat := <-c.BlockbeatChan: + bestHeight = beat.Height() - // If we're not in the default state, then we can - // ignore this signal as we're waiting for contract - // resolution. - if c.state != StateDefault { - continue - } + log.Debugf("ChannelArbitrator(%v): new block height=%v", + c.cfg.ChanPoint, bestHeight) - // Now that a new block has arrived, we'll attempt to - // advance our state forward. - nextState, _, err := c.advanceState( - uint32(bestHeight), chainTrigger, nil, - ) + err := c.handleBlockbeat(beat) if err != nil { - log.Errorf("Unable to advance state: %v", err) + log.Errorf("Handle block=%v got err: %v", + bestHeight, err) } // If as a result of this trigger, the contract is // fully resolved, then well exit. - if nextState == StateFullyResolved { + if c.state == StateFullyResolved { return } @@ -2802,14 +2795,12 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { // We have broadcasted our commitment, and it is now confirmed // on-chain. case closeInfo := <-c.cfg.ChainEvents.LocalUnilateralClosure: - log.Infof("ChannelArbitrator(%v): local on-chain "+ - "channel close", c.cfg.ChanPoint) - if c.state != StateCommitmentBroadcasted { log.Errorf("ChannelArbitrator(%v): unexpected "+ "local on-chain channel close", c.cfg.ChanPoint) } + closeTx := closeInfo.CloseTx resolutions, err := closeInfo.ContractResolutions. @@ -2837,6 +2828,10 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { return } + log.Infof("ChannelArbitrator(%v): local force close "+ + "tx=%v confirmed", c.cfg.ChanPoint, + closeTx.TxHash()) + contractRes := &ContractResolutions{ CommitHash: closeTx.TxHash(), CommitResolution: resolutions.CommitResolution, @@ -3104,6 +3099,34 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { } } +// handleBlockbeat processes a newly received blockbeat by advancing the +// arbitrator's internal state using the received block height. +func (c *ChannelArbitrator) handleBlockbeat(beat chainio.Blockbeat) error { + // Notify we've processed the block. + defer c.NotifyBlockProcessed(beat, nil) + + // Try to advance the state if we are in StateDefault. + if c.state == StateDefault { + // Now that a new block has arrived, we'll attempt to advance + // our state forward. + _, _, err := c.advanceState( + uint32(beat.Height()), chainTrigger, nil, + ) + if err != nil { + return fmt.Errorf("unable to advance state: %w", err) + } + } + + return nil +} + +// Name returns a human-readable string for this subsystem. +// +// NOTE: Part of chainio.Consumer interface. +func (c *ChannelArbitrator) Name() string { + return fmt.Sprintf("ChannelArbitrator(%v)", c.cfg.ChanPoint) +} + // checkLegacyBreach returns StateFullyResolved if the channel was closed with // a breach transaction before the channel arbitrator launched its own breach // resolver. StateContractClosed is returned if this is a modern breach close diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index ac5253787d3..0ced525dca8 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" @@ -226,6 +227,15 @@ func (c *chanArbTestCtx) CleanUp() { } } +// receiveBlockbeat mocks the behavior of a blockbeat being sent by the +// BlockbeatDispatcher, which essentially mocks the method `ProcessBlock`. +func (c *chanArbTestCtx) receiveBlockbeat(height int) { + go func() { + beat := newBeatFromHeight(int32(height)) + c.chanArb.BlockbeatChan <- beat + }() +} + // AssertStateTransitions asserts that the state machine steps through the // passed states in order. func (c *chanArbTestCtx) AssertStateTransitions(expectedStates ...ArbitratorState) { @@ -285,7 +295,8 @@ func (c *chanArbTestCtx) Restart(restartClosure func(*chanArbTestCtx)) (*chanArb restartClosure(newCtx) } - if err := newCtx.chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := newCtx.chanArb.Start(nil, beat); err != nil { return nil, err } @@ -512,7 +523,8 @@ func TestChannelArbitratorCooperativeClose(t *testing.T) { chanArbCtx, err := createTestChannelArbitrator(t, log) require.NoError(t, err, "unable to create ChannelArbitrator") - if err := chanArbCtx.chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArbCtx.chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } t.Cleanup(func() { @@ -570,7 +582,8 @@ func TestChannelArbitratorRemoteForceClose(t *testing.T) { require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -623,7 +636,8 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) { require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -735,7 +749,8 @@ func TestChannelArbitratorBreachClose(t *testing.T) { chanArb.cfg.PreimageDB = newMockWitnessBeacon() chanArb.cfg.Registry = &mockRegistry{} - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } t.Cleanup(func() { @@ -862,7 +877,8 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { chanArb.cfg.PreimageDB = newMockWitnessBeacon() chanArb.cfg.Registry = &mockRegistry{} - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -1036,7 +1052,7 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { } require.Equal(t, expectedFinalHtlcs, chanArbCtx.finalHtlcs) - // We'll no re-create the resolver, notice that we use the existing + // We'll now re-create the resolver, notice that we use the existing // arbLog so it carries over the same on-disk state. chanArbCtxNew, err := chanArbCtx.Restart(nil) require.NoError(t, err, "unable to create ChannelArbitrator") @@ -1074,7 +1090,12 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { } // Send a notification that the expiry height has been reached. + // + // TODO(yy): remove the EpochChan and use the blockbeat below once + // resolvers are hooked with the blockbeat. oldNotifier.EpochChan <- &chainntnfs.BlockEpoch{Height: 10} + // beat := chainio.NewBlockbeatFromHeight(10) + // chanArb.BlockbeatChan <- beat // htlcOutgoingContestResolver is now transforming into a // htlcTimeoutResolver and should send the contract off for incubation. @@ -1138,7 +1159,8 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) { require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -1245,7 +1267,8 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) { require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -1351,7 +1374,8 @@ func TestChannelArbitratorPersistence(t *testing.T) { require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1469,7 +1493,8 @@ func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) { require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1656,7 +1681,8 @@ func TestChannelArbitratorCommitFailure(t *testing.T) { } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1740,7 +1766,8 @@ func TestChannelArbitratorEmptyResolutions(t *testing.T) { chanArb.cfg.ClosingHeight = 100 chanArb.cfg.CloseType = channeldb.RemoteForceClose - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(100) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1770,7 +1797,8 @@ func TestChannelArbitratorAlreadyForceClosed(t *testing.T) { chanArbCtx, err := createTestChannelArbitrator(t, log) require.NoError(t, err, "unable to create ChannelArbitrator") chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } defer chanArb.Stop() @@ -1868,9 +1896,10 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { t.Fatalf("unable to create ChannelArbitrator: %v", err) } chanArb := chanArbCtx.chanArb - if err := chanArb.Start(nil); err != nil { - t.Fatalf("unable to start ChannelArbitrator: %v", err) - } + beat := newBeatFromHeight(0) + err = chanArb.Start(nil, beat) + require.NoError(t, err) + defer chanArb.Stop() // Now that our channel arb has started, we'll set up @@ -1914,7 +1943,8 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { // now mine a block (height 5), which is 5 blocks away // (our grace delta) from the expiry of that HTLC. case testCase.htlcExpired: - chanArbCtx.chanArb.blocks <- 5 + beat := newBeatFromHeight(5) + chanArbCtx.chanArb.BlockbeatChan <- beat // Otherwise, we'll just trigger a regular force close // request. @@ -2026,8 +2056,7 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { // so instead, we'll mine another block which'll cause // it to re-examine its state and realize there're no // more HTLCs. - chanArbCtx.chanArb.blocks <- 6 - chanArbCtx.AssertStateTransitions(StateFullyResolved) + chanArbCtx.receiveBlockbeat(6) }) } } @@ -2064,7 +2093,8 @@ func TestChannelArbitratorPendingExpiredHTLC(t *testing.T) { return false } - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } t.Cleanup(func() { @@ -2098,13 +2128,15 @@ func TestChannelArbitratorPendingExpiredHTLC(t *testing.T) { // We will advance the uptime to 10 seconds which should be still within // the grace period and should not trigger going to chain. testClock.SetTime(startTime.Add(time.Second * 10)) - chanArbCtx.chanArb.blocks <- 5 + beat = newBeatFromHeight(5) + chanArbCtx.chanArb.BlockbeatChan <- beat chanArbCtx.AssertState(StateDefault) // We will advance the uptime to 16 seconds which should trigger going // to chain. testClock.SetTime(startTime.Add(time.Second * 16)) - chanArbCtx.chanArb.blocks <- 6 + beat = newBeatFromHeight(6) + chanArbCtx.chanArb.BlockbeatChan <- beat chanArbCtx.AssertStateTransitions( StateBroadcastCommit, StateCommitmentBroadcasted, @@ -2217,8 +2249,8 @@ func TestRemoteCloseInitiator(t *testing.T) { "ChannelArbitrator: %v", err) } chanArb := chanArbCtx.chanArb - - if err := chanArb.Start(nil); err != nil { + beat := newBeatFromHeight(0) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start "+ "ChannelArbitrator: %v", err) } @@ -2472,7 +2504,7 @@ func TestSweepAnchors(t *testing.T) { // Set current block height. heightHint := uint32(1000) - chanArbCtx.chanArb.blocks <- int32(heightHint) + chanArbCtx.receiveBlockbeat(int(heightHint)) htlcIndexBase := uint64(99) deadlineDelta := uint32(10) @@ -2635,7 +2667,7 @@ func TestSweepLocalAnchor(t *testing.T) { // Set current block height. heightHint := uint32(1000) - chanArbCtx.chanArb.blocks <- int32(heightHint) + chanArbCtx.receiveBlockbeat(int(heightHint)) htlcIndex := uint64(99) deadlineDelta := uint32(10) @@ -2769,7 +2801,9 @@ func TestChannelArbitratorAnchors(t *testing.T) { }, } - if err := chanArb.Start(nil); err != nil { + heightHint := uint32(1000) + beat := newBeatFromHeight(int32(heightHint)) + if err := chanArb.Start(nil, beat); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } t.Cleanup(func() { @@ -2782,8 +2816,8 @@ func TestChannelArbitratorAnchors(t *testing.T) { chanArb.UpdateContractSignals(signals) // Set current block height. - heightHint := uint32(1000) - chanArbCtx.chanArb.blocks <- int32(heightHint) + beat = newBeatFromHeight(int32(heightHint)) + chanArbCtx.chanArb.BlockbeatChan <- beat htlcAmt := lnwire.MilliSatoshi(1_000_000) @@ -2950,10 +2984,14 @@ func TestChannelArbitratorAnchors(t *testing.T) { // to htlcWithPreimage's CLTV. require.Equal(t, 2, len(chanArbCtx.sweeper.deadlines)) require.EqualValues(t, + heightHint+deadlinePreimageDelta/2, + chanArbCtx.sweeper.deadlines[0], "want %d, got %d", heightHint+deadlinePreimageDelta/2, chanArbCtx.sweeper.deadlines[0], ) require.EqualValues(t, + heightHint+deadlinePreimageDelta/2, + chanArbCtx.sweeper.deadlines[1], "want %d, got %d", heightHint+deadlinePreimageDelta/2, chanArbCtx.sweeper.deadlines[1], ) @@ -3054,7 +3092,8 @@ func TestChannelArbitratorStartForceCloseFail(t *testing.T) { return test.broadcastErr } - err = chanArb.Start(nil) + beat := newBeatFromHeight(0) + err = chanArb.Start(nil, beat) if !test.expectedStartup { require.ErrorIs(t, err, test.broadcastErr) @@ -3133,3 +3172,11 @@ func (m *mockChannel) ForceCloseChan() (*wire.MsgTx, error) { return &wire.MsgTx{}, nil } + +func newBeatFromHeight(height int32) *chainio.Beat { + epoch := chainntnfs.BlockEpoch{ + Height: height, + } + + return chainio.NewBeat(epoch) +} diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go index 4b47a342948..47ad4b81054 100644 --- a/contractcourt/commit_sweep_resolver.go +++ b/contractcourt/commit_sweep_resolver.go @@ -101,36 +101,6 @@ func (c *commitSweepResolver) ResolverKey() []byte { return key[:] } -// waitForHeight registers for block notifications and waits for the provided -// block height to be reached. -func waitForHeight(waitHeight uint32, notifier chainntnfs.ChainNotifier, - quit <-chan struct{}) error { - - // Register for block epochs. After registration, the current height - // will be sent on the channel immediately. - blockEpochs, err := notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return err - } - defer blockEpochs.Cancel() - - for { - select { - case newBlock, ok := <-blockEpochs.Epochs: - if !ok { - return errResolverShuttingDown - } - height := newBlock.Height - if height >= int32(waitHeight) { - return nil - } - - case <-quit: - return errResolverShuttingDown - } - } -} - // waitForSpend waits for the given outpoint to be spent, and returns the // details of the spending tx. func waitForSpend(op *wire.OutPoint, pkScript []byte, heightHint uint32, @@ -195,9 +165,11 @@ func (c *commitSweepResolver) getCommitTxConfHeight() (uint32, error) { // returned. // // NOTE: This function MUST be run as a goroutine. + +// TODO(yy): fix the funlen in the next PR. // //nolint:funlen -func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) { +func (c *commitSweepResolver) Resolve() (ContractResolver, error) { // If we're already resolved, then we can exit early. if c.resolved { return nil, nil @@ -225,39 +197,6 @@ func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) { c.currentReport.MaturityHeight = unlockHeight c.reportLock.Unlock() - // If there is a csv/cltv lock, we'll wait for that. - if c.commitResolution.MaturityDelay > 0 || c.hasCLTV() { - // Determine what height we should wait until for the locks to - // expire. - var waitHeight uint32 - switch { - // If we have both a csv and cltv lock, we'll need to look at - // both and see which expires later. - case c.commitResolution.MaturityDelay > 0 && c.hasCLTV(): - c.log.Debugf("waiting for CSV and CLTV lock to expire "+ - "at height %v", unlockHeight) - // If the CSV expires after the CLTV, or there is no - // CLTV, then we can broadcast a sweep a block before. - // Otherwise, we need to broadcast at our expected - // unlock height. - waitHeight = uint32(math.Max( - float64(unlockHeight-1), float64(c.leaseExpiry), - )) - - // If we only have a csv lock, wait for the height before the - // lock expires as the spend path should be unlocked by then. - case c.commitResolution.MaturityDelay > 0: - c.log.Debugf("waiting for CSV lock to expire at "+ - "height %v", unlockHeight) - waitHeight = unlockHeight - 1 - } - - err := waitForHeight(waitHeight, c.Notifier, c.quit) - if err != nil { - return nil, err - } - } - var ( isLocalCommitTx bool diff --git a/contractcourt/commit_sweep_resolver_test.go b/contractcourt/commit_sweep_resolver_test.go index f2b43b0f80a..15c92344cce 100644 --- a/contractcourt/commit_sweep_resolver_test.go +++ b/contractcourt/commit_sweep_resolver_test.go @@ -82,7 +82,7 @@ func (i *commitSweepResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) go func() { - nextResolver, err := i.resolver.Resolve(false) + nextResolver, err := i.resolver.Resolve() i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, err: err, @@ -90,12 +90,6 @@ func (i *commitSweepResolverTestContext) resolve() { }() } -func (i *commitSweepResolverTestContext) notifyEpoch(height int32) { - i.notifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: height, - } -} - func (i *commitSweepResolverTestContext) waitForResult() { i.t.Helper() @@ -292,22 +286,10 @@ func testCommitSweepResolverDelay(t *testing.T, sweepErr error) { t.Fatal("report maturity height incorrect") } - // Notify initial block height. The csv lock is still in effect, so we - // don't expect any sweep to happen yet. - ctx.notifyEpoch(testInitialBlockHeight) - - select { - case <-ctx.sweeper.sweptInputs: - t.Fatal("no sweep expected") - case <-time.After(sweepProcessInterval): - } - - // A new block arrives. The commit tx confirmed at height -1 and the csv - // is 3, so a spend will be valid in the first block after height +1. - ctx.notifyEpoch(testInitialBlockHeight + 1) - - <-ctx.sweeper.sweptInputs - + // Notify initial block height. Although the csv lock is still in + // effect, we expect the input being sent to the sweeper before the csv + // lock expires. + // // Set the resolution report outcome based on whether our sweep // succeeded. outcome := channeldb.ResolverOutcomeClaimed diff --git a/contractcourt/contract_resolver.go b/contractcourt/contract_resolver.go index 691822610a4..cdf6a76c32e 100644 --- a/contractcourt/contract_resolver.go +++ b/contractcourt/contract_resolver.go @@ -43,7 +43,7 @@ type ContractResolver interface { // resolution, then another resolve is returned. // // NOTE: This function MUST be run as a goroutine. - Resolve(immediate bool) (ContractResolver, error) + Resolve() (ContractResolver, error) // SupplementState allows the user of a ContractResolver to supplement // it with state required for the proper resolution of a contract. diff --git a/contractcourt/htlc_incoming_contest_resolver.go b/contractcourt/htlc_incoming_contest_resolver.go index 6bda4e398b0..2addc91c119 100644 --- a/contractcourt/htlc_incoming_contest_resolver.go +++ b/contractcourt/htlc_incoming_contest_resolver.go @@ -90,9 +90,7 @@ func (h *htlcIncomingContestResolver) processFinalHtlcFail() error { // as we have no remaining actions left at our disposal. // // NOTE: Part of the ContractResolver interface. -func (h *htlcIncomingContestResolver) Resolve( - _ bool) (ContractResolver, error) { - +func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { // If we're already full resolved, then we don't have anything further // to do. if h.resolved { diff --git a/contractcourt/htlc_incoming_contest_resolver_test.go b/contractcourt/htlc_incoming_contest_resolver_test.go index 55d93a6fb37..649f0adf337 100644 --- a/contractcourt/htlc_incoming_contest_resolver_test.go +++ b/contractcourt/htlc_incoming_contest_resolver_test.go @@ -395,7 +395,7 @@ func (i *incomingResolverTestContext) resolve() { i.resolveErr = make(chan error, 1) go func() { var err error - i.nextResolver, err = i.resolver.Resolve(false) + i.nextResolver, err = i.resolver.Resolve() i.resolveErr <- err }() diff --git a/contractcourt/htlc_lease_resolver.go b/contractcourt/htlc_lease_resolver.go index 53fa8935534..c904f21d1b8 100644 --- a/contractcourt/htlc_lease_resolver.go +++ b/contractcourt/htlc_lease_resolver.go @@ -57,10 +57,10 @@ func (h *htlcLeaseResolver) makeSweepInput(op *wire.OutPoint, signDesc *input.SignDescriptor, csvDelay, broadcastHeight uint32, payHash [32]byte, resBlob fn.Option[tlv.Blob]) *input.BaseInput { - if h.hasCLTV() { - log.Infof("%T(%x): CSV and CLTV locks expired, offering "+ - "second-layer output to sweeper: %v", h, payHash, op) + log.Infof("%T(%x): offering second-layer output to sweeper: %v", h, + payHash, op) + if h.hasCLTV() { return input.NewCsvInputWithCltv( op, cltvWtype, signDesc, broadcastHeight, csvDelay, diff --git a/contractcourt/htlc_outgoing_contest_resolver.go b/contractcourt/htlc_outgoing_contest_resolver.go index 2466544c982..874d26ab9cb 100644 --- a/contractcourt/htlc_outgoing_contest_resolver.go +++ b/contractcourt/htlc_outgoing_contest_resolver.go @@ -49,9 +49,7 @@ func newOutgoingContestResolver(res lnwallet.OutgoingHtlcResolution, // When either of these two things happens, we'll create a new resolver which // is able to handle the final resolution of the contract. We're only the pivot // point. -func (h *htlcOutgoingContestResolver) Resolve( - _ bool) (ContractResolver, error) { - +func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) { // If we're already full resolved, then we don't have anything further // to do. if h.resolved { diff --git a/contractcourt/htlc_outgoing_contest_resolver_test.go b/contractcourt/htlc_outgoing_contest_resolver_test.go index f67c34ff4e1..e4a3aaee0d3 100644 --- a/contractcourt/htlc_outgoing_contest_resolver_test.go +++ b/contractcourt/htlc_outgoing_contest_resolver_test.go @@ -209,7 +209,7 @@ func (i *outgoingResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) go func() { - nextResolver, err := i.resolver.Resolve(false) + nextResolver, err := i.resolver.Resolve() i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, err: err, diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index 4c9d2b200bb..39adae88fef 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -115,9 +115,7 @@ func (h *htlcSuccessResolver) ResolverKey() []byte { // TODO(roasbeef): create multi to batch // // NOTE: Part of the ContractResolver interface. -func (h *htlcSuccessResolver) Resolve( - immediate bool) (ContractResolver, error) { - +func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) { // If we're already resolved, then we can exit early. if h.resolved { return nil, nil @@ -126,12 +124,12 @@ func (h *htlcSuccessResolver) Resolve( // If we don't have a success transaction, then this means that this is // an output on the remote party's commitment transaction. if h.htlcResolution.SignedSuccessTx == nil { - return h.resolveRemoteCommitOutput(immediate) + return h.resolveRemoteCommitOutput() } // Otherwise this an output on our own commitment, and we must start by // broadcasting the second-level success transaction. - secondLevelOutpoint, err := h.broadcastSuccessTx(immediate) + secondLevelOutpoint, err := h.broadcastSuccessTx() if err != nil { return nil, err } @@ -165,8 +163,8 @@ func (h *htlcSuccessResolver) Resolve( // broadcasting the second-level success transaction. It returns the ultimate // outpoint of the second-level tx, that we must wait to be spent for the // resolver to be fully resolved. -func (h *htlcSuccessResolver) broadcastSuccessTx( - immediate bool) (*wire.OutPoint, error) { +func (h *htlcSuccessResolver) broadcastSuccessTx() ( + *wire.OutPoint, error) { // If we have non-nil SignDetails, this means that have a 2nd level // HTLC transaction that is signed using sighash SINGLE|ANYONECANPAY @@ -175,7 +173,7 @@ func (h *htlcSuccessResolver) broadcastSuccessTx( // the checkpointed outputIncubating field to determine if we already // swept the HTLC output into the second level transaction. if h.htlcResolution.SignDetails != nil { - return h.broadcastReSignedSuccessTx(immediate) + return h.broadcastReSignedSuccessTx() } // Otherwise we'll publish the second-level transaction directly and @@ -225,10 +223,8 @@ func (h *htlcSuccessResolver) broadcastSuccessTx( // broadcastReSignedSuccessTx handles the case where we have non-nil // SignDetails, and offers the second level transaction to the Sweeper, that // will re-sign it and attach fees at will. -// -//nolint:funlen -func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) ( - *wire.OutPoint, error) { +func (h *htlcSuccessResolver) broadcastReSignedSuccessTx() (*wire.OutPoint, + error) { // Keep track of the tx spending the HTLC output on the commitment, as // this will be the confirmed second-level tx we'll ultimately sweep. @@ -287,7 +283,6 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) ( sweep.Params{ Budget: budget, DeadlineHeight: deadline, - Immediate: immediate, }, ) if err != nil { @@ -359,30 +354,6 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) ( "height %v", h, h.htlc.RHash[:], waitHeight) } - // Deduct one block so this input is offered to the sweeper one block - // earlier since the sweeper will wait for one block to trigger the - // sweeping. - // - // TODO(yy): this is done so the outputs can be aggregated - // properly. Suppose CSV locks of five 2nd-level outputs all - // expire at height 840000, there is a race in block digestion - // between contractcourt and sweeper: - // - G1: block 840000 received in contractcourt, it now offers - // the outputs to the sweeper. - // - G2: block 840000 received in sweeper, it now starts to - // sweep the received outputs - there's no guarantee all - // fives have been received. - // To solve this, we either offer the outputs earlier, or - // implement `blockbeat`, and force contractcourt and sweeper - // to consume each block sequentially. - waitHeight-- - - // TODO(yy): let sweeper handles the wait? - err := waitForHeight(waitHeight, h.Notifier, h.quit) - if err != nil { - return nil, err - } - // We'll use this input index to determine the second-level output // index on the transaction, as the signatures requires the indexes to // be the same. We don't look for the second-level output script @@ -421,7 +392,7 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) ( h.htlc.RHash[:], budget, waitHeight) // TODO(roasbeef): need to update above for leased types - _, err = h.Sweeper.SweepInput( + _, err := h.Sweeper.SweepInput( inp, sweep.Params{ Budget: budget, @@ -443,7 +414,7 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx(immediate bool) ( // resolveRemoteCommitOutput handles sweeping an HTLC output on the remote // commitment with the preimage. In this case we can sweep the output directly, // and don't have to broadcast a second-level transaction. -func (h *htlcSuccessResolver) resolveRemoteCommitOutput(immediate bool) ( +func (h *htlcSuccessResolver) resolveRemoteCommitOutput() ( ContractResolver, error) { isTaproot := txscript.IsPayToTaproot( @@ -495,7 +466,6 @@ func (h *htlcSuccessResolver) resolveRemoteCommitOutput(immediate bool) ( sweep.Params{ Budget: budget, DeadlineHeight: deadline, - Immediate: immediate, }, ) if err != nil { diff --git a/contractcourt/htlc_success_resolver_test.go b/contractcourt/htlc_success_resolver_test.go index b9182500bb4..b962a55a6ea 100644 --- a/contractcourt/htlc_success_resolver_test.go +++ b/contractcourt/htlc_success_resolver_test.go @@ -134,7 +134,7 @@ func (i *htlcResolverTestContext) resolve() { // Start resolver. i.resolverResultChan = make(chan resolveResult, 1) go func() { - nextResolver, err := i.resolver.Resolve(false) + nextResolver, err := i.resolver.Resolve() i.resolverResultChan <- resolveResult{ nextResolver: nextResolver, err: err, @@ -437,10 +437,6 @@ func TestHtlcSuccessSecondStageResolutionSweeper(t *testing.T) { } } - ctx.notifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: 13, - } - // We expect it to sweep the second-level // transaction we notfied about above. resolver := ctx.resolver.(*htlcSuccessResolver) diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index e7ab4216917..fbca316cfd5 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -418,9 +418,7 @@ func checkSizeAndIndex(witness wire.TxWitness, size, index int) bool { // see a direct sweep via the timeout clause. // // NOTE: Part of the ContractResolver interface. -func (h *htlcTimeoutResolver) Resolve( - immediate bool) (ContractResolver, error) { - +func (h *htlcTimeoutResolver) Resolve() (ContractResolver, error) { // If we're already resolved, then we can exit early. if h.resolved { return nil, nil @@ -429,7 +427,7 @@ func (h *htlcTimeoutResolver) Resolve( // Start by spending the HTLC output, either by broadcasting the // second-level timeout transaction, or directly if this is the remote // commitment. - commitSpend, err := h.spendHtlcOutput(immediate) + commitSpend, err := h.spendHtlcOutput() if err != nil { return nil, err } @@ -477,7 +475,7 @@ func (h *htlcTimeoutResolver) Resolve( // sweepSecondLevelTx sends a second level timeout transaction to the sweeper. // This transaction uses the SINLGE|ANYONECANPAY flag. -func (h *htlcTimeoutResolver) sweepSecondLevelTx(immediate bool) error { +func (h *htlcTimeoutResolver) sweepSecondLevelTx() error { log.Infof("%T(%x): offering second-layer timeout tx to sweeper: %v", h, h.htlc.RHash[:], spew.Sdump(h.htlcResolution.SignedTimeoutTx)) @@ -538,7 +536,6 @@ func (h *htlcTimeoutResolver) sweepSecondLevelTx(immediate bool) error { sweep.Params{ Budget: budget, DeadlineHeight: h.incomingHTLCExpiryHeight, - Immediate: immediate, }, ) if err != nil { @@ -572,7 +569,7 @@ func (h *htlcTimeoutResolver) sendSecondLevelTxLegacy() error { // sweeper. This is used when the remote party goes on chain, and we're able to // sweep an HTLC we offered after a timeout. Only the CLTV encumbered outputs // are resolved via this path. -func (h *htlcTimeoutResolver) sweepDirectHtlcOutput(immediate bool) error { +func (h *htlcTimeoutResolver) sweepDirectHtlcOutput() error { var htlcWitnessType input.StandardWitnessType if h.isTaproot() { htlcWitnessType = input.TaprootHtlcOfferedRemoteTimeout @@ -612,7 +609,6 @@ func (h *htlcTimeoutResolver) sweepDirectHtlcOutput(immediate bool) error { // This is an outgoing HTLC, so we want to make sure // that we sweep it before the incoming HTLC expires. DeadlineHeight: h.incomingHTLCExpiryHeight, - Immediate: immediate, }, ) if err != nil { @@ -627,8 +623,8 @@ func (h *htlcTimeoutResolver) sweepDirectHtlcOutput(immediate bool) error { // used to spend the output into the next stage. If this is the remote // commitment, the output will be swept directly without the timeout // transaction. -func (h *htlcTimeoutResolver) spendHtlcOutput( - immediate bool) (*chainntnfs.SpendDetail, error) { +func (h *htlcTimeoutResolver) spendHtlcOutput() ( + *chainntnfs.SpendDetail, error) { switch { // If we have non-nil SignDetails, this means that have a 2nd level @@ -636,7 +632,7 @@ func (h *htlcTimeoutResolver) spendHtlcOutput( // (the case for anchor type channels). In this case we can re-sign it // and attach fees at will. We let the sweeper handle this job. case h.htlcResolution.SignDetails != nil && !h.outputIncubating: - if err := h.sweepSecondLevelTx(immediate); err != nil { + if err := h.sweepSecondLevelTx(); err != nil { log.Errorf("Sending timeout tx to sweeper: %v", err) return nil, err @@ -645,7 +641,7 @@ func (h *htlcTimeoutResolver) spendHtlcOutput( // If this is a remote commitment there's no second level timeout txn, // and we can just send this directly to the sweeper. case h.htlcResolution.SignedTimeoutTx == nil && !h.outputIncubating: - if err := h.sweepDirectHtlcOutput(immediate); err != nil { + if err := h.sweepDirectHtlcOutput(); err != nil { log.Errorf("Sending direct spend to sweeper: %v", err) return nil, err @@ -789,30 +785,6 @@ func (h *htlcTimeoutResolver) handleCommitSpend( "height %v", h, h.htlc.RHash[:], waitHeight) } - // Deduct one block so this input is offered to the sweeper one - // block earlier since the sweeper will wait for one block to - // trigger the sweeping. - // - // TODO(yy): this is done so the outputs can be aggregated - // properly. Suppose CSV locks of five 2nd-level outputs all - // expire at height 840000, there is a race in block digestion - // between contractcourt and sweeper: - // - G1: block 840000 received in contractcourt, it now offers - // the outputs to the sweeper. - // - G2: block 840000 received in sweeper, it now starts to - // sweep the received outputs - there's no guarantee all - // fives have been received. - // To solve this, we either offer the outputs earlier, or - // implement `blockbeat`, and force contractcourt and sweeper - // to consume each block sequentially. - waitHeight-- - - // TODO(yy): let sweeper handles the wait? - err := waitForHeight(waitHeight, h.Notifier, h.quit) - if err != nil { - return nil, err - } - // We'll use this input index to determine the second-level // output index on the transaction, as the signatures requires // the indexes to be the same. We don't look for the @@ -853,7 +825,7 @@ func (h *htlcTimeoutResolver) handleCommitSpend( "sweeper with no deadline and budget=%v at height=%v", h, h.htlc.RHash[:], budget, waitHeight) - _, err = h.Sweeper.SweepInput( + _, err := h.Sweeper.SweepInput( inp, sweep.Params{ Budget: budget, diff --git a/contractcourt/htlc_timeout_resolver_test.go b/contractcourt/htlc_timeout_resolver_test.go index 47be71d3ec1..63d0cf7d5c0 100644 --- a/contractcourt/htlc_timeout_resolver_test.go +++ b/contractcourt/htlc_timeout_resolver_test.go @@ -390,7 +390,7 @@ func testHtlcTimeoutResolver(t *testing.T, testCase htlcTimeoutTestCase) { go func() { defer wg.Done() - _, err := resolver.Resolve(false) + _, err := resolver.Resolve() if err != nil { resolveErr <- err } @@ -1120,11 +1120,6 @@ func TestHtlcTimeoutSecondStageSweeper(t *testing.T) { t.Fatalf("resolution not sent") } - // Mimic CSV lock expiring. - ctx.notifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: 13, - } - // The timeout tx output should now be given to // the sweeper. resolver := ctx.resolver.(*htlcTimeoutResolver) diff --git a/contractcourt/utxonursery.go b/contractcourt/utxonursery.go index b7b4d33a8b1..a920699a7bb 100644 --- a/contractcourt/utxonursery.go +++ b/contractcourt/utxonursery.go @@ -793,7 +793,7 @@ func (u *UtxoNursery) graduateClass(classHeight uint32) error { return err } - utxnLog.Infof("Attempting to graduate height=%v: num_kids=%v, "+ + utxnLog.Debugf("Attempting to graduate height=%v: num_kids=%v, "+ "num_babies=%v", classHeight, len(kgtnOutputs), len(cribOutputs)) // Offer the outputs to the sweeper and set up notifications that will diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index cbc2a16dae2..cc345e461b6 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -1605,7 +1605,7 @@ out: } } - log.Infof("Received outside contract resolution, "+ + log.Debugf("Received outside contract resolution, "+ "mapping to: %v", spew.Sdump(pkt)) // We don't check the error, as the only failure we can diff --git a/lnwallet/wallet.go b/lnwallet/wallet.go index a9018437d5c..a236894e09a 100644 --- a/lnwallet/wallet.go +++ b/lnwallet/wallet.go @@ -733,7 +733,7 @@ func (l *LightningWallet) RegisterFundingIntent(expectedID [32]byte, } if _, ok := l.fundingIntents[expectedID]; ok { - return fmt.Errorf("%w: already has intent registered: %v", + return fmt.Errorf("%w: already has intent registered: %x", ErrDuplicatePendingChanID, expectedID[:]) } diff --git a/log.go b/log.go index c88208ef3b2..89047b2eedf 100644 --- a/log.go +++ b/log.go @@ -9,6 +9,7 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/autopilot" "github.com/lightningnetwork/lnd/build" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/chanacceptor" @@ -192,6 +193,7 @@ func SetupLoggers(root *build.SubLoggerManager, interceptor signal.Interceptor) AddSubLogger( root, blindedpath.Subsystem, interceptor, blindedpath.UseLogger, ) + AddSubLogger(root, chainio.Subsystem, interceptor, chainio.UseLogger) } // AddSubLogger is a helper method to conveniently create and register the diff --git a/server.go b/server.go index c1317591248..d4f24eee22d 100644 --- a/server.go +++ b/server.go @@ -28,6 +28,7 @@ import ( "github.com/lightningnetwork/lnd/aliasmgr" "github.com/lightningnetwork/lnd/autopilot" "github.com/lightningnetwork/lnd/brontide" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/chanacceptor" "github.com/lightningnetwork/lnd/chanbackup" @@ -349,6 +350,10 @@ type server struct { // txPublisher is a publisher with fee-bumping capability. txPublisher *sweep.TxPublisher + // blockbeatDispatcher is a block dispatcher that notifies subscribers + // of new blocks. + blockbeatDispatcher *chainio.BlockbeatDispatcher + quit chan struct{} wg sync.WaitGroup @@ -612,6 +617,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, readPool: readPool, chansToRestore: chansToRestore, + blockbeatDispatcher: chainio.NewBlockbeatDispatcher( + cc.ChainNotifier, + ), channelNotifier: channelnotifier.New( dbs.ChanStateDB.ChannelStateDB(), ), @@ -654,6 +662,17 @@ func newServer(cfg *Config, listenAddrs []net.Addr, quit: make(chan struct{}), } + // Start the low-level services once they are initialized. + // + // TODO(yy): break the server startup into four steps, + // 1. init the low-level services. + // 2. start the low-level services. + // 3. init the high-level services. + // 4. start the high-level services. + if err := s.startLowLevelServices(); err != nil { + return nil, err + } + currentHash, currentHeight, err := s.cc.ChainIO.GetBestBlock() if err != nil { return nil, err @@ -1787,6 +1806,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, } s.connMgr = cmgr + // Finally, register the subsystems in blockbeat. + s.registerBlockConsumers() + return s, nil } @@ -1819,6 +1841,25 @@ func (s *server) UpdateRoutingConfig(cfg *routing.MissionControlConfig) { routerCfg.MaxMcHistory = cfg.MaxMcHistory } +// registerBlockConsumers registers the subsystems that consume block events. +// By calling `RegisterQueue`, a list of subsystems are registered in the +// blockbeat for block notifications. When a new block arrives, the subsystems +// in the same queue are notified sequentially, and different queues are +// notified concurrently. +// +// NOTE: To put a subsystem in a different queue, create a slice and pass it to +// a new `RegisterQueue` call. +func (s *server) registerBlockConsumers() { + // In this queue, when a new block arrives, it will be received and + // processed in this order: chainArb -> sweeper -> txPublisher. + consumers := []chainio.Consumer{ + s.chainArb, + s.sweeper, + s.txPublisher, + } + s.blockbeatDispatcher.RegisterQueue(consumers) +} + // signAliasUpdate takes a ChannelUpdate and returns the signature. This is // used for option_scid_alias channels where the ChannelUpdate to be sent back // may differ from what is on disk. @@ -2041,12 +2082,41 @@ func (c cleaner) run() { } } +// startLowLevelServices starts the low-level services of the server. These +// services must be started successfully before running the main server. The +// services are, +// 1. the chain notifier. +// +// TODO(yy): identify and add more low-level services here. +func (s *server) startLowLevelServices() error { + var startErr error + + cleanup := cleaner{} + + cleanup = cleanup.add(s.cc.ChainNotifier.Stop) + if err := s.cc.ChainNotifier.Start(); err != nil { + startErr = err + } + + if startErr != nil { + cleanup.run() + } + + return startErr +} + // Start starts the main daemon server, all requested listeners, and any helper // goroutines. // NOTE: This function is safe for concurrent access. // //nolint:funlen func (s *server) Start() error { + // Get the current blockbeat. + beat, err := s.getStartingBeat() + if err != nil { + return err + } + var startErr error // If one sub system fails to start, the following code ensures that the @@ -2100,12 +2170,6 @@ func (s *server) Start() error { return } - cleanup = cleanup.add(s.cc.ChainNotifier.Stop) - if err := s.cc.ChainNotifier.Start(); err != nil { - startErr = err - return - } - cleanup = cleanup.add(s.cc.BestBlockTracker.Stop) if err := s.cc.BestBlockTracker.Start(); err != nil { startErr = err @@ -2141,13 +2205,13 @@ func (s *server) Start() error { } cleanup = cleanup.add(s.txPublisher.Stop) - if err := s.txPublisher.Start(); err != nil { + if err := s.txPublisher.Start(beat); err != nil { startErr = err return } cleanup = cleanup.add(s.sweeper.Stop) - if err := s.sweeper.Start(); err != nil { + if err := s.sweeper.Start(beat); err != nil { startErr = err return } @@ -2192,7 +2256,7 @@ func (s *server) Start() error { } cleanup = cleanup.add(s.chainArb.Stop) - if err := s.chainArb.Start(); err != nil { + if err := s.chainArb.Start(beat); err != nil { startErr = err return } @@ -2433,6 +2497,17 @@ func (s *server) Start() error { srvrLog.Infof("Auto peer bootstrapping is disabled") } + // Start the blockbeat after all other subsystems have been + // started so they are ready to receive new blocks. + cleanup = cleanup.add(func() error { + s.blockbeatDispatcher.Stop() + return nil + }) + if err := s.blockbeatDispatcher.Start(); err != nil { + startErr = err + return + } + // Set the active flag now that we've completed the full // startup. atomic.StoreInt32(&s.active, 1) @@ -2457,6 +2532,9 @@ func (s *server) Stop() error { // Shutdown connMgr first to prevent conns during shutdown. s.connMgr.Stop() + // Stop dispatching blocks to other systems immediately. + s.blockbeatDispatcher.Stop() + // Shutdown the wallet, funding manager, and the rpc server. if err := s.chanStatusMgr.Stop(); err != nil { srvrLog.Warnf("failed to stop chanStatusMgr: %v", err) @@ -5115,3 +5193,35 @@ func (s *server) fetchClosedChannelSCIDs() map[lnwire.ShortChannelID]struct{} { return closedSCIDs } + +// getStartingBeat returns the current beat. This is used during the startup to +// initialize blockbeat consumers. +func (s *server) getStartingBeat() (*chainio.Beat, error) { + // beat is the current blockbeat. + var beat *chainio.Beat + + // We should get a notification with the current best block immediately + // by passing a nil block. + blockEpochs, err := s.cc.ChainNotifier.RegisterBlockEpochNtfn(nil) + if err != nil { + return beat, fmt.Errorf("register block epoch ntfn: %w", err) + } + defer blockEpochs.Cancel() + + // We registered for the block epochs with a nil request. The notifier + // should send us the current best block immediately. So we need to + // wait for it here because we need to know the current best height. + select { + case bestBlock := <-blockEpochs.Epochs: + srvrLog.Infof("Received initial block %v at height %d", + bestBlock.Hash, bestBlock.Height) + + // Update the current blockbeat. + beat = chainio.NewBeat(*bestBlock) + + case <-s.quit: + srvrLog.Debug("LND shutting down") + } + + return beat, nil +} diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index 8731c6b7ad8..115ae15bdfd 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/chain" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" @@ -344,6 +345,10 @@ type TxPublisher struct { started atomic.Bool stopped atomic.Bool + // Embed the blockbeat consumer struct to get access to the method + // `NotifyBlockProcessed` and the `BlockbeatChan`. + chainio.BeatConsumer + wg sync.WaitGroup // cfg specifies the configuration of the TxPublisher. @@ -371,14 +376,22 @@ type TxPublisher struct { // Compile-time constraint to ensure TxPublisher implements Bumper. var _ Bumper = (*TxPublisher)(nil) +// Compile-time check for the chainio.Consumer interface. +var _ chainio.Consumer = (*TxPublisher)(nil) + // NewTxPublisher creates a new TxPublisher. func NewTxPublisher(cfg TxPublisherConfig) *TxPublisher { - return &TxPublisher{ + tp := &TxPublisher{ cfg: &cfg, records: lnutils.SyncMap[uint64, *monitorRecord]{}, subscriberChans: lnutils.SyncMap[uint64, chan *BumpResult]{}, quit: make(chan struct{}), } + + // Mount the block consumer. + tp.BeatConsumer = chainio.NewBeatConsumer(tp.quit, tp.Name()) + + return tp } // isNeutrinoBackend checks if the wallet backend is neutrino. @@ -427,6 +440,11 @@ func (t *TxPublisher) storeInitialRecord(req *BumpRequest) ( return requestID, record } +// NOTE: part of the `chainio.Consumer` interface. +func (t *TxPublisher) Name() string { + return "TxPublisher" +} + // initializeTx initializes a fee function and creates an RBF-compliant tx. If // succeeded, the initial tx is stored in the records map. func (t *TxPublisher) initializeTx(requestID uint64, req *BumpRequest) error { @@ -777,20 +795,18 @@ type monitorRecord struct { // Start starts the publisher by subscribing to block epoch updates and kicking // off the monitor loop. -func (t *TxPublisher) Start() error { +func (t *TxPublisher) Start(beat chainio.Blockbeat) error { log.Info("TxPublisher starting...") if t.started.Swap(true) { return fmt.Errorf("TxPublisher started more than once") } - blockEvent, err := t.cfg.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return fmt.Errorf("register block epoch ntfn: %w", err) - } + // Set the current height. + t.currentHeight.Store(beat.Height()) t.wg.Add(1) - go t.monitor(blockEvent) + go t.monitor() log.Debugf("TxPublisher started") @@ -818,33 +834,25 @@ func (t *TxPublisher) Stop() error { // to be bumped. If so, it will attempt to bump the fee of the tx. // // NOTE: Must be run as a goroutine. -func (t *TxPublisher) monitor(blockEvent *chainntnfs.BlockEpochEvent) { - defer blockEvent.Cancel() +func (t *TxPublisher) monitor() { defer t.wg.Done() for { select { - case epoch, ok := <-blockEvent.Epochs: - if !ok { - // We should stop the publisher before stopping - // the chain service. Otherwise it indicates an - // error. - log.Error("Block epoch channel closed, exit " + - "monitor") - - return - } - - log.Debugf("TxPublisher received new block: %v", - epoch.Height) + case beat := <-t.BlockbeatChan: + height := beat.Height() + log.Debugf("TxPublisher received new block: %v", height) // Update the best known height for the publisher. - t.currentHeight.Store(epoch.Height) + t.currentHeight.Store(height) // Check all monitored txns to see if any of them needs // to be bumped. t.processRecords() + // Notify we've processed the block. + t.NotifyBlockProcessed(beat, nil) + case <-t.quit: log.Debug("Fee bumper stopped, exit monitor") return diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 16fb81dedbd..500fcdc2dfe 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" @@ -241,8 +242,9 @@ func (p *SweeperInput) isMature(currentHeight uint32) (bool, uint32) { // currentHeight plus one. locktime = p.BlocksToMaturity() + p.HeightHint() if currentHeight+1 < locktime { - log.Debugf("Input %v has CSV expiry=%v, current height is %v", - p.OutPoint(), locktime, currentHeight) + log.Debugf("Input %v has CSV expiry=%v, current height is %v, "+ + "skipped sweeping", p.OutPoint(), locktime, + currentHeight) return false, locktime } @@ -308,6 +310,10 @@ type UtxoSweeper struct { started uint32 // To be used atomically. stopped uint32 // To be used atomically. + // Embed the blockbeat consumer struct to get access to the method + // `NotifyBlockProcessed` and the `BlockbeatChan`. + chainio.BeatConsumer + cfg *UtxoSweeperConfig newInputs chan *sweepInputMessage @@ -342,6 +348,9 @@ type UtxoSweeper struct { bumpRespChan chan *bumpResp } +// Compile-time check for the chainio.Consumer interface. +var _ chainio.Consumer = (*UtxoSweeper)(nil) + // UtxoSweeperConfig contains dependencies of UtxoSweeper. type UtxoSweeperConfig struct { // GenSweepScript generates a P2WKH script belonging to the wallet where @@ -415,7 +424,7 @@ type sweepInputMessage struct { // New returns a new Sweeper instance. func New(cfg *UtxoSweeperConfig) *UtxoSweeper { - return &UtxoSweeper{ + s := &UtxoSweeper{ cfg: cfg, newInputs: make(chan *sweepInputMessage), spendChan: make(chan *chainntnfs.SpendDetail), @@ -425,10 +434,15 @@ func New(cfg *UtxoSweeperConfig) *UtxoSweeper { inputs: make(InputsMap), bumpRespChan: make(chan *bumpResp, 100), } + + // Mount the block consumer. + s.BeatConsumer = chainio.NewBeatConsumer(s.quit, s.Name()) + + return s } // Start starts the process of constructing and publish sweep txes. -func (s *UtxoSweeper) Start() error { +func (s *UtxoSweeper) Start(beat chainio.Blockbeat) error { if !atomic.CompareAndSwapUint32(&s.started, 0, 1) { return nil } @@ -439,49 +453,12 @@ func (s *UtxoSweeper) Start() error { // not change from here on. s.relayFeeRate = s.cfg.FeeEstimator.RelayFeePerKW() - // We need to register for block epochs and retry sweeping every block. - // We should get a notification with the current best block immediately - // if we don't provide any epoch. We'll wait for that in the collector. - blockEpochs, err := s.cfg.Notifier.RegisterBlockEpochNtfn(nil) - if err != nil { - return fmt.Errorf("register block epoch ntfn: %w", err) - } + // Set the current height. + s.currentHeight = beat.Height() // Start sweeper main loop. s.wg.Add(1) - go func() { - defer blockEpochs.Cancel() - defer s.wg.Done() - - s.collector(blockEpochs.Epochs) - - // The collector exited and won't longer handle incoming - // requests. This can happen on shutdown, when the block - // notifier shuts down before the sweeper and its clients. In - // order to not deadlock the clients waiting for their requests - // being handled, we handle them here and immediately return an - // error. When the sweeper finally is shut down we can exit as - // the clients will be notified. - for { - select { - case inp := <-s.newInputs: - inp.resultChan <- Result{ - Err: ErrSweeperShuttingDown, - } - - case req := <-s.pendingSweepsReqs: - req.errChan <- ErrSweeperShuttingDown - - case req := <-s.updateReqs: - req.responseChan <- &updateResp{ - err: ErrSweeperShuttingDown, - } - - case <-s.quit: - return - } - } - }() + go s.collector() return nil } @@ -508,6 +485,11 @@ func (s *UtxoSweeper) Stop() error { return nil } +// NOTE: part of the `chainio.Consumer` interface. +func (s *UtxoSweeper) Name() string { + return "UtxoSweeper" +} + // SweepInput sweeps inputs back into the wallet. The inputs will be batched and // swept after the batch time window ends. A custom fee preference can be // provided to determine what fee rate should be used for the input. Note that @@ -639,17 +621,8 @@ func (s *UtxoSweeper) removeConflictSweepDescendants( // collector is the sweeper main loop. It processes new inputs, spend // notifications and counts down to publication of the sweep tx. -func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { - // We registered for the block epochs with a nil request. The notifier - // should send us the current best block immediately. So we need to wait - // for it here because we need to know the current best height. - select { - case bestBlock := <-blockEpochs: - s.currentHeight = bestBlock.Height - - case <-s.quit: - return - } +func (s *UtxoSweeper) collector() { + defer s.wg.Done() for { // Clean inputs, which will remove inputs that are swept, @@ -719,25 +692,16 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { // A new block comes in, update the bestHeight, perform a check // over all pending inputs and publish sweeping txns if needed. - case epoch, ok := <-blockEpochs: - if !ok { - // We should stop the sweeper before stopping - // the chain service. Otherwise it indicates an - // error. - log.Error("Block epoch channel closed") - - return - } - + case beat := <-s.BlockbeatChan: // Update the sweeper to the best height. - s.currentHeight = epoch.Height + s.currentHeight = beat.Height() // Update the inputs with the latest height. inputs := s.updateSweeperInputs() log.Debugf("Received new block: height=%v, attempt "+ "sweeping %d inputs:\n%s", - epoch.Height, len(inputs), + s.currentHeight, len(inputs), lnutils.NewLogClosure(func() string { inps := make( []input.Input, 0, len(inputs), @@ -752,6 +716,9 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { // Attempt to sweep any pending inputs. s.sweepPendingInputs(inputs) + // Notify we've processed the block. + s.NotifyBlockProcessed(beat, nil) + case <-s.quit: return } @@ -1231,8 +1198,8 @@ func (s *UtxoSweeper) calculateDefaultDeadline(pi *SweeperInput) int32 { if !matured { defaultDeadline = int32(locktime + s.cfg.NoDeadlineConfTarget) log.Debugf("Input %v is immature, using locktime=%v instead "+ - "of current height=%d", pi.OutPoint(), locktime, - s.currentHeight) + "of current height=%d as starting height", + pi.OutPoint(), locktime, s.currentHeight) } return defaultDeadline @@ -1244,7 +1211,8 @@ func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) error { outpoint := input.input.OutPoint() pi, pending := s.inputs[outpoint] if pending { - log.Debugf("Already has pending input %v received", outpoint) + log.Infof("Already has pending input %v received, old params: "+ + "%v, new params %v", outpoint, pi.params, input.params) s.handleExistingInput(input, pi) @@ -1526,6 +1494,8 @@ func (s *UtxoSweeper) updateSweeperInputs() InputsMap { // turn this inputs map into a SyncMap in case we wanna add concurrent // access to the map in the future. for op, input := range s.inputs { + log.Tracef("Checking input: %s, state=%v", input, input.state) + // If the input has reached a final state, that it's either // been swept, or failed, or excluded, we will remove it from // our sweeper. @@ -1555,7 +1525,7 @@ func (s *UtxoSweeper) updateSweeperInputs() InputsMap { // skip this input and wait for the locktime to be reached. mature, locktime := input.isMature(uint32(s.currentHeight)) if !mature { - log.Infof("Skipping input %v due to locktime=%v not "+ + log.Debugf("Skipping input %v due to locktime=%v not "+ "reached, current height is %v", op, locktime, s.currentHeight)