From aed568b6ed558a68ffafa9023a3afc7c9710f3b7 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Fri, 23 Aug 2024 18:19:09 -0700 Subject: [PATCH 01/10] lnwallet: properly set aux HTLC blob on retransmission Before this commit, we weren't properly setting the aux HTLC blob when we went to retransmit a signature. We fix this by setting the `ExtraData` field as expected in the `CommitSig` message. --- lnwallet/channel.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index fc279c264a9..7f8724c0278 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -4946,6 +4946,7 @@ func (lc *LightningChannel) ProcessChanSyncMsg( CommitSig: newCommit.CommitSig, HtlcSigs: newCommit.HtlcSigs, PartialSig: newCommit.PartialSig, + ExtraData: newCommit.AuxSigBlob, } updates = append(updates, commitSig) @@ -5500,9 +5501,9 @@ func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel, // store in the custom records map so we can write to // disk later. sigType := htlcCustomSigType.TypeVal() - htlc.CustomRecords[uint64(sigType)] = auxSig.UnwrapOr( - nil, - ) + auxSig.WhenSome(func(sigB tlv.Blob) { + htlc.CustomRecords[uint64(sigType)] = sigB + }) auxVerifyJobs = append(auxVerifyJobs, auxVerifyJob) } From f03a6d5a0bb3feccc940cfa13ab8156e7dd0519c Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 28 Aug 2024 18:23:14 -0500 Subject: [PATCH 02/10] lnwallet: ensure we re-sign retransmitted commits for taproot channels In this commit, we fix an existing bug with the taproot channel type that can cause force closes if a peer disconnects while attempting to send the commitment signature. Before this commit, since the `PartialSig` we send is never committed to disk, the version read wouldn't contain the musig2 partial sig. We never write these signatures to disk, as each time we make a new session, we need to generate fresh nonces to avoid nonce-reuse. Due to the above interaction, if we went to re-send a signature after a disconnection, the `CommitSig` message we sent wouldn't actualy contain a `PartialSigWithNonce`, causing a protocol error. --- lnwallet/channel.go | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 7f8724c0278..0c3e4c9ce1b 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -4748,6 +4748,27 @@ func (lc *LightningChannel) SignNextCommitment() (*NewCommitState, error) { }, nil } +// resignMusigCommit is used to resign a commitment transaction for taproot +// channels when we need to retransmit a signature after a channel reestablish +// message. Taproot channels use musig2, which means we must use fresh nonces +// each time. After we receive the channel reestablish message, we learn the +// nonce we need to use for the remote party. As a result, we need to generate +// the partial signature again with the new nonce. +func (lc *LightningChannel) resignMusigCommit(commitTx *wire.MsgTx, +) (lnwire.OptPartialSigWithNonceTLV, error) { + + remoteSession := lc.musigSessions.RemoteSession + musig, err := remoteSession.SignCommit(commitTx) + if err != nil { + var none lnwire.OptPartialSigWithNonceTLV + return none, err + } + + partialSig := lnwire.MaybePartialSigWithNonce(musig.ToWireSig()) + + return partialSig, nil +} + // ProcessChanSyncMsg processes a ChannelReestablish message sent by the remote // connection upon re establishment of our connection with them. This method // will return a single message if we are currently out of sync, otherwise a @@ -5026,12 +5047,23 @@ func (lc *LightningChannel) ProcessChanSyncMsg( commitUpdates = append(commitUpdates, logUpdate.UpdateMsg) } + // If this is a taproot channel, then we need to regenerate the + // musig2 signature for the remote party, using their fresh + // nonce. + if lc.channelState.ChanType.IsTaproot() { + partialSig, err := lc.resignMusigCommit( + commitDiff.Commitment.CommitTx, + ) + if err != nil { + return nil, nil, nil, err + } + + commitDiff.CommitSig.PartialSig = partialSig + } + // With the batch of updates accumulated, we'll now re-send the // original CommitSig message required to re-sync their remote // commitment chain with our local version of their chain. - // - // TODO(roasbeef): need to re-sign commitment states w/ - // fresh nonce commitUpdates = append(commitUpdates, commitDiff.CommitSig) // NOTE: If a revocation is not owed, then updates is empty. From f5b57e8d0895530362b6060dc651b9548034f8cf Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 28 Aug 2024 18:24:20 -0500 Subject: [PATCH 03/10] lnwallet: extract initMusigNonce from initRevocationWindows This'll be useful later to make some enhancements to the existing unit tests. --- lnwallet/test_utils.go | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/lnwallet/test_utils.go b/lnwallet/test_utils.go index d7ac5df3e48..8ff8860e969 100644 --- a/lnwallet/test_utils.go +++ b/lnwallet/test_utils.go @@ -431,6 +431,25 @@ func CreateTestChannels(t *testing.T, chanType channeldb.ChannelType, return channelAlice, channelBob, nil } +// initMusigNonce is used to manually setup musig2 nonces for a new channel, +// outside the normal chan-reest flow. +func initMusigNonce(chanA, chanB *LightningChannel) error { + chanANonces, err := chanA.GenMusigNonces() + if err != nil { + return err + } + chanBNonces, err := chanB.GenMusigNonces() + if err != nil { + return err + } + + if err := chanA.InitRemoteMusigNonces(chanBNonces); err != nil { + return err + } + + return chanB.InitRemoteMusigNonces(chanANonces) +} + // initRevocationWindows simulates a new channel being opened within the p2p // network by populating the initial revocation windows of the passed // commitment state machines. @@ -439,19 +458,7 @@ func initRevocationWindows(chanA, chanB *LightningChannel) error { // either FundingLocked or ChannelReestablish by calling // InitRemoteMusigNonces for both sides. if chanA.channelState.ChanType.IsTaproot() { - chanANonces, err := chanA.GenMusigNonces() - if err != nil { - return err - } - chanBNonces, err := chanB.GenMusigNonces() - if err != nil { - return err - } - - if err := chanA.InitRemoteMusigNonces(chanBNonces); err != nil { - return err - } - if err := chanB.InitRemoteMusigNonces(chanANonces); err != nil { + if err := initMusigNonce(chanA, chanB); err != nil { return err } } From 819de822628304b95a98907b423073bec8305612 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 28 Aug 2024 18:26:51 -0500 Subject: [PATCH 04/10] lnwallet: expand chan sync tests to cover taproot channels In this commit, we expand some of the existing chan sync tests to cover taproot channels (the others already did). Along the way, we always assert that the `PartialSig` is populated on retransmission. In addition, we now send the new commit sig rather than the existing in-memory one to test the new logic that re-signs the commitment. --- lnwallet/channel_test.go | 275 +++++++++++++++++++++++++++++++++++---- 1 file changed, 250 insertions(+), 25 deletions(-) diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 9b93ef6ac99..eec3515459b 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -3025,19 +3025,11 @@ func restartChannel(channelOld *LightningChannel) (*LightningChannel, error) { return channelNew, nil } -// TestChanSyncOweCommitment tests that if Bob restarts (and then Alice) before -// he receives Alice's CommitSig message, then Alice concludes that she needs -// to re-send the CommitDiff. After the diff has been sent, both nodes should -// resynchronize and be able to complete the dangling commit. -func TestChanSyncOweCommitment(t *testing.T) { - t.Parallel() - +func testChanSyncOweCommitment(t *testing.T, chanType channeldb.ChannelType) { // Create a test channel which will be used for the duration of this // unittest. The channel will be funded evenly with Alice having 5 BTC, // and Bob having 5 BTC. - aliceChannel, bobChannel, err := CreateTestChannels( - t, channeldb.SingleFunderTweaklessBit, - ) + aliceChannel, bobChannel, err := CreateTestChannels(t, chanType) require.NoError(t, err, "unable to create test channels") var fakeOnionBlob [lnwire.OnionPacketSize]byte @@ -3112,6 +3104,15 @@ func TestChanSyncOweCommitment(t *testing.T) { aliceNewCommit, err := aliceChannel.SignNextCommitment() require.NoError(t, err, "unable to sign commitment") + // If this is a taproot channel, then we'll generate fresh verification + // nonce for both sides. + if chanType.IsTaproot() { + _, err = aliceChannel.GenMusigNonces() + require.NoError(t, err) + _, err = bobChannel.GenMusigNonces() + require.NoError(t, err) + } + // Bob doesn't get this message so upon reconnection, they need to // synchronize. Alice should conclude that she owes Bob a commitment, // while Bob should think he's properly synchronized. @@ -3123,7 +3124,7 @@ func TestChanSyncOweCommitment(t *testing.T) { // This is a helper function that asserts Alice concludes that she // needs to retransmit the exact commitment that we failed to send // above. - assertAliceCommitRetransmit := func() { + assertAliceCommitRetransmit := func() *lnwire.CommitSig { aliceMsgsToSend, _, _, err := aliceChannel.ProcessChanSyncMsg( bobSyncMsg, ) @@ -3188,12 +3189,25 @@ func TestChanSyncOweCommitment(t *testing.T) { len(commitSigMsg.HtlcSigs)) } for i, htlcSig := range commitSigMsg.HtlcSigs { - if htlcSig != aliceNewCommit.HtlcSigs[i] { + if !bytes.Equal(htlcSig.RawBytes(), + aliceNewCommit.HtlcSigs[i].RawBytes()) { + t.Fatalf("htlc sig msgs don't match: "+ - "expected %x got %x", - aliceNewCommit.HtlcSigs[i], htlcSig) + "expected %v got %v", + spew.Sdump(aliceNewCommit.HtlcSigs[i]), + spew.Sdump(htlcSig)) } } + + // If this is a taproot channel, then partial sig information + // should be present in the commit sig sent over. This + // signature will be re-regenerated, so we can't compare it + // with the old one. + if chanType.IsTaproot() { + require.True(t, commitSigMsg.PartialSig.IsSome()) + } + + return commitSigMsg } // Alice should detect that she needs to re-send 5 messages: the 3 @@ -3214,14 +3228,19 @@ func TestChanSyncOweCommitment(t *testing.T) { // send the exact same set of messages. aliceChannel, err = restartChannel(aliceChannel) require.NoError(t, err, "unable to restart alice") - assertAliceCommitRetransmit() - // TODO(roasbeef): restart bob as well??? + // To properly simulate a restart, we'll use the *new* signature that + // would send in an actual p2p setting. + aliceReCommitSig := assertAliceCommitRetransmit() // At this point, we should be able to resume the prior state update // without any issues, resulting in Alice settling the 3 htlc's, and // adding one of her own. - err = bobChannel.ReceiveNewCommitment(aliceNewCommit.CommitSigs) + err = bobChannel.ReceiveNewCommitment(&CommitSigs{ + CommitSig: aliceReCommitSig.CommitSig, + HtlcSigs: aliceReCommitSig.HtlcSigs, + PartialSig: aliceReCommitSig.PartialSig, + }) require.NoError(t, err, "bob unable to process alice's commitment") bobRevocation, _, _, err := bobChannel.RevokeCurrentCommitment() require.NoError(t, err, "unable to revoke bob commitment") @@ -3308,16 +3327,134 @@ func TestChanSyncOweCommitment(t *testing.T) { } } -// TestChanSyncOweCommitmentPendingRemote asserts that local updates are applied -// to the remote commit across restarts. -func TestChanSyncOweCommitmentPendingRemote(t *testing.T) { +// TestChanSyncOweCommitment tests that if Bob restarts (and then Alice) before +// he receives Alice's CommitSig message, then Alice concludes that she needs +// to re-send the CommitDiff. After the diff has been sent, both nodes should +// resynchronize and be able to complete the dangling commit. +func TestChanSyncOweCommitment(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + chanType channeldb.ChannelType + }{ + { + name: "tweakless", + chanType: channeldb.SingleFunderTweaklessBit, + }, + { + name: "anchors", + chanType: channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit, + }, + { + name: "taproot", + chanType: channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit | + channeldb.SimpleTaprootFeatureBit, + }, + { + name: "taproot with tapscript root", + chanType: channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit | + channeldb.SimpleTaprootFeatureBit | + channeldb.TapscriptRootBit, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + testChanSyncOweCommitment(t, tc.chanType) + }) + } +} + +// TestChanSyncOweCommitmentAuxSigner tests that when one party owes a +// signature after a channel reest, if an aux signer is present, then the +// signature message sent includes the additional aux sigs as extra data. +func TestChanSyncOweCommitmentAuxSigner(t *testing.T) { t.Parallel() // Create a test channel which will be used for the duration of this - // unittest. - aliceChannel, bobChannel, err := CreateTestChannels( - t, channeldb.SingleFunderTweaklessBit, + // unittest. The channel will be funded evenly with Alice having 5 BTC, + // and Bob having 5 BTC. + chanType := channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit | channeldb.SimpleTaprootFeatureBit | + channeldb.TapscriptRootBit + + aliceChannel, bobChannel, err := CreateTestChannels(t, chanType) + require.NoError(t, err, "unable to create test channels") + + // We'll now manually attach an aux signer to Alice's channel. + auxSigner := &auxSignerMock{} + aliceChannel.auxSigner = fn.Some[AuxSigner](auxSigner) + + var fakeOnionBlob [lnwire.OnionPacketSize]byte + copy( + fakeOnionBlob[:], + bytes.Repeat([]byte{0x05}, lnwire.OnionPacketSize), ) + + // To kick things off, we'll have Alice send a single HTLC to Bob. + htlcAmt := lnwire.NewMSatFromSatoshis(20000) + var bobPreimage [32]byte + copy(bobPreimage[:], bytes.Repeat([]byte{0}, 32)) + rHash := sha256.Sum256(bobPreimage[:]) + h := &lnwire.UpdateAddHTLC{ + PaymentHash: rHash, + Amount: htlcAmt, + Expiry: uint32(10), + OnionBlob: fakeOnionBlob, + } + + _, err = aliceChannel.AddHTLC(h, nil) + require.NoError(t, err, "unable to recv bob's htlc: %v", err) + + // We'll set up the mock to expect calls to PackSigs and also + // SubmitSubmitSecondLevelSigBatch. + sigBlobs := bytes.Repeat([]byte{0x01}, 64) + auxSigner.On( + "SubmitSecondLevelSigBatch", mock.Anything, mock.Anything, + mock.Anything, + ).Return(nil).Twice() + auxSigner.On("PackSigs", mock.Anything).Return(fn.Some(sigBlobs), nil) + + _, err = aliceChannel.SignNextCommitment() + require.NoError(t, err, "unable to sign commitment") + + _, err = aliceChannel.GenMusigNonces() + require.NoError(t, err, "unable to generate musig nonces") + + // Next we'll simulate a restart, by having Bob send over a chan sync + // message to Alice. + bobSyncMsg, err := bobChannel.channelState.ChanSyncMsg() + require.NoError(t, err, "unable to produce chan sync msg") + + aliceMsgsToSend, _, _, err := aliceChannel.ProcessChanSyncMsg( + bobSyncMsg, + ) + require.NoError(t, err) + require.Len(t, aliceMsgsToSend, 2) + + // The first message should be an update add HTLC. + require.IsType(t, &lnwire.UpdateAddHTLC{}, aliceMsgsToSend[0]) + + // The second should be a commit sig message. + sigMsg, ok := aliceMsgsToSend[1].(*lnwire.CommitSig) + require.True(t, ok) + require.True(t, sigMsg.PartialSig.IsSome()) + + // The signature should have the ExtraData field set. + require.NotNil(t, sigMsg.ExtraData) + + // TODO(roasbeef): also make one for owe revocation +} + +func testChanSyncOweCommitmentPendingRemote(t *testing.T, + chanType channeldb.ChannelType) { + + // Create a test channel which will be used for the duration of this + // unittest. + aliceChannel, bobChannel, err := CreateTestChannels(t, chanType) require.NoError(t, err, "unable to create test channels") var fakeOnionBlob [lnwire.OnionPacketSize]byte @@ -3400,6 +3537,12 @@ func TestChanSyncOweCommitmentPendingRemote(t *testing.T) { bobChannel, err = restartChannel(bobChannel) require.NoError(t, err, "unable to restart bob") + // If this is a taproot channel, then since Bob just restarted, we need + // to exchange nonces once again. + if chanType.IsTaproot() { + require.NoError(t, initMusigNonce(aliceChannel, bobChannel)) + } + // Bob signs the commitment he owes. bobNewCommit, err := bobChannel.SignNextCommitment() require.NoError(t, err, "unable to sign commitment") @@ -3425,6 +3568,45 @@ func TestChanSyncOweCommitmentPendingRemote(t *testing.T) { } } +// TestChanSyncOweCommitmentPendingRemote asserts that local updates are applied +// to the remote commit across restarts. +func TestChanSyncOweCommitmentPendingRemote(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + chanType channeldb.ChannelType + }{ + { + name: "tweakless", + chanType: channeldb.SingleFunderTweaklessBit, + }, + { + name: "anchors", + chanType: channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit, + }, + { + name: "taproot", + chanType: channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit | + channeldb.SimpleTaprootFeatureBit, + }, + { + name: "taproot with tapscript root", + chanType: channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit | + channeldb.SimpleTaprootFeatureBit | + channeldb.TapscriptRootBit, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + testChanSyncOweCommitmentPendingRemote(t, tc.chanType) + }) + } +} + // testChanSyncOweRevocation is the internal version of // TestChanSyncOweRevocation that is parameterized based on the type of channel // being used in the test. @@ -3574,8 +3756,6 @@ func testChanSyncOweRevocation(t *testing.T, chanType channeldb.ChannelType) { assertAliceOwesRevoke() - // TODO(roasbeef): restart bob too??? - // We'll continue by then allowing bob to process Alice's revocation // message. _, _, _, _, err = bobChannel.ReceiveRevocation(aliceRevocation) @@ -3624,6 +3804,23 @@ func TestChanSyncOweRevocation(t *testing.T) { testChanSyncOweRevocation(t, taprootBits) }) + t.Run("taproot", func(t *testing.T) { + taprootBits := channeldb.SimpleTaprootFeatureBit | + channeldb.AnchorOutputsBit | + channeldb.ZeroHtlcTxFeeBit | + channeldb.SingleFunderTweaklessBit + + testChanSyncOweRevocation(t, taprootBits) + }) + t.Run("taproot with tapscript root", func(t *testing.T) { + taprootBits := channeldb.SimpleTaprootFeatureBit | + channeldb.AnchorOutputsBit | + channeldb.ZeroHtlcTxFeeBit | + channeldb.SingleFunderTweaklessBit | + channeldb.TapscriptRootBit + + testChanSyncOweRevocation(t, taprootBits) + }) } func testChanSyncOweRevocationAndCommit(t *testing.T, @@ -3753,6 +3950,14 @@ func testChanSyncOweRevocationAndCommit(t *testing.T, bobNewCommit.HtlcSigs[i]) } } + + // If this is a taproot channel, then partial sig information + // should be present in the commit sig sent over. This + // signature will be re-regenerated, so we can't compare it + // with the old one. + if chanType.IsTaproot() { + require.True(t, bobReCommitSigMsg.PartialSig.IsSome()) + } } // We expect Bob to send exactly two messages: first his revocation @@ -3809,6 +4014,15 @@ func TestChanSyncOweRevocationAndCommit(t *testing.T) { testChanSyncOweRevocationAndCommit(t, taprootBits) }) + t.Run("taproot with tapscript root", func(t *testing.T) { + taprootBits := channeldb.SimpleTaprootFeatureBit | + channeldb.AnchorOutputsBit | + channeldb.ZeroHtlcTxFeeBit | + channeldb.SingleFunderTweaklessBit | + channeldb.TapscriptRootBit + + testChanSyncOweRevocationAndCommit(t, taprootBits) + }) } func testChanSyncOweRevocationAndCommitForceTransition(t *testing.T, @@ -4040,6 +4254,17 @@ func TestChanSyncOweRevocationAndCommitForceTransition(t *testing.T) { t, taprootBits, ) }) + t.Run("taproot with tapscript root", func(t *testing.T) { + taprootBits := channeldb.SimpleTaprootFeatureBit | + channeldb.AnchorOutputsBit | + channeldb.ZeroHtlcTxFeeBit | + channeldb.SingleFunderTweaklessBit | + channeldb.TapscriptRootBit + + testChanSyncOweRevocationAndCommitForceTransition( + t, taprootBits, + ) + }) } // TestChanSyncFailure tests the various scenarios during channel sync where we From e4c97d23c8b0fe3a43c8b7f7706c050c2278da67 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Thu, 29 Aug 2024 18:44:24 -0500 Subject: [PATCH 05/10] build: add rapid as a new dep --- go.mod | 1 + go.sum | 2 ++ 2 files changed, 3 insertions(+) diff --git a/go.mod b/go.mod index 88741ad8d75..739e2746dce 100644 --- a/go.mod +++ b/go.mod @@ -62,6 +62,7 @@ require ( google.golang.org/protobuf v1.33.0 gopkg.in/macaroon-bakery.v2 v2.0.1 gopkg.in/macaroon.v2 v2.0.0 + pgregory.net/rapid v1.1.0 ) require ( diff --git a/go.sum b/go.sum index 09e515c0689..4931c5e435b 100644 --- a/go.sum +++ b/go.sum @@ -1070,6 +1070,8 @@ modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA= modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= +pgregory.net/rapid v1.1.0 h1:CMa0sjHSru3puNx+J0MIAuiiEV4N0qj8/cMWGBBCsjw= +pgregory.net/rapid v1.1.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= From 78f31da2fb3dff11023d539d4844684679115556 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Thu, 29 Aug 2024 18:45:20 -0500 Subject: [PATCH 06/10] lnwire: refactor custom vs official tlv parsing into new func We'll use this to update CommitSig in the next commit. --- lnwire/extra_bytes.go | 60 ++++++++++++++++++++++++++++++++++++++ lnwire/extra_bytes_test.go | 47 +++++++++++++++++++++++++++++ lnwire/update_add_htlc.go | 35 +++++++--------------- 3 files changed, 118 insertions(+), 24 deletions(-) diff --git a/lnwire/extra_bytes.go b/lnwire/extra_bytes.go index b90988a7711..4b6a953a04d 100644 --- a/lnwire/extra_bytes.go +++ b/lnwire/extra_bytes.go @@ -1,9 +1,15 @@ package lnwire +// For some reason golangci-lint has a false positive on the sort order of the +// imports for the new "maps" package... We need the nolint directive here to +// ignore that. +// +//nolint:gci import ( "bytes" "fmt" "io" + "maps" "github.com/lightningnetwork/lnd/tlv" ) @@ -194,3 +200,57 @@ func EncodeMessageExtraData(extraData *ExtraOpaqueData, // are all properly sorted. return extraData.PackRecords(recordProducers...) } + +// wireTlvMap is a struct that holds the official records and custom records in +// a TLV type map. This is useful for ensuring that the set of custom TLV +// records are handled properly and don't overlap with the official records. +type wireTlvMap struct { + // officialTypes is the set of official records that are defined in the + // spec. + officialTypes tlv.TypeMap + + // customTypes is the set of custom records that are not defined in + // spec, and are used by higher level applications. + customTypes tlv.TypeMap +} + +// newWireTlvMap creates a new tlv.TypeMap from the given set of parsed TLV +// records. A struct with two maps are returned: +// +// 1. officialTypes: the set of official records that are defined in the +// spec. +// +// 2. customTypes: the set of custom records that are not defined in +// the spec. +func newWireTlvMap(typeMap tlv.TypeMap) wireTlvMap { + officialRecords := maps.Clone(typeMap) + + // Any records from the extra data TLV map which are in the custom + // records TLV type range will be included in the custom records field + // and removed from the extra data field. + customRecordsTlvMap := make(tlv.TypeMap, len(typeMap)) + for k, v := range typeMap { + // Skip records that are not in the custom records TLV type + // range. + if k < MinCustomRecordsTlvType { + continue + } + + // Include the record in the custom records map. + customRecordsTlvMap[k] = v + + // Now that the record is included in the custom records map, + // we can remove it from the extra data TLV map. + delete(officialRecords, k) + } + + return wireTlvMap{ + officialTypes: officialRecords, + customTypes: customRecordsTlvMap, + } +} + +// Len returns the total number of records in the wireTlvMap. +func (w *wireTlvMap) Len() int { + return len(w.officialTypes) + len(w.customTypes) +} diff --git a/lnwire/extra_bytes_test.go b/lnwire/extra_bytes_test.go index b05b19db5f8..98c7eeefca1 100644 --- a/lnwire/extra_bytes_test.go +++ b/lnwire/extra_bytes_test.go @@ -7,8 +7,11 @@ import ( "testing" "testing/quick" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" + "pgregory.net/rapid" ) // TestExtraOpaqueDataEncodeDecode tests that we're able to encode/decode @@ -206,3 +209,47 @@ func TestPackRecords(t *testing.T) { require.Equal(t, recordBytes2, extractedRecords[tlvType2.TypeVal()]) require.Equal(t, recordBytes3, extractedRecords[tlvType3.TypeVal()]) } + +// TestNewWireTlvMap tests the newWireTlvMap function using property-based +// testing. +func TestNewWireTlvMap(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + // Make a random type map, using the generic Make which'll + // figure out what type to generate. + tlvTypeMap := rapid.Make[tlv.TypeMap]().Draw(t, "typeMap") + + // Create a wireTlvMap from the generated type map, this'll + // operate on our random input. + result := newWireTlvMap(tlvTypeMap) + + // Property 1: The sum of lengths of officialTypes and + // customTypes should equal the length of the input typeMap. + require.Equal(t, len(tlvTypeMap), result.Len()) + + // Property 2: All types in customTypes should be >= + // MinCustomRecordsTlvType. + require.True(t, fn.All(func(k tlv.Type) bool { + return uint64(k) >= uint64(MinCustomRecordsTlvType) + }, maps.Keys(result.customTypes))) + + // Property 3: All types in officialTypes should be < + // MinCustomRecordsTlvType. + require.True(t, fn.All(func(k tlv.Type) bool { + return uint64(k) < uint64(MinCustomRecordsTlvType) + }, maps.Keys(result.officialTypes))) + + // Property 4: The union of officialTypes and customTypes + // should equal the input typeMap. + unionMap := make(tlv.TypeMap) + maps.Copy(unionMap, result.officialTypes) + maps.Copy(unionMap, result.customTypes) + require.Equal(t, tlvTypeMap, unionMap) + + // Property 5: No type should appear in both officialTypes and + // customTypes. + require.True(t, fn.All(func(k tlv.Type) bool { + _, exists := result.officialTypes[k] + return !exists + }, maps.Keys(result.customTypes))) + }) +} diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go index 3669f81e89a..982e66b8330 100644 --- a/lnwire/update_add_htlc.go +++ b/lnwire/update_add_htlc.go @@ -131,29 +131,14 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { delete(extraDataTlvMap, c.BlindingPoint.TlvType()) } - // Any records from the extra data TLV map which are in the custom - // records TLV type range will be included in the custom records field - // and removed from the extra data field. - customRecordsTlvMap := make(tlv.TypeMap, len(extraDataTlvMap)) - for k, v := range extraDataTlvMap { - // Skip records that are not in the custom records TLV type - // range. - if k < MinCustomRecordsTlvType { - continue - } - - // Include the record in the custom records map. - customRecordsTlvMap[k] = v - - // Now that the record is included in the custom records map, - // we can remove it from the extra data TLV map. - delete(extraDataTlvMap, k) - } + // Parse through the remaining extra data map to separate the custom + // records, from the set of official records. + tlvTypes := newWireTlvMap(extraDataTlvMap) // Set the custom records field to the custom records specific TLV // record map. customRecords, err := NewCustomRecordsFromTlvTypeMap( - customRecordsTlvMap, + tlvTypes.customTypes, ) if err != nil { return err @@ -162,21 +147,23 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { // Set custom records to nil if we didn't parse anything out of it so // that we can use assert.Equal in tests. - if len(customRecordsTlvMap) == 0 { + if len(customRecords) == 0 { c.CustomRecords = nil } // Set extra data to nil if we didn't parse anything out of it so that // we can use assert.Equal in tests. - if len(extraDataTlvMap) == 0 { + if len(tlvTypes.officialTypes) == 0 { c.ExtraData = nil return nil } // Encode the remaining records back into the extra data field. These - // records are not in the custom records TLV type range and do not - // have associated fields in the UpdateAddHTLC struct. - c.ExtraData, err = NewExtraOpaqueDataFromTlvTypeMap(extraDataTlvMap) + // records are not in the custom records TLV type range and do not have + // associated fields in the UpdateAddHTLC struct. + c.ExtraData, err = NewExtraOpaqueDataFromTlvTypeMap( + tlvTypes.officialTypes, + ) if err != nil { return err } From 681f44fd16d6b045d87a25a17e950d3030d6962f Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Thu, 29 Aug 2024 20:51:09 -0500 Subject: [PATCH 07/10] lnwire: modify TestLightningWireProtocol to use sub-tests This way, it's possible to run induvidual tests to target failures. --- lnwire/lnwire_test.go | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index c9413d4c2f4..f326cf10a85 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -887,7 +887,7 @@ func TestLightningWireProtocol(t *testing.T) { v[0] = reflect.ValueOf(ks) }, MsgCommitSig: func(v []reflect.Value, r *rand.Rand) { - req := NewCommitSig() + req := &CommitSig{} if _, err := r.Read(req.ChanID[:]); err != nil { t.Fatalf("unable to generate chan id: %v", err) return @@ -1653,22 +1653,28 @@ func TestLightningWireProtocol(t *testing.T) { }, } for _, test := range tests { - var config *quick.Config - - // If the type defined is within the custom type gen map above, - // then we'll modify the default config to use this Value - // function that knows how to generate the proper types. - if valueGen, ok := customTypeGen[test.msgType]; ok { - config = &quick.Config{ - Values: valueGen, + t.Run(test.msgType.String(), func(t *testing.T) { + var config *quick.Config + + // If the type defined is within the custom type gen + // map above, then we'll modify the default config to + // use this Value function that knows how to generate + // the proper types. + if valueGen, ok := customTypeGen[test.msgType]; ok { + config = &quick.Config{ + Values: valueGen, + } } - } - t.Logf("Running fuzz tests for msgType=%v", test.msgType) - if err := quick.Check(test.scenario, config); err != nil { - t.Fatalf("fuzz checks for msg=%v failed: %v", - test.msgType, err) - } + t.Logf("Running fuzz tests for msgType=%v", + test.msgType) + + err := quick.Check(test.scenario, config) + if err != nil { + t.Fatalf("fuzz checks for msg=%v failed: %v", + test.msgType, err) + } + }) } } From 7dd3a5b361876bb3f3de6eb170c1e590c22e458c Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Thu, 29 Aug 2024 20:52:29 -0500 Subject: [PATCH 08/10] lnwire: add CustomRecords field to CommitSig In a future commit, we'll use the new field to ensure that if we add any additional records, they aren't over written by the TLV records that would be encoded. --- lnwire/commit_sig.go | 98 ++++++++++++++++++++++++++++++++++++------- lnwire/lnwire_test.go | 2 + 2 files changed, 84 insertions(+), 16 deletions(-) diff --git a/lnwire/commit_sig.go b/lnwire/commit_sig.go index 7deb64ae1c1..c3f1a89b6bb 100644 --- a/lnwire/commit_sig.go +++ b/lnwire/commit_sig.go @@ -2,6 +2,7 @@ package lnwire import ( "bytes" + "fmt" "io" "github.com/lightningnetwork/lnd/tlv" @@ -45,6 +46,11 @@ type CommitSig struct { // being signed for. In this case, the above Sig type MUST be blank. PartialSig OptPartialSigWithNonceTLV + // CustomRecords maps TLV types to byte slices, storing arbitrary data + // intended for inclusion in the ExtraData field of the CommitSig + // message. + CustomRecords CustomRecords + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -62,8 +68,8 @@ func NewCommitSig() *CommitSig { // interface. var _ Message = (*CommitSig)(nil) -// Decode deserializes a serialized CommitSig message stored in the -// passed io.Reader observing the specified protocol version. +// Decode deserializes a serialized CommitSig message stored in the passed +// io.Reader observing the specified protocol version. // // This is part of the lnwire.Message interface. func (c *CommitSig) Decode(r io.Reader, pver uint32) error { @@ -90,29 +96,57 @@ func (c *CommitSig) Decode(r io.Reader, pver uint32) error { // Set the corresponding TLV types if they were included in the stream. if val, ok := typeMap[c.PartialSig.TlvType()]; ok && val == nil { c.PartialSig = tlv.SomeRecordT(partialSig) + + // Remove the entry from the TLV map. Anything left in the map + // will be included in the custom records field. + delete(typeMap, c.PartialSig.TlvType()) } - if len(tlvRecords) != 0 { - c.ExtraData = tlvRecords + // Parse through the remaining extra data map to separate the custom + // records, from the set of official records. + tlvTypes := newWireTlvMap(typeMap) + + // Set the custom records field to the custom records specific TLV + // record map. + customRecords, err := NewCustomRecordsFromTlvTypeMap( + tlvTypes.customTypes, + ) + if err != nil { + return err + } + c.CustomRecords = customRecords + + // Set custom records to nil if we didn't parse anything out of it so + // that we can use assert.Equal in tests. + if len(customRecords) == 0 { + c.CustomRecords = nil + } + + // Set extra data to nil if we didn't parse anything out of it so that + // we can use assert.Equal in tests. + if len(tlvTypes.officialTypes) == 0 { + c.ExtraData = nil + return nil + } + + // Encode the remaining records back into the extra data field. These + // records are not in the custom records TLV type range and do not have + // associated fields in the CommitSig struct. + c.ExtraData, err = NewExtraOpaqueDataFromTlvTypeMap( + tlvTypes.officialTypes, + ) + if err != nil { + return err } return nil } -// Encode serializes the target CommitSig into the passed io.Writer -// observing the protocol version specified. +// Encode serializes the target CommitSig into the passed io.Writer observing +// the protocol version specified. // // This is part of the lnwire.Message interface. func (c *CommitSig) Encode(w *bytes.Buffer, pver uint32) error { - recordProducers := make([]tlv.RecordProducer, 0, 1) - c.PartialSig.WhenSome(func(sig PartialSigWithNonceTLV) { - recordProducers = append(recordProducers, &sig) - }) - err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) - if err != nil { - return err - } - if err := WriteChannelID(w, c.ChanID); err != nil { return err } @@ -125,7 +159,39 @@ func (c *CommitSig) Encode(w *bytes.Buffer, pver uint32) error { return err } - return WriteBytes(w, c.ExtraData) + // Construct a slice of all the records that we should include in the + // message extra data field. We will start by including any records + // from the extra data field. + msgExtraDataRecords, err := c.ExtraData.RecordProducers() + if err != nil { + return err + } + + // Include the partial sig record if it is set. + c.PartialSig.WhenSome(func(sig PartialSigWithNonceTLV) { + msgExtraDataRecords = append(msgExtraDataRecords, &sig) + }) + + // Include custom records in the extra data wire field if they are + // present. Ensure that the custom records are validated before + // encoding them. + if err := c.CustomRecords.Validate(); err != nil { + return fmt.Errorf("custom records validation error: %w", err) + } + + // Extend the message extra data records slice with TLV records from + // the custom records field. + customTlvRecords := c.CustomRecords.RecordProducers() + msgExtraDataRecords = append(msgExtraDataRecords, customTlvRecords...) + + // We will now construct the message extra data field that will be + // encoded into the byte writer. + var msgExtraData ExtraOpaqueData + if err := msgExtraData.PackRecords(msgExtraDataRecords...); err != nil { + return err + } + + return WriteBytes(w, msgExtraData) } // MsgType returns the integer uniquely identifying this message type on the diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index f326cf10a85..5531f883ee2 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -915,6 +915,8 @@ func TestLightningWireProtocol(t *testing.T) { } } + req.CustomRecords = randCustomRecords(t, r) + // 50/50 chance to attach a partial sig. if r.Int31()%2 == 0 { req.PartialSig = somePartialSigWithNonce(t, r) From f41dd862d06642f15cb47fac21227f3ab88d07fe Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Thu, 29 Aug 2024 21:19:31 -0500 Subject: [PATCH 09/10] htlcswitch+lnwallet: use CustomRecords for aux sig blobs In this commit, we start to use the set of CustomRecords instead of ExtraData for the aux sig blobs. --- htlcswitch/link.go | 27 +++++++++++++++++++++------ lnwallet/channel.go | 35 +++++++++++++++++++++++++++-------- 2 files changed, 48 insertions(+), 14 deletions(-) diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 468eba13198..a59cedf06f3 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -2164,11 +2164,21 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) { // We just received a new updates to our local commitment // chain, validate this new commitment, closing the link if // invalid. + auxSigBlob, err := msg.CustomRecords.Serialize() + if err != nil { + l.fail( + LinkFailureError{code: ErrInternalError}, + "unable to serialize custom records: %v", + err, + ) + + return + } err = l.channel.ReceiveNewCommitment(&lnwallet.CommitSigs{ CommitSig: msg.CommitSig, HtlcSigs: msg.HtlcSigs, PartialSig: msg.PartialSig, - AuxSigBlob: msg.ExtraData, + AuxSigBlob: auxSigBlob, }) if err != nil { // If we were unable to reconstruct their proposed @@ -2577,12 +2587,17 @@ func (l *channelLink) updateCommitTx() error { default: } + auxBlobRecords, err := lnwire.ParseCustomRecords(newCommit.AuxSigBlob) + if err != nil { + return fmt.Errorf("error parsing aux sigs: %w", err) + } + commitSig := &lnwire.CommitSig{ - ChanID: l.ChanID(), - CommitSig: newCommit.CommitSig, - HtlcSigs: newCommit.HtlcSigs, - PartialSig: newCommit.PartialSig, - ExtraData: newCommit.AuxSigBlob, + ChanID: l.ChanID(), + CommitSig: newCommit.CommitSig, + HtlcSigs: newCommit.HtlcSigs, + PartialSig: newCommit.PartialSig, + CustomRecords: auxBlobRecords, } l.cfg.Peer.SendMessage(false, commitSig) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 0c3e4c9ce1b..b420b6134c0 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -4028,6 +4028,10 @@ func (lc *LightningChannel) createCommitDiff(newCommit *commitment, if err != nil { return nil, fmt.Errorf("error packing aux sigs: %w", err) } + auxBlobRecords, err := lnwire.ParseCustomRecords(auxSigBlob) + if err != nil { + return nil, fmt.Errorf("error parsing aux sigs: %w", err) + } return &channeldb.CommitDiff{ Commitment: *diskCommit, @@ -4035,9 +4039,9 @@ func (lc *LightningChannel) createCommitDiff(newCommit *commitment, ChanID: lnwire.NewChanIDFromOutPoint( lc.channelState.FundingOutpoint, ), - CommitSig: commitSig, - HtlcSigs: htlcSigs, - ExtraData: auxSigBlob, + CommitSig: commitSig, + HtlcSigs: htlcSigs, + CustomRecords: auxBlobRecords, }, LogUpdates: logUpdates, OpenedCircuitKeys: openCircuitKeys, @@ -4737,12 +4741,18 @@ func (lc *LightningChannel) SignNextCommitment() (*NewCommitState, error) { // latest commitment update. lc.remoteCommitChain.addCommitment(newCommitView) + auxSigBlob, err := commitDiff.CommitSig.CustomRecords.Serialize() + if err != nil { + return nil, fmt.Errorf("unable to serialize aux sig "+ + "blob: %v", err) + } + return &NewCommitState{ CommitSigs: &CommitSigs{ CommitSig: sig, HtlcSigs: htlcSigs, PartialSig: lnwire.MaybePartialSigWithNonce(partialSig), - AuxSigBlob: commitDiff.CommitSig.ExtraData, + AuxSigBlob: auxSigBlob, }, PendingHTLCs: commitDiff.Commitment.Htlcs, }, nil @@ -4960,14 +4970,23 @@ func (lc *LightningChannel) ProcessChanSyncMsg( // If we signed this state, then we'll accumulate // another update to send over. case err == nil: + blobRecords, err := lnwire.ParseCustomRecords( + newCommit.AuxSigBlob, + ) + if err != nil { + sErr := fmt.Errorf("error parsing "+ + "aux sigs: %w", err) + return nil, nil, nil, sErr + } + commitSig := &lnwire.CommitSig{ ChanID: lnwire.NewChanIDFromOutPoint( lc.channelState.FundingOutpoint, ), - CommitSig: newCommit.CommitSig, - HtlcSigs: newCommit.HtlcSigs, - PartialSig: newCommit.PartialSig, - ExtraData: newCommit.AuxSigBlob, + CommitSig: newCommit.CommitSig, + HtlcSigs: newCommit.HtlcSigs, + PartialSig: newCommit.PartialSig, + CustomRecords: blobRecords, } updates = append(updates, commitSig) From 505843213a20aa21b5781eb293616c60ffcf2fc0 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Thu, 29 Aug 2024 21:20:37 -0500 Subject: [PATCH 10/10] lnwallet: add new TestChanSyncOweCommitmentAuxSigner test This test ensures that when we go to retransmit a signature, we also include the set of CustomRecords. --- lnwallet/channel_test.go | 26 +++++++++++++++++----- lnwallet/mock.go | 47 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 6 deletions(-) diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index eec3515459b..4d972ece49a 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -28,6 +28,7 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -3368,6 +3369,10 @@ func TestChanSyncOweCommitment(t *testing.T) { } } +type testSigBlob struct { + BlobInt tlv.RecordT[tlv.TlvType65634, uint16] +} + // TestChanSyncOweCommitmentAuxSigner tests that when one party owes a // signature after a channel reest, if an aux signer is present, then the // signature message sent includes the additional aux sigs as extra data. @@ -3411,12 +3416,23 @@ func TestChanSyncOweCommitmentAuxSigner(t *testing.T) { // We'll set up the mock to expect calls to PackSigs and also // SubmitSubmitSecondLevelSigBatch. - sigBlobs := bytes.Repeat([]byte{0x01}, 64) + var sigBlobBuf bytes.Buffer + sigBlob := testSigBlob{ + BlobInt: tlv.NewPrimitiveRecord[tlv.TlvType65634, uint16](5), + } + tlvStream, err := tlv.NewStream(sigBlob.BlobInt.Record()) + require.NoError(t, err, "unable to create tlv stream") + require.NoError(t, tlvStream.Encode(&sigBlobBuf)) + auxSigner.On( "SubmitSecondLevelSigBatch", mock.Anything, mock.Anything, mock.Anything, ).Return(nil).Twice() - auxSigner.On("PackSigs", mock.Anything).Return(fn.Some(sigBlobs), nil) + auxSigner.On( + "PackSigs", mock.Anything, + ).Return( + fn.Some(sigBlobBuf.Bytes()), nil, + ) _, err = aliceChannel.SignNextCommitment() require.NoError(t, err, "unable to sign commitment") @@ -3443,10 +3459,8 @@ func TestChanSyncOweCommitmentAuxSigner(t *testing.T) { require.True(t, ok) require.True(t, sigMsg.PartialSig.IsSome()) - // The signature should have the ExtraData field set. - require.NotNil(t, sigMsg.ExtraData) - - // TODO(roasbeef): also make one for owe revocation + // The signature should have the CustomRecords field set. + require.NotEmpty(t, sigMsg.CustomRecords) } func testChanSyncOweCommitmentPendingRemote(t *testing.T, diff --git a/lnwallet/mock.go b/lnwallet/mock.go index 1873de79a84..89c31ad9857 100644 --- a/lnwallet/mock.go +++ b/lnwallet/mock.go @@ -17,7 +17,11 @@ import ( "github.com/btcsuite/btcwallet/wallet/txauthor" "github.com/btcsuite/btcwallet/wtxmgr" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/mock" ) var ( @@ -384,3 +388,46 @@ func (*mockChainIO) GetBlockHeader( return nil, nil } + +type auxSignerMock struct { + mock.Mock +} + +func (a *auxSignerMock) SubmitSecondLevelSigBatch( + chanState *channeldb.OpenChannel, + commitTx *wire.MsgTx, sigJobs []AuxSigJob) error { + + args := a.Called(chanState, commitTx, sigJobs) + + // While we return, we'll also send back an instant response for the + // set of jobs. + for _, sigJob := range sigJobs { + sigJob.Resp <- AuxSigJobResp{} + } + + return args.Error(0) +} + +func (a *auxSignerMock) PackSigs(sigs []fn.Option[tlv.Blob], +) (fn.Option[tlv.Blob], error) { + + args := a.Called(sigs) + + return args.Get(0).(fn.Option[tlv.Blob]), args.Error(1) +} + +func (a *auxSignerMock) UnpackSigs(sigs fn.Option[tlv.Blob]) ( + []fn.Option[tlv.Blob], error) { + + args := a.Called(sigs) + + return args.Get(0).([]fn.Option[tlv.Blob]), args.Error(1) +} + +func (a *auxSignerMock) VerifySecondLevelSigs(chanState *channeldb.OpenChannel, + commitTx *wire.MsgTx, verifyJob []AuxVerifyJob) error { + + args := a.Called(chanState, commitTx, verifyJob) + + return args.Error(0) +}