Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions contractcourt/briefcase_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ func assertResolversEqual(t *testing.T, originalResolver ContractResolver,
t.Fatalf("expected %v, got %v",
ogRes.broadcastHeight, diskRes.broadcastHeight)
}
if ogRes.htlcIndex != diskRes.htlcIndex {
t.Fatalf("expected %v, got %v", ogRes.htlcIndex,
diskRes.htlcIndex)
if ogRes.htlc.HtlcIndex != diskRes.htlc.HtlcIndex {
t.Fatalf("expected %v, got %v", ogRes.htlc.HtlcIndex,
diskRes.htlc.HtlcIndex)
}
}

Expand All @@ -184,9 +184,9 @@ func assertResolversEqual(t *testing.T, originalResolver ContractResolver,
t.Fatalf("expected %v, got %v",
ogRes.broadcastHeight, diskRes.broadcastHeight)
}
if ogRes.payHash != diskRes.payHash {
t.Fatalf("expected %v, got %v", ogRes.payHash,
diskRes.payHash)
if ogRes.htlc.RHash != diskRes.htlc.RHash {
t.Fatalf("expected %v, got %v", ogRes.htlc.RHash,
diskRes.htlc.RHash)
}
}

Expand Down Expand Up @@ -265,7 +265,9 @@ func TestContractInsertionRetrieval(t *testing.T) {
outputIncubating: true,
resolved: true,
broadcastHeight: 102,
htlcIndex: 12,
htlc: channeldb.HTLC{
HtlcIndex: 12,
},
}
successResolver := htlcSuccessResolver{
htlcResolution: lnwallet.IncomingHtlcResolution{
Expand All @@ -278,8 +280,10 @@ func TestContractInsertionRetrieval(t *testing.T) {
outputIncubating: true,
resolved: true,
broadcastHeight: 109,
payHash: testPreimage,
sweepTx: nil,
htlc: channeldb.HTLC{
RHash: testPreimage,
},
sweepTx: nil,
}
resolvers := []ContractResolver{
&timeoutResolver,
Expand Down Expand Up @@ -395,7 +399,9 @@ func TestContractResolution(t *testing.T) {
outputIncubating: true,
resolved: true,
broadcastHeight: 192,
htlcIndex: 9912,
htlc: channeldb.HTLC{
HtlcIndex: 9912,
},
}

// First, we'll insert the resolver into the database and ensure that
Expand Down Expand Up @@ -454,7 +460,9 @@ func TestContractSwapping(t *testing.T) {
outputIncubating: true,
resolved: true,
broadcastHeight: 102,
htlcIndex: 12,
htlc: channeldb.HTLC{
HtlcIndex: 12,
},
}
contestResolver := &htlcOutgoingContestResolver{
htlcTimeoutResolver: timeoutResolver,
Expand Down
4 changes: 4 additions & 0 deletions contractcourt/chain_arbitrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ type ChainArbitratorConfig struct {
// NotifyClosedChannel is a function closure that the ChainArbitrator
// will use to notify the ChannelNotifier about a newly closed channel.
NotifyClosedChannel func(wire.OutPoint)

// OnionProcessor is used to decode onion payloads for on-chain
// resolution.
OnionProcessor OnionProcessor
}

// ChainArbitrator is a sub-system that oversees the on-chain resolution of all
Expand Down
129 changes: 29 additions & 100 deletions contractcourt/channel_arbitrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,96 +501,24 @@ func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet) error {
"resolvers", c.cfg.ChanPoint, len(unresolvedContracts))

for _, resolver := range unresolvedContracts {
if err := c.supplementResolver(resolver, htlcMap); err != nil {
return err
htlcResolver, ok := resolver.(htlcContractResolver)
if !ok {
continue
}
}

c.launchResolvers(unresolvedContracts)

return nil
}

// supplementResolver takes a resolver as it is restored from the log and fills
// in missing data from the htlcMap.
func (c *ChannelArbitrator) supplementResolver(resolver ContractResolver,
htlcMap map[wire.OutPoint]*channeldb.HTLC) error {

switch r := resolver.(type) {

case *htlcSuccessResolver:
return c.supplementSuccessResolver(r, htlcMap)

case *htlcIncomingContestResolver:
return c.supplementIncomingContestResolver(r, htlcMap)

case *htlcTimeoutResolver:
return c.supplementTimeoutResolver(r, htlcMap)

case *htlcOutgoingContestResolver:
return c.supplementTimeoutResolver(
&r.htlcTimeoutResolver, htlcMap,
)
}

return nil
}

// supplementSuccessResolver takes a htlcIncomingContestResolver as it is
// restored from the log and fills in missing data from the htlcMap.
func (c *ChannelArbitrator) supplementIncomingContestResolver(
r *htlcIncomingContestResolver,
htlcMap map[wire.OutPoint]*channeldb.HTLC) error {

res := r.htlcResolution
htlcPoint := res.HtlcPoint()
htlc, ok := htlcMap[htlcPoint]
if !ok {
return errors.New(
"htlc for incoming contest resolver unavailable",
)
}
htlcPoint := htlcResolver.HtlcPoint()
htlc, ok := htlcMap[htlcPoint]
if !ok {
return fmt.Errorf(
"htlc resolver %T unavailable", resolver,
)
}

r.htlcAmt = htlc.Amt
r.circuitKey = channeldb.CircuitKey{
ChanID: c.cfg.ShortChanID,
HtlcID: htlc.HtlcIndex,
htlcResolver.Supplement(*htlc)
}

return nil
}

// supplementSuccessResolver takes a htlcSuccessResolver as it is restored from
// the log and fills in missing data from the htlcMap.
func (c *ChannelArbitrator) supplementSuccessResolver(r *htlcSuccessResolver,
htlcMap map[wire.OutPoint]*channeldb.HTLC) error {

res := r.htlcResolution
htlcPoint := res.HtlcPoint()
htlc, ok := htlcMap[htlcPoint]
if !ok {
return errors.New(
"htlc for success resolver unavailable",
)
}
r.htlcAmt = htlc.Amt
return nil
}
c.launchResolvers(unresolvedContracts)

// supplementTimeoutResolver takes a htlcSuccessResolver as it is restored from
// the log and fills in missing data from the htlcMap.
func (c *ChannelArbitrator) supplementTimeoutResolver(r *htlcTimeoutResolver,
htlcMap map[wire.OutPoint]*channeldb.HTLC) error {

res := r.htlcResolution
htlcPoint := res.HtlcPoint()
htlc, ok := htlcMap[htlcPoint]
if !ok {
return errors.New(
"htlc for timeout resolver unavailable",
)
}
r.htlcAmt = htlc.Amt
return nil
}

Expand Down Expand Up @@ -1224,8 +1152,10 @@ func (c *ChannelArbitrator) checkCommitChainActions(height uint32,
// * race condition if adding and we broadcast, etc
// * or would make each instance sync?

log.Debugf("ChannelArbitrator(%v): checking chain actions at "+
"height=%v", c.cfg.ChanPoint, height)
log.Debugf("ChannelArbitrator(%v): checking commit chain actions at "+
"height=%v, in_htlc_count=%v, out_htlc_count=%v",
c.cfg.ChanPoint, height,
len(htlcs.incomingHTLCs), len(htlcs.outgoingHTLCs))

actionMap := make(ChainActionMap)

Expand Down Expand Up @@ -1719,6 +1649,8 @@ func (c *ChannelArbitrator) prepContractResolutions(
// claim the HTLC (second-level or directly), then add the pre
case HtlcClaimAction:
for _, htlc := range htlcs {
htlc := htlc

htlcOp := wire.OutPoint{
Hash: commitHash,
Index: uint32(htlc.OutputIndex),
Expand All @@ -1734,8 +1666,7 @@ func (c *ChannelArbitrator) prepContractResolutions(
}

resolver := newSuccessResolver(
resolution, height,
htlc.RHash, htlc.Amt, resolverCfg,
resolution, height, htlc, resolverCfg,
)
htlcResolvers = append(htlcResolvers, resolver)
}
Expand All @@ -1745,6 +1676,8 @@ func (c *ChannelArbitrator) prepContractResolutions(
// backwards.
case HtlcTimeoutAction:
for _, htlc := range htlcs {
htlc := htlc

htlcOp := wire.OutPoint{
Hash: commitHash,
Index: uint32(htlc.OutputIndex),
Expand All @@ -1758,8 +1691,7 @@ func (c *ChannelArbitrator) prepContractResolutions(
}

resolver := newTimeoutResolver(
resolution, height, htlc.HtlcIndex,
htlc.Amt, resolverCfg,
resolution, height, htlc, resolverCfg,
)
htlcResolvers = append(htlcResolvers, resolver)
}
Expand All @@ -1769,6 +1701,8 @@ func (c *ChannelArbitrator) prepContractResolutions(
// learn of the pre-image, or let the remote party time out.
case HtlcIncomingWatchAction:
for _, htlc := range htlcs {
htlc := htlc

htlcOp := wire.OutPoint{
Hash: commitHash,
Index: uint32(htlc.OutputIndex),
Expand All @@ -1785,15 +1719,9 @@ func (c *ChannelArbitrator) prepContractResolutions(
continue
}

circuitKey := channeldb.CircuitKey{
HtlcID: htlc.HtlcIndex,
ChanID: c.cfg.ShortChanID,
}

resolver := newIncomingContestResolver(
htlc.RefundTimeout, circuitKey,
resolution, height, htlc.RHash,
htlc.Amt, resolverCfg,
resolution, height, htlc,
resolverCfg,
)
htlcResolvers = append(htlcResolvers, resolver)
}
Expand All @@ -1803,6 +1731,8 @@ func (c *ChannelArbitrator) prepContractResolutions(
// backwards), or just timeout.
case HtlcOutgoingWatchAction:
for _, htlc := range htlcs {
htlc := htlc

htlcOp := wire.OutPoint{
Hash: commitHash,
Index: uint32(htlc.OutputIndex),
Expand All @@ -1817,8 +1747,7 @@ func (c *ChannelArbitrator) prepContractResolutions(
}

resolver := newOutgoingContestResolver(
resolution, height, htlc.HtlcIndex,
htlc.Amt, resolverCfg,
resolution, height, htlc, resolverCfg,
)
htlcResolvers = append(htlcResolvers, resolver)
}
Expand Down
7 changes: 4 additions & 3 deletions contractcourt/channel_arbitrator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog) (*chanArbTestC
incubateChan <- struct{}{}
return nil
},
OnionProcessor: &mockOnionProcessor{},
}

// We'll use the resolvedChan to synchronize on call to
Expand Down Expand Up @@ -858,10 +859,10 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) {
resolver)
}

// The resolver should have its htlcAmt field populated as it.
if int64(outgoingResolver.htlcAmt) != int64(htlcAmt) {
// The resolver should have its htlc amt field populated as it.
if int64(outgoingResolver.htlc.Amt) != int64(htlcAmt) {
t.Fatalf("wrong htlc amount: expected %v, got %v,",
htlcAmt, int64(outgoingResolver.htlcAmt))
htlcAmt, int64(outgoingResolver.htlc.Amt))
}

// htlcOutgoingContestResolver is now active and waiting for the HTLC to
Expand Down
15 changes: 15 additions & 0 deletions contractcourt/contract_resolvers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"encoding/binary"
"errors"
"io"

"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
)

var (
Expand Down Expand Up @@ -51,6 +54,18 @@ type ContractResolver interface {
Stop()
}

// htlcContractResolver is the required interface for htlc resolvers.
type htlcContractResolver interface {
ContractResolver

// HtlcPoint returns the htlc's outpoint on the commitment tx.
HtlcPoint() wire.OutPoint

// Supplement adds additional information to the resolver that is
// required before Resolve() is called.
Supplement(htlc channeldb.HTLC)
}

// reportingContractResolver is a ContractResolver that also exposes a report on
// the resolution state of the contract.
type reportingContractResolver interface {
Expand Down
Loading