Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ profile.tmp
.DS_Store

.vscode
*.code-workspace

# Coverage test
coverage.txt
Expand Down
5 changes: 4 additions & 1 deletion contractcourt/chain_watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package contractcourt

import (
"bytes"
"context"
"crypto/sha256"
"fmt"
"testing"
Expand Down Expand Up @@ -145,7 +146,9 @@ func TestChainWatcherRemoteUnilateralClosePendingCommit(t *testing.T) {

// With the HTLC added, we'll now manually initiate a state transition
// from Alice to Bob.
_, err = aliceChannel.SignNextCommitment()
testQuit, testQuitFunc := context.WithCancel(context.Background())
_ = testQuitFunc
_, err = aliceChannel.SignNextCommitment(testQuit)
if err != nil {
t.Fatal(err)
}
Expand Down
6 changes: 3 additions & 3 deletions htlcswitch/interceptable_switch.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ type InterceptableSwitch struct {

type interceptedPackets struct {
packets []*htlcPacket
linkQuit chan struct{}
linkQuit <-chan struct{}
isReplay bool
}

Expand Down Expand Up @@ -442,8 +442,8 @@ func (s *InterceptableSwitch) Resolve(res *FwdResolution) error {
// interceptor. If the interceptor signals the resume action, the htlcs are
// forwarded to the switch. The link's quit signal should be provided to allow
// cancellation of forwarding during link shutdown.
func (s *InterceptableSwitch) ForwardPackets(linkQuit chan struct{}, isReplay bool,
packets ...*htlcPacket) error {
func (s *InterceptableSwitch) ForwardPackets(linkQuit <-chan struct{},
isReplay bool, packets ...*htlcPacket) error {

// Synchronize with the main event loop. This should be light in the
// case where there is no interceptor.
Expand Down
41 changes: 24 additions & 17 deletions htlcswitch/link.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package htlcswitch

import (
"bytes"
"context"
crand "crypto/rand"
"crypto/sha256"
"fmt"
Expand Down Expand Up @@ -101,7 +102,7 @@ type ChannelLinkConfig struct {
// switch. The function returns and error in case it fails to send one or
// more packets. The link's quit signal should be provided to allow
// cancellation of forwarding during link shutdown.
ForwardPackets func(chan struct{}, bool, ...*htlcPacket) error
ForwardPackets func(<-chan struct{}, bool, ...*htlcPacket) error

// DecodeHopIterators facilitates batched decoding of HTLC Sphinx onion
// blobs, which are then used to inform how to forward an HTLC.
Expand Down Expand Up @@ -382,8 +383,9 @@ type channelLink struct {
// our next CommitSig.
incomingCommitHooks hookMap

wg sync.WaitGroup
quit chan struct{}
wg sync.WaitGroup
quit context.Context //nolint:containedctx
quitFunc context.CancelFunc
}

// hookMap is a data structure that is used to track the hooks that need to be
Expand Down Expand Up @@ -448,6 +450,10 @@ func NewChannelLink(cfg ChannelLinkConfig,
channel *lnwallet.LightningChannel) ChannelLink {

logPrefix := fmt.Sprintf("ChannelLink(%v):", channel.ChannelPoint())
quitCtx, quitFunc := context.WithCancel(context.Background())

// Initialize the Done channel for our quit context.
_ = quitCtx.Done()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's doesn't really have to be, but if we wanted to save that channel and then use it later (call Done() once) we'd want to explicitly initialize it as here.


return &channelLink{
cfg: cfg,
Expand All @@ -458,7 +464,8 @@ func NewChannelLink(cfg ChannelLinkConfig,
flushHooks: newHookMap(),
outgoingCommitHooks: newHookMap(),
incomingCommitHooks: newHookMap(),
quit: make(chan struct{}),
quit: quitCtx,
quitFunc: quitFunc,
}
}

Expand Down Expand Up @@ -573,7 +580,7 @@ func (l *channelLink) Stop() {

l.hodlQueue.Stop()

close(l.quit)
l.quitFunc()
l.wg.Wait()

// Now that the htlcManager has completely exited, reset the packet
Expand Down Expand Up @@ -660,7 +667,7 @@ func (l *channelLink) IsFlushing(linkDirection LinkDirection) bool {
func (l *channelLink) OnFlushedOnce(hook func()) {
select {
case l.flushHooks.newTransients <- hook:
case <-l.quit:
case <-l.quit.Done():
}
}

Expand All @@ -679,7 +686,7 @@ func (l *channelLink) OnCommitOnce(direction LinkDirection, hook func()) {

select {
case queue <- hook:
case <-l.quit:
case <-l.quit.Done():
}
}

Expand Down Expand Up @@ -889,7 +896,7 @@ func (l *channelLink) syncChanStates() error {
// party, so we'll process the message in order to determine
// if we need to re-transmit any messages to the remote party.
msgsToReSend, openedCircuits, closedCircuits, err =
l.channel.ProcessChanSyncMsg(remoteChanSyncMsg)
l.channel.ProcessChanSyncMsg(l.quit, remoteChanSyncMsg)
if err != nil {
return err
}
Expand Down Expand Up @@ -918,7 +925,7 @@ func (l *channelLink) syncChanStates() error {
l.cfg.Peer.SendMessage(false, msg)
}

case <-l.quit:
case <-l.quit.Done():
return ErrLinkShuttingDown
}

Expand Down Expand Up @@ -1041,7 +1048,7 @@ func (l *channelLink) fwdPkgGarbager() {
err)
continue
}
case <-l.quit:
case <-l.quit.Done():
return
}
}
Expand Down Expand Up @@ -1442,7 +1449,7 @@ func (l *channelLink) htlcManager() {
)
}

case <-l.quit:
case <-l.quit.Done():
return
}
}
Expand Down Expand Up @@ -2272,7 +2279,7 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) {
}

select {
case <-l.quit:
case <-l.quit.Done():
return
default:
}
Expand Down Expand Up @@ -2334,7 +2341,7 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) {
}

select {
case <-l.quit:
case <-l.quit.Done():
return
default:
}
Expand Down Expand Up @@ -2541,7 +2548,7 @@ func (l *channelLink) updateCommitTx() error {
return nil
}

newCommit, err := l.channel.SignNextCommitment()
newCommit, err := l.channel.SignNextCommitment(l.quit)
if err == lnwallet.ErrNoWindow {
l.cfg.PendingCommitTicker.Resume()
l.log.Trace("PendingCommitTicker resumed")
Expand Down Expand Up @@ -2582,7 +2589,7 @@ func (l *channelLink) updateCommitTx() error {
}

select {
case <-l.quit:
case <-l.quit.Done():
return ErrLinkShuttingDown
default:
}
Expand Down Expand Up @@ -3057,7 +3064,7 @@ func (l *channelLink) handleSwitchPacket(pkt *htlcPacket) error {
// NOTE: Part of the ChannelLink interface.
func (l *channelLink) HandleChannelUpdate(message lnwire.Message) {
select {
case <-l.quit:
case <-l.quit.Done():
// Return early if the link is already in the process of
// quitting. It doesn't make sense to hand the message to the
// mailbox here.
Expand Down Expand Up @@ -3744,7 +3751,7 @@ func (l *channelLink) forwardBatch(replay bool, packets ...*htlcPacket) {
filteredPkts = append(filteredPkts, pkt)
}

err := l.cfg.ForwardPackets(l.quit, replay, filteredPkts...)
err := l.cfg.ForwardPackets(l.quit.Done(), replay, filteredPkts...)
if err != nil {
log.Errorf("Unhandled error while reforwarding htlc "+
"settle/fail over htlcswitch: %v", err)
Expand Down
5 changes: 4 additions & 1 deletion htlcswitch/link_isolated_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package htlcswitch

import (
"context"
"crypto/sha256"
"testing"
"time"
Expand Down Expand Up @@ -94,7 +95,9 @@ func (l *linkTestContext) receiveHtlcAliceToBob() {
func (l *linkTestContext) sendCommitSigBobToAlice(expHtlcs int) {
l.t.Helper()

sigs, err := l.bobChannel.SignNextCommitment()
testQuit, testQuitFunc := context.WithCancel(context.Background())
_ = testQuitFunc
sigs, err := l.bobChannel.SignNextCommitment(testQuit)
if err != nil {
l.t.Fatalf("error signing commitment: %v", err)
}
Expand Down
82 changes: 49 additions & 33 deletions htlcswitch/link_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2189,17 +2189,21 @@ func newSingleLinkTestHarness(t *testing.T, chanAmt,
return nil
}

forwardPackets := func(linkQuit <-chan struct{}, _ bool,
packets ...*htlcPacket) error {

return aliceSwitch.ForwardPackets(linkQuit, packets...)
}

// Instantiate with a long interval, so that we can precisely control
// the firing via force feeding.
bticker := ticker.NewForce(time.Hour)
aliceCfg := ChannelLinkConfig{
FwrdingPolicy: globalPolicy,
Peer: alicePeer,
BestHeight: aliceSwitch.BestHeight,
Circuits: aliceSwitch.CircuitModifier(),
ForwardPackets: func(linkQuit chan struct{}, _ bool, packets ...*htlcPacket) error {
return aliceSwitch.ForwardPackets(linkQuit, packets...)
},
FwrdingPolicy: globalPolicy,
Peer: alicePeer,
BestHeight: aliceSwitch.BestHeight,
Circuits: aliceSwitch.CircuitModifier(),
ForwardPackets: forwardPackets,
DecodeHopIterators: decoder.DecodeHopIterators,
ExtractErrorEncrypter: func(*btcec.PublicKey) (
hop.ErrorEncrypter, lnwire.FailCode) {
Expand Down Expand Up @@ -2240,12 +2244,14 @@ func newSingleLinkTestHarness(t *testing.T, chanAmt,
return aliceSwitch.AddLink(aliceLink)
}
go func() {
for {
select {
case <-notifyUpdateChan:
case <-aliceLink.(*channelLink).quit:
close(doneChan)
return
if chanLink, ok := aliceLink.(*channelLink); ok {
for {
select {
case <-notifyUpdateChan:
case <-chanLink.quit.Done():
close(doneChan)
return
}
}
}
}()
Expand Down Expand Up @@ -2312,7 +2318,7 @@ func handleStateUpdate(link *channelLink,
}
link.HandleChannelUpdate(remoteRev)

remoteSigs, err := remoteChannel.SignNextCommitment()
remoteSigs, err := remoteChannel.SignNextCommitment(link.quit)
if err != nil {
return err
}
Expand Down Expand Up @@ -2355,15 +2361,15 @@ func updateState(batchTick chan time.Time, link *channelLink,
// Trigger update by ticking the batchTicker.
select {
case batchTick <- time.Now():
case <-link.quit:
case <-link.quit.Done():
return fmt.Errorf("link shutting down")
}
return handleStateUpdate(link, remoteChannel)
}

// The remote is triggering the state update, emulate this by
// signing and sending CommitSig to the link.
remoteSigs, err := remoteChannel.SignNextCommitment()
remoteSigs, err := remoteChannel.SignNextCommitment(link.quit)
if err != nil {
return err
}
Expand Down Expand Up @@ -4849,17 +4855,21 @@ func (h *persistentLinkHarness) restartLink(
return nil
}

forwardPackets := func(linkQuit <-chan struct{}, _ bool,
packets ...*htlcPacket) error {

return h.hSwitch.ForwardPackets(linkQuit, packets...)
}

// Instantiate with a long interval, so that we can precisely control
// the firing via force feeding.
bticker := ticker.NewForce(time.Hour)
aliceCfg := ChannelLinkConfig{
FwrdingPolicy: globalPolicy,
Peer: alicePeer,
BestHeight: h.hSwitch.BestHeight,
Circuits: h.hSwitch.CircuitModifier(),
ForwardPackets: func(linkQuit chan struct{}, _ bool, packets ...*htlcPacket) error {
return h.hSwitch.ForwardPackets(linkQuit, packets...)
},
FwrdingPolicy: globalPolicy,
Peer: alicePeer,
BestHeight: h.hSwitch.BestHeight,
Circuits: h.hSwitch.CircuitModifier(),
ForwardPackets: forwardPackets,
DecodeHopIterators: decoder.DecodeHopIterators,
ExtractErrorEncrypter: func(*btcec.PublicKey) (
hop.ErrorEncrypter, lnwire.FailCode) {
Expand Down Expand Up @@ -4904,12 +4914,14 @@ func (h *persistentLinkHarness) restartLink(
return nil, nil, err
}
go func() {
for {
select {
case <-notifyUpdateChan:
case <-aliceLink.(*channelLink).quit:
close(doneChan)
return
if chanLink, ok := aliceLink.(*channelLink); ok {
for {
select {
case <-notifyUpdateChan:
case <-chanLink.quit.Done():
close(doneChan)
return
}
}
}
}()
Expand Down Expand Up @@ -5892,7 +5904,9 @@ func TestChannelLinkFail(t *testing.T) {

// Sign a commitment that will include
// signature for the HTLC just sent.
sigs, err := remoteChannel.SignNextCommitment()
sigs, err := remoteChannel.SignNextCommitment(
c.quit,
)
if err != nil {
t.Fatalf("error signing commitment: %v",
err)
Expand Down Expand Up @@ -5934,7 +5948,9 @@ func TestChannelLinkFail(t *testing.T) {

// Sign a commitment that will include
// signature for the HTLC just sent.
sigs, err := remoteChannel.SignNextCommitment()
sigs, err := remoteChannel.SignNextCommitment(
c.quit,
)
if err != nil {
t.Fatalf("error signing commitment: %v",
err)
Expand Down Expand Up @@ -7018,7 +7034,7 @@ func TestPipelineSettle(t *testing.T) {
// erroneously forwarded. If the forwardChan is closed before the last
// step, then the test will fail.
forwardChan := make(chan struct{})
fwdPkts := func(c chan struct{}, _ bool, hp ...*htlcPacket) error {
fwdPkts := func(c <-chan struct{}, _ bool, hp ...*htlcPacket) error {
close(forwardChan)
return nil
}
Expand Down Expand Up @@ -7204,7 +7220,7 @@ func TestChannelLinkShortFailureRelay(t *testing.T) {
aliceMsgs := mockPeer.sentMsgs
switchChan := make(chan *htlcPacket)

coreLink.cfg.ForwardPackets = func(linkQuit chan struct{}, _ bool,
coreLink.cfg.ForwardPackets = func(linkQuit <-chan struct{}, _ bool,
packets ...*htlcPacket) error {

for _, p := range packets {
Expand Down
Loading