From f0dc5d23d2509f44028c0ec2bbb2e9f0c151454f Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Mon, 15 Oct 2018 14:14:37 +0200 Subject: [PATCH] htlcswitch: settle invoices behind switch --- htlcswitch/invoice_settler.go | 243 ++++++++++++++++++++ htlcswitch/link.go | 414 +++++++++------------------------- htlcswitch/link_test.go | 5 +- htlcswitch/mock.go | 11 +- htlcswitch/switch.go | 81 ++++--- htlcswitch/switch_test.go | 38 ++-- 6 files changed, 424 insertions(+), 368 deletions(-) create mode 100644 htlcswitch/invoice_settler.go diff --git a/htlcswitch/invoice_settler.go b/htlcswitch/invoice_settler.go new file mode 100644 index 00000000000..7c5c465ece3 --- /dev/null +++ b/htlcswitch/invoice_settler.go @@ -0,0 +1,243 @@ +package htlcswitch + +import ( + "github.com/lightningnetwork/lnd/contractcourt" +) + +// InvoiceSettler handles settling and failing of invoices. +type InvoiceSettler struct { + // Registry is a sub-system which responsible for managing the invoices + // in thread-safe manner. + Registry InvoiceDatabase + + // ResolutionMsgs is the channel that the switch is listening to for + // invoice resolutions. InvoiceSettler isn't calling switch directly + // because this would create a circular dependency. + ResolutionMsgs chan resolutionMsg +} + +// NewInvoiceSettler returns a new invoice settler instance. +func NewInvoiceSettler(Registry InvoiceDatabase) *InvoiceSettler { + return &InvoiceSettler{ + Registry: Registry, + ResolutionMsgs: make(chan resolutionMsg), + } +} + +// Settle settles the (hold) invoice corresponding to the given preimage. +// TO BE IMPLEMENTED. +func (i *InvoiceSettler) Settle(preimage [32]byte) error { + + return nil +} + +// Fail fails the (hold) invoice corresponding to the given hash. +// TO BE IMPLEMENTED. +func (i *InvoiceSettler) Fail(hash []byte) error { + return nil +} + +// handleIncoming is called from switch when a htlc comes in for which we are +// the exit hop. +func (i *InvoiceSettler) handleIncoming(pkt *htlcPacket) error { + // We're the designated payment destination. Therefore + // we attempt to see if we have an invoice locally + // which'll allow us to settle this htlc. + invoiceHash := pkt.circuit.PaymentHash + invoice, _, err := i.Registry.LookupInvoice( + invoiceHash, + ) + + // TODO: Return errors below synchronously or use async flow for + // all cases? + + /*if err != nil { + log.Errorf("unable to query invoice registry: "+ + " %v", err) + failure := lnwire.FailUnknownPaymentHash{} + + i.ResolutionMsgs <- contractcourt.ResolutionMsg{ + SourceChan: exitHop, + Failure: failure, + } + l.sendHTLCError( + pd.HtlcIndex, failure, obfuscator, pd.SourceRef, + ) + + needUpdate = true + return + }*/ + + // If the invoice is already settled, we choose to + // accept the payment to simplify failure recovery. + // + // NOTE: Though our recovery and forwarding logic is + // predominately batched, settling invoices happens + // iteratively. We may reject one of two payments + // for the same rhash at first, but then restart and + // reject both after seeing that the invoice has been + // settled. Without any record of which one settles + // first, it is ambiguous as to which one actually + // settled the invoice. Thus, by accepting all + // payments, we eliminate the race condition that can + // lead to this inconsistency. + // + // TODO(conner): track ownership of settlements to + // properly recover from failures? or add batch invoice + // settlement + + /*if invoice.Terms.Settled { + log.Warnf("Accepting duplicate payment for "+ + "hash=%x", invoiceHash) + }*/ + + // If we're not currently in debug mode, and the + // extended htlc doesn't meet the value requested, then + // we'll fail the htlc. Otherwise, we settle this htlc + // within our local state update log, then send the + // update entry to the remote party. + // + // NOTE: We make an exception when the value requested + // by the invoice is zero. This means the invoice + // allows the payee to specify the amount of satoshis + // they wish to send. So since we expect the htlc to + // have a different amount, we should not fail. + /*if !l.cfg.DebugHTLC && invoice.Terms.Value > 0 && + pd.Amount < invoice.Terms.Value { + + log.Errorf("rejecting htlc due to incorrect "+ + "amount: expected %v, received %v", + invoice.Terms.Value, pd.Amount) + + failure := lnwire.FailIncorrectPaymentAmount{} + l.sendHTLCError( + pd.HtlcIndex, failure, obfuscator, pd.SourceRef, + ) + + needUpdate = true + return + }*/ + + // As we're the exit hop, we'll double check the + // hop-payload included in the HTLC to ensure that it + // was crafted correctly by the sender and matches the + // HTLC we were extended. + // + // NOTE: We make an exception when the value requested + // by the invoice is zero. This means the invoice + // allows the payee to specify the amount of satoshis + // they wish to send. So since we expect the htlc to + // have a different amount, we should not fail. + /*if !l.cfg.DebugHTLC && invoice.Terms.Value > 0 && + fwdInfo.AmountToForward < invoice.Terms.Value { + + log.Errorf("Onion payload of incoming htlc(%x) "+ + "has incorrect value: expected %v, "+ + "got %v", pd.RHash, invoice.Terms.Value, + fwdInfo.AmountToForward) + + failure := lnwire.FailIncorrectPaymentAmount{} + l.sendHTLCError( + pd.HtlcIndex, failure, obfuscator, pd.SourceRef, + ) + + needUpdate = true + return + }*/ + + // We'll also ensure that our time-lock value has been + // computed correctly. + + /*expectedHeight := heightNow + minCltvDelta + switch { + + case !l.cfg.DebugHTLC && fwdInfo.OutgoingCTLV < expectedHeight: + log.Errorf("Onion payload of incoming "+ + "htlc(%x) has incorrect time-lock: "+ + "expected %v, got %v", + pd.RHash[:], expectedHeight, + fwdInfo.OutgoingCTLV) + + failure := lnwire.NewFinalIncorrectCltvExpiry( + fwdInfo.OutgoingCTLV, + ) + l.sendHTLCError( + pd.HtlcIndex, failure, obfuscator, pd.SourceRef, + ) + + needUpdate = true + return + + case !l.cfg.DebugHTLC && pd.Timeout != fwdInfo.OutgoingCTLV: + log.Errorf("HTLC(%x) has incorrect "+ + "time-lock: expected %v, got %v", + pd.RHash[:], pd.Timeout, + fwdInfo.OutgoingCTLV) + + failure := lnwire.NewFinalIncorrectCltvExpiry( + fwdInfo.OutgoingCTLV, + ) + l.sendHTLCError( + pd.HtlcIndex, failure, obfuscator, pd.SourceRef, + ) + + needUpdate = true + return + } + */ + preimage := invoice.Terms.PaymentPreimage + + // TODO: Mark the invoice as accepted here + + // Execute sending resolution message in a go routine to prevent + // deadlock. Eventually InvoiceSettler may need its own main loop to + // receive events from the switch and rpcserver. + // + // Resolution is only possible when the preimage is known. Otherwise do + // nothing yet and wait for InvoiceSettler.Settle to be called with the + // preimage. + go func() { + done := make(chan struct{}) + + // TODO: This does not work, because the switch cannot look up + // the incoming channel. Outgoing HtlcIndex hasn't been + // committed to the circuit map. + i.ResolutionMsgs <- resolutionMsg{ + ResolutionMsg: contractcourt.ResolutionMsg{ + SourceChan: exitHop, + HtlcIndex: invoice.AddIndex, + PreImage: &preimage, + }, + doneChan: done, + } + + <-done + + // Notify the invoiceRegistry of the invoices we just + // settled (with the amount accepted at settle time) + // with this latest commitment update. + err = i.Registry.SettleInvoice( + invoiceHash, pkt.incomingAmount, + ) + if err != nil { + log.Errorf("unable to settle invoice: %v", err) + } + + log.Infof("settling %x as exit hop", invoiceHash) + + // If the link is in hodl.BogusSettle mode, replace the + // preimage with a fake one before sending it to the + // peer. + // + // TODO: This isn't the place anymore? + + /*if l.cfg.DebugHTLC && + l.cfg.HodlMask.Active(hodl.BogusSettle) { + l.warnf(hodl.BogusSettle.Warning()) + preimage = [32]byte{} + copy(preimage[:], bytes.Repeat([]byte{2}, 32)) + }*/ + }() + + return nil +} diff --git a/htlcswitch/link.go b/htlcswitch/link.go index cf6c1a34491..80874ecc32a 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -9,7 +9,6 @@ import ( "sync/atomic" "time" - "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" @@ -2253,340 +2252,131 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, continue } - heightNow := l.cfg.Switch.BestHeight() - fwdInfo := chanIterator.ForwardingInstructions() - switch fwdInfo.NextHop { - case exitHop: - // If hodl.ExitSettle is requested, we will not validate - // the final hop's ADD, nor will we settle the - // corresponding invoice or respond with the preimage. - if l.cfg.DebugHTLC && - l.cfg.HodlMask.Active(hodl.ExitSettle) { - l.warnf(hodl.ExitSettle.Warning()) - continue - } - - // First, we'll check the expiry of the HTLC itself - // against, the current block height. If the timeout is - // too soon, then we'll reject the HTLC. - if pd.Timeout-expiryGraceDelta <= heightNow { - log.Errorf("htlc(%x) has an expiry that's too "+ - "soon: expiry=%v, best_height=%v", - pd.RHash[:], pd.Timeout, heightNow) - - failure := lnwire.FailFinalExpiryTooSoon{} - l.sendHTLCError( - pd.HtlcIndex, &failure, obfuscator, pd.SourceRef, - ) - needUpdate = true - continue - } - // We're the designated payment destination. Therefore - // we attempt to see if we have an invoice locally - // which'll allow us to settle this htlc. - invoiceHash := chainhash.Hash(pd.RHash) - invoice, minCltvDelta, err := l.cfg.Registry.LookupInvoice( - invoiceHash, - ) - if err != nil { - log.Errorf("unable to query invoice registry: "+ - " %v", err) - failure := lnwire.FailUnknownPaymentHash{} - l.sendHTLCError( - pd.HtlcIndex, failure, obfuscator, pd.SourceRef, - ) + // If hodl.AddIncoming is requested, we will not + // validate the forwarded ADD, nor will we send the + // packet to the htlc switch. + if l.cfg.DebugHTLC && + l.cfg.HodlMask.Active(hodl.AddIncoming) { + l.warnf(hodl.AddIncoming.Warning()) + continue + } - needUpdate = true - continue + switch fwdPkg.State { + case channeldb.FwdStateProcessed: + // This add was not forwarded on the previous + // processing phase, run it through our + // validation pipeline to reproduce an error. + // This may trigger a different error due to + // expiring timelocks, but we expect that an + // error will be reproduced. + if !fwdPkg.FwdFilter.Contains(idx) { + break } - // If the invoice is already settled, we choose to - // accept the payment to simplify failure recovery. - // - // NOTE: Though our recovery and forwarding logic is - // predominately batched, settling invoices happens - // iteratively. We may reject one of two payments - // for the same rhash at first, but then restart and - // reject both after seeing that the invoice has been - // settled. Without any record of which one settles - // first, it is ambiguous as to which one actually - // settled the invoice. Thus, by accepting all - // payments, we eliminate the race condition that can - // lead to this inconsistency. - // - // TODO(conner): track ownership of settlements to - // properly recover from failures? or add batch invoice - // settlement - if invoice.Terms.Settled { - log.Warnf("Accepting duplicate payment for "+ - "hash=%x", pd.RHash[:]) + // Otherwise, it was already processed, we can + // can collect it and continue. + addMsg := &lnwire.UpdateAddHTLC{ + Expiry: fwdInfo.OutgoingCTLV, + Amount: fwdInfo.AmountToForward, + PaymentHash: pd.RHash, } - // If we're not currently in debug mode, and the - // extended htlc doesn't meet the value requested, then - // we'll fail the htlc. Otherwise, we settle this htlc - // within our local state update log, then send the - // update entry to the remote party. - // - // NOTE: We make an exception when the value requested - // by the invoice is zero. This means the invoice - // allows the payee to specify the amount of satoshis - // they wish to send. So since we expect the htlc to - // have a different amount, we should not fail. - if !l.cfg.DebugHTLC && invoice.Terms.Value > 0 && - pd.Amount < invoice.Terms.Value { - - log.Errorf("rejecting htlc due to incorrect "+ - "amount: expected %v, received %v", - invoice.Terms.Value, pd.Amount) - - failure := lnwire.FailIncorrectPaymentAmount{} - l.sendHTLCError( - pd.HtlcIndex, failure, obfuscator, pd.SourceRef, - ) + // Finally, we'll encode the onion packet for + // the _next_ hop using the hop iterator + // decoded for the current hop. + buf := bytes.NewBuffer(addMsg.OnionBlob[0:0]) - needUpdate = true - continue + // We know this cannot fail, as this ADD + // was marked forwarded in a previous + // round of processing. + chanIterator.EncodeNextHop(buf) + + updatePacket := &htlcPacket{ + incomingChanID: l.ShortChanID(), + incomingHTLCID: pd.HtlcIndex, + outgoingChanID: fwdInfo.NextHop, + sourceRef: pd.SourceRef, + incomingAmount: pd.Amount, + amount: addMsg.Amount, + htlc: addMsg, + obfuscator: obfuscator, + incomingTimeout: pd.Timeout, + outgoingTimeout: fwdInfo.OutgoingCTLV, } + switchPackets = append( + switchPackets, updatePacket, + ) - // As we're the exit hop, we'll double check the - // hop-payload included in the HTLC to ensure that it - // was crafted correctly by the sender and matches the - // HTLC we were extended. - // - // NOTE: We make an exception when the value requested - // by the invoice is zero. This means the invoice - // allows the payee to specify the amount of satoshis - // they wish to send. So since we expect the htlc to - // have a different amount, we should not fail. - if !l.cfg.DebugHTLC && invoice.Terms.Value > 0 && - fwdInfo.AmountToForward < invoice.Terms.Value { - - log.Errorf("Onion payload of incoming htlc(%x) "+ - "has incorrect value: expected %v, "+ - "got %v", pd.RHash, invoice.Terms.Value, - fwdInfo.AmountToForward) - - failure := lnwire.FailIncorrectPaymentAmount{} - l.sendHTLCError( - pd.HtlcIndex, failure, obfuscator, pd.SourceRef, - ) - - needUpdate = true - continue - } - - // We'll also ensure that our time-lock value has been - // computed correctly. - expectedHeight := heightNow + minCltvDelta - switch { - - case !l.cfg.DebugHTLC && fwdInfo.OutgoingCTLV < expectedHeight: - log.Errorf("Onion payload of incoming "+ - "htlc(%x) has incorrect time-lock: "+ - "expected %v, got %v", - pd.RHash[:], expectedHeight, - fwdInfo.OutgoingCTLV) - - failure := lnwire.NewFinalIncorrectCltvExpiry( - fwdInfo.OutgoingCTLV, - ) - l.sendHTLCError( - pd.HtlcIndex, failure, obfuscator, pd.SourceRef, - ) - - needUpdate = true - continue + continue + } - case !l.cfg.DebugHTLC && pd.Timeout != fwdInfo.OutgoingCTLV: - log.Errorf("HTLC(%x) has incorrect "+ - "time-lock: expected %v, got %v", - pd.RHash[:], pd.Timeout, - fwdInfo.OutgoingCTLV) + // TODO(roasbeef): ensure don't accept outrageous + // timeout for htlc - failure := lnwire.NewFinalIncorrectCltvExpiry( - fwdInfo.OutgoingCTLV, - ) - l.sendHTLCError( - pd.HtlcIndex, failure, obfuscator, pd.SourceRef, - ) + // With all our forwarding constraints met, we'll + // create the outgoing HTLC using the parameters as + // specified in the forwarding info. + addMsg := &lnwire.UpdateAddHTLC{ + Expiry: fwdInfo.OutgoingCTLV, + Amount: fwdInfo.AmountToForward, + PaymentHash: pd.RHash, + } - needUpdate = true - continue - } + // Finally, we'll encode the onion packet for the + // _next_ hop using the hop iterator decoded for the + // current hop. + buf := bytes.NewBuffer(addMsg.OnionBlob[0:0]) + err := chanIterator.EncodeNextHop(buf) + if err != nil { + log.Errorf("unable to encode the "+ + "remaining route %v", err) - preimage := invoice.Terms.PaymentPreimage - err = l.channel.SettleHTLC( - preimage, pd.HtlcIndex, pd.SourceRef, nil, nil, + var failure lnwire.FailureMessage + update, err := l.cfg.FetchLastChannelUpdate( + l.ShortChanID(), ) if err != nil { - l.fail(LinkFailureError{code: ErrInternalError}, - "unable to settle htlc: %v", err) - return false + failure = &lnwire.FailTemporaryNodeFailure{} + } else { + failure = lnwire.NewTemporaryChannelFailure( + update, + ) } - // Notify the invoiceRegistry of the invoices we just - // settled (with the amount accepted at settle time) - // with this latest commitment update. - err = l.cfg.Registry.SettleInvoice( - invoiceHash, pd.Amount, + l.sendHTLCError( + pd.HtlcIndex, failure, obfuscator, pd.SourceRef, ) - if err != nil { - l.fail(LinkFailureError{code: ErrInternalError}, - "unable to settle invoice: %v", err) - return false - } - - l.infof("settling %x as exit hop", pd.RHash) - - // If the link is in hodl.BogusSettle mode, replace the - // preimage with a fake one before sending it to the - // peer. - if l.cfg.DebugHTLC && - l.cfg.HodlMask.Active(hodl.BogusSettle) { - l.warnf(hodl.BogusSettle.Warning()) - preimage = [32]byte{} - copy(preimage[:], bytes.Repeat([]byte{2}, 32)) - } - - // HTLC was successfully settled locally send - // notification about it remote peer. - l.cfg.Peer.SendMessage(false, &lnwire.UpdateFulfillHTLC{ - ChanID: l.ChanID(), - ID: pd.HtlcIndex, - PaymentPreimage: preimage, - }) needUpdate = true + continue + } - // There are additional channels left within this route. So - // we'll simply do some forwarding package book-keeping. - default: - // If hodl.AddIncoming is requested, we will not - // validate the forwarded ADD, nor will we send the - // packet to the htlc switch. - if l.cfg.DebugHTLC && - l.cfg.HodlMask.Active(hodl.AddIncoming) { - l.warnf(hodl.AddIncoming.Warning()) - continue - } - - switch fwdPkg.State { - case channeldb.FwdStateProcessed: - // This add was not forwarded on the previous - // processing phase, run it through our - // validation pipeline to reproduce an error. - // This may trigger a different error due to - // expiring timelocks, but we expect that an - // error will be reproduced. - if !fwdPkg.FwdFilter.Contains(idx) { - break - } - - // Otherwise, it was already processed, we can - // can collect it and continue. - addMsg := &lnwire.UpdateAddHTLC{ - Expiry: fwdInfo.OutgoingCTLV, - Amount: fwdInfo.AmountToForward, - PaymentHash: pd.RHash, - } - - // Finally, we'll encode the onion packet for - // the _next_ hop using the hop iterator - // decoded for the current hop. - buf := bytes.NewBuffer(addMsg.OnionBlob[0:0]) - - // We know this cannot fail, as this ADD - // was marked forwarded in a previous - // round of processing. - chanIterator.EncodeNextHop(buf) - - updatePacket := &htlcPacket{ - incomingChanID: l.ShortChanID(), - incomingHTLCID: pd.HtlcIndex, - outgoingChanID: fwdInfo.NextHop, - sourceRef: pd.SourceRef, - incomingAmount: pd.Amount, - amount: addMsg.Amount, - htlc: addMsg, - obfuscator: obfuscator, - incomingTimeout: pd.Timeout, - outgoingTimeout: fwdInfo.OutgoingCTLV, - } - switchPackets = append( - switchPackets, updatePacket, - ) - - continue - } - - // TODO(roasbeef): ensure don't accept outrageous - // timeout for htlc - - // With all our forwarding constraints met, we'll - // create the outgoing HTLC using the parameters as - // specified in the forwarding info. - addMsg := &lnwire.UpdateAddHTLC{ - Expiry: fwdInfo.OutgoingCTLV, - Amount: fwdInfo.AmountToForward, - PaymentHash: pd.RHash, - } - - // Finally, we'll encode the onion packet for the - // _next_ hop using the hop iterator decoded for the - // current hop. - buf := bytes.NewBuffer(addMsg.OnionBlob[0:0]) - err := chanIterator.EncodeNextHop(buf) - if err != nil { - log.Errorf("unable to encode the "+ - "remaining route %v", err) - - var failure lnwire.FailureMessage - update, err := l.cfg.FetchLastChannelUpdate( - l.ShortChanID(), - ) - if err != nil { - failure = &lnwire.FailTemporaryNodeFailure{} - } else { - failure = lnwire.NewTemporaryChannelFailure( - update, - ) - } - - l.sendHTLCError( - pd.HtlcIndex, failure, obfuscator, pd.SourceRef, - ) - needUpdate = true - continue + // Now that this add has been reprocessed, only append + // it to our list of packets to forward to the switch + // this is the first time processing the add. If the + // fwd pkg has already been processed, then we entered + // the above section to recreate a previous error. If + // the packet had previously been forwarded, it would + // have been added to switchPackets at the top of this + // section. + if fwdPkg.State == channeldb.FwdStateLockedIn { + updatePacket := &htlcPacket{ + incomingChanID: l.ShortChanID(), + incomingHTLCID: pd.HtlcIndex, + outgoingChanID: fwdInfo.NextHop, + sourceRef: pd.SourceRef, + incomingAmount: pd.Amount, + amount: addMsg.Amount, + htlc: addMsg, + obfuscator: obfuscator, + incomingTimeout: pd.Timeout, + outgoingTimeout: fwdInfo.OutgoingCTLV, } - // Now that this add has been reprocessed, only append - // it to our list of packets to forward to the switch - // this is the first time processing the add. If the - // fwd pkg has already been processed, then we entered - // the above section to recreate a previous error. If - // the packet had previously been forwarded, it would - // have been added to switchPackets at the top of this - // section. - if fwdPkg.State == channeldb.FwdStateLockedIn { - updatePacket := &htlcPacket{ - incomingChanID: l.ShortChanID(), - incomingHTLCID: pd.HtlcIndex, - outgoingChanID: fwdInfo.NextHop, - sourceRef: pd.SourceRef, - incomingAmount: pd.Amount, - amount: addMsg.Amount, - htlc: addMsg, - obfuscator: obfuscator, - incomingTimeout: pd.Timeout, - outgoingTimeout: fwdInfo.OutgoingCTLV, - } - - fwdPkg.FwdFilter.Set(idx) - switchPackets = append(switchPackets, - updatePacket) - } + fwdPkg.FwdFilter.Set(idx) + switchPackets = append(switchPackets, + updatePacket) } } diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 2fe1dabf4ff..f2321b7f07a 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -1535,7 +1535,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( } aliceDb := aliceChannel.State().Db - aliceSwitch, err := initSwitchWithDB(testStartingHeight, aliceDb) + aliceSwitch, err := initSwitchWithDB(testStartingHeight, aliceDb, invoiceRegistry) if err != nil { return nil, nil, nil, nil, nil, nil, err } @@ -4044,7 +4044,8 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch, if aliceSwitch == nil { var err error - aliceSwitch, err = initSwitchWithDB(testStartingHeight, aliceDb) + aliceSwitch, err = initSwitchWithDB(testStartingHeight, aliceDb, + invoiceRegistry) if err != nil { return nil, nil, nil, err } diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 58ccc8b18e9..fa0d5be0014 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -138,7 +138,9 @@ func initDB() (*channeldb.DB, error) { return db, err } -func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) { +func initSwitchWithDB(startingHeight uint32, db *channeldb.DB, + registry InvoiceDatabase) (*Switch, error) { + var err error if db == nil { @@ -160,6 +162,7 @@ func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) Notifier: &mockNotifier{}, FwdEventTicker: ticker.MockNew(DefaultFwdEventInterval), LogEventTicker: ticker.MockNew(DefaultLogInterval), + InvoiceSettler: NewInvoiceSettler(registry), } return New(cfg, startingHeight) @@ -172,7 +175,9 @@ func newMockServer(t testing.TB, name string, startingHeight uint32, h := sha256.Sum256([]byte(name)) copy(id[:], h[:]) - htlcSwitch, err := initSwitchWithDB(startingHeight, db) + registry := newMockRegistry(defaultDelta) + + htlcSwitch, err := initSwitchWithDB(startingHeight, db, registry) if err != nil { return nil, err } @@ -183,7 +188,7 @@ func newMockServer(t testing.TB, name string, startingHeight uint32, name: name, messages: make(chan lnwire.Message, 3000), quit: make(chan struct{}), - registry: newMockRegistry(defaultDelta), + registry: registry, htlcSwitch: htlcSwitch, interceptorFuncs: make([]messageInterceptor, 0), }, nil diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 7c12f621948..20b88c8178c 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -176,6 +176,8 @@ type Config struct { // LogEventTicker is a signal instructing the htlcswitch to log // aggregate stats about it's forwarding during the last interval. LogEventTicker ticker.Ticker + + InvoiceSettler *InvoiceSettler } // Switch is the central messaging bus for all incoming/outgoing HTLCs. @@ -1004,6 +1006,12 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket, return failure } +func (s *Switch) handleLocalForward(pkt *htlcPacket) error { + err := s.cfg.InvoiceSettler.handleIncoming(pkt) + + return err +} + // handlePacketForward is used in cases when we need forward the htlc update // from one channel link to another and be able to propagate the settle/fail // updates back. This behaviour is achieved by creation of payment circuits. @@ -1019,6 +1027,9 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { // a pending user-initiated payment. return s.handleLocalDispatch(packet) } + if packet.outgoingChanID == exitHop { + return s.handleLocalForward(packet) + } s.indexMtx.RLock() targetLink, err := s.getLinkByShortID(packet.outgoingChanID) @@ -1516,6 +1527,41 @@ func (s *Switch) htlcForwarder() { s.cfg.FwdEventTicker.Resume() defer s.cfg.FwdEventTicker.Stop() + handleResolutionMsg := func(resolutionMsg *resolutionMsg) { + pkt := &htlcPacket{ + outgoingChanID: resolutionMsg.SourceChan, + outgoingHTLCID: resolutionMsg.HtlcIndex, + isResolution: true, + } + + // Resolution messages will either be cancelling + // backwards an existing HTLC, or settling a previously + // outgoing HTLC. Based on this, we'll map the message + // to the proper htlcPacket. + if resolutionMsg.Failure != nil { + pkt.htlc = &lnwire.UpdateFailHTLC{} + } else { + pkt.htlc = &lnwire.UpdateFulfillHTLC{ + PaymentPreimage: *resolutionMsg.PreImage, + } + } + + log.Infof("Received outside contract resolution, "+ + "mapping to: %v", spew.Sdump(pkt)) + + // We don't check the error, as the only failure we can + // encounter is due to the circuit already being + // closed. This is fine, as processing this message is + // meant to be idempotent. + err := s.handlePacketForward(pkt) + if err != nil { + log.Errorf("Unable to forward resolution msg: %v", err) + } + + // With the message processed, we'll now close out + close(resolutionMsg.doneChan) + } + out: for { select { @@ -1550,38 +1596,9 @@ out: go s.cfg.LocalChannelClose(peerPub[:], req) case resolutionMsg := <-s.resolutionMsgs: - pkt := &htlcPacket{ - outgoingChanID: resolutionMsg.SourceChan, - outgoingHTLCID: resolutionMsg.HtlcIndex, - isResolution: true, - } - - // Resolution messages will either be cancelling - // backwards an existing HTLC, or settling a previously - // outgoing HTLC. Based on this, we'll map the message - // to the proper htlcPacket. - if resolutionMsg.Failure != nil { - pkt.htlc = &lnwire.UpdateFailHTLC{} - } else { - pkt.htlc = &lnwire.UpdateFulfillHTLC{ - PaymentPreimage: *resolutionMsg.PreImage, - } - } - - log.Infof("Received outside contract resolution, "+ - "mapping to: %v", spew.Sdump(pkt)) - - // We don't check the error, as the only failure we can - // encounter is due to the circuit already being - // closed. This is fine, as processing this message is - // meant to be idempotent. - err := s.handlePacketForward(pkt) - if err != nil { - log.Errorf("Unable to forward resolution msg: %v", err) - } - - // With the message processed, we'll now close out - close(resolutionMsg.doneChan) + handleResolutionMsg(resolutionMsg) + case resolutionMsg := <-s.cfg.InvoiceSettler.ResolutionMsgs: + handleResolutionMsg(&resolutionMsg) // A new packet has arrived for forwarding, we'll interpret the // packet concretely, then either forward it along, or diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 6dd8980fee5..99a85f35619 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -36,7 +36,7 @@ func TestSwitchAddDuplicateLink(t *testing.T) { t.Fatalf("unable to create alice server: %v", err) } - s, err := initSwitchWithDB(testStartingHeight, nil) + s, err := initSwitchWithDB(testStartingHeight, nil, nil) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -92,7 +92,7 @@ func TestSwitchSendPending(t *testing.T) { t.Fatalf("unable to create alice server: %v", err) } - s, err := initSwitchWithDB(testStartingHeight, nil) + s, err := initSwitchWithDB(testStartingHeight, nil, alicePeer.registry) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -191,7 +191,7 @@ func TestSwitchForward(t *testing.T) { t.Fatalf("unable to create bob server: %v", err) } - s, err := initSwitchWithDB(testStartingHeight, nil) + s, err := initSwitchWithDB(testStartingHeight, nil, nil) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -306,7 +306,7 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) { t.Fatalf("unable to open channeldb: %v", err) } - s, err := initSwitchWithDB(testStartingHeight, cdb) + s, err := initSwitchWithDB(testStartingHeight, cdb, nil) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -401,7 +401,7 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) { t.Fatalf("unable to reopen channeldb: %v", err) } - s2, err := initSwitchWithDB(testStartingHeight, cdb2) + s2, err := initSwitchWithDB(testStartingHeight, cdb2, nil) if err != nil { t.Fatalf("unable reinit switch: %v", err) } @@ -497,7 +497,7 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) { t.Fatalf("unable to open channeldb: %v", err) } - s, err := initSwitchWithDB(testStartingHeight, cdb) + s, err := initSwitchWithDB(testStartingHeight, cdb, nil) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -592,7 +592,7 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) { t.Fatalf("unable to reopen channeldb: %v", err) } - s2, err := initSwitchWithDB(testStartingHeight, cdb2) + s2, err := initSwitchWithDB(testStartingHeight, cdb2, nil) if err != nil { t.Fatalf("unable reinit switch: %v", err) } @@ -691,7 +691,7 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) { t.Fatalf("unable to open channeldb: %v", err) } - s, err := initSwitchWithDB(testStartingHeight, cdb) + s, err := initSwitchWithDB(testStartingHeight, cdb, nil) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -778,7 +778,7 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) { t.Fatalf("unable to reopen channeldb: %v", err) } - s2, err := initSwitchWithDB(testStartingHeight, cdb2) + s2, err := initSwitchWithDB(testStartingHeight, cdb2, nil) if err != nil { t.Fatalf("unable reinit switch: %v", err) } @@ -854,7 +854,7 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) { t.Fatalf("unable to open channeldb: %v", err) } - s, err := initSwitchWithDB(testStartingHeight, cdb) + s, err := initSwitchWithDB(testStartingHeight, cdb, nil) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -936,7 +936,7 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) { t.Fatalf("unable to reopen channeldb: %v", err) } - s2, err := initSwitchWithDB(testStartingHeight, cdb2) + s2, err := initSwitchWithDB(testStartingHeight, cdb2, nil) if err != nil { t.Fatalf("unable reinit switch: %v", err) } @@ -1012,7 +1012,7 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { t.Fatalf("unable to open channeldb: %v", err) } - s, err := initSwitchWithDB(testStartingHeight, cdb) + s, err := initSwitchWithDB(testStartingHeight, cdb, nil) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -1093,7 +1093,7 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { t.Fatalf("unable to reopen channeldb: %v", err) } - s2, err := initSwitchWithDB(testStartingHeight, cdb2) + s2, err := initSwitchWithDB(testStartingHeight, cdb2, nil) if err != nil { t.Fatalf("unable reinit switch: %v", err) } @@ -1186,7 +1186,7 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { t.Fatalf("unable to reopen channeldb: %v", err) } - s3, err := initSwitchWithDB(testStartingHeight, cdb3) + s3, err := initSwitchWithDB(testStartingHeight, cdb3, nil) if err != nil { t.Fatalf("unable reinit switch: %v", err) } @@ -1233,7 +1233,7 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { t.Fatalf("unable to create bob server: %v", err) } - s, err := initSwitchWithDB(testStartingHeight, nil) + s, err := initSwitchWithDB(testStartingHeight, nil, nil) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -1299,7 +1299,7 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) { t.Fatalf("unable to create alice server: %v", err) } - s, err := initSwitchWithDB(testStartingHeight, nil) + s, err := initSwitchWithDB(testStartingHeight, nil, nil) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -1354,7 +1354,7 @@ func TestSwitchCancel(t *testing.T) { t.Fatalf("unable to create bob server: %v", err) } - s, err := initSwitchWithDB(testStartingHeight, nil) + s, err := initSwitchWithDB(testStartingHeight, nil, nil) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -1467,7 +1467,7 @@ func TestSwitchAddSamePayment(t *testing.T) { t.Fatalf("unable to create bob server: %v", err) } - s, err := initSwitchWithDB(testStartingHeight, nil) + s, err := initSwitchWithDB(testStartingHeight, nil, nil) if err != nil { t.Fatalf("unable to init switch: %v", err) } @@ -1622,7 +1622,7 @@ func TestSwitchSendPayment(t *testing.T) { t.Fatalf("unable to create alice server: %v", err) } - s, err := initSwitchWithDB(testStartingHeight, nil) + s, err := initSwitchWithDB(testStartingHeight, nil, nil) if err != nil { t.Fatalf("unable to init switch: %v", err) }