Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion channeldb/invoices.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ const (
resolveTimeType tlv.Type = 11
expiryHeightType tlv.Type = 13
htlcStateType tlv.Type = 15
mppTotalAmtType tlv.Type = 17

// A set of tlv type definitions used to serialize invoice bodiees.
//
Expand Down Expand Up @@ -289,6 +290,10 @@ type InvoiceHTLC struct {
// Amt is the amount that is carried by this htlc.
Amt lnwire.MilliSatoshi

// MppTotalAmt is a field for mpp that indicates the expected total
// amount.
MppTotalAmt lnwire.MilliSatoshi

// AcceptHeight is the block height at which the invoice registry
// decided to accept this htlc as a payment to the invoice. At this
// height, the invoice cltv delay must have been met.
Expand Down Expand Up @@ -323,6 +328,10 @@ type HtlcAcceptDesc struct {
// Amt is the amount that is carried by this htlc.
Amt lnwire.MilliSatoshi

// MppTotalAmt is a field for mpp that indicates the expected total
// amount.
MppTotalAmt lnwire.MilliSatoshi

// Expiry is the expiry height of this htlc.
Expiry uint32

Expand Down Expand Up @@ -1018,6 +1027,7 @@ func serializeHtlcs(w io.Writer, htlcs map[CircuitKey]*InvoiceHTLC) error {
// Encode the htlc in a tlv stream.
chanID := key.ChanID.ToUint64()
amt := uint64(htlc.Amt)
mppTotalAmt := uint64(htlc.MppTotalAmt)
acceptTime := uint64(htlc.AcceptTime.UnixNano())
resolveTime := uint64(htlc.ResolveTime.UnixNano())
state := uint8(htlc.State)
Expand All @@ -1034,6 +1044,7 @@ func serializeHtlcs(w io.Writer, htlcs map[CircuitKey]*InvoiceHTLC) error {
tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
tlv.MakePrimitiveRecord(htlcStateType, &state),
tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
)

// Convert the custom records to tlv.Record types that are ready
Expand Down Expand Up @@ -1193,7 +1204,7 @@ func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) {
chanID uint64
state uint8
acceptTime, resolveTime uint64
amt uint64
amt, mppTotalAmt uint64
)
tlvStream, err := tlv.NewStream(
tlv.MakePrimitiveRecord(chanIDType, &chanID),
Expand All @@ -1206,6 +1217,7 @@ func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) {
tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
tlv.MakePrimitiveRecord(htlcStateType, &state),
tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
)
if err != nil {
return nil, err
Expand All @@ -1221,6 +1233,7 @@ func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) {
htlc.ResolveTime = time.Unix(0, int64(resolveTime))
htlc.State = HtlcState(state)
htlc.Amt = lnwire.MilliSatoshi(amt)
htlc.MppTotalAmt = lnwire.MilliSatoshi(mppTotalAmt)
Comment thread
halseth marked this conversation as resolved.
Outdated

// Reconstruct the custom records fields from the parsed types
// map return from the tlv parser.
Expand Down Expand Up @@ -1324,6 +1337,7 @@ func (d *DB) updateInvoice(hash lntypes.Hash, invoices, settleIndex *bbolt.Bucke

htlc := &InvoiceHTLC{
Amt: htlcUpdate.Amt,
MppTotalAmt: htlcUpdate.MppTotalAmt,
Expiry: htlcUpdate.Expiry,
AcceptHeight: uint32(htlcUpdate.AcceptHeight),
AcceptTime: now,
Expand Down
9 changes: 6 additions & 3 deletions htlcswitch/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -790,9 +790,12 @@ func newMockRegistry(minDelta uint32) *mockInvoiceRegistry {
panic(err)
}

finalCltvRejectDelta := int32(5)

registry := invoices.NewRegistry(cdb, finalCltvRejectDelta)
registry := invoices.NewRegistry(
cdb,
&invoices.RegistryConfig{
FinalCltvRejectDelta: 5,
},
)
registry.Start()

return &mockInvoiceRegistry{
Expand Down
77 changes: 77 additions & 0 deletions invoices/clock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package invoices
Comment thread
joostjager marked this conversation as resolved.
Outdated

import (
"sync"
"time"
)

// testClock can be used in tests to mock time.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice! May I suggest you to extend clock.Clock interface and use that as a common ground instead? Maybe also move this test clock in that package such that others can use it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to delay this PR with a discussion on where to put that package/file. Given the time those took in past prs I'd rather avoid it (reference class forecasting).

type testClock struct {
currentTime time.Time
timeChanMap map[time.Time][]chan time.Time
timeLock sync.Mutex
}

// newTestClock returns a new test clock.
func newTestClock(startTime time.Time) *testClock {
return &testClock{
currentTime: startTime,
timeChanMap: make(map[time.Time][]chan time.Time),
}
}

// now returns the current (test) time.
func (c *testClock) now() time.Time {
c.timeLock.Lock()
defer c.timeLock.Unlock()

return c.currentTime
}

// tickAfter returns a channel that will receive a tick at the specified time.
func (c *testClock) tickAfter(duration time.Duration) <-chan time.Time {
c.timeLock.Lock()
defer c.timeLock.Unlock()

triggerTime := c.currentTime.Add(duration)
log.Debugf("tickAfter called: duration=%v, trigger_time=%v",
duration, triggerTime)

ch := make(chan time.Time, 1)

// If already expired, tick immediately.
if !triggerTime.After(c.currentTime) {
ch <- c.currentTime
return ch
}

// Otherwise store the channel until the trigger time is there.
chans := c.timeChanMap[triggerTime]
chans = append(chans, ch)
c.timeChanMap[triggerTime] = chans

return ch
}

// setTime sets the (test) time and triggers tick channels when they expire.
func (c *testClock) setTime(now time.Time) {
c.timeLock.Lock()
defer c.timeLock.Unlock()

c.currentTime = now
remainingChans := make(map[time.Time][]chan time.Time)
for triggerTime, chans := range c.timeChanMap {
// If the trigger time is still in the future, keep this channel
// in the channel map for later.
if triggerTime.After(now) {
remainingChans[triggerTime] = chans
continue
}

for _, c := range chans {
c <- now
}
}

c.timeChanMap = remainingChans
}
Loading