Skip to content
Closed
222 changes: 208 additions & 14 deletions lntest/itest/lnd_multi-hop_htlc_local_chain_claim_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ package itest
import (
"context"
"fmt"
"testing"
"time"

"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd"
"github.com/lightningnetwork/lnd/lnrpc"
Expand All @@ -17,18 +19,119 @@ import (
"github.com/lightningnetwork/lnd/lntypes"
)

func testMultiHopHtlcClaims(net *lntest.NetworkHarness, t *harnessTest) {

type testCase struct {
name string
test func(net *lntest.NetworkHarness, t *harnessTest, alice,
bob *lntest.HarnessNode, c commitType)
}

subTests := []testCase{
{
// bob: outgoing our commit timeout
// carol: incoming their commit watch and see timeout
name: "local force close immediate expiry",
test: testMultiHopHtlcLocalTimeout,
},
{
// bob: outgoing watch and see, they sweep on chain
// carol: incoming our commit, know preimage
name: "receiver chain claim",
test: testMultiHopReceiverChainClaim,
},
{
// bob: outgoing our commit watch and see timeout
// carol: incoming their commit watch and see timeout
name: "local force close on-chain htlc timeout",
test: testMultiHopLocalForceCloseOnChainHtlcTimeout,
},
{
// bob: outgoing their commit watch and see timeout
// carol: incoming our commit watch and see timeout
name: "remote force close on-chain htlc timeout",
test: testMultiHopRemoteForceCloseOnChainHtlcTimeout,
},
{
// bob: outgoing our commit watch and see, they sweep on chain
// bob: incoming our commit watch and learn preimage
// carol: incoming their commit know preimage
name: "local chain claim",
test: testMultiHopHtlcLocalChainClaim,
},
{
// bob: outgoing their commit watch and see, they sweep on chain
// bob: incoming their commit watch and learn preimage
// carol: incoming our commit know preimage
name: "remote chain claim",
test: testMultiHopHtlcRemoteChainClaim,
},
}

allTypes := []commitType{
commitTypeLegacy,
commitTypeTweakless,
commitTypeAnchors,
}

for _, commitType := range allTypes {
testName := fmt.Sprintf("committype=%v", commitType.String())

success := t.t.Run(testName, func(t *testing.T) {
ht := newHarnessTest(t, net)

args := commitType.Args()
alice, err := net.NewNode("Alice", args)
if err != nil {
t.Fatalf("unable to create new node: %v", err)
}
defer shutdownAndAssert(net, ht, alice)

bob, err := net.NewNode("Bob", args)
if err != nil {
t.Fatalf("unable to create new node: %v", err)
}
defer shutdownAndAssert(net, ht, bob)

ctxb := context.Background()
ctxt, _ := context.WithTimeout(ctxb, defaultTimeout)
if err := net.ConnectNodes(ctxt, alice, bob); err != nil {
t.Fatalf("unable to connect alice to bob: %v", err)
}

for _, subTest := range subTests {
subTest := subTest

success := ht.t.Run(subTest.name, func(t *testing.T) {
ht := newHarnessTest(t, net)

subTest.test(net, ht, alice, bob, commitType)
})
if !success {
return
}
}
})
if !success {
return
}
}
}

// testMultiHopHtlcLocalChainClaim tests that in a multi-hop HTLC scenario, if
// we're forced to go to chain with an incoming HTLC, then when we find out the
// preimage via the witness beacon, we properly settle the HTLC on-chain in
// order to ensure we don't lose any funds.
func testMultiHopHtlcLocalChainClaim(net *lntest.NetworkHarness, t *harnessTest) {
// we force close a channel with an incoming HTLC, and later find out the
// preimage via the witness beacon, we properly settle the HTLC on-chain using
// the HTLC success transaction in order to ensure we don't lose any funds.
func testMultiHopHtlcLocalChainClaim(net *lntest.NetworkHarness, t *harnessTest,
alice, bob *lntest.HarnessNode, c commitType) {

ctxb := context.Background()

// First, we'll create a three hop network: Alice -> Bob -> Carol, with
// Carol refusing to actually settle or directly cancel any HTLC's
// self.
aliceChanPoint, bobChanPoint, carol := createThreeHopNetwork(
t, net, false,
t, net, alice, bob, false, c,
)

// Clean up carol's node when the test finishes.
Expand Down Expand Up @@ -59,7 +162,7 @@ func testMultiHopHtlcLocalChainClaim(net *lntest.NetworkHarness, t *harnessTest)
ctx, cancel := context.WithCancel(ctxb)
defer cancel()

alicePayStream, err := net.Alice.SendPayment(ctx)
alicePayStream, err := alice.SendPayment(ctx)
if err != nil {
t.Fatalf("unable to create payment stream for alice: %v", err)
}
Expand All @@ -73,7 +176,7 @@ func testMultiHopHtlcLocalChainClaim(net *lntest.NetworkHarness, t *harnessTest)
// At this point, all 3 nodes should now have an active channel with
// the created HTLC pending on all of them.
var predErr error
nodes := []*lntest.HarnessNode{net.Alice, net.Bob, carol}
nodes := []*lntest.HarnessNode{alice, bob, carol}
err = wait.Predicate(func() bool {
predErr = assertActiveHtlcs(nodes, payHash[:])
if predErr != nil {
Expand All @@ -94,7 +197,7 @@ func testMultiHopHtlcLocalChainClaim(net *lntest.NetworkHarness, t *harnessTest)
// At this point, Bob decides that he wants to exit the channel
// immediately, so he force closes his commitment transaction.
ctxt, _ = context.WithTimeout(ctxb, channelCloseTimeout)
bobForceClose := closeChannelAndAssert(ctxt, t, net, net.Bob,
bobForceClose := closeChannelAndAssert(ctxt, t, net, bob,
aliceChanPoint, true)

// Alice will sweep her output immediately.
Expand All @@ -105,7 +208,7 @@ func testMultiHopHtlcLocalChainClaim(net *lntest.NetworkHarness, t *harnessTest)
}

// Suspend Bob to force Carol to go to chain.
restartBob, err := net.SuspendNode(net.Bob)
restartBob, err := net.SuspendNode(bob)
if err != nil {
t.Fatalf("unable to suspend bob: %v", err)
}
Expand Down Expand Up @@ -197,7 +300,7 @@ func testMultiHopHtlcLocalChainClaim(net *lntest.NetworkHarness, t *harnessTest)

// At this point we suspend Alice to make sure she'll handle the
// on-chain settle after a restart.
restartAlice, err := net.SuspendNode(net.Alice)
restartAlice, err := net.SuspendNode(alice)
if err != nil {
t.Fatalf("unable to suspend alice: %v", err)
}
Expand Down Expand Up @@ -239,7 +342,7 @@ func testMultiHopHtlcLocalChainClaim(net *lntest.NetworkHarness, t *harnessTest)
pendingChansRequest := &lnrpc.PendingChannelsRequest{}
err = wait.Predicate(func() bool {
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
pendingChanResp, err := net.Bob.PendingChannels(
pendingChanResp, err := bob.PendingChannels(
ctxt, pendingChansRequest,
)
if err != nil {
Expand Down Expand Up @@ -337,7 +440,7 @@ func testMultiHopHtlcLocalChainClaim(net *lntest.NetworkHarness, t *harnessTest)

err = wait.Predicate(func() bool {
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
pendingChanResp, err := net.Bob.PendingChannels(
pendingChanResp, err := bob.PendingChannels(
ctxt, pendingChansRequest,
)
if err != nil {
Expand All @@ -352,7 +455,7 @@ func testMultiHopHtlcLocalChainClaim(net *lntest.NetworkHarness, t *harnessTest)
}
req := &lnrpc.ListChannelsRequest{}
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
chanInfo, err := net.Bob.ListChannels(ctxt, req)
chanInfo, err := bob.ListChannels(ctxt, req)
if err != nil {
predErr = fmt.Errorf("unable to query for open "+
"channels: %v", err)
Expand Down Expand Up @@ -411,7 +514,7 @@ func testMultiHopHtlcLocalChainClaim(net *lntest.NetworkHarness, t *harnessTest)
// succeeded.
ctxt, _ = context.WithTimeout(ctxt, defaultTimeout)
err = checkPaymentStatus(
ctxt, net.Alice, preimage, lnrpc.Payment_SUCCEEDED,
ctxt, alice, preimage, lnrpc.Payment_SUCCEEDED,
)
if err != nil {
t.Fatalf(err.Error())
Expand Down Expand Up @@ -498,3 +601,94 @@ func checkPaymentStatus(ctxt context.Context, node *lntest.HarnessNode,

return nil
}

func createThreeHopNetwork(t *harnessTest, net *lntest.NetworkHarness,
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move in separate (trivial) PR?

alice, bob *lntest.HarnessNode, carolHodl bool, c commitType) (
*lnrpc.ChannelPoint, *lnrpc.ChannelPoint, *lntest.HarnessNode) {

ctxb := context.Background()

ctxt, _ := context.WithTimeout(ctxb, defaultTimeout)
err := net.EnsureConnected(ctxt, alice, bob)
if err != nil {
t.Fatalf("unable to connect peers: %v", err)
}

ctxt, _ = context.WithTimeout(context.Background(), defaultTimeout)
err = net.SendCoins(ctxt, btcutil.SatoshiPerBitcoin, alice)
if err != nil {
t.Fatalf("unable to send coins to Alice: %v", err)
}

ctxt, _ = context.WithTimeout(context.Background(), defaultTimeout)
err = net.SendCoins(ctxt, btcutil.SatoshiPerBitcoin, bob)
if err != nil {
t.Fatalf("unable to send coins to Bob: %v", err)
}

// We'll start the test by creating a channel between Alice and Bob,
// which will act as the first leg for out multi-hop HTLC.
const chanAmt = 1000000
ctxt, _ = context.WithTimeout(ctxb, channelOpenTimeout)
aliceChanPoint := openChannelAndAssert(
ctxt, t, net, alice, bob,
lntest.OpenChannelParams{
Amt: chanAmt,
},
)

ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
err = alice.WaitForNetworkChannelOpen(ctxt, aliceChanPoint)
if err != nil {
t.Fatalf("alice didn't report channel: %v", err)
}

ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
err = bob.WaitForNetworkChannelOpen(ctxt, aliceChanPoint)
if err != nil {
t.Fatalf("bob didn't report channel: %v", err)
}

// Next, we'll create a new node "carol" and have Bob connect to her. If
// the carolHodl flag is set, we'll make carol always hold onto the
// HTLC, this way it'll force Bob to go to chain to resolve the HTLC.
carolFlags := c.Args()
if carolHodl {
carolFlags = append(carolFlags, "--hodl.exit-settle")
}
carol, err := net.NewNode("Carol", carolFlags)
if err != nil {
t.Fatalf("unable to create new node: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
if err := net.ConnectNodes(ctxt, bob, carol); err != nil {
t.Fatalf("unable to connect bob to carol: %v", err)
}

// We'll then create a channel from Bob to Carol. After this channel is
// open, our topology looks like: A -> B -> C.
ctxt, _ = context.WithTimeout(ctxb, channelOpenTimeout)
bobChanPoint := openChannelAndAssert(
ctxt, t, net, bob, carol,
lntest.OpenChannelParams{
Amt: chanAmt,
},
)
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
err = bob.WaitForNetworkChannelOpen(ctxt, bobChanPoint)
if err != nil {
t.Fatalf("alice didn't report channel: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
err = carol.WaitForNetworkChannelOpen(ctxt, bobChanPoint)
if err != nil {
t.Fatalf("bob didn't report channel: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
err = alice.WaitForNetworkChannelOpen(ctxt, bobChanPoint)
if err != nil {
t.Fatalf("bob didn't report channel: %v", err)
}

return aliceChanPoint, bobChanPoint, carol
}
Loading