Skip to content
13 changes: 13 additions & 0 deletions breacharbiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,19 @@ func (bo *breachedOutput) OutPoint() *wire.OutPoint {
return &bo.outpoint
}

// RequiredTxOut returns a non-nil TxOut if input commits to a certain
// transaction output. This is used in the SINGLE|ANYONECANPAY case to make
// sure any presigned input is still valid by including the output.
func (bo *breachedOutput) RequiredTxOut() *wire.TxOut {
Comment thread
Roasbeef marked this conversation as resolved.
Outdated
return nil
}

// RequiredLockTime returns whether this input commits to a tx locktime that
// must be used in the transaction including it.
func (bo *breachedOutput) RequiredLockTime() (uint32, bool) {
return 0, false
}

// WitnessType returns the type of witness that must be generated to spend the
// breached output.
func (bo *breachedOutput) WitnessType() input.WitnessType {
Expand Down
22 changes: 22 additions & 0 deletions input/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ type Input interface {
// construct the corresponding transaction input.
OutPoint() *wire.OutPoint

// RequiredTxOut returns a non-nil TxOut if input commits to a certain
// transaction output. This is used in the SINGLE|ANYONECANPAY case to
// make sure any presigned input is still valid by including the
// output.
RequiredTxOut() *wire.TxOut

// RequiredLockTime returns whether this input commits to a tx locktime
// that must be used in the transaction including it.
RequiredLockTime() (uint32, bool)

// WitnessType returns an enum specifying the type of witness that must
// be generated in order to spend this output.
WitnessType() WitnessType
Expand Down Expand Up @@ -75,6 +85,18 @@ func (i *inputKit) OutPoint() *wire.OutPoint {
return &i.outpoint
}

// RequiredTxOut returns a nil for the base input type.
func (i *inputKit) RequiredTxOut() *wire.TxOut {
return nil
}

// RequiredLockTime returns whether this input commits to a tx locktime that
// must be used in the transaction including it. This will be false for the
// base input type since we can re-sign for any lock time.
func (i *inputKit) RequiredLockTime() (uint32, bool) {
return 0, false
}

// WitnessType returns the type of witness that must be generated to spend the
// breached output.
func (i *inputKit) WitnessType() WitnessType {
Expand Down
8 changes: 8 additions & 0 deletions input/size.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,14 @@ func (twe *TxWeightEstimator) AddNestedP2WSHInput(witnessSize int) *TxWeightEsti
return twe
}

// AddTxOutput adds a known TxOut to the weight estimator.
func (twe *TxWeightEstimator) AddTxOutput(txOut *wire.TxOut) *TxWeightEstimator {
twe.outputSize += txOut.SerializeSize()
Comment thread
Roasbeef marked this conversation as resolved.
Outdated
twe.outputCount++

return twe
}

// AddP2PKHOutput updates the weight estimate to account for an additional P2PKH
// output.
func (twe *TxWeightEstimator) AddP2PKHOutput() *TxWeightEstimator {
Expand Down
3 changes: 2 additions & 1 deletion rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1254,7 +1254,8 @@ func (r *rpcServer) SendCoins(ctx context.Context,
// single transaction. This will be generated in a concurrent
// safe manner, so no need to worry about locking.
sweepTxPkg, err := sweep.CraftSweepAllTx(
feePerKw, uint32(bestHeight), targetAddr, wallet,
feePerKw, lnwallet.DefaultDustLimit(),
uint32(bestHeight), targetAddr, wallet,
wallet.WalletController, wallet.WalletController,
r.server.cc.FeeEstimator, r.server.cc.Signer,
)
Expand Down
11 changes: 11 additions & 0 deletions sweep/backend_mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type mockBackend struct {
publishChan chan wire.MsgTx

walletUtxos []*lnwallet.Utxo
utxoCnt int
}

func newMockBackend(t *testing.T, notifier *MockNotifier) *mockBackend {
Expand Down Expand Up @@ -88,6 +89,16 @@ func (b *mockBackend) PublishTransaction(tx *wire.MsgTx, _ string) error {

func (b *mockBackend) ListUnspentWitness(minconfirms, maxconfirms int32) (
[]*lnwallet.Utxo, error) {
b.lock.Lock()
defer b.lock.Unlock()

// Each time we list output, we increment the utxo counter, to
// ensure we don't return the same outpoint every time.
b.utxoCnt++

for i := range b.walletUtxos {
b.walletUtxos[i].OutPoint.Hash[0] = byte(b.utxoCnt)
}

return b.walletUtxos, nil
}
Expand Down
190 changes: 184 additions & 6 deletions sweep/sweeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ type pendingInputs = map[wire.OutPoint]*pendingInput
// inputCluster is a helper struct to gather a set of pending inputs that should
// be swept with the specified fee rate.
type inputCluster struct {
lockTime *uint32
sweepFeeRate chainfee.SatPerKWeight
inputs pendingInputs
}
Expand Down Expand Up @@ -647,7 +648,7 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) {
// this to ensure any inputs which have had their fee
// rate bumped are broadcast first in order enforce the
// RBF policy.
inputClusters := s.clusterBySweepFeeRate()
inputClusters := s.createInputClusters()
sort.Slice(inputClusters, func(i, j int) bool {
return inputClusters[i].sweepFeeRate >
inputClusters[j].sweepFeeRate
Expand Down Expand Up @@ -750,17 +751,100 @@ func (s *UtxoSweeper) bucketForFeeRate(
return 1 + int(feeRate-s.relayFeeRate)/s.cfg.FeeRateBucketSize
}

// createInputClusters creates a list of input clusters from the set of pending
// inputs known by the UtxoSweeper. It clusters inputs by
// 1) Required tx locktime
// 2) Similar fee rates
func (s *UtxoSweeper) createInputClusters() []inputCluster {
inputs := s.pendingInputs

// We start by getting the inputs clusters by locktime. Since the
// inputs commit to the locktime, they can only be clustered together
// if the locktime is equal.
lockTimeClusters, nonLockTimeInputs := s.clusterByLockTime(inputs)

// Cluster the the remaining inputs by sweep fee rate.
feeClusters := s.clusterBySweepFeeRate(nonLockTimeInputs)

// Since the inputs that we clustered by fee rate don't commit to a
// specific locktime, we can try to merge a locktime cluster with a fee
// cluster.
return zipClusters(lockTimeClusters, feeClusters)
}

// clusterByLockTime takes the given set of pending inputs and clusters those
// with equal locktime together. Each cluster contains a sweep fee rate, which
// is determined by calculating the average fee rate of all inputs within that
// cluster. In addition to the created clusters, inputs that did not specify a
// required lock time are returned.
func (s *UtxoSweeper) clusterByLockTime(inputs pendingInputs) ([]inputCluster,
pendingInputs) {

locktimes := make(map[uint32]pendingInputs)
inputFeeRates := make(map[wire.OutPoint]chainfee.SatPerKWeight)
rem := make(pendingInputs)

// Go through all inputs and check if they require a certain locktime.
for op, input := range inputs {
lt, ok := input.RequiredLockTime()
if !ok {
rem[op] = input
continue
}

// Check if we already have inputs with this locktime.
p, ok := locktimes[lt]
if !ok {
p = make(pendingInputs)
}

p[op] = input
locktimes[lt] = p

// We also get the preferred fee rate for this input.
feeRate, err := s.feeRateForPreference(input.params.Fee)
if err != nil {
log.Warnf("Skipping input %v: %v", op, err)
continue
}

input.lastFeeRate = feeRate
Comment thread
Roasbeef marked this conversation as resolved.
Outdated
inputFeeRates[op] = feeRate
}

// We'll then determine the sweep fee rate for each set of inputs by
// calculating the average fee rate of the inputs within each set.
inputClusters := make([]inputCluster, 0, len(locktimes))
for lt, inputs := range locktimes {
lt := lt

var sweepFeeRate chainfee.SatPerKWeight
for op := range inputs {
sweepFeeRate += inputFeeRates[op]
}

sweepFeeRate /= chainfee.SatPerKWeight(len(inputs))
Comment thread
Roasbeef marked this conversation as resolved.
Outdated
inputClusters = append(inputClusters, inputCluster{
lockTime: &lt,
sweepFeeRate: sweepFeeRate,
inputs: inputs,
})
}

return inputClusters, rem
}

// clusterBySweepFeeRate takes the set of pending inputs within the UtxoSweeper
// and clusters those together with similar fee rates. Each cluster contains a
// sweep fee rate, which is determined by calculating the average fee rate of
// all inputs within that cluster.
func (s *UtxoSweeper) clusterBySweepFeeRate() []inputCluster {
func (s *UtxoSweeper) clusterBySweepFeeRate(inputs pendingInputs) []inputCluster {
bucketInputs := make(map[int]*bucketList)
inputFeeRates := make(map[wire.OutPoint]chainfee.SatPerKWeight)

// First, we'll group together all inputs with similar fee rates. This
// is done by determining the fee rate bucket they should belong in.
for op, input := range s.pendingInputs {
for op, input := range inputs {
feeRate, err := s.feeRateForPreference(input.params.Fee)
if err != nil {
log.Warnf("Skipping input %v: %v", op, err)
Expand Down Expand Up @@ -824,6 +908,99 @@ func (s *UtxoSweeper) clusterBySweepFeeRate() []inputCluster {
return inputClusters
}

// zipClusters merges pairwise clusters from as and bs such that cluster a from
// as is merged with a cluster from bs that has at least the fee rate of a.
// This to ensure we don't delay confirmation by decreasing the fee rate (the
// lock time inputs are typically second level HTLC transactions, that are time
// sensitive).
func zipClusters(as, bs []inputCluster) []inputCluster {
// Sort the clusters by decreasing fee rates.
sort.Slice(as, func(i, j int) bool {
return as[i].sweepFeeRate >
as[j].sweepFeeRate
})
sort.Slice(bs, func(i, j int) bool {
return bs[i].sweepFeeRate >
bs[j].sweepFeeRate
})

var (
finalClusters []inputCluster
j int
)

// Go through each cluster in as, and merge with the next one from bs
// if it has at least the fee rate needed.
for i := range as {
a := as[i]

switch {

// If the fee rate for the next one from bs is at least a's, we
// merge.
case j < len(bs) && bs[j].sweepFeeRate >= a.sweepFeeRate:
merged := mergeClusters(a, bs[j])
finalClusters = append(finalClusters, merged...)

// Increment j for the next round.
j++

// We did not merge, meaning all the remining clusters from bs
// have lower fee rate. Instead we add a directly to the final
// clusters.
default:
finalClusters = append(finalClusters, a)
}
}

// Add any remaining clusters from bs.
for ; j < len(bs); j++ {
b := bs[j]
finalClusters = append(finalClusters, b)
}

return finalClusters
}

// mergeClusters attempts to merge cluster a and b if they are compatible. The
// new cluster will have the locktime set if a or b had a locktime set, and a
// sweep fee rate that is the maximum of a and b's. If the two clusters are not
// compatible, they will be returned unchanged.
func mergeClusters(a, b inputCluster) []inputCluster {
newCluster := inputCluster{}

switch {

// Incompatible locktimes, return the sets without merging them.
case a.lockTime != nil && b.lockTime != nil && *a.lockTime != *b.lockTime:
return []inputCluster{a, b}

case a.lockTime != nil:
newCluster.lockTime = a.lockTime

case b.lockTime != nil:
newCluster.lockTime = b.lockTime
}

if a.sweepFeeRate > b.sweepFeeRate {
newCluster.sweepFeeRate = a.sweepFeeRate
} else {
newCluster.sweepFeeRate = b.sweepFeeRate
}

newCluster.inputs = make(pendingInputs)

for op, in := range a.inputs {
newCluster.inputs[op] = in
}

for op, in := range b.inputs {
newCluster.inputs[op] = in
}

return []inputCluster{newCluster}
}

// scheduleSweep starts the sweep timer to create an opportunity for more inputs
// to be added.
func (s *UtxoSweeper) scheduleSweep(currentHeight int32) error {
Expand All @@ -836,7 +1013,7 @@ func (s *UtxoSweeper) scheduleSweep(currentHeight int32) error {

// We'll only start our timer once we have inputs we're able to sweep.
startTimer := false
for _, cluster := range s.clusterBySweepFeeRate() {
for _, cluster := range s.createInputClusters() {
// Examine pending inputs and try to construct lists of inputs.
// We don't need to obtain the coin selection lock, because we
// just need an indication as to whether we can sweep. More
Expand Down Expand Up @@ -988,7 +1165,7 @@ func (s *UtxoSweeper) sweep(inputs inputSet, feeRate chainfee.SatPerKWeight,
// Create sweep tx.
tx, err := createSweepTx(
inputs, s.currentOutputScript, uint32(currentHeight), feeRate,
s.cfg.Signer,
dustLimit(s.relayFeeRate), s.cfg.Signer,
)
if err != nil {
return fmt.Errorf("create sweep tx: %v", err)
Expand Down Expand Up @@ -1278,7 +1455,8 @@ func (s *UtxoSweeper) CreateSweepTx(inputs []input.Input, feePref FeePreference,
}

return createSweepTx(
inputs, pkScript, currentBlockHeight, feePerKw, s.cfg.Signer,
inputs, pkScript, currentBlockHeight, feePerKw,
dustLimit(s.relayFeeRate), s.cfg.Signer,
)
}

Expand Down
Loading