diff --git a/discovery/chan_series.go b/discovery/chan_series.go index 4a9a51914e..8fa460a852 100644 --- a/discovery/chan_series.go +++ b/discovery/chan_series.go @@ -6,6 +6,7 @@ import ( "time" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/netann" @@ -115,7 +116,10 @@ func (c *ChanSeries) UpdatesInHorizon(chain chainhash.Hash, // First, we'll query for all the set of channels that have an // update that falls within the specified horizon. chansInHorizon := c.graph.ChanUpdatesInHorizon( - context.TODO(), startTime, endTime, + context.TODO(), graphdb.ChanUpdateRange{ + StartTime: fn.Some(startTime), + EndTime: fn.Some(endTime), + }, ) for channel, err := range chansInHorizon { @@ -181,7 +185,10 @@ func (c *ChanSeries) UpdatesInHorizon(chain chainhash.Hash, // update within the horizon as well. We send these second to // ensure that they follow any active channels they have. nodeAnnsInHorizon := c.graph.NodeUpdatesInHorizon( - context.TODO(), startTime, endTime, + context.TODO(), graphdb.NodeUpdateRange{ + StartTime: fn.Some(startTime), + EndTime: fn.Some(endTime), + }, graphdb.WithIterPublicNodesOnly(), ) for nodeAnn, err := range nodeAnnsInHorizon { diff --git a/docs/release-notes/release-notes-0.21.0.md b/docs/release-notes/release-notes-0.21.0.md index cb5fa1a3f9..0fbfe2d12b 100644 --- a/docs/release-notes/release-notes-0.21.0.md +++ b/docs/release-notes/release-notes-0.21.0.md @@ -267,6 +267,12 @@ [4](https://github.com/lightningnetwork/lnd/pull/10542), [5](https://github.com/lightningnetwork/lnd/pull/10572), [6](https://github.com/lightningnetwork/lnd/pull/10582). +* Make the [graph `Store` interface + cross-version](https://github.com/lightningnetwork/lnd/pull/10656) so that + query methods (`ForEachNode`, `ForEachChannel`, `NodeUpdatesInHorizon`, + `ChanUpdatesInHorizon`, `FilterKnownChanIDs`) work across gossip v1 and v2. + Add `PreferHighest` fetch helpers and `GetVersions` queries so callers can + retrieve channels without knowing which gossip version announced them. * Updated waiting proof persistence for gossip upgrades by introducing typed waiting proof keys and payloads, with a DB migration to rewrite legacy waiting proof records to the new key/value format diff --git a/graph/builder.go b/graph/builder.go index 614c1102fc..31d2199e6a 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnutils" @@ -648,7 +649,11 @@ func (b *Builder) pruneZombieChans() error { startTime := time.Unix(0, 0) endTime := time.Now().Add(-1 * chanExpiry) oldEdgesIter := b.cfg.Graph.ChanUpdatesInHorizon( - context.TODO(), startTime, endTime, + context.TODO(), lnwire.GossipVersion1, + graphdb.ChanUpdateRange{ + StartTime: fn.Some(startTime), + EndTime: fn.Some(endTime), + }, ) for u, err := range oldEdgesIter { diff --git a/graph/db/benchmark_test.go b/graph/db/benchmark_test.go index 19d6a134a1..062d4f5239 100644 --- a/graph/db/benchmark_test.go +++ b/graph/db/benchmark_test.go @@ -372,7 +372,7 @@ func TestPopulateDBs(t *testing.T) { numPolicies = 0 ) err := graph.ForEachChannel( - ctx, lnwire.GossipVersion1, + ctx, func(info *models.ChannelEdgeInfo, policy, policy2 *models.ChannelEdgePolicy) error { @@ -500,7 +500,7 @@ func syncGraph(t *testing.T, src, dest *ChannelGraph) { } var wgChans sync.WaitGroup - err = src.ForEachChannel(ctx, lnwire.GossipVersion1, + err = src.ForEachChannel(ctx, func(info *models.ChannelEdgeInfo, policy1, policy2 *models.ChannelEdgePolicy) error { @@ -624,7 +624,7 @@ func BenchmarkGraphReadMethods(b *testing.B) { name: "ForEachNode", fn: func(b testing.TB, store Store) { err := store.ForEachNode( - ctx, lnwire.GossipVersion1, + ctx, func(_ *models.Node) error { // Increment the counter to // ensure the callback is doing @@ -640,12 +640,11 @@ func BenchmarkGraphReadMethods(b *testing.B) { { name: "ForEachChannel", fn: func(b testing.TB, store Store) { - //nolint:ll - err := store.ForEachChannel( - ctx, lnwire.GossipVersion1, + err := store.ForEachChannel(ctx, func(_ *models.ChannelEdgeInfo, + _, _ *models.ChannelEdgePolicy, - _ *models.ChannelEdgePolicy) error { + ) error { // Increment the counter to // ensure the callback is doing @@ -662,7 +661,13 @@ func BenchmarkGraphReadMethods(b *testing.B) { name: "NodeUpdatesInHorizon", fn: func(b testing.TB, store Store) { iter := store.NodeUpdatesInHorizon( - ctx, time.Unix(0, 0), time.Now(), + ctx, lnwire.GossipVersion1, + NodeUpdateRange{ + StartTime: fn.Some( + time.Unix(0, 0), + ), + EndTime: fn.Some(time.Now()), + }, ) _, err := fn.CollectErr(iter) require.NoError(b, err) @@ -713,7 +718,13 @@ func BenchmarkGraphReadMethods(b *testing.B) { name: "ChanUpdatesInHorizon", fn: func(b testing.TB, store Store) { iter := store.ChanUpdatesInHorizon( - ctx, time.Unix(0, 0), time.Now(), + ctx, lnwire.GossipVersion1, + ChanUpdateRange{ + StartTime: fn.Some( + time.Unix(0, 0), + ), + EndTime: fn.Some(time.Now()), + }, ) _, err := fn.CollectErr(iter) require.NoError(b, err) @@ -817,7 +828,7 @@ func BenchmarkFindOptimalSQLQueryConfig(b *testing.B) { ) err := store.ForEachNode( - ctx, lnwire.GossipVersion1, + ctx, func(_ *models.Node) error { numNodes++ @@ -828,7 +839,7 @@ func BenchmarkFindOptimalSQLQueryConfig(b *testing.B) { //nolint:ll err = store.ForEachChannel( - ctx, lnwire.GossipVersion1, + ctx, func(_ *models.ChannelEdgeInfo, _, _ *models.ChannelEdgePolicy) error { diff --git a/graph/db/graph.go b/graph/db/graph.go index 507312f656..36a12a2179 100644 --- a/graph/db/graph.go +++ b/graph/db/graph.go @@ -63,6 +63,14 @@ type ChannelGraph struct { cancel fn.Option[context.CancelFunc] } +// preferHighestNodeDirectedChanneler is implemented by stores that can stream +// cross-version node-directed channel traversals directly. +type preferHighestNodeDirectedChanneler interface { + ForEachNodeDirectedChannelPreferHighest(ctx context.Context, + node route.Vertex, cb func(channel *DirectedChannel) error, + reset func()) error +} + // NewChannelGraph creates a new ChannelGraph instance with the given backend. func NewChannelGraph(v1Store Store, options ...ChanGraphOption) (*ChannelGraph, error) { @@ -238,8 +246,9 @@ func (c *ChannelGraph) populateCache(ctx context.Context) error { for _, v := range []lnwire.GossipVersion{ gossipV1, gossipV2, } { - // TODO(elle): If we have both v1 and v2 entries for the same - // node/channel, prefer v2 when merging. + // We iterate v1 first, then v2. Since AddNodeFeatures and + // AddChannel overwrite on key collision, v2 data naturally + // takes precedence when both versions exist. err := c.db.ForEachNodeCacheable(ctx, v, func(node route.Vertex, features *lnwire.FeatureVector) error { @@ -299,12 +308,59 @@ func (c *ChannelGraph) ForEachNodeDirectedChannel(ctx context.Context, return c.cache.graphCache.ForEachChannel(node, cb) } - // TODO(elle): once the no-cache path needs to support - // pathfinding across gossip versions, this should iterate - // across all versions rather than defaulting to v1. - return c.db.ForEachNodeDirectedChannel( - ctx, gossipV1, node, cb, reset, - ) + if db, ok := c.db.(preferHighestNodeDirectedChanneler); ok { + return db.ForEachNodeDirectedChannelPreferHighest( + ctx, node, cb, reset, + ) + } + + // Iterate across all gossip versions (highest first) so that + // channels announced via v2 are preferred over v1. We buffer + // results and deliver them at the end so that ExecTx retries + // within a single version don't corrupt the caller's state or + // lose channels from already-committed versions. + seen := make(map[uint64]struct{}) + var allChannels []*DirectedChannel + + for _, v := range []lnwire.GossipVersion{gossipV2, gossipV1} { + prevLen := len(allChannels) + err := c.db.ForEachNodeDirectedChannel( + ctx, v, node, + func(channel *DirectedChannel) error { + if _, ok := seen[channel.ChannelID]; ok { + return nil + } + seen[channel.ChannelID] = struct{}{} + + allChannels = append(allChannels, channel) + + return nil + }, + func() { + // On ExecTx retry, undo this version's + // additions while keeping channels from + // earlier (already committed) versions. + for _, ch := range allChannels[prevLen:] { + delete(seen, ch.ChannelID) + } + allChannels = allChannels[:prevLen] + }, + ) + if err != nil && + !errors.Is(err, ErrVersionNotSupportedForKVDB) { + + return err + } + } + + // Deliver the collected channels to the caller. + for _, ch := range allChannels { + if err := cb(ch); err != nil { + return err + } + } + + return nil } // FetchNodeFeatures returns the features of the given node. If no features are @@ -320,7 +376,22 @@ func (c *ChannelGraph) FetchNodeFeatures(ctx context.Context, return c.cache.graphCache.GetFeatures(node), nil } - return c.db.FetchNodeFeatures(ctx, lnwire.GossipVersion1, node) + // Try v2 first, fall back to v1 if the v2 features are empty. + for _, v := range []lnwire.GossipVersion{gossipV2, gossipV1} { + features, err := c.db.FetchNodeFeatures(ctx, v, node) + if errors.Is(err, ErrVersionNotSupportedForKVDB) { + continue + } + if err != nil { + return nil, err + } + + if !features.IsEmpty() { + return features, nil + } + } + + return lnwire.EmptyFeatureVector(), nil } // GraphSession will provide the call-back with access to a NodeTraverser @@ -600,56 +671,6 @@ func (c *ChannelGraph) PruneGraphNodes(ctx context.Context) error { return nil } -// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan -// ID's that we don't know and are not known zombies of the passed set. In other -// words, we perform a set difference of our set of chan ID's and the ones -// passed in. This method can be used by callers to determine the set of -// channels another peer knows of that we don't. -func (c *ChannelGraph) FilterKnownChanIDs(ctx context.Context, - chansInfo []ChannelUpdateInfo, - isZombieChan func(ChannelUpdateInfo) bool) ([]uint64, error) { - - unknown, knownZombies, err := c.db.FilterKnownChanIDs(ctx, chansInfo) - if err != nil { - return nil, err - } - - for _, info := range knownZombies { - // TODO(ziggie): Make sure that for the strict pruning case we - // compare the pubkeys and whether the right timestamp is not - // older than the `ChannelPruneExpiry`. - // - // NOTE: The timestamp data has no verification attached to it - // in the `ReplyChannelRange` msg so we are trusting this data - // at this point. However it is not critical because we are just - // removing the channel from the db when the timestamps are more - // recent. During the querying of the gossip msg verification - // happens as usual. However we should start punishing peers - // when they don't provide us honest data ? - if isZombieChan(info) { - continue - } - - // If we have marked it as a zombie but the latest update - // info could bring it back from the dead, then we mark it - // alive, and we let it be added to the set of IDs to query our - // peer for. - err := c.db.MarkEdgeLive( - ctx, info.Version, - info.ShortChannelID.ToUint64(), - ) - // Since there is a chance that the edge could have been marked - // as "live" between the FilterKnownChanIDs call and the - // MarkEdgeLive call, we ignore the error if the edge is already - // marked as live. - if err != nil && !errors.Is(err, ErrZombieEdgeNotFound) { - return nil, err - } - } - - return unknown, nil -} - // MarkEdgeZombie attempts to mark a channel identified by its channel ID as a // zombie for the given gossip version. This method is used on an ad-hoc basis, // when channels need to be marked as zombies outside the normal pruning cycle. @@ -722,12 +743,12 @@ func (c *ChannelGraph) ForEachNodeCacheable(ctx context.Context, } // NodeUpdatesInHorizon returns all known lightning nodes with updates in the -// range. +// range for the given gossip version. func (c *ChannelGraph) NodeUpdatesInHorizon(ctx context.Context, - startTime, endTime time.Time, + v lnwire.GossipVersion, r NodeUpdateRange, opts ...IteratorOption) iter.Seq2[*models.Node, error] { - return c.db.NodeUpdatesInHorizon(ctx, startTime, endTime, opts...) + return c.db.NodeUpdatesInHorizon(ctx, v, r, opts...) } // HasV1Node determines if the graph has a vertex identified by the target node @@ -738,13 +759,14 @@ func (c *ChannelGraph) HasV1Node(ctx context.Context, return c.db.HasV1Node(ctx, nodePub) } -// ForEachChannel iterates through all channel edges stored within the graph. +// ForEachChannel iterates through all channel edges stored within the graph +// across all gossip versions. func (c *ChannelGraph) ForEachChannel(ctx context.Context, - v lnwire.GossipVersion, cb func(*models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error, + cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error, reset func()) error { - return c.db.ForEachChannel(ctx, v, cb, reset) + return c.db.ForEachChannel(ctx, cb, reset) } // DisabledChannelIDs returns the channel ids of disabled channels. @@ -784,12 +806,12 @@ func (c *ChannelGraph) HighestChanID(ctx context.Context, } // ChanUpdatesInHorizon returns all known channel edges with updates in the -// horizon. +// range for the given gossip version. func (c *ChannelGraph) ChanUpdatesInHorizon(ctx context.Context, - startTime, endTime time.Time, + v lnwire.GossipVersion, r ChanUpdateRange, opts ...IteratorOption) iter.Seq2[ChannelEdge, error] { - return c.db.ChanUpdatesInHorizon(ctx, startTime, endTime, opts...) + return c.db.ChanUpdatesInHorizon(ctx, v, r, opts...) } // FilterChannelRange returns channel IDs within the passed block height range @@ -822,26 +844,39 @@ func (c *ChannelGraph) FetchChanInfos(ctx context.Context, } // FetchChannelEdgesByOutpoint attempts to lookup directed edges by funding -// outpoint. +// outpoint, returning the highest available gossip version. func (c *ChannelGraph) FetchChannelEdgesByOutpoint(ctx context.Context, op *wire.OutPoint) ( *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) { - return c.db.FetchChannelEdgesByOutpoint( - ctx, lnwire.GossipVersion1, op, - ) + return c.db.FetchChannelEdgesByOutpointPreferHighest(ctx, op) } -// FetchChannelEdgesByID attempts to lookup directed edges by channel ID. +// FetchChannelEdgesByID attempts to lookup directed edges by channel ID, +// returning the highest available gossip version. func (c *ChannelGraph) FetchChannelEdgesByID(ctx context.Context, chanID uint64) ( *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) { - return c.db.FetchChannelEdgesByID( - ctx, lnwire.GossipVersion1, chanID, - ) + return c.db.FetchChannelEdgesByIDPreferHighest(ctx, chanID) +} + +// GetVersionsBySCID returns the list of gossip versions for which a channel +// with the given SCID exists in the database. +func (c *ChannelGraph) GetVersionsBySCID(ctx context.Context, + chanID uint64) ([]lnwire.GossipVersion, error) { + + return c.db.GetVersionsBySCID(ctx, chanID) +} + +// GetVersionsByOutpoint returns the list of gossip versions for which a channel +// with the given funding outpoint exists in the database. +func (c *ChannelGraph) GetVersionsByOutpoint(ctx context.Context, + op *wire.OutPoint) ([]lnwire.GossipVersion, error) { + + return c.db.GetVersionsByOutpoint(ctx, op) } // PutClosedScid stores a SCID for a closed channel in the database. @@ -936,7 +971,7 @@ func (c *VersionedGraph) ForEachNodeCached(ctx context.Context, func (c *VersionedGraph) ForEachNode(ctx context.Context, cb func(*models.Node) error, reset func()) error { - return c.db.ForEachNode(ctx, c.v, cb, reset) + return c.db.ForEachNode(ctx, cb, reset) } // NumZombies returns the current number of zombie channels in the graph. @@ -944,13 +979,74 @@ func (c *VersionedGraph) NumZombies(ctx context.Context) (uint64, error) { return c.db.NumZombies(ctx, c.v) } -// NodeUpdatesInHorizon returns all known lightning nodes which have an update -// timestamp within the passed range. +// NodeUpdatesInHorizon returns all known lightning nodes which have updates +// within the passed range. func (c *VersionedGraph) NodeUpdatesInHorizon(ctx context.Context, - startTime, endTime time.Time, + r NodeUpdateRange, opts ...IteratorOption) iter.Seq2[*models.Node, error] { - return c.db.NodeUpdatesInHorizon(ctx, startTime, endTime, opts...) + return c.db.NodeUpdatesInHorizon(ctx, c.v, r, opts...) +} + +// FilterKnownChanIDs takes a set of channel IDs and returns the subset of chan +// ID's that we don't know and are not known zombies of the passed set. In +// other words, we perform a set difference of our set of chan ID's and the ones +// passed in. This method can be used by callers to determine the set of +// channels another peer knows of that we don't. +func (c *VersionedGraph) FilterKnownChanIDs(ctx context.Context, + chansInfo []ChannelUpdateInfo, + isZombieChan func(ChannelUpdateInfo) bool) ([]uint64, error) { + + unknown, knownZombies, err := c.db.FilterKnownChanIDs( + ctx, c.v, chansInfo, + ) + if err != nil { + return nil, err + } + + for _, info := range knownZombies { + // TODO(ziggie): Make sure that for the strict pruning case we + // compare the pubkeys and whether the right timestamp is not + // older than the `ChannelPruneExpiry`. + // + // NOTE: The timestamp data has no verification attached to it + // in the `ReplyChannelRange` msg so we are trusting this data + // at this point. However it is not critical because we are just + // removing the channel from the db when the timestamps are more + // recent. During the querying of the gossip msg verification + // happens as usual. However we should start punishing peers + // when they don't provide us honest data ? + if isZombieChan(info) { + continue + } + + // If we have marked it as a zombie but the latest update + // info could bring it back from the dead, then we mark it + // alive, and we let it be added to the set of IDs to query + // our peer for. + err := c.db.MarkEdgeLive( + ctx, info.Version, + info.ShortChannelID.ToUint64(), + ) + // Since there is a chance that the edge could have been + // marked as "live" between the FilterKnownChanIDs call and + // the MarkEdgeLive call, we ignore the error if the edge is + // already marked as live. + if err != nil && !errors.Is(err, ErrZombieEdgeNotFound) { + return nil, err + } + } + + return unknown, nil +} + +// ChanUpdatesInHorizon returns all known channel edges with updates in the +// range. +func (c *VersionedGraph) ChanUpdatesInHorizon(ctx context.Context, + r ChanUpdateRange, + opts ...IteratorOption) iter.Seq2[ChannelEdge, error] { + + return c.db.ChanUpdatesInHorizon(ctx, c.v, r, opts...) } // ChannelView returns the verifiable edge information for each active channel. @@ -1107,7 +1203,7 @@ func (c *VersionedGraph) ForEachChannel(ctx context.Context, cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error, reset func()) error { - return c.db.ForEachChannel(ctx, c.v, cb, reset) + return c.db.ForEachChannel(ctx, cb, reset) } // ForEachNodeCacheable iterates through all stored vertices/nodes in the graph. diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 46ac8d0eca..b8052ccc4c 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -1799,7 +1799,7 @@ func TestGraphTraversal(t *testing.T) { // Iterate through all the known channels within the graph DB, once // again if the map is empty that indicates that all edges have // properly been reached. - err = graph.ForEachChannel(ctx, lnwire.GossipVersion1, + err = graph.ForEachChannel(ctx, func(ei *models.ChannelEdgeInfo, _ *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { @@ -2154,7 +2154,7 @@ func assertPruneTip(t *testing.T, graph *ChannelGraph, func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { numChans := 0 err := graph.ForEachChannel( - t.Context(), lnwire.GossipVersion1, + t.Context(), func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error { @@ -2435,7 +2435,10 @@ func TestChanUpdatesInHorizon(t *testing.T) { // If we issue an arbitrary query before any channel updates are // inserted in the database, we should get zero results. chanIter := graph.ChanUpdatesInHorizon( - ctx, time.Unix(999, 0), time.Unix(9999, 0), + ctx, lnwire.GossipVersion1, ChanUpdateRange{ + StartTime: fn.Some(time.Unix(999, 0)), + EndTime: fn.Some(time.Unix(9999, 0)), + }, ) chanUpdates, err := fn.CollectErr(chanIter) @@ -2542,7 +2545,10 @@ func TestChanUpdatesInHorizon(t *testing.T) { } for _, queryCase := range queryCases { respIter := graph.ChanUpdatesInHorizon( - ctx, queryCase.start, queryCase.end, + ctx, lnwire.GossipVersion1, ChanUpdateRange{ + StartTime: fn.Some(queryCase.start), + EndTime: fn.Some(queryCase.end), + }, ) resp, err := fn.CollectErr(respIter) @@ -2579,7 +2585,10 @@ func TestNodeUpdatesInHorizon(t *testing.T) { // If we issue an arbitrary query before we insert any nodes into the // database, then we shouldn't get any results back. nodeUpdatesIter := graph.NodeUpdatesInHorizon( - ctx, time.Unix(999, 0), time.Unix(9999, 0), + ctx, lnwire.GossipVersion1, NodeUpdateRange{ + StartTime: fn.Some(time.Unix(999, 0)), + EndTime: fn.Some(time.Unix(9999, 0)), + }, ) nodeUpdates, err := fn.CollectErr(nodeUpdatesIter) require.NoError(t, err, "unable to query for node updates") @@ -2654,7 +2663,10 @@ func TestNodeUpdatesInHorizon(t *testing.T) { } for _, queryCase := range queryCases { iter := graph.NodeUpdatesInHorizon( - ctx, queryCase.start, queryCase.end, + ctx, lnwire.GossipVersion1, NodeUpdateRange{ + StartTime: fn.Some(queryCase.start), + EndTime: fn.Some(queryCase.end), + }, ) resp, err := fn.CollectErr(iter) @@ -2782,7 +2794,11 @@ func testNodeUpdatesWithBatchSize(t *testing.T, ctx context.Context, for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { iter := testGraph.NodeUpdatesInHorizon( - ctx, tc.start, tc.end, + ctx, lnwire.GossipVersion1, + NodeUpdateRange{ + StartTime: fn.Some(tc.start), + EndTime: fn.Some(tc.end), + }, WithNodeUpdateIterBatchSize( batchSize, ), @@ -2855,7 +2871,13 @@ func TestNodeUpdatesInHorizonEarlyTermination(t *testing.T) { for _, stopAt := range terminationPoints { t.Run(fmt.Sprintf("StopAt%d", stopAt), func(t *testing.T) { iter := graph.NodeUpdatesInHorizon( - ctx, startTime, startTime.Add(200*time.Hour), + ctx, lnwire.GossipVersion1, + NodeUpdateRange{ + StartTime: fn.Some(startTime), + EndTime: fn.Some( + startTime.Add(200 * time.Hour), + ), + }, WithNodeUpdateIterBatchSize(10), ) @@ -2944,7 +2966,12 @@ func TestChanUpdatesInHorizonBoundaryConditions(t *testing.T) { // Now we'll run the main query, and verify that we get // back the expected number of channels. iter := graph.ChanUpdatesInHorizon( - ctx, startTime, startTime.Add(26*time.Hour), + ctx, lnwire.GossipVersion1, ChanUpdateRange{ + StartTime: fn.Some(startTime), + EndTime: fn.Some( + startTime.Add(26 * time.Hour), + ), + }, WithChanUpdateIterBatchSize(batchSize), ) @@ -2970,7 +2997,9 @@ func TestFilterKnownChanIDsZombieRevival(t *testing.T) { t.Parallel() ctx := t.Context() - graph := MakeTestGraph(t) + graph := NewVersionedGraph( + MakeTestGraph(t), lnwire.GossipVersion1, + ) var ( scid1 = lnwire.ShortChannelID{BlockHeight: 1} @@ -2978,9 +3007,8 @@ func TestFilterKnownChanIDsZombieRevival(t *testing.T) { scid3 = lnwire.ShortChannelID{BlockHeight: 3} ) - v1Graph := NewVersionedGraph(graph, lnwire.GossipVersion1) isZombie := func(scid lnwire.ShortChannelID) bool { - zombie, _, _, err := v1Graph.IsZombieEdge(ctx, scid.ToUint64()) + zombie, _, _, err := graph.IsZombieEdge(ctx, scid.ToUint64()) require.NoError(t, err) return zombie @@ -3004,13 +3032,15 @@ func TestFilterKnownChanIDsZombieRevival(t *testing.T) { // Call FilterKnownChanIDs with an isStillZombie call-back that would // result in the current zombies still be considered as zombies. - _, err = graph.FilterKnownChanIDs(ctx, []ChannelUpdateInfo{ - {ShortChannelID: scid1, Version: lnwire.GossipVersion1}, - {ShortChannelID: scid2, Version: lnwire.GossipVersion1}, - {ShortChannelID: scid3, Version: lnwire.GossipVersion1}, - }, func(_ ChannelUpdateInfo) bool { - return true - }) + _, err = graph.FilterKnownChanIDs( + ctx, []ChannelUpdateInfo{ + {ShortChannelID: scid1, Version: lnwire.GossipVersion1}, + {ShortChannelID: scid2, Version: lnwire.GossipVersion1}, + {ShortChannelID: scid3, Version: lnwire.GossipVersion1}, + }, func(_ ChannelUpdateInfo) bool { + return true + }, + ) require.NoError(t, err) require.True(t, isZombie(scid1)) @@ -3020,17 +3050,19 @@ func TestFilterKnownChanIDsZombieRevival(t *testing.T) { // Now call it again but this time with a isStillZombie call-back that // would result in channel with SCID 2 no longer being considered a // zombie. - _, err = graph.FilterKnownChanIDs(ctx, []ChannelUpdateInfo{ - {ShortChannelID: scid1, Version: lnwire.GossipVersion1}, - { - ShortChannelID: scid2, - Version: lnwire.GossipVersion1, - Node1Freshness: lnwire.UnixTimestamp(1000), + _, err = graph.FilterKnownChanIDs( + ctx, []ChannelUpdateInfo{ + {ShortChannelID: scid1, Version: lnwire.GossipVersion1}, + { + ShortChannelID: scid2, + Version: lnwire.GossipVersion1, + Node1Freshness: lnwire.UnixTimestamp(1000), + }, + {ShortChannelID: scid3, Version: lnwire.GossipVersion1}, + }, func(info ChannelUpdateInfo) bool { + return info.Node1Freshness != lnwire.UnixTimestamp(1000) }, - {ShortChannelID: scid3, Version: lnwire.GossipVersion1}, - }, func(info ChannelUpdateInfo) bool { - return info.Node1Freshness != lnwire.UnixTimestamp(1000) - }) + ) require.NoError(t, err) // Show that SCID 2 has been marked as live. @@ -3046,7 +3078,9 @@ func TestFilterKnownChanIDs(t *testing.T) { t.Parallel() ctx := t.Context() - graph := MakeTestGraph(t) + graph := NewVersionedGraph( + MakeTestGraph(t), lnwire.GossipVersion1, + ) isZombieUpdate := func(_ ChannelUpdateInfo) bool { return true @@ -3106,8 +3140,7 @@ func TestFilterKnownChanIDs(t *testing.T) { ) require.NoError(t, graph.AddChannelEdge(ctx, channel)) err := graph.DeleteChannelEdges( - ctx, lnwire.GossipVersion1, false, true, - channel.ChannelID, + ctx, false, true, channel.ChannelID, ) require.NoError(t, err) @@ -3372,7 +3405,10 @@ func TestStressTestChannelGraphAPI(t *testing.T) { chanIDs = append(chanIDs, info) } - _, err := graph.FilterKnownChanIDs( + vg := NewVersionedGraph( + graph, lnwire.GossipVersion1, + ) + _, err := vg.FilterKnownChanIDs( ctx, chanIDs, func(_ ChannelUpdateInfo) bool { return rand.Intn(2) == 0 @@ -3422,8 +3458,17 @@ func TestStressTestChannelGraphAPI(t *testing.T) { name: "ChanUpdateInHorizon", fn: func() error { iter := graph.ChanUpdatesInHorizon( - ctx, time.Now().Add(-time.Hour), - time.Now(), + ctx, lnwire.GossipVersion1, + ChanUpdateRange{ + StartTime: fn.Some( + time.Now().Add( + -time.Hour, + ), + ), + EndTime: fn.Some( + time.Now(), + ), + }, ) _, err := fn.CollectErr(iter) @@ -4219,7 +4264,10 @@ func TestNodePruningUpdateIndexDeletion(t *testing.T) { startTime := time.Unix(9, 0) endTime := node1.LastUpdate.Add(time.Minute) nodesInHorizonIter := graph.NodeUpdatesInHorizon( - ctx, startTime, endTime, + ctx, NodeUpdateRange{ + StartTime: fn.Some(startTime), + EndTime: fn.Some(endTime), + }, ) // We should only have a single node, and that node should exactly @@ -4237,7 +4285,10 @@ func TestNodePruningUpdateIndexDeletion(t *testing.T) { // Now that the node has been deleted, we'll again query the nodes in // the horizon. This time we should have no nodes at all. nodesInHorizonIter = graph.NodeUpdatesInHorizon( - ctx, startTime, endTime, + ctx, NodeUpdateRange{ + StartTime: fn.Some(startTime), + EndTime: fn.Some(endTime), + }, ) nodesInHorizon, err = fn.CollectErr(nodesInHorizonIter) require.NoError(t, err, "unable to fetch nodes in horizon") @@ -5786,3 +5837,879 @@ func TestLightningNodePersistence(t *testing.T) { require.Equal(t, nodeAnnBytes, b.Bytes()) } + +// TestUpdateRangeValidateForVersion verifies that ChanUpdateRange and +// NodeUpdateRange reject invalid field combinations for each gossip version. +func TestUpdateRangeValidateForVersion(t *testing.T) { + t.Parallel() + + now := time.Now() + + tests := []struct { + name string + fn func() error + wantErr string + }{ + { + name: "v1 chan range with time - ok", + fn: func() error { + r := ChanUpdateRange{ + StartTime: fn.Some(now), + EndTime: fn.Some(now), + } + + return r.validateForVersion( + lnwire.GossipVersion1, + ) + }, + }, + { + name: "v1 chan range with height - rejected", + fn: func() error { + r := ChanUpdateRange{ + StartHeight: fn.Some(uint32(1)), + EndHeight: fn.Some(uint32(100)), + } + + return r.validateForVersion( + lnwire.GossipVersion1, + ) + }, + wantErr: "v1 chan update range must use time", + }, + { + name: "v2 chan range with height - ok", + fn: func() error { + r := ChanUpdateRange{ + StartHeight: fn.Some(uint32(1)), + EndHeight: fn.Some(uint32(100)), + } + + return r.validateForVersion( + lnwire.GossipVersion2, + ) + }, + }, + { + name: "v2 chan range with time - rejected", + fn: func() error { + r := ChanUpdateRange{ + StartTime: fn.Some(now), + EndTime: fn.Some(now), + } + + return r.validateForVersion( + lnwire.GossipVersion2, + ) + }, + wantErr: "v2 chan update range must use blocks", + }, + { + name: "mixed chan range - rejected", + fn: func() error { + r := ChanUpdateRange{ + StartTime: fn.Some(now), + StartHeight: fn.Some(uint32(1)), + } + + return r.validateForVersion( + lnwire.GossipVersion1, + ) + }, + wantErr: "both time and block", + }, + { + name: "v1 node range with time - ok", + fn: func() error { + r := NodeUpdateRange{ + StartTime: fn.Some(now), + EndTime: fn.Some(now), + } + + return r.validateForVersion( + lnwire.GossipVersion1, + ) + }, + }, + { + name: "v2 node range with height - ok", + fn: func() error { + r := NodeUpdateRange{ + StartHeight: fn.Some(uint32(1)), + EndHeight: fn.Some(uint32(100)), + } + + return r.validateForVersion( + lnwire.GossipVersion2, + ) + }, + }, + { + name: "v2 node range with time - rejected", + fn: func() error { + r := NodeUpdateRange{ + StartTime: fn.Some(now), + EndTime: fn.Some(now), + } + + return r.validateForVersion( + lnwire.GossipVersion2, + ) + }, + wantErr: "v2 node update range must use height", + }, + { + name: "v1 chan range missing bounds - rejected", + fn: func() error { + r := ChanUpdateRange{ + StartTime: fn.Some(now), + } + + return r.validateForVersion( + lnwire.GossipVersion1, + ) + }, + wantErr: "missing time bounds", + }, + { + name: "v1 chan range inverted - rejected", + fn: func() error { + r := ChanUpdateRange{ + StartTime: fn.Some(now.Add(time.Hour)), + EndTime: fn.Some(now), + } + + return r.validateForVersion( + lnwire.GossipVersion1, + ) + }, + wantErr: "start time after end time", + }, + { + name: "v2 chan range inverted - rejected", + fn: func() error { + r := ChanUpdateRange{ + StartHeight: fn.Some(uint32(100)), + EndHeight: fn.Some(uint32(50)), + } + + return r.validateForVersion( + lnwire.GossipVersion2, + ) + }, + wantErr: "start height after end height", + }, + { + name: "v1 node range inverted - rejected", + fn: func() error { + r := NodeUpdateRange{ + StartTime: fn.Some(now.Add(time.Hour)), + EndTime: fn.Some(now), + } + + return r.validateForVersion( + lnwire.GossipVersion1, + ) + }, + wantErr: "start time after end time", + }, + { + name: "v2 node range inverted - rejected", + fn: func() error { + r := NodeUpdateRange{ + StartHeight: fn.Some(uint32(100)), + EndHeight: fn.Some(uint32(50)), + } + + return r.validateForVersion( + lnwire.GossipVersion2, + ) + }, + wantErr: "start height after end height", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.fn() + if tc.wantErr == "" { + require.NoError(t, err) + } else { + require.ErrorContains(t, err, + tc.wantErr) + } + }) + } +} + +// TestV2HorizonQueries tests that NodeUpdatesInHorizon and +// ChanUpdatesInHorizon work with v2 gossip (block-height ranges). This test +// only runs on SQL backends since KV does not support v2. +func TestV2HorizonQueries(t *testing.T) { + t.Parallel() + + if !isSQLDB { + t.Skip("v2 horizon queries only supported on SQL backends") + } + + ctx := t.Context() + graph := MakeTestGraph(t) + + // Create two v2 nodes with specific block heights. + node1 := createTestVertex(t, lnwire.GossipVersion2) + node1.LastBlockHeight = 100 + require.NoError(t, graph.AddNode(ctx, node1)) + + node2 := createTestVertex(t, lnwire.GossipVersion2) + node2.LastBlockHeight = 200 + require.NoError(t, graph.AddNode(ctx, node2)) + + // Create a third node outside the query range. + node3 := createTestVertex(t, lnwire.GossipVersion2) + node3.LastBlockHeight = 500 + require.NoError(t, graph.AddNode(ctx, node3)) + + // --- NodeUpdatesInHorizon v2 --- + + // Query for nodes in block range [50, 250]. + nodeIter := graph.NodeUpdatesInHorizon( + ctx, lnwire.GossipVersion2, NodeUpdateRange{ + StartHeight: fn.Some(uint32(50)), + EndHeight: fn.Some(uint32(250)), + }, + ) + nodes, err := fn.CollectErr(nodeIter) + require.NoError(t, err) + require.Len(t, nodes, 2) + + // Query for nodes in block range [150, 600] should return node2 and + // node3. + nodeIter = graph.NodeUpdatesInHorizon( + ctx, lnwire.GossipVersion2, NodeUpdateRange{ + StartHeight: fn.Some(uint32(150)), + EndHeight: fn.Some(uint32(600)), + }, + ) + nodes, err = fn.CollectErr(nodeIter) + require.NoError(t, err) + require.Len(t, nodes, 2) + + // Query for nodes in block range [300, 400] should return nothing. + nodeIter = graph.NodeUpdatesInHorizon( + ctx, lnwire.GossipVersion2, NodeUpdateRange{ + StartHeight: fn.Some(uint32(300)), + EndHeight: fn.Some(uint32(400)), + }, + ) + nodes, err = fn.CollectErr(nodeIter) + require.NoError(t, err) + require.Empty(t, nodes) + + // --- ChanUpdatesInHorizon v2 --- + + // Create a v2 channel between node1 and node2. + edgeInfo, _ := createEdge( + lnwire.GossipVersion2, 100, 1, 0, 10, node1, node2, + ) + require.NoError(t, graph.AddChannelEdge(ctx, edgeInfo)) + + // Add v2 policies with specific block heights. + edge1 := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion2, + ChannelID: edgeInfo.ChannelID, + LastBlockHeight: 150, + TimeLockDelta: 14, + MinHTLC: 1000, + MaxHTLC: 1000000, + FeeBaseMSat: 1000, + FeeProportionalMillionths: 200, + } + edge2 := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion2, + SecondPeer: true, + ChannelID: edgeInfo.ChannelID, + LastBlockHeight: 160, + TimeLockDelta: 14, + MinHTLC: 1000, + MaxHTLC: 1000000, + FeeBaseMSat: 1000, + FeeProportionalMillionths: 200, + } + require.NoError(t, graph.UpdateEdgePolicy(ctx, edge1)) + require.NoError(t, graph.UpdateEdgePolicy(ctx, edge2)) + + // Query for channel updates in block range [100, 200]. + chanIter := graph.ChanUpdatesInHorizon( + ctx, lnwire.GossipVersion2, ChanUpdateRange{ + StartHeight: fn.Some(uint32(100)), + EndHeight: fn.Some(uint32(200)), + }, + ) + channels, err := fn.CollectErr(chanIter) + require.NoError(t, err) + require.Len(t, channels, 1) + require.Equal(t, edgeInfo.ChannelID, channels[0].Info.ChannelID) + + // Query for channel updates in block range [200, 300] should return + // nothing since policies are at heights 150 and 160. + chanIter = graph.ChanUpdatesInHorizon( + ctx, lnwire.GossipVersion2, ChanUpdateRange{ + StartHeight: fn.Some(uint32(200)), + EndHeight: fn.Some(uint32(300)), + }, + ) + channels, err = fn.CollectErr(chanIter) + require.NoError(t, err) + require.Empty(t, channels) +} + +// TestPreferHighestAndGetVersions tests the four new Store methods: +// FetchChannelEdgesByIDPreferHighest, FetchChannelEdgesByOutpointPreferHighest, +// GetVersionsBySCID, and GetVersionsByOutpoint. +func TestPreferHighestAndGetVersions(t *testing.T) { + t.Parallel() + ctx := t.Context() + + graph := MakeTestGraph(t) + store := graph.db + + node1Priv, err := btcec.NewPrivateKey() + require.NoError(t, err) + node2Priv, err := btcec.NewPrivateKey() + require.NoError(t, err) + + node1V1 := createNode(t, lnwire.GossipVersion1, node1Priv) + node2V1 := createNode(t, lnwire.GossipVersion1, node2Priv) + + require.NoError(t, graph.AddNode(ctx, node1V1)) + require.NoError(t, graph.AddNode(ctx, node2V1)) + + // Create and add a v1 channel edge. + edgeInfo, scid := createEdge( + lnwire.GossipVersion1, 100, 1, 0, 1, node1V1, node2V1, + ) + require.NoError(t, graph.AddChannelEdge(ctx, edgeInfo)) + + chanID := scid.ToUint64() + op := edgeInfo.ChannelPoint + + // FetchChannelEdgesByIDPreferHighest should return the v1 channel. + info, _, _, err := store.FetchChannelEdgesByIDPreferHighest( + ctx, chanID, + ) + require.NoError(t, err) + require.Equal(t, chanID, info.ChannelID) + + // FetchChannelEdgesByOutpointPreferHighest should also return it. + info, _, _, err = store.FetchChannelEdgesByOutpointPreferHighest( + ctx, &op, + ) + require.NoError(t, err) + require.Equal(t, chanID, info.ChannelID) + + // Querying a non-existent channel should return an error. + _, _, _, err = store.FetchChannelEdgesByIDPreferHighest(ctx, 999999) + require.Error(t, err) + + // GetVersionsBySCID should report v1. + versions, err := store.GetVersionsBySCID(ctx, chanID) + require.NoError(t, err) + require.Equal(t, []lnwire.GossipVersion{ + lnwire.GossipVersion1, + }, versions) + + // GetVersionsByOutpoint should also report v1. + versions, err = store.GetVersionsByOutpoint(ctx, &op) + require.NoError(t, err) + require.Equal(t, []lnwire.GossipVersion{ + lnwire.GossipVersion1, + }, versions) + + // GetVersions for a non-existent SCID should return empty. + versions, err = store.GetVersionsBySCID(ctx, 999999) + require.NoError(t, err) + require.Empty(t, versions) + + if !isSQLDB { + return + } + + node1V2 := createNode(t, lnwire.GossipVersion2, node1Priv) + node2V2 := createNode(t, lnwire.GossipVersion2, node2Priv) + require.NoError(t, graph.AddNode(ctx, node1V2)) + require.NoError(t, graph.AddNode(ctx, node2V2)) + + // Add a duplicate v1/v2 channel and verify prefer-highest chooses + // the v2 edge while GetVersions reports both versions. + dupV1, dupSCID := createEdge( + lnwire.GossipVersion1, 101, 1, 0, 2, node1V1, node2V1, + ) + dupV2, _ := createEdge( + lnwire.GossipVersion2, 101, 1, 0, 2, node1V2, node2V2, + ) + require.NoError(t, graph.AddChannelEdge(ctx, dupV1)) + require.NoError(t, graph.AddChannelEdge(ctx, dupV2)) + + dupChanID := dupSCID.ToUint64() + dupOutpoint := dupV1.ChannelPoint + + info, _, _, err = store.FetchChannelEdgesByIDPreferHighest( + ctx, dupChanID, + ) + require.NoError(t, err) + require.Equal(t, dupChanID, info.ChannelID) + require.Equal(t, lnwire.GossipVersion2, info.Version) + + info, _, _, err = store.FetchChannelEdgesByOutpointPreferHighest( + ctx, &dupOutpoint, + ) + require.NoError(t, err) + require.Equal(t, dupChanID, info.ChannelID) + require.Equal(t, lnwire.GossipVersion2, info.Version) + versions, err = store.GetVersionsBySCID(ctx, dupChanID) + require.NoError(t, err) + require.Equal(t, []lnwire.GossipVersion{ + lnwire.GossipVersion1, + lnwire.GossipVersion2, + }, versions) + + versions, err = store.GetVersionsByOutpoint(ctx, &dupOutpoint) + require.NoError(t, err) + require.Equal(t, []lnwire.GossipVersion{ + lnwire.GossipVersion1, + lnwire.GossipVersion2, + }, versions) + // Add another duplicate v1/v2 channel where only the v1 version has a + // policy. Prefer-highest should return the lower version with usable + // policy data instead of the higher version shell. + policyPrefV1, policyPrefSCID := createEdge( + lnwire.GossipVersion1, 102, 1, 0, 3, node1V1, node2V1, + ) + policyPrefV2, _ := createEdge( + lnwire.GossipVersion2, 102, 1, 0, 3, node1V2, node2V2, + ) + require.NoError(t, graph.AddChannelEdge(ctx, policyPrefV1)) + require.NoError(t, graph.AddChannelEdge(ctx, policyPrefV2)) + + policyOnlyV1 := newEdgePolicy( + lnwire.GossipVersion1, policyPrefV1.ChannelID, 1000, true, + ) + policyOnlyV1.ToNode = node2V1.PubKeyBytes + policyOnlyV1.SigBytes = testSig.Serialize() + require.NoError(t, graph.UpdateEdgePolicy(ctx, policyOnlyV1)) + + policyPrefChanID := policyPrefSCID.ToUint64() + policyPrefOutpoint := policyPrefV1.ChannelPoint + + info, p1, p2, err := store.FetchChannelEdgesByIDPreferHighest( + ctx, policyPrefChanID, + ) + require.NoError(t, err) + require.Equal(t, policyPrefChanID, info.ChannelID) + require.Equal(t, lnwire.GossipVersion1, info.Version) + require.NotNil(t, p1) + require.Nil(t, p2) + + info, p1, p2, err = store.FetchChannelEdgesByOutpointPreferHighest( + ctx, &policyPrefOutpoint, + ) + require.NoError(t, err) + require.Equal(t, policyPrefChanID, info.ChannelID) + require.Equal(t, lnwire.GossipVersion1, info.Version) + require.NotNil(t, p1) + require.Nil(t, p2) +} +// TestPreferHighestNodeTraversal verifies that ChannelGraph's +// ForEachNodeDirectedChannel and FetchNodeFeatures correctly prefer v2 over v1 +// when the graph cache is disabled (exercising the no-cache code paths). +func TestPreferHighestNodeTraversal(t *testing.T) { + t.Parallel() + + if !isSQLDB { + t.Skip("prefer-highest requires SQL backend") + } + + ctx := t.Context() + + // Disable the cache so we exercise the no-cache code paths in + // ChannelGraph.ForEachNodeDirectedChannel and FetchNodeFeatures. + graph := MakeTestGraph(t, WithUseGraphCache(false)) + + // --- FetchNodeFeatures --- + + // Create a v1-only node and verify its features are returned. + privV1, err := btcec.NewPrivateKey() + require.NoError(t, err) + + nodeV1 := createNode(t, lnwire.GossipVersion1, privV1) + require.NoError(t, graph.AddNode(ctx, nodeV1)) + + features, err := graph.FetchNodeFeatures(ctx, nodeV1.PubKeyBytes) + require.NoError(t, err) + require.False(t, features.IsEmpty(), + "v1-only node should have features") + + // Create a v2-only node and verify its features are returned + // (exercises the v2 fallback). + privV2, err := btcec.NewPrivateKey() + require.NoError(t, err) + + nodeV2 := createNode(t, lnwire.GossipVersion2, privV2) + require.NoError(t, graph.AddNode(ctx, nodeV2)) + + features, err = graph.FetchNodeFeatures(ctx, nodeV2.PubKeyBytes) + require.NoError(t, err) + require.False(t, features.IsEmpty(), + "v2-only node should have features") + + // Create a node with both v1 and v2 announcements. + privBoth, err := btcec.NewPrivateKey() + require.NoError(t, err) + + nodeBothV1 := createNode(t, lnwire.GossipVersion1, privBoth) + v1Features := lnwire.NewFeatureVector( + lnwire.NewRawFeatureVector(lnwire.GossipQueriesRequired), + lnwire.Features, + ) + nodeBothV1.Features = v1Features + require.NoError(t, graph.AddNode(ctx, nodeBothV1)) + + nodeBothV2 := createNode(t, lnwire.GossipVersion2, privBoth) + v2Features := lnwire.NewFeatureVector( + lnwire.NewRawFeatureVector(lnwire.TLVOnionPayloadRequired), + lnwire.Features, + ) + nodeBothV2.Features = v2Features + require.NoError(t, graph.AddNode(ctx, nodeBothV2)) + + features, err = graph.FetchNodeFeatures( + ctx, nodeBothV1.PubKeyBytes, + ) + require.NoError(t, err) + require.Equal(t, v2Features, features) + require.NotEqual(t, v1Features, features) + + // --- ForEachNodeDirectedChannel --- + + // Add a v1 channel between nodeV1 and nodeBothV1. + edge, _ := createEdge( + lnwire.GossipVersion1, 100, 0, 0, 0, + nodeV1, nodeBothV1, + ) + require.NoError(t, graph.AddChannelEdge(ctx, edge)) + + pol := newEdgePolicy( + lnwire.GossipVersion1, edge.ChannelID, 1000, true, + ) + pol.ToNode = nodeBothV1.PubKeyBytes + pol.SigBytes = testSig.Serialize() + require.NoError(t, graph.UpdateEdgePolicy(ctx, pol)) + + // ForEachNodeDirectedChannel should find the channel. + var foundChannels int + err = graph.ForEachNodeDirectedChannel( + ctx, nodeV1.PubKeyBytes, + func(_ *DirectedChannel) error { + foundChannels++ + return nil + }, func() { + foundChannels = 0 + }, + ) + require.NoError(t, err) + require.Equal(t, 1, foundChannels, + "expected 1 channel for v1 node") +} +// TestPreferHighestForEachNode verifies that SQLStore.ForEachNode returns one +// node per pubkey, preferring the highest announced version and otherwise +// falling back to the highest-version shell node. +// TestPreferHighestForEachNode verifies that SQLStore.ForEachNode returns one +// node per pubkey, preferring the highest announced version and otherwise +// falling back to the highest-version shell node. +func TestPreferHighestForEachNode(t *testing.T) { + t.Parallel() + + if !isSQLDB { + t.Skip("prefer-highest requires SQL backend") + } + + ctx := t.Context() + graph := MakeTestGraph(t) + store := graph.db + + v1Only := createTestVertex(t, lnwire.GossipVersion1) + v1Only.Alias = fn.Some("v1-only") + require.NoError(t, graph.AddNode(ctx, v1Only)) + + bothPriv, err := btcec.NewPrivateKey() + require.NoError(t, err) + + bothV1 := createNode(t, lnwire.GossipVersion1, bothPriv) + bothV1.Alias = fn.Some("both-v1") + require.NoError(t, graph.AddNode(ctx, bothV1)) + + bothV2 := createNode(t, lnwire.GossipVersion2, bothPriv) + bothV2.Alias = fn.Some("both-v2") + require.NoError(t, graph.AddNode(ctx, bothV2)) + + shellPriv, err := btcec.NewPrivateKey() + require.NoError(t, err) + + shellPub, err := route.NewVertexFromBytes( + shellPriv.PubKey().SerializeCompressed(), + ) + require.NoError(t, err) + + require.NoError(t, graph.AddNode( + ctx, models.NewShellNode(lnwire.GossipVersion1, shellPub), + )) + require.NoError(t, graph.AddNode( + ctx, models.NewShellNode(lnwire.GossipVersion2, shellPub), + )) + + nodesByPub := make(map[route.Vertex]*models.Node) + err = store.ForEachNode(ctx, func(node *models.Node) error { + nodesByPub[node.PubKeyBytes] = node + return nil + }, func() { + clear(nodesByPub) + }) + require.NoError(t, err) + require.Len(t, nodesByPub, 3) + + gotV1Only := nodesByPub[v1Only.PubKeyBytes] + require.NotNil(t, gotV1Only) + require.Equal(t, lnwire.GossipVersion1, gotV1Only.Version) + require.Equal(t, "v1-only", gotV1Only.Alias.UnwrapOr("")) + require.True(t, gotV1Only.HaveAnnouncement()) + + gotBoth := nodesByPub[bothV1.PubKeyBytes] + require.NotNil(t, gotBoth) + require.Equal(t, lnwire.GossipVersion2, gotBoth.Version) + require.Equal(t, "both-v2", gotBoth.Alias.UnwrapOr("")) + require.True(t, gotBoth.HaveAnnouncement()) + + gotShell := nodesByPub[shellPub] + require.NotNil(t, gotShell) + require.Equal(t, lnwire.GossipVersion2, gotShell.Version) + require.False(t, gotShell.HaveAnnouncement()) +} + +// TestPreferHighestForEachChannel verifies that SQLStore.ForEachChannel returns +// one channel per SCID, preferring a higher-version channel when both versions +// have policies, preserving lower-version policy data when the higher version +// has none, and otherwise falling back to the highest-version no-policy +// channel. +func TestPreferHighestForEachChannel(t *testing.T) { + t.Parallel() + + if !isSQLDB { + t.Skip("prefer-highest requires SQL backend") + } + + ctx := t.Context() + graph := MakeTestGraph(t) + store := graph.db + + node1Priv, err := btcec.NewPrivateKey() + require.NoError(t, err) + node2Priv, err := btcec.NewPrivateKey() + require.NoError(t, err) + + node1V1 := createNode(t, lnwire.GossipVersion1, node1Priv) + node1V2 := createNode(t, lnwire.GossipVersion2, node1Priv) + node2V1 := createNode(t, lnwire.GossipVersion1, node2Priv) + node2V2 := createNode(t, lnwire.GossipVersion2, node2Priv) + + require.NoError(t, graph.AddNode(ctx, node1V1)) + require.NoError(t, graph.AddNode(ctx, node1V2)) + require.NoError(t, graph.AddNode(ctx, node2V1)) + require.NoError(t, graph.AddNode(ctx, node2V2)) + + v1Only, _ := createEdge( + lnwire.GossipVersion1, 200, 0, 0, 1, node1V1, node2V1, + ) + require.NoError(t, graph.AddChannelEdge(ctx, v1Only)) + + policyPrefV1, _ := createEdge( + lnwire.GossipVersion1, 201, 0, 0, 2, node1V1, node2V1, + ) + policyPrefV2, _ := createEdge( + lnwire.GossipVersion2, 201, 0, 0, 2, node1V2, node2V2, + ) + require.NoError(t, graph.AddChannelEdge(ctx, policyPrefV1)) + require.NoError(t, graph.AddChannelEdge(ctx, policyPrefV2)) + + policyOnlyV1 := newEdgePolicy( + lnwire.GossipVersion1, policyPrefV1.ChannelID, 1000, true, + ) + policyOnlyV1.ToNode = node2V1.PubKeyBytes + policyOnlyV1.SigBytes = testSig.Serialize() + require.NoError(t, graph.UpdateEdgePolicy(ctx, policyOnlyV1)) + + versionPrefV1, _ := createEdge( + lnwire.GossipVersion1, 202, 0, 0, 3, node1V1, node2V1, + ) + versionPrefV2, _ := createEdge( + lnwire.GossipVersion2, 202, 0, 0, 3, node1V2, node2V2, + ) + require.NoError(t, graph.AddChannelEdge(ctx, versionPrefV1)) + require.NoError(t, graph.AddChannelEdge(ctx, versionPrefV2)) + + versionPolicyV1 := newEdgePolicy( + lnwire.GossipVersion1, versionPrefV1.ChannelID, 1001, true, + ) + versionPolicyV1.ToNode = node2V1.PubKeyBytes + versionPolicyV1.SigBytes = testSig.Serialize() + require.NoError(t, graph.UpdateEdgePolicy(ctx, versionPolicyV1)) + + versionPolicyV2 := newEdgePolicy( + lnwire.GossipVersion2, versionPrefV2.ChannelID, 1002, true, + ) + versionPolicyV2.ToNode = node2V2.PubKeyBytes + versionPolicyV2.SigBytes = testSig.Serialize() + require.NoError(t, graph.UpdateEdgePolicy(ctx, versionPolicyV2)) + + shellPrefV1, _ := createEdge( + lnwire.GossipVersion1, 203, 0, 0, 4, node1V1, node2V1, + ) + shellPrefV2, _ := createEdge( + lnwire.GossipVersion2, 203, 0, 0, 4, node1V2, node2V2, + ) + require.NoError(t, graph.AddChannelEdge(ctx, shellPrefV1)) + require.NoError(t, graph.AddChannelEdge(ctx, shellPrefV2)) + + type channelResult struct { + info *models.ChannelEdgeInfo + p1 *models.ChannelEdgePolicy + p2 *models.ChannelEdgePolicy + } + channelsByID := make(map[uint64]channelResult) + err = store.ForEachChannel(ctx, func(info *models.ChannelEdgeInfo, + p1, p2 *models.ChannelEdgePolicy) error { + + channelsByID[info.ChannelID] = channelResult{ + info: info, + p1: p1, + p2: p2, + } + return nil + }, func() { + clear(channelsByID) + }) + require.NoError(t, err) + require.Len(t, channelsByID, 4) + + gotV1Only := channelsByID[v1Only.ChannelID] + require.Equal(t, lnwire.GossipVersion1, gotV1Only.info.Version) + require.Nil(t, gotV1Only.p1) + require.Nil(t, gotV1Only.p2) + + gotPolicyPref := channelsByID[policyPrefV1.ChannelID] + require.Equal(t, lnwire.GossipVersion1, gotPolicyPref.info.Version) + require.NotNil(t, gotPolicyPref.p1) + require.Nil(t, gotPolicyPref.p2) + + gotVersionPref := channelsByID[versionPrefV1.ChannelID] + require.Equal(t, lnwire.GossipVersion2, gotVersionPref.info.Version) + require.NotNil(t, gotVersionPref.p1) + + gotShellPref := channelsByID[shellPrefV1.ChannelID] + require.Equal(t, lnwire.GossipVersion2, gotShellPref.info.Version) + require.Nil(t, gotShellPref.p1) + require.Nil(t, gotShellPref.p2) +} + +// TestPreferHighestNodeDirectedChannelTraversal verifies that the no-cache +// ChannelGraph.ForEachNodeDirectedChannel path streams one directed channel per +// SCID while preferring the v2 advertisement when both versions exist. +func TestPreferHighestNodeDirectedChannelTraversal(t *testing.T) { + t.Parallel() + + if !isSQLDB { + t.Skip("prefer-highest requires SQL backend") + } + + ctx := t.Context() + graph := MakeTestGraph(t, WithUseGraphCache(false)) + + localPriv, err := btcec.NewPrivateKey() + require.NoError(t, err) + peerBothPriv, err := btcec.NewPrivateKey() + require.NoError(t, err) + peerV1OnlyPriv, err := btcec.NewPrivateKey() + require.NoError(t, err) + + localV1 := createNode(t, lnwire.GossipVersion1, localPriv) + localV2 := createNode(t, lnwire.GossipVersion2, localPriv) + peerBothV1 := createNode(t, lnwire.GossipVersion1, peerBothPriv) + peerBothV2 := createNode(t, lnwire.GossipVersion2, peerBothPriv) + peerV1Only := createNode(t, lnwire.GossipVersion1, peerV1OnlyPriv) + + require.NoError(t, graph.AddNode(ctx, localV1)) + require.NoError(t, graph.AddNode(ctx, localV2)) + require.NoError(t, graph.AddNode(ctx, peerBothV1)) + require.NoError(t, graph.AddNode(ctx, peerBothV2)) + require.NoError(t, graph.AddNode(ctx, peerV1Only)) + + dupV1, _ := createEdge( + lnwire.GossipVersion1, 400, 0, 0, 10, localV1, peerBothV1, + ) + dupV2, _ := createEdge( + lnwire.GossipVersion2, 400, 0, 0, 10, localV2, peerBothV2, + ) + v1Only, _ := createEdge( + lnwire.GossipVersion1, 401, 0, 0, 11, localV1, peerV1Only, + ) + + require.NoError(t, graph.AddChannelEdge(ctx, dupV1)) + require.NoError(t, graph.AddChannelEdge(ctx, dupV2)) + require.NoError(t, graph.AddChannelEdge(ctx, v1Only)) + + addPolicies := func(edgeInfo *models.ChannelEdgeInfo, + version lnwire.GossipVersion, + fee lnwire.MilliSatoshi) { + + policy1 := newEdgePolicy(version, edgeInfo.ChannelID, 1000, true) + policy1.ToNode = edgeInfo.NodeKey2Bytes + policy1.SigBytes = testSig.Serialize() + policy1.FeeBaseMSat = fee + + policy2 := newEdgePolicy(version, edgeInfo.ChannelID, 1001, false) + policy2.ToNode = edgeInfo.NodeKey1Bytes + policy2.SigBytes = testSig.Serialize() + policy2.FeeBaseMSat = fee + + require.NoError(t, graph.UpdateEdgePolicy(ctx, policy1)) + require.NoError(t, graph.UpdateEdgePolicy(ctx, policy2)) + } + + addPolicies(dupV1, lnwire.GossipVersion1, 1111) + addPolicies(dupV2, lnwire.GossipVersion2, 2222) + addPolicies(v1Only, lnwire.GossipVersion1, 3333) + + channelsByID := make(map[uint64]*DirectedChannel) + err = graph.ForEachNodeDirectedChannel( + ctx, localV1.PubKeyBytes, + func(channel *DirectedChannel) error { + channelsByID[channel.ChannelID] = channel.DeepCopy() + return nil + }, func() { + clear(channelsByID) + }, + ) + require.NoError(t, err) + require.Len(t, channelsByID, 2) + + gotDup := channelsByID[dupV1.ChannelID] + require.NotNil(t, gotDup) + require.NotNil(t, gotDup.InPolicy) + require.Equal(t, lnwire.MilliSatoshi(2222), gotDup.InPolicy.FeeBaseMSat) + + gotV1Only := channelsByID[v1Only.ChannelID] + require.NotNil(t, gotV1Only) + require.NotNil(t, gotV1Only.InPolicy) + require.Equal( + t, lnwire.MilliSatoshi(3333), gotV1Only.InPolicy.FeeBaseMSat, + ) +} diff --git a/graph/db/interfaces.go b/graph/db/interfaces.go index a725cb8ee4..33af4c1e00 100644 --- a/graph/db/interfaces.go +++ b/graph/db/interfaces.go @@ -29,8 +29,8 @@ type NodeTraverser interface { nodePub route.Vertex) (*lnwire.FeatureVector, error) } -// Store represents the main interface for the channel graph database for all -// channels and nodes gossiped via the V1 gossip protocol as defined in BOLT 7. +// Store represents the main interface for the channel graph database. It +// supports channels and nodes from multiple gossip protocol versions. type Store interface { //nolint:interfacebloat // ForEachNodeDirectedChannel calls the callback for every channel of // the given node. @@ -95,11 +95,11 @@ type Store interface { //nolint:interfacebloat chans map[uint64]*DirectedChannel) error, reset func()) error - // ForEachNode iterates through all the stored vertices/nodes in the - // graph, executing the passed callback with each node encountered. If - // the callback returns an error, then the transaction is aborted and - // the iteration stops early. - ForEachNode(ctx context.Context, v lnwire.GossipVersion, + // ForEachNode iterates through all nodes in the graph across all + // gossip versions, yielding each unique node exactly once. The + // callback receives the best available Node (highest advertised + // version preferred, falling back to shell nodes). + ForEachNode(ctx context.Context, cb func(*models.Node) error, reset func()) error // ForEachNodeCacheable iterates through all the stored vertices/nodes @@ -120,11 +120,12 @@ type Store interface { //nolint:interfacebloat DeleteNode(ctx context.Context, v lnwire.GossipVersion, nodePub route.Vertex) error - // NodeUpdatesInHorizon returns all the known lightning node which have - // an update timestamp within the passed range. This method can be used - // by two nodes to quickly determine if they have the same set of up to - // date node announcements. - NodeUpdatesInHorizon(ctx context.Context, startTime, endTime time.Time, + // NodeUpdatesInHorizon returns all the known lightning nodes which have + // updates within the passed range for the given gossip version. This + // method can be used by two nodes to quickly determine if they have the + // same set of up to date node announcements. + NodeUpdatesInHorizon(ctx context.Context, v lnwire.GossipVersion, + r NodeUpdateRange, opts ...IteratorOption) iter.Seq2[*models.Node, error] // FetchNode attempts to look up a target node by its identity @@ -160,21 +161,16 @@ type Store interface { //nolint:interfacebloat GraphSession(ctx context.Context, cb func(graph NodeTraverser) error, reset func()) error - // ForEachChannel iterates through all the channel edges stored within - // the graph and invokes the passed callback for each edge. The callback - // takes two edges as since this is a directed graph, both the in/out - // edges are visited. If the callback returns an error, then the - // transaction is aborted and the iteration stops early. - // - // NOTE: If an edge can't be found, or wasn't advertised, then a nil - // pointer for that particular channel edge routing policy will be - // passed into the callback. - // - // TODO(elle): add a cross-version iteration API and make this iterate - // over all versions. - ForEachChannel(ctx context.Context, v lnwire.GossipVersion, + // ForEachChannel iterates through all channel edges stored within the + // graph across all gossip versions, yielding each unique channel + // exactly once. The callback receives the edge info and both + // directional policies. When both versions are present, v2 is + // preferred. Nil pointers are passed for policies that haven't been + // advertised. + ForEachChannel(ctx context.Context, cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy) error, reset func()) error + *models.ChannelEdgePolicy) error, + reset func()) error // ForEachChannelCacheable iterates through all the channel edges stored // within the graph and invokes the passed callback for each edge. The @@ -256,10 +252,10 @@ type Store interface { //nolint:interfacebloat uint64, error) // ChanUpdatesInHorizon returns all the known channel edges which have - // at least one edge that has an update timestamp within the specified - // horizon. - ChanUpdatesInHorizon(ctx context.Context, - startTime, endTime time.Time, + // at least one edge update within the specified range for the given + // gossip version. + ChanUpdatesInHorizon(ctx context.Context, v lnwire.GossipVersion, + r ChanUpdateRange, opts ...IteratorOption) iter.Seq2[ChannelEdge, error] // FilterKnownChanIDs takes a set of channel IDs and return the subset @@ -269,9 +265,9 @@ type Store interface { //nolint:interfacebloat // callers to determine the set of channels another peer knows of that // we don't. The ChannelUpdateInfos for the known zombies is also // returned. - FilterKnownChanIDs(ctx context.Context, - chansInfo []ChannelUpdateInfo) ([]uint64, []ChannelUpdateInfo, - error) + FilterKnownChanIDs(ctx context.Context, v lnwire.GossipVersion, + chansInfo []ChannelUpdateInfo) ([]uint64, + []ChannelUpdateInfo, error) // FilterChannelRange returns the channel ID's of all known channels // which were mined in a block height within the passed range for the @@ -321,6 +317,35 @@ type Store interface { //nolint:interfacebloat *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) + // FetchChannelEdgesByIDPreferHighest behaves like FetchChannelEdgesByID + // but is version-agnostic: if the channel exists under multiple gossip + // versions it returns the record with the highest version number. + FetchChannelEdgesByIDPreferHighest(ctx context.Context, + chanID uint64) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) + + // FetchChannelEdgesByOutpointPreferHighest behaves like + // FetchChannelEdgesByOutpoint but is version-agnostic: if the channel + // exists under multiple gossip versions it returns the record with the + // highest version number. + FetchChannelEdgesByOutpointPreferHighest(ctx context.Context, + op *wire.OutPoint) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) + + // GetVersionsBySCID returns the list of gossip versions for which a + // channel with the given SCID exists in the database, ordered + // ascending. + GetVersionsBySCID(ctx context.Context, + chanID uint64) ([]lnwire.GossipVersion, error) + + // GetVersionsByOutpoint returns the list of gossip versions for which + // a channel with the given funding outpoint exists in the database, + // ordered ascending. + GetVersionsByOutpoint(ctx context.Context, + op *wire.OutPoint) ([]lnwire.GossipVersion, error) + // ChannelView returns the verifiable edge information for each active // channel within the known channel graph for the given gossip version. // The set of UTXO's (along with their scripts) returned are the ones diff --git a/graph/db/kv_store.go b/graph/db/kv_store.go index aa32f75daf..364185613b 100644 --- a/graph/db/kv_store.go +++ b/graph/db/kv_store.go @@ -408,13 +408,10 @@ func (c *KVStore) AddrsForNode(ctx context.Context, v lnwire.GossipVersion, // NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer // for that particular channel edge routing policy will be passed into the // callback. -func (c *KVStore) ForEachChannel(_ context.Context, v lnwire.GossipVersion, +func (c *KVStore) ForEachChannel(_ context.Context, cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy) error, reset func()) error { - - if v != lnwire.GossipVersion1 { - return ErrVersionNotSupportedForKVDB - } + *models.ChannelEdgePolicy) error, + reset func()) error { return forEachChannel(c.db, cb, reset) } @@ -842,13 +839,9 @@ func (c *KVStore) DisabledChannelIDs( // early. // // NOTE: this is part of the Store interface. -func (c *KVStore) ForEachNode(_ context.Context, v lnwire.GossipVersion, +func (c *KVStore) ForEachNode(_ context.Context, cb func(*models.Node) error, reset func()) error { - if v != lnwire.GossipVersion1 { - return ErrVersionNotSupportedForKVDB - } - return forEachNode(c.db, func(tx kvdb.RTx, node *models.Node) error { @@ -2394,9 +2387,10 @@ func (c *KVStore) fetchNextChanUpdateBatch( } // ChanUpdatesInHorizon returns all the known channel edges which have at least -// one edge that has an update timestamp within the specified horizon. +// one edge that has an update within the specified range for the given gossip +// version. func (c *KVStore) ChanUpdatesInHorizon(_ context.Context, - startTime, endTime time.Time, + v lnwire.GossipVersion, r ChanUpdateRange, opts ...IteratorOption) iter.Seq2[ChannelEdge, error] { cfg := defaultIteratorConfig() @@ -2405,8 +2399,19 @@ func (c *KVStore) ChanUpdatesInHorizon(_ context.Context, } return func(yield func(ChannelEdge, error) bool) { + if v != lnwire.GossipVersion1 { + yield(ChannelEdge{}, ErrVersionNotSupportedForKVDB) + return + } + if err := r.validateForVersion(v); err != nil { + yield(ChannelEdge{}, err) + return + } + iterState := newChanUpdatesIterator( - cfg.chanUpdateIterBatchSize, startTime, endTime, + cfg.chanUpdateIterBatchSize, + r.StartTime.UnwrapOr(time.Time{}), + r.EndTime.UnwrapOr(time.Time{}), ) for { @@ -2455,8 +2460,8 @@ func (c *KVStore) ChanUpdatesInHorizon(_ context.Context, float64(iterState.total), iterState.hits, iterState.total) } else { - log.Tracef("ChanUpdatesInHorizon returned no edges "+ - "in horizon (%s, %s)", startTime, endTime) + log.Tracef("ChanUpdatesInHorizon(v%d) returned "+ + "no edges in horizon", v) } } } @@ -2645,9 +2650,9 @@ func (c *KVStore) fetchNextNodeBatch( } // NodeUpdatesInHorizon returns all the known lightning node which have an -// update timestamp within the passed range. -func (c *KVStore) NodeUpdatesInHorizon(_ context.Context, startTime, - endTime time.Time, +// update timestamp within the passed range for the given gossip version. +func (c *KVStore) NodeUpdatesInHorizon(_ context.Context, + v lnwire.GossipVersion, r NodeUpdateRange, opts ...IteratorOption) iter.Seq2[*models.Node, error] { cfg := defaultIteratorConfig() @@ -2656,10 +2661,20 @@ func (c *KVStore) NodeUpdatesInHorizon(_ context.Context, startTime, } return func(yield func(*models.Node, error) bool) { + if v != lnwire.GossipVersion1 { + yield(nil, ErrVersionNotSupportedForKVDB) + return + } + if err := r.validateForVersion(v); err != nil { + yield(nil, err) + return + } + // Initialize iterator state. state := newNodeUpdatesIterator( cfg.nodeUpdateIterBatchSize, - startTime, endTime, + r.StartTime.UnwrapOr(time.Time{}), + r.EndTime.UnwrapOr(time.Time{}), cfg.iterPublicNodes, ) @@ -2696,8 +2711,13 @@ func (c *KVStore) NodeUpdatesInHorizon(_ context.Context, startTime, // channels another peer knows of that we don't. The ChannelUpdateInfos for the // known zombies is also returned. func (c *KVStore) FilterKnownChanIDs(_ context.Context, + v lnwire.GossipVersion, chansInfo []ChannelUpdateInfo) ([]uint64, []ChannelUpdateInfo, error) { + if v != lnwire.GossipVersion1 { + return nil, nil, ErrVersionNotSupportedForKVDB + } + var ( newChanIDs []uint64 knownZombies []ChannelUpdateInfo @@ -4166,7 +4186,85 @@ func (c *KVStore) FetchChannelEdgesByID(_ context.Context, return edgeInfo, policy1, policy2, nil } -// IsPublicNode is a helper method that determines whether the node with the +// FetchChannelEdgesByIDPreferHighest looks up the channel by ID. The KV store +// only supports gossip v1, so this simply delegates to the versioned fetch. +// +// NOTE: part of the Store interface. +func (c *KVStore) FetchChannelEdgesByIDPreferHighest(ctx context.Context, + chanID uint64) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) { + + return c.FetchChannelEdgesByID(ctx, lnwire.GossipVersion1, chanID) +} + +// FetchChannelEdgesByOutpointPreferHighest looks up the channel by funding +// outpoint. The KV store only supports gossip v1, so this simply delegates to +// the versioned fetch. +// +// NOTE: part of the Store interface. +func (c *KVStore) FetchChannelEdgesByOutpointPreferHighest( + ctx context.Context, op *wire.OutPoint) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) { + + return c.FetchChannelEdgesByOutpoint( + ctx, lnwire.GossipVersion1, op, + ) +} + +// GetVersionsBySCID returns the gossip versions for which a channel with the +// given SCID exists. The KV store only supports gossip v1, so at most one +// version is returned. +// +// NOTE: part of the Store interface. +func (c *KVStore) GetVersionsBySCID(ctx context.Context, + chanID uint64) ([]lnwire.GossipVersion, error) { + + _, _, _, err := c.FetchChannelEdgesByID( + ctx, lnwire.GossipVersion1, chanID, + ) + switch { + case errors.Is(err, ErrEdgeNotFound): + return nil, nil + + case errors.Is(err, ErrZombieEdge): + return nil, nil + + case err != nil: + return nil, err + + default: + return []lnwire.GossipVersion{lnwire.GossipVersion1}, nil + } +} + +// GetVersionsByOutpoint returns the gossip versions for which a channel with +// the given funding outpoint exists. The KV store only supports gossip v1, so +// at most one version is returned. +// +// NOTE: part of the Store interface. +func (c *KVStore) GetVersionsByOutpoint(ctx context.Context, + op *wire.OutPoint) ([]lnwire.GossipVersion, error) { + + _, _, _, err := c.FetchChannelEdgesByOutpoint( + ctx, lnwire.GossipVersion1, op, + ) + switch { + case errors.Is(err, ErrEdgeNotFound): + return nil, nil + + case errors.Is(err, ErrZombieEdge): + return nil, nil + + case err != nil: + return nil, err + + default: + return []lnwire.GossipVersion{lnwire.GossipVersion1}, nil + } +} + // given public key is seen as a public node in the graph from the graph's // source node's point of view. func (c *KVStore) IsPublicNode(_ context.Context, v lnwire.GossipVersion, diff --git a/graph/db/options.go b/graph/db/options.go index df49fd7e0f..02e8895782 100644 --- a/graph/db/options.go +++ b/graph/db/options.go @@ -1,6 +1,13 @@ package graphdb -import "time" +import ( + "fmt" + "iter" + "time" + + "github.com/lightningnetwork/lnd/fn/v2" + "github.com/lightningnetwork/lnd/lnwire" +) const ( // DefaultRejectCacheSize is the default number of rejectCacheEntries to @@ -39,6 +46,181 @@ type iterConfig struct { iterPublicNodes bool } +// ChanUpdateRange describes a range for channel updates. Only one of the time +// or height ranges should be set depending on the gossip version. +type ChanUpdateRange struct { + // StartTime is the inclusive lower time bound (v1 gossip only). + StartTime fn.Option[time.Time] + + // EndTime is the exclusive upper time bound (v1 gossip only). + EndTime fn.Option[time.Time] + + // StartHeight is the inclusive lower block-height bound (v2 gossip + // only). + StartHeight fn.Option[uint32] + + // EndHeight is the exclusive upper block-height bound (v2 gossip + // only). + EndHeight fn.Option[uint32] +} + +// validateForVersion checks that the range fields are consistent with the +// given gossip version: v1 requires time bounds, v2 requires block-height +// bounds, and mixing the two is rejected. +func (r ChanUpdateRange) validateForVersion(v lnwire.GossipVersion) error { + hasStartTime := r.StartTime.IsSome() + hasEndTime := r.EndTime.IsSome() + hasTimeRange := hasStartTime || hasEndTime + + hasStartHeight := r.StartHeight.IsSome() + hasEndHeight := r.EndHeight.IsSome() + hasBlockRange := hasStartHeight || hasEndHeight + + if hasTimeRange && hasBlockRange { + return fmt.Errorf("chan update range has both " + + "time and block ranges") + } + + switch v { + case lnwire.GossipVersion1: + if hasBlockRange { + return fmt.Errorf("v1 chan update range must use time") + } + if !hasTimeRange { + return fmt.Errorf("v1 chan update range missing time") + } + if !hasStartTime || !hasEndTime { + return fmt.Errorf("v1 chan update range " + + "missing time bounds") + } + + start := r.StartTime.UnwrapOr(time.Time{}) + end := r.EndTime.UnwrapOr(time.Time{}) + if start.After(end) { + return fmt.Errorf("v1 chan update range: " + + "start time after end time") + } + + case lnwire.GossipVersion2: + if hasTimeRange { + return fmt.Errorf("v2 chan update range " + + "must use blocks") + } + if !hasBlockRange { + return fmt.Errorf("v2 chan update range " + + "missing block range") + } + if !hasStartHeight || !hasEndHeight { + return fmt.Errorf("v2 chan update range " + + "missing block bounds") + } + + start := r.StartHeight.UnwrapOr(0) + end := r.EndHeight.UnwrapOr(0) + if start > end { + return fmt.Errorf("v2 chan update range: " + + "start height after end height") + } + + default: + return fmt.Errorf("unknown gossip version: %v", v) + } + + return nil +} + +// chanUpdateRangeErrIter returns an iterator that yields a single error. +func chanUpdateRangeErrIter(err error) iter.Seq2[ChannelEdge, error] { + return func(yield func(ChannelEdge, error) bool) { + _ = yield(ChannelEdge{}, err) + } +} + +// NodeUpdateRange describes a range for node updates. Only one of the time or +// height ranges should be set depending on the gossip version. +type NodeUpdateRange struct { + // StartTime is the inclusive lower time bound (v1 gossip only). + StartTime fn.Option[time.Time] + + // EndTime is the inclusive upper time bound (v1 gossip only). + EndTime fn.Option[time.Time] + + // StartHeight is the inclusive lower block-height bound (v2 gossip + // only). + StartHeight fn.Option[uint32] + + // EndHeight is the inclusive upper block-height bound (v2 gossip + // only). + EndHeight fn.Option[uint32] +} + +// validateForVersion checks that the range fields are consistent with the +// given gossip version: v1 requires time bounds, v2 requires block-height +// bounds, and mixing the two is rejected. +func (r NodeUpdateRange) validateForVersion(v lnwire.GossipVersion) error { + hasStartTime := r.StartTime.IsSome() + hasEndTime := r.EndTime.IsSome() + hasTimeRange := hasStartTime || hasEndTime + + hasStartHeight := r.StartHeight.IsSome() + hasEndHeight := r.EndHeight.IsSome() + hasBlockRange := hasStartHeight || hasEndHeight + + if hasTimeRange && hasBlockRange { + return fmt.Errorf("node update range has both " + + "time and block ranges") + } + + switch v { + case lnwire.GossipVersion1: + if hasBlockRange { + return fmt.Errorf("v1 node update range " + + "must use time") + } + if !hasTimeRange { + return fmt.Errorf("v1 node update range " + + "missing time") + } + if !hasStartTime || !hasEndTime { + return fmt.Errorf("v1 node update range " + + "missing time bounds") + } + + start := r.StartTime.UnwrapOr(time.Time{}) + end := r.EndTime.UnwrapOr(time.Time{}) + if start.After(end) { + return fmt.Errorf("v1 node update range: " + + "start time after end time") + } + + case lnwire.GossipVersion2: + if hasTimeRange { + return fmt.Errorf("v2 node update range " + + "must use height") + } + if !hasBlockRange { + return fmt.Errorf("v2 node update range " + + "missing height") + } + if !hasStartHeight || !hasEndHeight { + return fmt.Errorf("v2 node update range " + + "missing height bounds") + } + + start := r.StartHeight.UnwrapOr(0) + end := r.EndHeight.UnwrapOr(0) + if start > end { + return fmt.Errorf("v2 node update range: " + + "start height after end height") + } + + default: + return fmt.Errorf("unknown gossip version: %d", v) + } + + return nil +} + // defaultIteratorConfig returns the default configuration. func defaultIteratorConfig() *iterConfig { return &iterConfig{ diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 0e3907599e..47b253d8d9 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -52,7 +52,9 @@ type SQLQueries interface { GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error) GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error) GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error) + GetNodesByBlockHeightRange(ctx context.Context, arg sqlc.GetNodesByBlockHeightRangeParams) ([]sqlc.GraphNode, error) ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error) + ListPreferredNodesPaginated(ctx context.Context, arg sqlc.ListPreferredNodesPaginatedParams) ([]sqlc.ListPreferredNodesPaginatedRow, error) ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error) IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error) IsPublicV2Node(ctx context.Context, pubKey []byte) (bool, error) @@ -101,12 +103,15 @@ type SQLQueries interface { GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error) HighestSCID(ctx context.Context, version int16) ([]byte, error) ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error) + ListPreferredDirectedChannelsPaginated(ctx context.Context, arg sqlc.ListPreferredDirectedChannelsPaginatedParams) ([]sqlc.ListPreferredDirectedChannelsPaginatedRow, error) ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error) ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error) + ListPreferredChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListPreferredChannelsWithPoliciesPaginatedParams) ([]sqlc.ListPreferredChannelsWithPoliciesPaginatedRow, error) ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error) ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error) ListChannelsPaginatedV2(ctx context.Context, arg sqlc.ListChannelsPaginatedV2Params) ([]sqlc.ListChannelsPaginatedV2Row, error) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error) + GetChannelsByPolicyBlockRange(ctx context.Context, arg sqlc.GetChannelsByPolicyBlockRangeParams) ([]sqlc.GetChannelsByPolicyBlockRangeRow, error) GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error) GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error) GetPublicV2ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV2ChannelsBySCIDParams) ([]sqlc.GraphChannel, error) @@ -606,14 +611,14 @@ func (s *SQLStore) SetSourceNode(ctx context.Context, }, sqldb.NoOpReset) } -// NodeUpdatesInHorizon returns all the known lightning node which have an -// update timestamp within the passed range. This method can be used by two -// nodes to quickly determine if they have the same set of up to date node -// announcements. +// NodeUpdatesInHorizon returns all the known lightning nodes which have +// updates within the passed range for the given gossip version. This method can +// be used by two nodes to quickly determine if they have the same set of +// up-to-date node announcements. // // NOTE: This is part of the Store interface. func (s *SQLStore) NodeUpdatesInHorizon(ctx context.Context, - startTime, endTime time.Time, + v lnwire.GossipVersion, r NodeUpdateRange, opts ...IteratorOption) iter.Seq2[*models.Node, error] { cfg := defaultIteratorConfig() @@ -621,27 +626,33 @@ func (s *SQLStore) NodeUpdatesInHorizon(ctx context.Context, opt(cfg) } + batchSize := cfg.nodeUpdateIterBatchSize + return func(yield func(*models.Node, error) bool) { var ( lastUpdateTime sql.NullInt64 + lastBlock sql.NullInt64 lastPubKey = make([]byte, 33) hasMore = true ) - // Each iteration, we'll read a batch amount of nodes, yield - // them, then decide is we have more or not. - for hasMore { - var batch []*models.Node + if err := r.validateForVersion(v); err != nil { + yield(nil, err) + return + } - //nolint:ll - err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { - //nolint:ll - params := sqlc.GetNodesByLastUpdateRangeParams{ + queryV1 := func(db SQLQueries) ([]sqlc.GraphNode, error) { + return db.GetNodesByLastUpdateRange( + ctx, sqlc.GetNodesByLastUpdateRangeParams{ StartTime: sqldb.SQLInt64( - startTime.Unix(), + r.StartTime.UnwrapOr( + time.Time{}, + ).Unix(), ), EndTime: sqldb.SQLInt64( - endTime.Unix(), + r.EndTime.UnwrapOr( + time.Time{}, + ).Unix(), ), LastUpdate: lastUpdateTime, LastPubKey: lastPubKey, @@ -652,44 +663,106 @@ func (s *SQLStore) NodeUpdatesInHorizon(ctx context.Context, MaxResults: sqldb.SQLInt32( cfg.nodeUpdateIterBatchSize, ), + }, + ) + } + + queryV2 := func(db SQLQueries) ([]sqlc.GraphNode, error) { + startHeight := int64(r.StartHeight.UnwrapOr(0)) + endHeight := int64(r.EndHeight.UnwrapOr(0)) + + return db.GetNodesByBlockHeightRange( + ctx, sqlc.GetNodesByBlockHeightRangeParams{ + Version: int16(v), + StartHeight: sqldb.SQLInt64( + startHeight, + ), + EndHeight: sqldb.SQLInt64( + endHeight, + ), + LastBlockHeight: lastBlock, + LastPubKey: lastPubKey, + OnlyPublic: sql.NullBool{ + Bool: cfg.iterPublicNodes, + Valid: true, + }, + MaxResults: sqldb.SQLInt32( + cfg.nodeUpdateIterBatchSize, + ), + }, + ) + } + + // queryNodes fetches the next batch of nodes in the + // horizon range, dispatching to the version-appropriate + // query. + queryNodes := func(db SQLQueries) ([]sqlc.GraphNode, error) { + switch v { + case gossipV1: + return queryV1(db) + + case gossipV2: + return queryV2(db) + + default: + return nil, fmt.Errorf("unknown gossip "+ + "version: %v", v) + } + } + + // processNode is called for each node in a batch to + // accumulate results and update pagination cursors. + processNode := func(_ int64, + node *models.Node, batch *[]*models.Node) error { + + *batch = append(*batch, node) + + switch v { + case gossipV1: + lastUpdateTime = sql.NullInt64{ + Int64: node.LastUpdate.Unix(), + Valid: true, } - rows, err := db.GetNodesByLastUpdateRange( - ctx, params, - ) - if err != nil { - return err + case gossipV2: + lastBlock = sql.NullInt64{ + Int64: int64(node.LastBlockHeight), + Valid: true, } + } + lastPubKey = node.PubKeyBytes[:] - hasMore = len(rows) == cfg.nodeUpdateIterBatchSize + return nil + } - err = forEachNodeInBatch( - ctx, s.cfg.QueryCfg, db, rows, - func(_ int64, node *models.Node) error { - batch = append(batch, node) + // Each iteration, we'll read a batch amount of nodes, + // yield them, then decide if we have more or not. + for hasMore { + var batch []*models.Node - // Update pagination cursors - // based on the last processed - // node. - lastUpdateTime = sql.NullInt64{ - Int64: node.LastUpdate. - Unix(), - Valid: true, - } - lastPubKey = node.PubKeyBytes[:] + err := s.db.ExecTx( + ctx, sqldb.ReadTxOpt(), + func(db SQLQueries) error { + rows, err := queryNodes(db) + if err != nil { + return err + } - return nil - }, - ) - if err != nil { - return fmt.Errorf("unable to build "+ - "nodes: %w", err) - } + hasMore = len(rows) == batchSize - return nil - }, func() { - batch = []*models.Node{} - }) + return forEachNodeInBatch( + ctx, s.cfg.QueryCfg, db, + rows, func(id int64, + n *models.Node) error { + return processNode( + id, n, &batch, + ) + }, + ) + }, func() { + batch = nil + }, + ) if err != nil { log.Errorf("NodeUpdatesInHorizon batch "+ "error: %v", err) @@ -705,7 +778,8 @@ func (s *SQLStore) NodeUpdatesInHorizon(ctx context.Context, } } - // If the batch didn't yield anything, then we're done. + // If the batch didn't yield anything, then + // we're done. if len(batch) == 0 { break } @@ -985,17 +1059,12 @@ func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context, // early. // // NOTE: part of the Store interface. -func (s *SQLStore) ForEachNode(ctx context.Context, v lnwire.GossipVersion, +func (s *SQLStore) ForEachNode(ctx context.Context, cb func(node *models.Node) error, reset func()) error { return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { - return forEachNodePaginated( - ctx, s.cfg.QueryCfg, db, - v, func(_ context.Context, _ int64, - node *models.Node) error { - - return cb(node) - }, + return forEachPreferredNodePaginated( + ctx, s.cfg.QueryCfg, db, cb, ) }, reset) } @@ -1017,6 +1086,20 @@ func (s *SQLStore) ForEachNodeDirectedChannel(ctx context.Context, }, reset) } +// ForEachNodeDirectedChannelPreferHighest iterates through all channels of a +// node across gossip versions, preferring v2 channels over v1 when both are +// present for the same SCID. +func (s *SQLStore) ForEachNodeDirectedChannelPreferHighest( + ctx context.Context, nodePub route.Vertex, + cb func(channel *DirectedChannel) error, reset func()) error { + + return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + return forEachPreferredNodeDirectedChannel( + ctx, s.cfg.QueryCfg, db, nodePub, cb, + ) + }, reset) +} + // ForEachNodeCacheable iterates through all the stored vertices/nodes in the // graph, executing the passed callback with each node encountered. If the // callback returns an error, then the transaction is aborted and the iteration @@ -1092,6 +1175,26 @@ func extractMaxUpdateTime( } } +// extractMaxBlockHeight returns the maximum of the two policy block heights. +// This is used for pagination cursor tracking in v2 gossip queries. +func extractMaxBlockHeight( + row sqlc.GetChannelsByPolicyLastUpdateRangeRow) int64 { + + switch { + case row.Policy1BlockHeight.Valid && + row.Policy2BlockHeight.Valid: + + return max(row.Policy1BlockHeight.Int64, + row.Policy2BlockHeight.Int64) + case row.Policy1BlockHeight.Valid: + return row.Policy1BlockHeight.Int64 + case row.Policy2BlockHeight.Valid: + return row.Policy2BlockHeight.Int64 + default: + return 0 + } +} + // buildChannelFromRow constructs a ChannelEdge from a database row. // This includes building the nodes, channel info, and policies. func (s *SQLStore) buildChannelFromRow(ctx context.Context, db SQLQueries, @@ -1172,128 +1275,216 @@ func (s *SQLStore) updateChanCacheBatch(v lnwire.GossipVersion, // 6. Repeat with updated pagination cursor until no more results // // NOTE: This is part of the Store interface. +// +//nolint:funlen func (s *SQLStore) ChanUpdatesInHorizon(ctx context.Context, - startTime, endTime time.Time, + v lnwire.GossipVersion, r ChanUpdateRange, opts ...IteratorOption) iter.Seq2[ChannelEdge, error] { + if err := r.validateForVersion(v); err != nil { + return chanUpdateRangeErrIter(err) + } + // Apply options. cfg := defaultIteratorConfig() for _, opt := range opts { opt(cfg) } + batchSize := cfg.chanUpdateIterBatchSize + return func(yield func(ChannelEdge, error) bool) { var ( - edgesSeen = make(map[uint64]struct{}) - edgesToCache = make(map[uint64]ChannelEdge) - hits int - total int - lastUpdateTime sql.NullInt64 - lastID sql.NullInt64 - hasMore = true + edgesSeen = make(map[uint64]struct{}) + edgesToCache = make(map[uint64]ChannelEdge) + hits int + total int + lastUpdateTime sql.NullInt64 + lastBlockHeight sql.NullInt64 + lastID sql.NullInt64 + hasMore = true ) - // Each iteration, we'll read a batch amount of channel updates - // (consulting the cache along the way), yield them, then loop - // back to decide if we have any more updates to read out. + queryV1 := func(db SQLQueries) ( + []sqlc.GetChannelsByPolicyLastUpdateRangeRow, error) { + + return db.GetChannelsByPolicyLastUpdateRange( + ctx, + sqlc.GetChannelsByPolicyLastUpdateRangeParams{ + StartTime: sqldb.SQLInt64( + r.StartTime.UnwrapOr( + time.Time{}, + ).Unix(), + ), + EndTime: sqldb.SQLInt64( + r.EndTime.UnwrapOr( + time.Time{}, + ).Unix(), + ), + LastUpdateTime: lastUpdateTime, + LastID: lastID, + MaxResults: sql.NullInt32{ + Int32: int32(batchSize), + Valid: true, + }, + }, + ) + } + + type updateRow = sqlc.GetChannelsByPolicyLastUpdateRangeRow + queryV2 := func(db SQLQueries) ( + []updateRow, error) { + + startHeight := int64(r.StartHeight.UnwrapOr(0)) + endHeight := int64(r.EndHeight.UnwrapOr(0)) + + blockRows, err := db.GetChannelsByPolicyBlockRange( + ctx, + sqlc.GetChannelsByPolicyBlockRangeParams{ + Version: int16(v), + StartHeight: sqldb.SQLInt64( + startHeight, + ), + EndHeight: sqldb.SQLInt64( + endHeight, + ), + LastBlockHeight: lastBlockHeight, + LastID: lastID, + MaxResults: sql.NullInt32{ + Int32: int32(batchSize), + Valid: true, + }, + }, + ) + if err != nil { + return nil, err + } + + rows := make([]updateRow, 0, len(blockRows)) + for _, br := range blockRows { + rows = append(rows, updateRow(br)) + } + + return rows, nil + } + + // queryChannels fetches the next batch of channels whose + // policies fall within the horizon range. + queryChannels := func(db SQLQueries) ( + []sqlc.GetChannelsByPolicyLastUpdateRangeRow, error) { + + switch v { + case gossipV1: + return queryV1(db) + + case gossipV2: + return queryV2(db) + + default: + return nil, fmt.Errorf("unknown gossip "+ + "version: %v", v) + } + } + + // processRow handles a single channel row: updates + // pagination cursors, checks the seen set and cache, and + // builds the channel edge if needed. + processRow := func(ctx context.Context, db SQLQueries, + row sqlc.GetChannelsByPolicyLastUpdateRangeRow, + batch *[]ChannelEdge) error { + + switch v { + case gossipV1: + lastUpdateTime = sql.NullInt64{ + Int64: extractMaxUpdateTime(row), + Valid: true, + } + case gossipV2: + lastBlockHeight = sql.NullInt64{ + Int64: extractMaxBlockHeight(row), + Valid: true, + } + } + lastID = sql.NullInt64{ + Int64: row.GraphChannel.ID, + Valid: true, + } + + chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid) + + if _, ok := edgesSeen[chanIDInt]; ok { + return nil + } + + // Check cache (we already hold shared read + // lock). + channel, ok := s.chanCache.get(v, chanIDInt) + if ok { + hits++ + total++ + edgesSeen[chanIDInt] = struct{}{} + *batch = append(*batch, channel) + + return nil + } + + chanEdge, err := s.buildChannelFromRow( + ctx, db, row, + ) + if err != nil { + return err + } + + edgesSeen[chanIDInt] = struct{}{} + edgesToCache[chanIDInt] = chanEdge + *batch = append(*batch, chanEdge) + total++ + + return nil + } + + // Each iteration, we'll read a batch amount of channel + // updates (consulting the cache along the way), yield + // them, then loop back to decide if we have any more + // updates to read out. for hasMore { var batch []ChannelEdge - // Acquire read lock before starting transaction to - // ensure consistent lock ordering (cacheMu -> DB) and - // prevent deadlock with write operations. + // Acquire read lock before starting transaction + // to ensure consistent lock ordering + // (cacheMu -> DB) and prevent deadlock with + // write operations. s.cacheMu.RLock() err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { - //nolint:ll - params := sqlc.GetChannelsByPolicyLastUpdateRangeParams{ - Version: int16(lnwire.GossipVersion1), - StartTime: sqldb.SQLInt64( - startTime.Unix(), - ), - EndTime: sqldb.SQLInt64( - endTime.Unix(), - ), - LastUpdateTime: lastUpdateTime, - LastID: lastID, - MaxResults: sql.NullInt32{ - Int32: int32( - cfg.chanUpdateIterBatchSize, - ), - Valid: true, - }, - } - //nolint:ll - rows, err := db.GetChannelsByPolicyLastUpdateRange( - ctx, params, - ) + rows, err := queryChannels(db) if err != nil { return err } - //nolint:ll - hasMore = len(rows) == cfg.chanUpdateIterBatchSize + hasMore = len(rows) == batchSize - //nolint:ll for _, row := range rows { - lastUpdateTime = sql.NullInt64{ - Int64: extractMaxUpdateTime(row), - Valid: true, - } - lastID = sql.NullInt64{ - Int64: row.GraphChannel.ID, - Valid: true, - } - - // Skip if we've already - // processed this channel. - chanIDInt := byteOrder.Uint64( - row.GraphChannel.Scid, - ) - _, ok := edgesSeen[chanIDInt] - if ok { - continue - } - - // Check cache (we already hold - // shared read lock). - channel, ok := s.chanCache.get( - lnwire.GossipVersion1, - chanIDInt, - ) - if ok { - hits++ - total++ - edgesSeen[chanIDInt] = struct{}{} - batch = append(batch, channel) - - continue - } - - chanEdge, err := s.buildChannelFromRow( - ctx, db, row, + err := processRow( + ctx, db, row, &batch, ) if err != nil { return err } - - edgesSeen[chanIDInt] = struct{}{} - edgesToCache[chanIDInt] = chanEdge - - batch = append(batch, chanEdge) - - total++ } return nil }, func() { batch = nil - edgesSeen = make(map[uint64]struct{}) + edgesSeen = make( + map[uint64]struct{}, + ) edgesToCache = make( map[uint64]ChannelEdge, ) - }) + }, + ) // Release read lock after transaction completes. s.cacheMu.RUnlock() @@ -1313,11 +1504,10 @@ func (s *SQLStore) ChanUpdatesInHorizon(ctx context.Context, } } - // Update cache after successful batch yield, setting - // the cache lock only once for the entire batch. - s.updateChanCacheBatch( - lnwire.GossipVersion1, edgesToCache, - ) + // Update cache after successful batch yield, + // setting the cache lock only once for the + // entire batch. + s.updateChanCacheBatch(v, edgesToCache) edgesToCache = make(map[uint64]ChannelEdge) // If the batch didn't yield anything, then we're done. @@ -1327,12 +1517,12 @@ func (s *SQLStore) ChanUpdatesInHorizon(ctx context.Context, } if total > 0 { - log.Debugf("ChanUpdatesInHorizon hit percentage: "+ - "%.2f (%d/%d)", + log.Debugf("ChanUpdatesInHorizon(v%d) hit "+ + "percentage: %.2f (%d/%d)", v, float64(hits)*100/float64(total), hits, total) } else { - log.Debugf("ChanUpdatesInHorizon returned no edges "+ - "in horizon (%s, %s)", startTime, endTime) + log.Debugf("ChanUpdatesInHorizon(v%d) returned "+ + "no edges in horizon", v) } } } @@ -1645,16 +1835,14 @@ func (s *SQLStore) ForEachChannelCacheable(ctx context.Context, // // NOTE: part of the Store interface. func (s *SQLStore) ForEachChannel(ctx context.Context, - v lnwire.GossipVersion, cb func(*models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error, + cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error, reset func()) error { - if !isKnownGossipVersion(v) { - return fmt.Errorf("unsupported gossip version: %d", v) - } - return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { - return forEachChannelWithPolicies(ctx, db, s.cfg, v, cb) + return forEachPreferredChannelWithPolicies( + ctx, db, s.cfg, cb, + ) }, reset) } @@ -2296,6 +2484,162 @@ func (s *SQLStore) FetchChannelEdgesByOutpoint(ctx context.Context, return edge, policy1, policy2, nil } +// gossipVersionsDescending lists gossip versions from highest to lowest for +// prefer-highest iteration. +var gossipVersionsDescending = []lnwire.GossipVersion{gossipV2, gossipV1} + +// FetchChannelEdgesByIDPreferHighest tries each known gossip version from +// highest to lowest and returns the first result that has at least one policy. +// If no version has policies, the highest version found is returned. This +// prevents a v2 channel with no policies from hiding a v1 channel that has +// valid policy data. +// +// NOTE: part of the Store interface. +func (s *SQLStore) FetchChannelEdgesByIDPreferHighest(ctx context.Context, + chanID uint64) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) { + + var ( + bestInfo *models.ChannelEdgeInfo + bestP1 *models.ChannelEdgePolicy + bestP2 *models.ChannelEdgePolicy + ) + + for _, v := range gossipVersionsDescending { + info, p1, p2, err := s.FetchChannelEdgesByID(ctx, v, chanID) + if errors.Is(err, ErrEdgeNotFound) || + errors.Is(err, ErrZombieEdge) { + + continue + } + if err != nil { + return nil, nil, nil, err + } + + // If this version has policies, return immediately. + if p1 != nil || p2 != nil { + return info, p1, p2, nil + } + + // Otherwise, remember the highest version as a + // fallback in case no version has policies. + if bestInfo == nil { + bestInfo = info + bestP1 = p1 + bestP2 = p2 + } + } + + if bestInfo != nil { + return bestInfo, bestP1, bestP2, nil + } + + return nil, nil, nil, ErrEdgeNotFound +} + +// FetchChannelEdgesByOutpointPreferHighest tries each known gossip version +// from highest to lowest and returns the first result that has at least one +// policy. If no version has policies, the highest version found is returned. +// This prevents a v2 channel with no policies from hiding a v1 channel that +// has valid policy data. +// +// NOTE: part of the Store interface. +func (s *SQLStore) FetchChannelEdgesByOutpointPreferHighest( + ctx context.Context, op *wire.OutPoint) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) { + + var ( + bestInfo *models.ChannelEdgeInfo + bestP1 *models.ChannelEdgePolicy + bestP2 *models.ChannelEdgePolicy + ) + + for _, v := range gossipVersionsDescending { + info, p1, p2, err := s.FetchChannelEdgesByOutpoint( + ctx, v, op, + ) + if errors.Is(err, ErrEdgeNotFound) || + errors.Is(err, ErrZombieEdge) { + + continue + } + if err != nil { + return nil, nil, nil, err + } + + // If this version has policies, return immediately. + if p1 != nil || p2 != nil { + return info, p1, p2, nil + } + + // Otherwise, remember the highest version as a + // fallback in case no version has policies. + if bestInfo == nil { + bestInfo = info + bestP1 = p1 + bestP2 = p2 + } + } + + if bestInfo != nil { + return bestInfo, bestP1, bestP2, nil + } + + return nil, nil, nil, ErrEdgeNotFound +} + +// GetVersionsBySCID returns the gossip versions for which a channel with the +// given SCID exists in the database. +// +// NOTE: part of the Store interface. +func (s *SQLStore) GetVersionsBySCID(ctx context.Context, + chanID uint64) ([]lnwire.GossipVersion, error) { + + var versions []lnwire.GossipVersion + for _, v := range []lnwire.GossipVersion{gossipV1, gossipV2} { + _, _, _, err := s.FetchChannelEdgesByID(ctx, v, chanID) + if errors.Is(err, ErrEdgeNotFound) || + errors.Is(err, ErrZombieEdge) { + + continue + } + if err != nil { + return nil, err + } + + versions = append(versions, v) + } + + return versions, nil +} + +// GetVersionsByOutpoint returns the gossip versions for which a channel with +// the given funding outpoint exists in the database. +// +// NOTE: part of the Store interface. +func (s *SQLStore) GetVersionsByOutpoint(ctx context.Context, + op *wire.OutPoint) ([]lnwire.GossipVersion, error) { + + var versions []lnwire.GossipVersion + for _, v := range []lnwire.GossipVersion{gossipV1, gossipV2} { + _, _, _, err := s.FetchChannelEdgesByOutpoint(ctx, v, op) + if errors.Is(err, ErrEdgeNotFound) || + errors.Is(err, ErrZombieEdge) { + + continue + } + if err != nil { + return nil, err + } + + versions = append(versions, v) + } + + return versions, nil +} + // HasV1ChannelEdge returns true if the database knows of a channel edge // with the passed channel ID, and false otherwise. If an edge with that ID // is found within the graph, then two time stamps representing the last time @@ -2740,6 +3084,7 @@ func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context, // // NOTE: part of the Store interface. func (s *SQLStore) FilterKnownChanIDs(ctx context.Context, + v lnwire.GossipVersion, chansInfo []ChannelUpdateInfo) ([]uint64, []ChannelUpdateInfo, error) { var ( @@ -2769,7 +3114,7 @@ func (s *SQLStore) FilterKnownChanIDs(ctx context.Context, return nil } - err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo) + err := s.forEachChanInSCIDList(ctx, db, v, cb, chansInfo) if err != nil { return fmt.Errorf("unable to iterate through "+ "channels: %w", err) @@ -2788,7 +3133,7 @@ func (s *SQLStore) FilterKnownChanIDs(ctx context.Context, isZombie, err := db.IsZombieChannel( ctx, sqlc.IsZombieChannelParams{ Scid: channelIDToBytes(channelID), - Version: int16(lnwire.GossipVersion1), + Version: int16(v), }, ) if err != nil { @@ -2827,6 +3172,7 @@ func (s *SQLStore) FilterKnownChanIDs(ctx context.Context, // ChannelUpdateInfo slice. The callback function is called for each channel // that is found. func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries, + v lnwire.GossipVersion, cb func(ctx context.Context, channel sqlc.GraphChannel) error, chansInfo []ChannelUpdateInfo) error { @@ -2835,7 +3181,7 @@ func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries, return db.GetChannelsBySCIDs( ctx, sqlc.GetChannelsBySCIDsParams{ - Version: int16(lnwire.GossipVersion1), + Version: int16(v), Scids: scids, }, ) @@ -3490,6 +3836,169 @@ func (s *sqlNodeTraverser) ForEachNodeDirectedChannel( ) } +// optionalNodeID looks up a versioned node ID and returns -1 when that version +// of the node is absent. +func optionalNodeID(ctx context.Context, db SQLQueries, v lnwire.GossipVersion, + nodePub route.Vertex) (int64, error) { + + id, err := db.GetNodeIDByPubKey( + ctx, sqlc.GetNodeIDByPubKeyParams{ + Version: int16(v), + PubKey: nodePub[:], + }, + ) + if errors.Is(err, sql.ErrNoRows) { + return -1, nil + } + if err != nil { + return 0, fmt.Errorf("unable to fetch node(%x): %w", + nodePub[:], err) + } + + return id, nil +} + +// forEachPreferredNodeDirectedChannel iterates through all channels of a given +// node across gossip versions, preferring v2 channels over v1 when both +// versions advertise the same SCID. +func forEachPreferredNodeDirectedChannel(ctx context.Context, + cfg *sqldb.QueryConfig, db SQLQueries, nodePub route.Vertex, + cb func(channel *DirectedChannel) error) error { + + nodeIDV2, err := optionalNodeID(ctx, db, gossipV2, nodePub) + if err != nil { + return err + } + + nodeIDV1, err := optionalNodeID(ctx, db, gossipV1, nodePub) + if err != nil { + return err + } + + if nodeIDV1 == -1 && nodeIDV2 == -1 { + return nil + } + + featuresByVersion := map[lnwire.GossipVersion]*lnwire.FeatureVector{ + gossipV1: lnwire.EmptyFeatureVector(), + gossipV2: lnwire.EmptyFeatureVector(), + } + + if nodeIDV1 != -1 { + featuresByVersion[gossipV1], err = getNodeFeatures( + ctx, db, nodeIDV1, + ) + if err != nil { + return fmt.Errorf("unable to fetch v1 node features: %w", + err) + } + } + + if nodeIDV2 != -1 { + featuresByVersion[gossipV2], err = getNodeFeatures( + ctx, db, nodeIDV2, + ) + if err != nil { + return fmt.Errorf("unable to fetch v2 node features: %w", + err) + } + } + + toNodeCallback := func() route.Vertex { + return nodePub + } + + pageQueryFunc := func(ctx context.Context, cursor []byte, + limit int32) ([]sqlc.ListPreferredDirectedChannelsPaginatedRow, + error) { + + return db.ListPreferredDirectedChannelsPaginated( + ctx, sqlc.ListPreferredDirectedChannelsPaginatedParams{ + NodeIDV2: nodeIDV2, + NodeIDV1: nodeIDV1, + Scid: cursor, + PageLimit: limit, + }, + ) + } + + extractCursor := func( + row sqlc.ListPreferredDirectedChannelsPaginatedRow) []byte { + + return row.GraphChannel.Scid + } + + processItem := func(_ context.Context, + row sqlc.ListPreferredDirectedChannelsPaginatedRow) error { + + node1, node2, err := buildNodeVertices( + row.Node1Pubkey, row.Node2Pubkey, + ) + if err != nil { + return fmt.Errorf("unable to build node vertices: %w", + err) + } + + edge := buildCacheableChannelInfo( + row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64, + node1, node2, + ) + + dbPol1, dbPol2, err := extractChannelPolicies(row) + if err != nil { + return err + } + + p1, p2, err := buildCachedChanPolicies( + dbPol1, dbPol2, edge.ChannelID, node1, node2, + ) + if err != nil { + return err + } + + outPolicy, inPolicy := p1, p2 + if p1 != nil && node2 == nodePub { + outPolicy, inPolicy = p2, p1 + } else if p2 != nil && node1 != nodePub { + outPolicy, inPolicy = p2, p1 + } + + var cachedInPolicy *models.CachedEdgePolicy + if inPolicy != nil { + cachedInPolicy = inPolicy + cachedInPolicy.ToNodePubKey = toNodeCallback + cachedInPolicy.ToNodeFeatures = + featuresByVersion[lnwire.GossipVersion( + row.GraphChannel.Version, + )] + } + + directedChannel := &DirectedChannel{ + ChannelID: edge.ChannelID, + IsNode1: nodePub == edge.NodeKey1Bytes, + OtherNode: edge.NodeKey2Bytes, + Capacity: edge.Capacity, + OutPolicySet: outPolicy != nil, + InPolicy: cachedInPolicy, + } + if outPolicy != nil { + outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) { + directedChannel.InboundFee = fee + }) + } + + if nodePub == edge.NodeKey2Bytes { + directedChannel.OtherNode = edge.NodeKey1Bytes + } + + return cb(directedChannel) + } + + return sqldb.ExecutePaginatedQuery( + ctx, cfg, []byte{}, pageQueryFunc, extractCursor, processItem, + ) +} + // FetchNodeFeatures returns the features of the given node. If the node is // unknown, assume no additional features are supported. // @@ -5507,6 +6016,102 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, return policy1, policy2, nil + case sqlc.ListPreferredDirectedChannelsPaginatedRow: + if r.Policy1ID.Valid { + policy1 = &sqlc.GraphChannelPolicy{ + ID: r.Policy1ID.Int64, + Version: r.Policy1Version.Int16, + ChannelID: r.GraphChannel.ID, + NodeID: r.Policy1NodeID.Int64, + Timelock: r.Policy1Timelock.Int32, + FeePpm: r.Policy1FeePpm.Int64, + BaseFeeMsat: r.Policy1BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy1MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy1MaxHtlcMsat, + LastUpdate: r.Policy1LastUpdate, + InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat, + Disabled: r.Policy1Disabled, + MessageFlags: r.Policy1MessageFlags, + ChannelFlags: r.Policy1ChannelFlags, + Signature: r.Policy1Signature, + BlockHeight: r.Policy1BlockHeight, + DisableFlags: r.Policy1DisableFlags, + } + } + if r.Policy2ID.Valid { + policy2 = &sqlc.GraphChannelPolicy{ + ID: r.Policy2ID.Int64, + Version: r.Policy2Version.Int16, + ChannelID: r.GraphChannel.ID, + NodeID: r.Policy2NodeID.Int64, + Timelock: r.Policy2Timelock.Int32, + FeePpm: r.Policy2FeePpm.Int64, + BaseFeeMsat: r.Policy2BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy2MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy2MaxHtlcMsat, + LastUpdate: r.Policy2LastUpdate, + InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat, + Disabled: r.Policy2Disabled, + MessageFlags: r.Policy2MessageFlags, + ChannelFlags: r.Policy2ChannelFlags, + Signature: r.Policy2Signature, + BlockHeight: r.Policy2BlockHeight, + DisableFlags: r.Policy2DisableFlags, + } + } + + return policy1, policy2, nil + + case sqlc.ListPreferredChannelsWithPoliciesPaginatedRow: + if r.Policy1ID.Valid { + policy1 = &sqlc.GraphChannelPolicy{ + ID: r.Policy1ID.Int64, + Version: r.Policy1Version.Int16, + ChannelID: r.GraphChannel.ID, + NodeID: r.Policy1NodeID.Int64, + Timelock: r.Policy1Timelock.Int32, + FeePpm: r.Policy1FeePpm.Int64, + BaseFeeMsat: r.Policy1BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy1MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy1MaxHtlcMsat, + LastUpdate: r.Policy1LastUpdate, + InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat, + Disabled: r.Policy1Disabled, + MessageFlags: r.Policy1MessageFlags, + ChannelFlags: r.Policy1ChannelFlags, + Signature: r.Policy1Signature, + BlockHeight: r.Policy1BlockHeight, + DisableFlags: r.Policy1DisableFlags, + } + } + if r.Policy2ID.Valid { + policy2 = &sqlc.GraphChannelPolicy{ + ID: r.Policy2ID.Int64, + Version: r.Policy2Version.Int16, + ChannelID: r.GraphChannel.ID, + NodeID: r.Policy2NodeID.Int64, + Timelock: r.Policy2Timelock.Int32, + FeePpm: r.Policy2FeePpm.Int64, + BaseFeeMsat: r.Policy2BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy2MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy2MaxHtlcMsat, + LastUpdate: r.Policy2LastUpdate, + InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat, + Disabled: r.Policy2Disabled, + MessageFlags: r.Policy2MessageFlags, + ChannelFlags: r.Policy2ChannelFlags, + Signature: r.Policy2Signature, + BlockHeight: r.Policy2BlockHeight, + DisableFlags: r.Policy2DisableFlags, + } + } + + return policy1, policy2, nil + case sqlc.ListChannelsWithPoliciesPaginatedRow: if r.Policy1ID.Valid { policy1 = &sqlc.GraphChannelPolicy{ @@ -6076,32 +6681,32 @@ func batchLoadChannelPolicyExtrasHelper(ctx context.Context, ) } -// forEachNodePaginated executes a paginated query to process each node in the -// graph. It uses the provided SQLQueries interface to fetch nodes in batches -// and applies the provided processNode function to each node. -func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig, - db SQLQueries, protocol lnwire.GossipVersion, - processNode func(context.Context, int64, - *models.Node) error) error { +// forEachPreferredNodePaginated executes a paginated query that yields one +// preferred node per pubkey across all gossip versions. +func forEachPreferredNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig, + db SQLQueries, processNode func(*models.Node) error) error { - pageQueryFunc := func(ctx context.Context, lastID int64, - limit int32) ([]sqlc.GraphNode, error) { + pageQueryFunc := func(ctx context.Context, cursor []byte, + limit int32) ([]sqlc.ListPreferredNodesPaginatedRow, error) { - return db.ListNodesPaginated( - ctx, sqlc.ListNodesPaginatedParams{ - Version: int16(protocol), - ID: lastID, - Limit: limit, + return db.ListPreferredNodesPaginated( + ctx, sqlc.ListPreferredNodesPaginatedParams{ + PubKey: cursor, + Limit: limit, }, ) } - extractPageCursor := func(node sqlc.GraphNode) int64 { - return node.ID + extractPageCursor := func( + row sqlc.ListPreferredNodesPaginatedRow) []byte { + + return row.GraphNode.PubKey } - collectFunc := func(node sqlc.GraphNode) (int64, error) { - return node.ID, nil + collectFunc := func( + row sqlc.ListPreferredNodesPaginatedRow) (int64, error) { + + return row.GraphNode.ID, nil } batchQueryFunc := func(ctx context.Context, @@ -6110,63 +6715,63 @@ func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig, return batchLoadNodeData(ctx, cfg, db, nodeIDs) } - processItem := func(ctx context.Context, dbNode sqlc.GraphNode, + processItem := func(_ context.Context, + row sqlc.ListPreferredNodesPaginatedRow, batchData *batchNodeData) error { + dbNode := row.GraphNode node, err := buildNodeWithBatchData(dbNode, batchData) if err != nil { - return fmt.Errorf("unable to build "+ - "node(id=%d): %w", dbNode.ID, err) + return fmt.Errorf("unable to build node(id=%d): %w", + dbNode.ID, err) } - return processNode(ctx, dbNode.ID, node) + return processNode(node) } return sqldb.ExecuteCollectAndBatchWithSharedDataQuery( - ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor, + ctx, cfg, []byte{}, pageQueryFunc, extractPageCursor, collectFunc, batchQueryFunc, processItem, ) } -// forEachChannelWithPolicies executes a paginated query to process each channel -// with policies in the graph. -func forEachChannelWithPolicies(ctx context.Context, db SQLQueries, - cfg *SQLStoreConfig, v lnwire.GossipVersion, - processChannel func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy) error) error { +// forEachPreferredChannelWithPolicies executes a paginated query that yields +// one preferred channel per SCID across all gossip versions. +func forEachPreferredChannelWithPolicies(ctx context.Context, db SQLQueries, + cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error { type channelBatchIDs struct { channelID int64 policyIDs []int64 } - pageQueryFunc := func(ctx context.Context, lastID int64, - limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, + pageQueryFunc := func(ctx context.Context, cursor []byte, + limit int32) ([]sqlc.ListPreferredChannelsWithPoliciesPaginatedRow, error) { - return db.ListChannelsWithPoliciesPaginated( - ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{ - Version: int16(v), - ID: lastID, - Limit: limit, + return db.ListPreferredChannelsWithPoliciesPaginated( + ctx, sqlc.ListPreferredChannelsWithPoliciesPaginatedParams{ + Scid: cursor, + Limit: limit, }, ) } extractPageCursor := func( - row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 { + row sqlc.ListPreferredChannelsWithPoliciesPaginatedRow) []byte { - return row.GraphChannel.ID + return row.GraphChannel.Scid } - collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) ( + collectFunc := func( + row sqlc.ListPreferredChannelsWithPoliciesPaginatedRow) ( channelBatchIDs, error) { ids := channelBatchIDs{ channelID: row.GraphChannel.ID, } - // Extract policy IDs from the row. dbPol1, dbPol2, err := extractChannelPolicies(row) if err != nil { return ids, err @@ -6183,15 +6788,11 @@ func forEachChannelWithPolicies(ctx context.Context, db SQLQueries, } batchDataFunc := func(ctx context.Context, - allIDs []channelBatchIDs) (*batchChannelData, error) { - - // Separate channel IDs from policy IDs. - var ( - channelIDs = make([]int64, len(allIDs)) - policyIDs = make([]int64, 0, len(allIDs)*2) - ) + pageIDs []channelBatchIDs) (*batchChannelData, error) { - for i, ids := range allIDs { + channelIDs := make([]int64, len(pageIDs)) + policyIDs := make([]int64, 0, len(pageIDs)*2) + for i, ids := range pageIDs { channelIDs[i] = ids.channelID policyIDs = append(policyIDs, ids.policyIDs...) } @@ -6202,7 +6803,7 @@ func forEachChannelWithPolicies(ctx context.Context, db SQLQueries, } processItem := func(ctx context.Context, - row sqlc.ListChannelsWithPoliciesPaginatedRow, + row sqlc.ListPreferredChannelsWithPoliciesPaginatedRow, batchData *batchChannelData) error { node1, node2, err := buildNodeVertices( @@ -6237,7 +6838,7 @@ func forEachChannelWithPolicies(ctx context.Context, db SQLQueries, } return sqldb.ExecuteCollectAndBatchWithSharedDataQuery( - ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor, + ctx, cfg.QueryCfg, []byte{}, pageQueryFunc, extractPageCursor, collectFunc, batchDataFunc, processItem, ) } diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index dc0a0641ef..a1e0436008 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -1173,7 +1173,14 @@ func (q *Queries) GetChannelsByOutpoints(ctx context.Context, outpoints []string return items, nil } -const getChannelsByPolicyLastUpdateRange = `-- name: GetChannelsByPolicyLastUpdateRange :many +const getChannelsByPolicyBlockRange = `-- name: GetChannelsByPolicyBlockRange :many +WITH candidate_channels AS ( + SELECT DISTINCT channel_id + FROM graph_channel_policies + WHERE version = $1 + AND block_height >= $2 + AND block_height < $3 +) SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, c.signature, c.funding_pk_script, c.merkle_root_hash, n1.id, n1.version, n1.pub_key, n1.alias, n1.last_update, n1.color, n1.signature, n1.block_height, @@ -1217,7 +1224,8 @@ SELECT cp2.block_height AS policy2_block_height, cp2.disable_flags AS policy2_disable_flags -FROM graph_channels c +FROM candidate_channels cc + JOIN graph_channels c ON c.id = cc.channel_id JOIN graph_nodes n1 ON c.node_id_1 = n1.id JOIN graph_nodes n2 ON c.node_id_2 = n2.id LEFT JOIN graph_channel_policies cp1 @@ -1226,9 +1234,246 @@ FROM graph_channels c ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version WHERE c.version = $1 AND ( - (cp1.last_update >= $2 AND cp1.last_update < $3) + (cp1.block_height >= $2 AND cp1.block_height < $3) + OR + (cp2.block_height >= $2 AND cp2.block_height < $3) + ) + -- Pagination using compound cursor (max_block_height, id). + -- We use COALESCE with -1 as sentinel since heights are always positive. + AND ( + (CASE + WHEN COALESCE(cp1.block_height, 0) >= COALESCE(cp2.block_height, 0) + THEN COALESCE(cp1.block_height, 0) + ELSE COALESCE(cp2.block_height, 0) + END > COALESCE($4, -1)) + OR + (CASE + WHEN COALESCE(cp1.block_height, 0) >= COALESCE(cp2.block_height, 0) + THEN COALESCE(cp1.block_height, 0) + ELSE COALESCE(cp2.block_height, 0) + END = COALESCE($4, -1) + AND c.id > COALESCE($5, -1)) + ) +ORDER BY + CASE + WHEN COALESCE(cp1.block_height, 0) >= COALESCE(cp2.block_height, 0) + THEN COALESCE(cp1.block_height, 0) + ELSE COALESCE(cp2.block_height, 0) + END ASC, + c.id ASC +LIMIT COALESCE($6, 999999999) +` + +type GetChannelsByPolicyBlockRangeParams struct { + Version int16 + StartHeight sql.NullInt64 + EndHeight sql.NullInt64 + LastBlockHeight sql.NullInt64 + LastID sql.NullInt64 + MaxResults interface{} +} + +type GetChannelsByPolicyBlockRangeRow struct { + GraphChannel GraphChannel + GraphNode GraphNode + GraphNode_2 GraphNode + Policy1ID sql.NullInt64 + Policy1NodeID sql.NullInt64 + Policy1Version sql.NullInt16 + Policy1Timelock sql.NullInt32 + Policy1FeePpm sql.NullInt64 + Policy1BaseFeeMsat sql.NullInt64 + Policy1MinHtlcMsat sql.NullInt64 + Policy1MaxHtlcMsat sql.NullInt64 + Policy1LastUpdate sql.NullInt64 + Policy1Disabled sql.NullBool + Policy1InboundBaseFeeMsat sql.NullInt64 + Policy1InboundFeeRateMilliMsat sql.NullInt64 + Policy1MessageFlags sql.NullInt16 + Policy1ChannelFlags sql.NullInt16 + Policy1Signature []byte + Policy1BlockHeight sql.NullInt64 + Policy1DisableFlags sql.NullInt16 + Policy2ID sql.NullInt64 + Policy2NodeID sql.NullInt64 + Policy2Version sql.NullInt16 + Policy2Timelock sql.NullInt32 + Policy2FeePpm sql.NullInt64 + Policy2BaseFeeMsat sql.NullInt64 + Policy2MinHtlcMsat sql.NullInt64 + Policy2MaxHtlcMsat sql.NullInt64 + Policy2LastUpdate sql.NullInt64 + Policy2Disabled sql.NullBool + Policy2InboundBaseFeeMsat sql.NullInt64 + Policy2InboundFeeRateMilliMsat sql.NullInt64 + Policy2MessageFlags sql.NullInt16 + Policy2ChannelFlags sql.NullInt16 + Policy2Signature []byte + Policy2BlockHeight sql.NullInt64 + Policy2DisableFlags sql.NullInt16 +} + +func (q *Queries) GetChannelsByPolicyBlockRange(ctx context.Context, arg GetChannelsByPolicyBlockRangeParams) ([]GetChannelsByPolicyBlockRangeRow, error) { + rows, err := q.db.QueryContext(ctx, getChannelsByPolicyBlockRange, + arg.Version, + arg.StartHeight, + arg.EndHeight, + arg.LastBlockHeight, + arg.LastID, + arg.MaxResults, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetChannelsByPolicyBlockRangeRow + for rows.Next() { + var i GetChannelsByPolicyBlockRangeRow + if err := rows.Scan( + &i.GraphChannel.ID, + &i.GraphChannel.Version, + &i.GraphChannel.Scid, + &i.GraphChannel.NodeID1, + &i.GraphChannel.NodeID2, + &i.GraphChannel.Outpoint, + &i.GraphChannel.Capacity, + &i.GraphChannel.BitcoinKey1, + &i.GraphChannel.BitcoinKey2, + &i.GraphChannel.Node1Signature, + &i.GraphChannel.Node2Signature, + &i.GraphChannel.Bitcoin1Signature, + &i.GraphChannel.Bitcoin2Signature, + &i.GraphChannel.Signature, + &i.GraphChannel.FundingPkScript, + &i.GraphChannel.MerkleRootHash, + &i.GraphNode.ID, + &i.GraphNode.Version, + &i.GraphNode.PubKey, + &i.GraphNode.Alias, + &i.GraphNode.LastUpdate, + &i.GraphNode.Color, + &i.GraphNode.Signature, + &i.GraphNode.BlockHeight, + &i.GraphNode_2.ID, + &i.GraphNode_2.Version, + &i.GraphNode_2.PubKey, + &i.GraphNode_2.Alias, + &i.GraphNode_2.LastUpdate, + &i.GraphNode_2.Color, + &i.GraphNode_2.Signature, + &i.GraphNode_2.BlockHeight, + &i.Policy1ID, + &i.Policy1NodeID, + &i.Policy1Version, + &i.Policy1Timelock, + &i.Policy1FeePpm, + &i.Policy1BaseFeeMsat, + &i.Policy1MinHtlcMsat, + &i.Policy1MaxHtlcMsat, + &i.Policy1LastUpdate, + &i.Policy1Disabled, + &i.Policy1InboundBaseFeeMsat, + &i.Policy1InboundFeeRateMilliMsat, + &i.Policy1MessageFlags, + &i.Policy1ChannelFlags, + &i.Policy1Signature, + &i.Policy1BlockHeight, + &i.Policy1DisableFlags, + &i.Policy2ID, + &i.Policy2NodeID, + &i.Policy2Version, + &i.Policy2Timelock, + &i.Policy2FeePpm, + &i.Policy2BaseFeeMsat, + &i.Policy2MinHtlcMsat, + &i.Policy2MaxHtlcMsat, + &i.Policy2LastUpdate, + &i.Policy2Disabled, + &i.Policy2InboundBaseFeeMsat, + &i.Policy2InboundFeeRateMilliMsat, + &i.Policy2MessageFlags, + &i.Policy2ChannelFlags, + &i.Policy2Signature, + &i.Policy2BlockHeight, + &i.Policy2DisableFlags, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getChannelsByPolicyLastUpdateRange = `-- name: GetChannelsByPolicyLastUpdateRange :many +WITH candidate_channels AS ( + SELECT DISTINCT channel_id + FROM graph_channel_policies + WHERE version = 1 + AND last_update >= $1 + AND last_update < $2 +) +SELECT + c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, c.signature, c.funding_pk_script, c.merkle_root_hash, + n1.id, n1.version, n1.pub_key, n1.alias, n1.last_update, n1.color, n1.signature, n1.block_height, + n2.id, n2.version, n2.pub_key, n2.alias, n2.last_update, n2.color, n2.signature, n2.block_height, + + -- Policy 1 (node_id_1) + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.message_flags AS policy1_message_flags, + cp1.channel_flags AS policy1_channel_flags, + cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, + + -- Policy 2 (node_id_2) + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.message_flags AS policy2_message_flags, + cp2.channel_flags AS policy2_channel_flags, + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags + +FROM candidate_channels cc + JOIN graph_channels c ON c.id = cc.channel_id + JOIN graph_nodes n1 ON c.node_id_1 = n1.id + JOIN graph_nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN graph_channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN graph_channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE c.version = 1 + AND ( + (cp1.last_update >= $1 AND cp1.last_update < $2) OR - (cp2.last_update >= $2 AND cp2.last_update < $3) + (cp2.last_update >= $1 AND cp2.last_update < $2) ) -- Pagination using compound cursor (max_update_time, id). -- We use COALESCE with -1 as sentinel since timestamps are always positive. @@ -1237,14 +1482,14 @@ WHERE c.version = $1 WHEN COALESCE(cp1.last_update, 0) >= COALESCE(cp2.last_update, 0) THEN COALESCE(cp1.last_update, 0) ELSE COALESCE(cp2.last_update, 0) - END > COALESCE($4, -1)) + END > COALESCE($3, -1)) OR (CASE WHEN COALESCE(cp1.last_update, 0) >= COALESCE(cp2.last_update, 0) THEN COALESCE(cp1.last_update, 0) ELSE COALESCE(cp2.last_update, 0) - END = COALESCE($4, -1) - AND c.id > COALESCE($5, -1)) + END = COALESCE($3, -1) + AND c.id > COALESCE($4, -1)) ) ORDER BY CASE @@ -1253,11 +1498,10 @@ ORDER BY ELSE COALESCE(cp2.last_update, 0) END ASC, c.id ASC -LIMIT COALESCE($6, 999999999) +LIMIT COALESCE($5, 999999999) ` type GetChannelsByPolicyLastUpdateRangeParams struct { - Version int16 StartTime sql.NullInt64 EndTime sql.NullInt64 LastUpdateTime sql.NullInt64 @@ -1307,7 +1551,6 @@ type GetChannelsByPolicyLastUpdateRangeRow struct { func (q *Queries) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error) { rows, err := q.db.QueryContext(ctx, getChannelsByPolicyLastUpdateRange, - arg.Version, arg.StartTime, arg.EndTime, arg.LastUpdateTime, @@ -2070,24 +2313,62 @@ func (q *Queries) GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyPa return id, err } -const getNodesByIDs = `-- name: GetNodesByIDs :many +const getNodesByBlockHeightRange = `-- name: GetNodesByBlockHeightRange :many SELECT id, version, pub_key, alias, last_update, color, signature, block_height FROM graph_nodes -WHERE id IN (/*SLICE:ids*/?) +WHERE graph_nodes.version = $1 + AND block_height >= $2 + AND block_height <= $3 + -- Pagination: We use (block_height, pub_key) as a compound cursor. + -- This ensures stable ordering and allows us to resume from where we left off. + -- We use COALESCE with -1 as sentinel since heights are always positive. + AND ( + -- Include rows with block_height greater than cursor (or all rows if cursor is -1). + block_height > COALESCE($4, -1) + OR + -- For rows with same block_height, use pub_key as tiebreaker. + (block_height = COALESCE($4, -1) + AND pub_key > $5) + ) + -- Optional filter for public nodes only. + AND ( + -- If only_public is false or not provided, include all nodes. + COALESCE($6, FALSE) IS FALSE + OR + -- For V2 protocol, a node is public if it has at least one public channel. + -- A public channel has signature set (channel announcement received). + EXISTS ( + SELECT 1 + FROM graph_channels c + WHERE c.version = 2 + AND COALESCE(length(c.signature), 0) > 0 + AND (c.node_id_1 = graph_nodes.id OR c.node_id_2 = graph_nodes.id) + ) + ) +ORDER BY block_height ASC, pub_key ASC +LIMIT COALESCE($7, 999999999) ` -func (q *Queries) GetNodesByIDs(ctx context.Context, ids []int64) ([]GraphNode, error) { - query := getNodesByIDs - var queryParams []interface{} - if len(ids) > 0 { - for _, v := range ids { - queryParams = append(queryParams, v) - } - query = strings.Replace(query, "/*SLICE:ids*/?", makeQueryParams(len(queryParams), len(ids)), 1) - } else { - query = strings.Replace(query, "/*SLICE:ids*/?", "NULL", 1) - } - rows, err := q.db.QueryContext(ctx, query, queryParams...) +type GetNodesByBlockHeightRangeParams struct { + Version int16 + StartHeight sql.NullInt64 + EndHeight sql.NullInt64 + LastBlockHeight sql.NullInt64 + LastPubKey []byte + OnlyPublic interface{} + MaxResults interface{} +} + +func (q *Queries) GetNodesByBlockHeightRange(ctx context.Context, arg GetNodesByBlockHeightRangeParams) ([]GraphNode, error) { + rows, err := q.db.QueryContext(ctx, getNodesByBlockHeightRange, + arg.Version, + arg.StartHeight, + arg.EndHeight, + arg.LastBlockHeight, + arg.LastPubKey, + arg.OnlyPublic, + arg.MaxResults, + ) if err != nil { return nil, err } @@ -2118,27 +2399,76 @@ func (q *Queries) GetNodesByIDs(ctx context.Context, ids []int64) ([]GraphNode, return items, nil } -const getNodesByLastUpdateRange = `-- name: GetNodesByLastUpdateRange :many +const getNodesByIDs = `-- name: GetNodesByIDs :many SELECT id, version, pub_key, alias, last_update, color, signature, block_height FROM graph_nodes -WHERE last_update >= $1 - AND last_update <= $2 - -- Pagination: We use (last_update, pub_key) as a compound cursor. - -- This ensures stable ordering and allows us to resume from where we left off. - -- We use COALESCE with -1 as sentinel since timestamps are always positive. - AND ( - -- Include rows with last_update greater than cursor (or all rows if cursor is -1) - last_update > COALESCE($3, -1) - OR - -- For rows with same last_update, use pub_key as tiebreaker - (last_update = COALESCE($3, -1) - AND pub_key > $4) - ) - -- Optional filter for public nodes only - AND ( - -- If only_public is false or not provided, include all nodes - COALESCE($5, FALSE) IS FALSE - OR +WHERE id IN (/*SLICE:ids*/?) +` + +func (q *Queries) GetNodesByIDs(ctx context.Context, ids []int64) ([]GraphNode, error) { + query := getNodesByIDs + var queryParams []interface{} + if len(ids) > 0 { + for _, v := range ids { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:ids*/?", makeQueryParams(len(queryParams), len(ids)), 1) + } else { + query = strings.Replace(query, "/*SLICE:ids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GraphNode + for rows.Next() { + var i GraphNode + if err := rows.Scan( + &i.ID, + &i.Version, + &i.PubKey, + &i.Alias, + &i.LastUpdate, + &i.Color, + &i.Signature, + &i.BlockHeight, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getNodesByLastUpdateRange = `-- name: GetNodesByLastUpdateRange :many +SELECT id, version, pub_key, alias, last_update, color, signature, block_height +FROM graph_nodes +WHERE graph_nodes.version = 1 + AND last_update >= $1 + AND last_update <= $2 + -- Pagination: We use (last_update, pub_key) as a compound cursor. + -- This ensures stable ordering and allows us to resume from where we left off. + -- We use COALESCE with -1 as sentinel since timestamps are always positive. + AND ( + -- Include rows with last_update greater than cursor (or all rows if cursor is -1) + last_update > COALESCE($3, -1) + OR + -- For rows with same last_update, use pub_key as tiebreaker + (last_update = COALESCE($3, -1) + AND pub_key > $4) + ) + -- Optional filter for public nodes only + AND ( + -- If only_public is false or not provided, include all nodes + COALESCE($5, FALSE) IS FALSE + OR -- For V1 protocol, a node is public if it has at least one public channel. -- A public channel has bitcoin_1_signature set (channel announcement received). EXISTS ( @@ -3837,6 +4167,561 @@ func (q *Queries) ListNodesPaginated(ctx context.Context, arg ListNodesPaginated return items, nil } +const listPreferredChannelsWithPoliciesPaginated = `-- name: ListPreferredChannelsWithPoliciesPaginated :many +WITH page_scids(cursor_scid) AS ( + SELECT page.scid + FROM ( + SELECT c2.scid AS scid + FROM graph_channels c2 + WHERE c2.version = 2 + AND c2.scid > $1 + UNION + SELECT c1.scid AS scid + FROM graph_channels c1 + WHERE c1.version = 1 + AND c1.scid > $1 + ) AS page + ORDER BY page.scid + LIMIT $2 +), +selected_channels AS ( + SELECT + s.cursor_scid AS selected_scid, + COALESCE( + ( + SELECT c.id + FROM graph_channels c + WHERE c.scid = s.cursor_scid + AND c.version = 2 + AND EXISTS ( + SELECT 1 + FROM graph_channel_policies p + WHERE p.channel_id = c.id + AND p.version = 2 + ) + ), + ( + SELECT c.id + FROM graph_channels c + WHERE c.scid = s.cursor_scid + AND c.version = 1 + AND EXISTS ( + SELECT 1 + FROM graph_channel_policies p + WHERE p.channel_id = c.id + AND p.version = 1 + ) + ), + ( + SELECT c.id + FROM graph_channels c + WHERE c.scid = s.cursor_scid + AND c.version = 2 + ), + ( + SELECT c.id + FROM graph_channels c + WHERE c.scid = s.cursor_scid + AND c.version = 1 + ) + ) AS channel_db_id + FROM page_scids s +) +SELECT + c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, c.signature, c.funding_pk_script, c.merkle_root_hash, + + -- Join node pubkeys + n1.pub_key AS node1_pubkey, + n2.pub_key AS node2_pubkey, + + -- Node 1 policy + cp1.id AS policy_1_id, + cp1.node_id AS policy_1_node_id, + cp1.version AS policy_1_version, + cp1.timelock AS policy_1_timelock, + cp1.fee_ppm AS policy_1_fee_ppm, + cp1.base_fee_msat AS policy_1_base_fee_msat, + cp1.min_htlc_msat AS policy_1_min_htlc_msat, + cp1.max_htlc_msat AS policy_1_max_htlc_msat, + cp1.last_update AS policy_1_last_update, + cp1.disabled AS policy_1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.message_flags AS policy1_message_flags, + cp1.channel_flags AS policy1_channel_flags, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, + cp1.signature AS policy_1_signature, + + -- Node 2 policy + cp2.id AS policy_2_id, + cp2.node_id AS policy_2_node_id, + cp2.version AS policy_2_version, + cp2.timelock AS policy_2_timelock, + cp2.fee_ppm AS policy_2_fee_ppm, + cp2.base_fee_msat AS policy_2_base_fee_msat, + cp2.min_htlc_msat AS policy_2_min_htlc_msat, + cp2.max_htlc_msat AS policy_2_max_htlc_msat, + cp2.last_update AS policy_2_last_update, + cp2.disabled AS policy_2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.message_flags AS policy2_message_flags, + cp2.channel_flags AS policy2_channel_flags, + cp2.signature AS policy_2_signature, + cp2.block_height AS policy_2_block_height, + cp2.disable_flags AS policy_2_disable_flags + +FROM selected_channels s +JOIN graph_channels c ON c.id = s.channel_db_id +JOIN graph_nodes n1 ON c.node_id_1 = n1.id +JOIN graph_nodes n2 ON c.node_id_2 = n2.id +LEFT JOIN graph_channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version +LEFT JOIN graph_channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +ORDER BY s.selected_scid +` + +type ListPreferredChannelsWithPoliciesPaginatedParams struct { + Scid []byte + Limit int32 +} + +type ListPreferredChannelsWithPoliciesPaginatedRow struct { + GraphChannel GraphChannel + Node1Pubkey []byte + Node2Pubkey []byte + Policy1ID sql.NullInt64 + Policy1NodeID sql.NullInt64 + Policy1Version sql.NullInt16 + Policy1Timelock sql.NullInt32 + Policy1FeePpm sql.NullInt64 + Policy1BaseFeeMsat sql.NullInt64 + Policy1MinHtlcMsat sql.NullInt64 + Policy1MaxHtlcMsat sql.NullInt64 + Policy1LastUpdate sql.NullInt64 + Policy1Disabled sql.NullBool + Policy1InboundBaseFeeMsat sql.NullInt64 + Policy1InboundFeeRateMilliMsat sql.NullInt64 + Policy1MessageFlags sql.NullInt16 + Policy1ChannelFlags sql.NullInt16 + Policy1BlockHeight sql.NullInt64 + Policy1DisableFlags sql.NullInt16 + Policy1Signature []byte + Policy2ID sql.NullInt64 + Policy2NodeID sql.NullInt64 + Policy2Version sql.NullInt16 + Policy2Timelock sql.NullInt32 + Policy2FeePpm sql.NullInt64 + Policy2BaseFeeMsat sql.NullInt64 + Policy2MinHtlcMsat sql.NullInt64 + Policy2MaxHtlcMsat sql.NullInt64 + Policy2LastUpdate sql.NullInt64 + Policy2Disabled sql.NullBool + Policy2InboundBaseFeeMsat sql.NullInt64 + Policy2InboundFeeRateMilliMsat sql.NullInt64 + Policy2MessageFlags sql.NullInt16 + Policy2ChannelFlags sql.NullInt16 + Policy2Signature []byte + Policy2BlockHeight sql.NullInt64 + Policy2DisableFlags sql.NullInt16 +} + +func (q *Queries) ListPreferredChannelsWithPoliciesPaginated(ctx context.Context, arg ListPreferredChannelsWithPoliciesPaginatedParams) ([]ListPreferredChannelsWithPoliciesPaginatedRow, error) { + rows, err := q.db.QueryContext(ctx, listPreferredChannelsWithPoliciesPaginated, arg.Scid, arg.Limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListPreferredChannelsWithPoliciesPaginatedRow + for rows.Next() { + var i ListPreferredChannelsWithPoliciesPaginatedRow + if err := rows.Scan( + &i.GraphChannel.ID, + &i.GraphChannel.Version, + &i.GraphChannel.Scid, + &i.GraphChannel.NodeID1, + &i.GraphChannel.NodeID2, + &i.GraphChannel.Outpoint, + &i.GraphChannel.Capacity, + &i.GraphChannel.BitcoinKey1, + &i.GraphChannel.BitcoinKey2, + &i.GraphChannel.Node1Signature, + &i.GraphChannel.Node2Signature, + &i.GraphChannel.Bitcoin1Signature, + &i.GraphChannel.Bitcoin2Signature, + &i.GraphChannel.Signature, + &i.GraphChannel.FundingPkScript, + &i.GraphChannel.MerkleRootHash, + &i.Node1Pubkey, + &i.Node2Pubkey, + &i.Policy1ID, + &i.Policy1NodeID, + &i.Policy1Version, + &i.Policy1Timelock, + &i.Policy1FeePpm, + &i.Policy1BaseFeeMsat, + &i.Policy1MinHtlcMsat, + &i.Policy1MaxHtlcMsat, + &i.Policy1LastUpdate, + &i.Policy1Disabled, + &i.Policy1InboundBaseFeeMsat, + &i.Policy1InboundFeeRateMilliMsat, + &i.Policy1MessageFlags, + &i.Policy1ChannelFlags, + &i.Policy1BlockHeight, + &i.Policy1DisableFlags, + &i.Policy1Signature, + &i.Policy2ID, + &i.Policy2NodeID, + &i.Policy2Version, + &i.Policy2Timelock, + &i.Policy2FeePpm, + &i.Policy2BaseFeeMsat, + &i.Policy2MinHtlcMsat, + &i.Policy2MaxHtlcMsat, + &i.Policy2LastUpdate, + &i.Policy2Disabled, + &i.Policy2InboundBaseFeeMsat, + &i.Policy2InboundFeeRateMilliMsat, + &i.Policy2MessageFlags, + &i.Policy2ChannelFlags, + &i.Policy2Signature, + &i.Policy2BlockHeight, + &i.Policy2DisableFlags, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listPreferredDirectedChannelsPaginated = `-- name: ListPreferredDirectedChannelsPaginated :many +WITH page_scids(cursor_scid) AS ( + SELECT page.scid + FROM ( + SELECT c2.scid AS scid + FROM graph_channels c2 + WHERE c2.version = 2 + AND (c2.node_id_1 = $1 OR c2.node_id_2 = $1) + AND c2.scid > $2 + UNION + SELECT c1.scid AS scid + FROM graph_channels c1 + WHERE c1.version = 1 + AND (c1.node_id_1 = $3 OR c1.node_id_2 = $3) + AND c1.scid > $2 + ) AS page + ORDER BY page.scid + LIMIT $4 +), +selected_channels AS ( + SELECT + s.cursor_scid AS selected_scid, + COALESCE( + ( + SELECT c.id + FROM graph_channels c + WHERE c.scid = s.cursor_scid + AND c.version = 2 + AND (c.node_id_1 = $1 OR c.node_id_2 = $1) + ), + ( + SELECT c.id + FROM graph_channels c + WHERE c.scid = s.cursor_scid + AND c.version = 1 + AND (c.node_id_1 = $3 OR c.node_id_2 = $3) + ) + ) AS channel_db_id + FROM page_scids s +) +SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, c.signature, c.funding_pk_script, c.merkle_root_hash, + n1.pub_key AS node1_pubkey, + n2.pub_key AS node2_pubkey, + + -- Policy 1 + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.message_flags AS policy1_message_flags, + cp1.channel_flags AS policy1_channel_flags, + cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, + + -- Policy 2 + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.message_flags AS policy2_message_flags, + cp2.channel_flags AS policy2_channel_flags, + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags + +FROM selected_channels s +JOIN graph_channels c ON c.id = s.channel_db_id +JOIN graph_nodes n1 ON c.node_id_1 = n1.id +JOIN graph_nodes n2 ON c.node_id_2 = n2.id +LEFT JOIN graph_channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version +LEFT JOIN graph_channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +ORDER BY s.selected_scid +` + +type ListPreferredDirectedChannelsPaginatedParams struct { + NodeIDV2 int64 + Scid []byte + NodeIDV1 int64 + PageLimit int32 +} + +type ListPreferredDirectedChannelsPaginatedRow struct { + GraphChannel GraphChannel + Node1Pubkey []byte + Node2Pubkey []byte + Policy1ID sql.NullInt64 + Policy1NodeID sql.NullInt64 + Policy1Version sql.NullInt16 + Policy1Timelock sql.NullInt32 + Policy1FeePpm sql.NullInt64 + Policy1BaseFeeMsat sql.NullInt64 + Policy1MinHtlcMsat sql.NullInt64 + Policy1MaxHtlcMsat sql.NullInt64 + Policy1LastUpdate sql.NullInt64 + Policy1Disabled sql.NullBool + Policy1InboundBaseFeeMsat sql.NullInt64 + Policy1InboundFeeRateMilliMsat sql.NullInt64 + Policy1MessageFlags sql.NullInt16 + Policy1ChannelFlags sql.NullInt16 + Policy1Signature []byte + Policy1BlockHeight sql.NullInt64 + Policy1DisableFlags sql.NullInt16 + Policy2ID sql.NullInt64 + Policy2NodeID sql.NullInt64 + Policy2Version sql.NullInt16 + Policy2Timelock sql.NullInt32 + Policy2FeePpm sql.NullInt64 + Policy2BaseFeeMsat sql.NullInt64 + Policy2MinHtlcMsat sql.NullInt64 + Policy2MaxHtlcMsat sql.NullInt64 + Policy2LastUpdate sql.NullInt64 + Policy2Disabled sql.NullBool + Policy2InboundBaseFeeMsat sql.NullInt64 + Policy2InboundFeeRateMilliMsat sql.NullInt64 + Policy2MessageFlags sql.NullInt16 + Policy2ChannelFlags sql.NullInt16 + Policy2Signature []byte + Policy2BlockHeight sql.NullInt64 + Policy2DisableFlags sql.NullInt16 +} + +func (q *Queries) ListPreferredDirectedChannelsPaginated(ctx context.Context, arg ListPreferredDirectedChannelsPaginatedParams) ([]ListPreferredDirectedChannelsPaginatedRow, error) { + rows, err := q.db.QueryContext(ctx, listPreferredDirectedChannelsPaginated, + arg.NodeIDV2, + arg.Scid, + arg.NodeIDV1, + arg.PageLimit, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListPreferredDirectedChannelsPaginatedRow + for rows.Next() { + var i ListPreferredDirectedChannelsPaginatedRow + if err := rows.Scan( + &i.GraphChannel.ID, + &i.GraphChannel.Version, + &i.GraphChannel.Scid, + &i.GraphChannel.NodeID1, + &i.GraphChannel.NodeID2, + &i.GraphChannel.Outpoint, + &i.GraphChannel.Capacity, + &i.GraphChannel.BitcoinKey1, + &i.GraphChannel.BitcoinKey2, + &i.GraphChannel.Node1Signature, + &i.GraphChannel.Node2Signature, + &i.GraphChannel.Bitcoin1Signature, + &i.GraphChannel.Bitcoin2Signature, + &i.GraphChannel.Signature, + &i.GraphChannel.FundingPkScript, + &i.GraphChannel.MerkleRootHash, + &i.Node1Pubkey, + &i.Node2Pubkey, + &i.Policy1ID, + &i.Policy1NodeID, + &i.Policy1Version, + &i.Policy1Timelock, + &i.Policy1FeePpm, + &i.Policy1BaseFeeMsat, + &i.Policy1MinHtlcMsat, + &i.Policy1MaxHtlcMsat, + &i.Policy1LastUpdate, + &i.Policy1Disabled, + &i.Policy1InboundBaseFeeMsat, + &i.Policy1InboundFeeRateMilliMsat, + &i.Policy1MessageFlags, + &i.Policy1ChannelFlags, + &i.Policy1Signature, + &i.Policy1BlockHeight, + &i.Policy1DisableFlags, + &i.Policy2ID, + &i.Policy2NodeID, + &i.Policy2Version, + &i.Policy2Timelock, + &i.Policy2FeePpm, + &i.Policy2BaseFeeMsat, + &i.Policy2MinHtlcMsat, + &i.Policy2MaxHtlcMsat, + &i.Policy2LastUpdate, + &i.Policy2Disabled, + &i.Policy2InboundBaseFeeMsat, + &i.Policy2InboundFeeRateMilliMsat, + &i.Policy2MessageFlags, + &i.Policy2ChannelFlags, + &i.Policy2Signature, + &i.Policy2BlockHeight, + &i.Policy2DisableFlags, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listPreferredNodesPaginated = `-- name: ListPreferredNodesPaginated :many +WITH page_pub_keys(cursor_pub_key) AS ( + SELECT page.pub_key + FROM ( + SELECT n2.pub_key AS pub_key + FROM graph_nodes n2 + WHERE n2.version = 2 + AND n2.pub_key > $1 + UNION + SELECT n1.pub_key AS pub_key + FROM graph_nodes n1 + WHERE n1.version = 1 + AND n1.pub_key > $1 + ) AS page + ORDER BY page.pub_key + LIMIT $2 +), +selected_nodes AS ( + SELECT + p.cursor_pub_key AS selected_pub_key, + COALESCE( + ( + SELECT n.id + FROM graph_nodes n + WHERE n.pub_key = p.cursor_pub_key + AND n.version = 2 + AND COALESCE(length(n.signature), 0) > 0 + ), + ( + SELECT n.id + FROM graph_nodes n + WHERE n.pub_key = p.cursor_pub_key + AND n.version = 1 + AND COALESCE(length(n.signature), 0) > 0 + ), + ( + SELECT n.id + FROM graph_nodes n + WHERE n.pub_key = p.cursor_pub_key + AND n.version = 2 + ), + ( + SELECT n.id + FROM graph_nodes n + WHERE n.pub_key = p.cursor_pub_key + AND n.version = 1 + ) + ) AS node_id + FROM page_pub_keys p +) +SELECT n.id, n.version, n.pub_key, n.alias, n.last_update, n.color, n.signature, n.block_height +FROM selected_nodes s +JOIN graph_nodes n ON n.id = s.node_id +ORDER BY s.selected_pub_key +` + +type ListPreferredNodesPaginatedParams struct { + PubKey []byte + Limit int32 +} + +type ListPreferredNodesPaginatedRow struct { + GraphNode GraphNode +} + +func (q *Queries) ListPreferredNodesPaginated(ctx context.Context, arg ListPreferredNodesPaginatedParams) ([]ListPreferredNodesPaginatedRow, error) { + rows, err := q.db.QueryContext(ctx, listPreferredNodesPaginated, arg.PubKey, arg.Limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListPreferredNodesPaginatedRow + for rows.Next() { + var i ListPreferredNodesPaginatedRow + if err := rows.Scan( + &i.GraphNode.ID, + &i.GraphNode.Version, + &i.GraphNode.PubKey, + &i.GraphNode.Alias, + &i.GraphNode.LastUpdate, + &i.GraphNode.Color, + &i.GraphNode.Signature, + &i.GraphNode.BlockHeight, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const nodeExists = `-- name: NodeExists :one SELECT EXISTS ( SELECT 1 diff --git a/sqldb/sqlc/migrations/000009_graph_v2.down.sql b/sqldb/sqlc/migrations/000009_graph_v2.down.sql index b566e4016d..16c00511aa 100644 --- a/sqldb/sqlc/migrations/000009_graph_v2.down.sql +++ b/sqldb/sqlc/migrations/000009_graph_v2.down.sql @@ -1,3 +1,7 @@ +DROP INDEX IF EXISTS graph_channel_policy_version_last_update_channel_id_idx; +DROP INDEX IF EXISTS graph_channel_policy_version_block_height_channel_id_idx; +DROP INDEX IF EXISTS graph_nodes_version_block_height_pub_key_idx; + -- Remove the block_height column from graph_nodes ALTER TABLE graph_nodes DROP COLUMN block_height; @@ -14,4 +18,4 @@ ALTER TABLE graph_channels DROP COLUMN merkle_root_hash; ALTER TABLE graph_channel_policies DROP COLUMN block_height; -- Remove the disable_flags column from graph_channel_policies -ALTER TABLE graph_channel_policies DROP COLUMN disable_flags; \ No newline at end of file +ALTER TABLE graph_channel_policies DROP COLUMN disable_flags; diff --git a/sqldb/sqlc/migrations/000009_graph_v2.up.sql b/sqldb/sqlc/migrations/000009_graph_v2.up.sql index 19a3e99f91..80345178de 100644 --- a/sqldb/sqlc/migrations/000009_graph_v2.up.sql +++ b/sqldb/sqlc/migrations/000009_graph_v2.up.sql @@ -2,6 +2,11 @@ -- It may be zero if we have not received a node announcement yet. ALTER TABLE graph_nodes ADD COLUMN block_height BIGINT; +-- Support v2 node horizon queries by indexing the versioned block-height and +-- pubkey cursor fields together. +CREATE INDEX IF NOT EXISTS graph_nodes_version_block_height_pub_key_idx + ON graph_nodes(version, block_height, pub_key); + -- The signature of the channel announcement. If this is null, then the channel -- belongs to the source node and the channel has not been announced yet. ALTER TABLE graph_channels ADD COLUMN signature BLOB; @@ -20,4 +25,14 @@ ALTER TABLE graph_channel_policies ADD COLUMN block_height BIGINT; -- A bitfield describing the disabled flags for a v2 channel update. ALTER TABLE graph_channel_policies ADD COLUMN disable_flags SMALLINT - CHECK (disable_flags >= 0 AND disable_flags <= 255); \ No newline at end of file + CHECK (disable_flags >= 0 AND disable_flags <= 255); + +-- Support v2 channel horizon queries by indexing the versioned block-height +-- cursor fields on channel policies. +CREATE INDEX IF NOT EXISTS graph_channel_policy_version_block_height_channel_id_idx + ON graph_channel_policies(version, block_height, channel_id); + +-- Support version-aware channel horizon queries by indexing the v1 timestamp +-- cursor fields on channel policies. +CREATE INDEX IF NOT EXISTS graph_channel_policy_version_last_update_channel_id_idx + ON graph_channel_policies(version, last_update, channel_id); diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index d2481be8c2..adb3b4eb76 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -103,6 +103,7 @@ type Querier interface { GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]GetChannelPolicyExtraTypesBatchRow, error) GetChannelsByIDs(ctx context.Context, ids []int64) ([]GetChannelsByIDsRow, error) GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]GetChannelsByOutpointsRow, error) + GetChannelsByPolicyBlockRange(ctx context.Context, arg GetChannelsByPolicyBlockRangeParams) ([]GetChannelsByPolicyBlockRangeRow, error) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error) GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsBySCIDRangeParams) ([]GetChannelsBySCIDRangeRow, error) GetChannelsBySCIDWithPolicies(ctx context.Context, arg GetChannelsBySCIDWithPoliciesParams) ([]GetChannelsBySCIDWithPoliciesRow, error) @@ -128,6 +129,7 @@ type Querier interface { GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]GraphNodeFeature, error) GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeaturesByPubKeyParams) ([]int32, error) GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyParams) (int64, error) + GetNodesByBlockHeightRange(ctx context.Context, arg GetNodesByBlockHeightRangeParams) ([]GraphNode, error) GetNodesByIDs(ctx context.Context, ids []int64) ([]GraphNode, error) GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]GraphNode, error) GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]GraphPruneLog, error) @@ -211,6 +213,9 @@ type Querier interface { ListChannelsWithPoliciesPaginated(ctx context.Context, arg ListChannelsWithPoliciesPaginatedParams) ([]ListChannelsWithPoliciesPaginatedRow, error) ListNodeIDsAndPubKeys(ctx context.Context, arg ListNodeIDsAndPubKeysParams) ([]ListNodeIDsAndPubKeysRow, error) ListNodesPaginated(ctx context.Context, arg ListNodesPaginatedParams) ([]GraphNode, error) + ListPreferredChannelsWithPoliciesPaginated(ctx context.Context, arg ListPreferredChannelsWithPoliciesPaginatedParams) ([]ListPreferredChannelsWithPoliciesPaginatedRow, error) + ListPreferredDirectedChannelsPaginated(ctx context.Context, arg ListPreferredDirectedChannelsPaginatedParams) ([]ListPreferredDirectedChannelsPaginatedRow, error) + ListPreferredNodesPaginated(ctx context.Context, arg ListPreferredNodesPaginatedParams) ([]ListPreferredNodesPaginatedRow, error) NextInvoiceSettleIndex(ctx context.Context) (int64, error) NodeExists(ctx context.Context, arg NodeExistsParams) (bool, error) OnAMPSubInvoiceCanceled(ctx context.Context, arg OnAMPSubInvoiceCanceledParams) error diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index 78c1ebe7c9..a7479297b3 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -80,6 +80,61 @@ WHERE version = $1 AND id > $2 ORDER BY id LIMIT $3; +-- name: ListPreferredNodesPaginated :many +WITH page_pub_keys(cursor_pub_key) AS ( + SELECT page.pub_key + FROM ( + SELECT n2.pub_key AS pub_key + FROM graph_nodes n2 + WHERE n2.version = 2 + AND n2.pub_key > $1 + UNION + SELECT n1.pub_key AS pub_key + FROM graph_nodes n1 + WHERE n1.version = 1 + AND n1.pub_key > $1 + ) AS page + ORDER BY page.pub_key + LIMIT $2 +), +selected_nodes AS ( + SELECT + p.cursor_pub_key AS selected_pub_key, + COALESCE( + ( + SELECT n.id + FROM graph_nodes n + WHERE n.pub_key = p.cursor_pub_key + AND n.version = 2 + AND COALESCE(length(n.signature), 0) > 0 + ), + ( + SELECT n.id + FROM graph_nodes n + WHERE n.pub_key = p.cursor_pub_key + AND n.version = 1 + AND COALESCE(length(n.signature), 0) > 0 + ), + ( + SELECT n.id + FROM graph_nodes n + WHERE n.pub_key = p.cursor_pub_key + AND n.version = 2 + ), + ( + SELECT n.id + FROM graph_nodes n + WHERE n.pub_key = p.cursor_pub_key + AND n.version = 1 + ) + ) AS node_id + FROM page_pub_keys p +) +SELECT sqlc.embed(n) +FROM selected_nodes s +JOIN graph_nodes n ON n.id = s.node_id +ORDER BY s.selected_pub_key; + -- name: ListNodeIDsAndPubKeys :many SELECT id, pub_key FROM graph_nodes @@ -227,7 +282,8 @@ ORDER BY node_id, type, position; -- name: GetNodesByLastUpdateRange :many SELECT * FROM graph_nodes -WHERE last_update >= @start_time +WHERE graph_nodes.version = 1 + AND last_update >= @start_time AND last_update <= @end_time -- Pagination: We use (last_update, pub_key) as a compound cursor. -- This ensures stable ordering and allows us to resume from where we left off. @@ -258,6 +314,41 @@ WHERE last_update >= @start_time ORDER BY last_update ASC, pub_key ASC LIMIT COALESCE(sqlc.narg('max_results'), 999999999); +-- name: GetNodesByBlockHeightRange :many +SELECT * +FROM graph_nodes +WHERE graph_nodes.version = @version + AND block_height >= @start_height + AND block_height <= @end_height + -- Pagination: We use (block_height, pub_key) as a compound cursor. + -- This ensures stable ordering and allows us to resume from where we left off. + -- We use COALESCE with -1 as sentinel since heights are always positive. + AND ( + -- Include rows with block_height greater than cursor (or all rows if cursor is -1). + block_height > COALESCE(sqlc.narg('last_block_height'), -1) + OR + -- For rows with same block_height, use pub_key as tiebreaker. + (block_height = COALESCE(sqlc.narg('last_block_height'), -1) + AND pub_key > sqlc.narg('last_pub_key')) + ) + -- Optional filter for public nodes only. + AND ( + -- If only_public is false or not provided, include all nodes. + COALESCE(sqlc.narg('only_public'), FALSE) IS FALSE + OR + -- For V2 protocol, a node is public if it has at least one public channel. + -- A public channel has signature set (channel announcement received). + EXISTS ( + SELECT 1 + FROM graph_channels c + WHERE c.version = 2 + AND COALESCE(length(c.signature), 0) > 0 + AND (c.node_id_1 = graph_nodes.id OR c.node_id_2 = graph_nodes.id) + ) + ) +ORDER BY block_height ASC, pub_key ASC +LIMIT COALESCE(sqlc.narg('max_results'), 999999999); + -- name: DeleteNodeAddresses :exec DELETE FROM graph_node_addresses WHERE node_id = $1; @@ -498,6 +589,13 @@ FROM graph_channels c WHERE c.id IN (sqlc.slice('ids')/*SLICE:ids*/); -- name: GetChannelsByPolicyLastUpdateRange :many +WITH candidate_channels AS ( + SELECT DISTINCT channel_id + FROM graph_channel_policies + WHERE version = 1 + AND last_update >= @start_time + AND last_update < @end_time +) SELECT sqlc.embed(c), sqlc.embed(n1), @@ -541,14 +639,15 @@ SELECT cp2.block_height AS policy2_block_height, cp2.disable_flags AS policy2_disable_flags -FROM graph_channels c +FROM candidate_channels cc + JOIN graph_channels c ON c.id = cc.channel_id JOIN graph_nodes n1 ON c.node_id_1 = n1.id JOIN graph_nodes n2 ON c.node_id_2 = n2.id LEFT JOIN graph_channel_policies cp1 ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version LEFT JOIN graph_channel_policies cp2 ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version -WHERE c.version = @version +WHERE c.version = 1 AND ( (cp1.last_update >= @start_time AND cp1.last_update < @end_time) OR @@ -579,6 +678,96 @@ ORDER BY c.id ASC LIMIT COALESCE(sqlc.narg('max_results'), 999999999); +-- name: GetChannelsByPolicyBlockRange :many +WITH candidate_channels AS ( + SELECT DISTINCT channel_id + FROM graph_channel_policies + WHERE version = @version + AND block_height >= @start_height + AND block_height < @end_height +) +SELECT + sqlc.embed(c), + sqlc.embed(n1), + sqlc.embed(n2), + + -- Policy 1 (node_id_1) + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.message_flags AS policy1_message_flags, + cp1.channel_flags AS policy1_channel_flags, + cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, + + -- Policy 2 (node_id_2) + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.message_flags AS policy2_message_flags, + cp2.channel_flags AS policy2_channel_flags, + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags + +FROM candidate_channels cc + JOIN graph_channels c ON c.id = cc.channel_id + JOIN graph_nodes n1 ON c.node_id_1 = n1.id + JOIN graph_nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN graph_channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN graph_channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE c.version = @version + AND ( + (cp1.block_height >= @start_height AND cp1.block_height < @end_height) + OR + (cp2.block_height >= @start_height AND cp2.block_height < @end_height) + ) + -- Pagination using compound cursor (max_block_height, id). + -- We use COALESCE with -1 as sentinel since heights are always positive. + AND ( + (CASE + WHEN COALESCE(cp1.block_height, 0) >= COALESCE(cp2.block_height, 0) + THEN COALESCE(cp1.block_height, 0) + ELSE COALESCE(cp2.block_height, 0) + END > COALESCE(sqlc.narg('last_block_height'), -1)) + OR + (CASE + WHEN COALESCE(cp1.block_height, 0) >= COALESCE(cp2.block_height, 0) + THEN COALESCE(cp1.block_height, 0) + ELSE COALESCE(cp2.block_height, 0) + END = COALESCE(sqlc.narg('last_block_height'), -1) + AND c.id > COALESCE(sqlc.narg('last_id'), -1)) + ) +ORDER BY + CASE + WHEN COALESCE(cp1.block_height, 0) >= COALESCE(cp2.block_height, 0) + THEN COALESCE(cp1.block_height, 0) + ELSE COALESCE(cp2.block_height, 0) + END ASC, + c.id ASC +LIMIT COALESCE(sqlc.narg('max_results'), 999999999); + -- name: GetChannelByOutpointWithPolicies :one SELECT sqlc.embed(c), @@ -752,6 +941,98 @@ FROM graph_channels c WHERE c.version = $1 AND (c.node_id_1 = $2 OR c.node_id_2 = $2); +-- name: ListPreferredDirectedChannelsPaginated :many +WITH page_scids(cursor_scid) AS ( + SELECT page.scid + FROM ( + SELECT c2.scid AS scid + FROM graph_channels c2 + WHERE c2.version = 2 + AND (c2.node_id_1 = @node_id_v2 OR c2.node_id_2 = @node_id_v2) + AND c2.scid > @scid + UNION + SELECT c1.scid AS scid + FROM graph_channels c1 + WHERE c1.version = 1 + AND (c1.node_id_1 = @node_id_v1 OR c1.node_id_2 = @node_id_v1) + AND c1.scid > @scid + ) AS page + ORDER BY page.scid + LIMIT @page_limit +), +selected_channels AS ( + SELECT + s.cursor_scid AS selected_scid, + COALESCE( + ( + SELECT c.id + FROM graph_channels c + WHERE c.scid = s.cursor_scid + AND c.version = 2 + AND (c.node_id_1 = @node_id_v2 OR c.node_id_2 = @node_id_v2) + ), + ( + SELECT c.id + FROM graph_channels c + WHERE c.scid = s.cursor_scid + AND c.version = 1 + AND (c.node_id_1 = @node_id_v1 OR c.node_id_2 = @node_id_v1) + ) + ) AS channel_db_id + FROM page_scids s +) +SELECT sqlc.embed(c), + n1.pub_key AS node1_pubkey, + n2.pub_key AS node2_pubkey, + + -- Policy 1 + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.message_flags AS policy1_message_flags, + cp1.channel_flags AS policy1_channel_flags, + cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, + + -- Policy 2 + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.message_flags AS policy2_message_flags, + cp2.channel_flags AS policy2_channel_flags, + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags + +FROM selected_channels s +JOIN graph_channels c ON c.id = s.channel_db_id +JOIN graph_nodes n1 ON c.node_id_1 = n1.id +JOIN graph_nodes n2 ON c.node_id_2 = n2.id +LEFT JOIN graph_channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version +LEFT JOIN graph_channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +ORDER BY s.selected_scid; + -- name: GetPublicV1ChannelsBySCID :many SELECT * FROM graph_channels @@ -841,6 +1122,121 @@ WHERE c.version = $1 AND c.id > $2 ORDER BY c.id LIMIT $3; +-- name: ListPreferredChannelsWithPoliciesPaginated :many +WITH page_scids(cursor_scid) AS ( + SELECT page.scid + FROM ( + SELECT c2.scid AS scid + FROM graph_channels c2 + WHERE c2.version = 2 + AND c2.scid > $1 + UNION + SELECT c1.scid AS scid + FROM graph_channels c1 + WHERE c1.version = 1 + AND c1.scid > $1 + ) AS page + ORDER BY page.scid + LIMIT $2 +), +selected_channels AS ( + SELECT + s.cursor_scid AS selected_scid, + COALESCE( + ( + SELECT c.id + FROM graph_channels c + WHERE c.scid = s.cursor_scid + AND c.version = 2 + AND EXISTS ( + SELECT 1 + FROM graph_channel_policies p + WHERE p.channel_id = c.id + AND p.version = 2 + ) + ), + ( + SELECT c.id + FROM graph_channels c + WHERE c.scid = s.cursor_scid + AND c.version = 1 + AND EXISTS ( + SELECT 1 + FROM graph_channel_policies p + WHERE p.channel_id = c.id + AND p.version = 1 + ) + ), + ( + SELECT c.id + FROM graph_channels c + WHERE c.scid = s.cursor_scid + AND c.version = 2 + ), + ( + SELECT c.id + FROM graph_channels c + WHERE c.scid = s.cursor_scid + AND c.version = 1 + ) + ) AS channel_db_id + FROM page_scids s +) +SELECT + sqlc.embed(c), + + -- Join node pubkeys + n1.pub_key AS node1_pubkey, + n2.pub_key AS node2_pubkey, + + -- Node 1 policy + cp1.id AS policy_1_id, + cp1.node_id AS policy_1_node_id, + cp1.version AS policy_1_version, + cp1.timelock AS policy_1_timelock, + cp1.fee_ppm AS policy_1_fee_ppm, + cp1.base_fee_msat AS policy_1_base_fee_msat, + cp1.min_htlc_msat AS policy_1_min_htlc_msat, + cp1.max_htlc_msat AS policy_1_max_htlc_msat, + cp1.last_update AS policy_1_last_update, + cp1.disabled AS policy_1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.message_flags AS policy1_message_flags, + cp1.channel_flags AS policy1_channel_flags, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, + cp1.signature AS policy_1_signature, + + -- Node 2 policy + cp2.id AS policy_2_id, + cp2.node_id AS policy_2_node_id, + cp2.version AS policy_2_version, + cp2.timelock AS policy_2_timelock, + cp2.fee_ppm AS policy_2_fee_ppm, + cp2.base_fee_msat AS policy_2_base_fee_msat, + cp2.min_htlc_msat AS policy_2_min_htlc_msat, + cp2.max_htlc_msat AS policy_2_max_htlc_msat, + cp2.last_update AS policy_2_last_update, + cp2.disabled AS policy_2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.message_flags AS policy2_message_flags, + cp2.channel_flags AS policy2_channel_flags, + cp2.signature AS policy_2_signature, + cp2.block_height AS policy_2_block_height, + cp2.disable_flags AS policy_2_disable_flags + +FROM selected_channels s +JOIN graph_channels c ON c.id = s.channel_db_id +JOIN graph_nodes n1 ON c.node_id_1 = n1.id +JOIN graph_nodes n2 ON c.node_id_2 = n2.id +LEFT JOIN graph_channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version +LEFT JOIN graph_channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +ORDER BY s.selected_scid; + -- name: ListChannelsWithPoliciesForCachePaginated :many SELECT c.id as id,