diff --git a/cl/blockbuilder/blockbuilder.go b/cl/blockbuilder/blockbuilder.go index fdce4bf49..b2110e6fc 100644 --- a/cl/blockbuilder/blockbuilder.go +++ b/cl/blockbuilder/blockbuilder.go @@ -121,7 +121,7 @@ func (bb *BlockBuilder) GetPayload(ctx context.Context) error { if bb.executionHead == nil { bb.logger.Info("executionHead is nil, it'll be set by RPC. CL is likely being restarted") err = util.RetryWithBackoff(ctx, maxAttempts, bb.logger, func() error { - innerErr := bb.setExecutionHeadFromRPC(ctx) + innerErr := bb.SetExecutionHeadFromRPC(ctx) if innerErr != nil { bb.logger.Warn( "Failed to set execution head from rpc, retrying...", @@ -493,7 +493,7 @@ func (bb *BlockBuilder) updateForkChoice(ctx context.Context, fcs engine.Forkcho }) } -func (bb *BlockBuilder) setExecutionHeadFromRPC(ctx context.Context) error { +func (bb *BlockBuilder) SetExecutionHeadFromRPC(ctx context.Context) error { header, err := bb.engineCl.HeaderByNumber(ctx, nil) // nil for the latest block if err != nil { return fmt.Errorf("failed to get the latest block header: %w", err) diff --git a/cl/cmd/singlenode/main.go b/cl/cmd/singlenode/main.go index 3d4e1b306..dc1386aa3 100644 --- a/cl/cmd/singlenode/main.go +++ b/cl/cmd/singlenode/main.go @@ -13,8 +13,11 @@ import ( "syscall" "time" + "github.com/primev/mev-commit/cl/blockbuilder" + "github.com/primev/mev-commit/cl/ethclient" "github.com/primev/mev-commit/cl/singlenode" - "github.com/primev/mev-commit/cl/singlenode/membernode" + "github.com/primev/mev-commit/cl/singlenode/follower" + "github.com/primev/mev-commit/cl/singlenode/payloadstore" "github.com/primev/mev-commit/x/util" "github.com/urfave/cli/v2" "github.com/urfave/cli/v2/altsrc" @@ -211,29 +214,12 @@ var ( Value: 5 * time.Millisecond, }) - // Member node specific flags - leaderAPIURLFlag = altsrc.NewStringFlag(&cli.StringFlag{ - Name: "leader-api-url", - Usage: "Leader node API URL for member nodes (e.g., 'http://leader:9090')", - EnvVars: []string{"MEMBER_LEADER_API_URL"}, - Category: categoryMember, - Action: func(_ *cli.Context, s string) error { - if s == "" { - return nil // Will be validated in member command - } - if _, err := url.Parse(s); err != nil { - return fmt.Errorf("invalid leader-api-url: %v", err) - } - return nil - }, - }) - - pollIntervalFlag = altsrc.NewDurationFlag(&cli.DurationFlag{ - Name: "poll-interval", - Usage: "Interval for polling leader node for new payloads (e.g., '1s')", - EnvVars: []string{"MEMBER_POLL_INTERVAL"}, - Value: 1 * time.Second, - Category: categoryMember, + // Follower node specific flags + syncBatchSizeFlag = altsrc.NewUint64Flag(&cli.Uint64Flag{ + Name: "sync-batch-size", + Usage: "Number of payloads per request to the EL during sync", + EnvVars: []string{"FOLLOWER_SYNC_BATCH_SIZE"}, + Value: 100, }) ) @@ -256,7 +242,7 @@ func main() { txPoolPollingIntervalFlag, } - memberFlags := []cli.Flag{ + followerFlags := []cli.Flag{ configFlag, instanceIDFlag, ethClientURLFlag, @@ -265,8 +251,8 @@ func main() { logLevelFlag, logTagsFlag, healthAddrPortFlag, - leaderAPIURLFlag, - pollIntervalFlag, + postgresDSNFlag, + syncBatchSizeFlag, } app := &cli.App{ @@ -290,10 +276,10 @@ func main() { }, }, { - Name: "member", + Name: "follower", Usage: "Start as member node (follows leader)", - Flags: memberFlags, - Before: altsrc.InitInputSourceWithContext(memberFlags, + Flags: followerFlags, + Before: altsrc.InitInputSourceWithContext(followerFlags, func(c *cli.Context) (altsrc.InputSourceContext, error) { configFile := c.String(configFlag.Name) if configFile != "" { @@ -302,7 +288,7 @@ func main() { return &altsrc.MapInputSource{}, nil }), Action: func(c *cli.Context) error { - return startMemberNode(c) + return startFollowerNode(c) }, }, // Keep the old "start" command for backward compatibility @@ -380,12 +366,7 @@ func startLeaderNode(c *cli.Context) error { return nil } -func startMemberNode(c *cli.Context) error { - leaderURL := c.String(leaderAPIURLFlag.Name) - if leaderURL == "" { - return fmt.Errorf("leader-api-url is required for member nodes") - } - +func startFollowerNode(c *cli.Context) error { logger, err := util.NewLogger( c.String(logLevelFlag.Name), c.String(logFmtFlag.Name), @@ -395,36 +376,61 @@ func startMemberNode(c *cli.Context) error { if err != nil { return fmt.Errorf("failed to create logger: %w", err) } - logger = logger.With("app", "snode", "role", "member") - - cfg := membernode.Config{ - InstanceID: c.String(instanceIDFlag.Name), - LeaderAPIURL: leaderURL, - EthClientURL: c.String(ethClientURLFlag.Name), - JWTSecret: c.String(jwtSecretFlag.Name), - HealthAddr: c.String(healthAddrPortFlag.Name), - PollInterval: c.Duration(pollIntervalFlag.Name), - } + logger = logger.With("app", "snode", "role", "follower") - logger.Info("Starting member node", "config", cfg) + logger.Info("Starting follower node") - // Create a root context that can be cancelled for graceful shutdown rootCtx, rootCancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer rootCancel() - memberNode, err := membernode.NewMemberNodeApp(rootCtx, cfg, logger) + postgresDSN := c.String(postgresDSNFlag.Name) + if postgresDSN == "" { + return fmt.Errorf("postgresDSN is required") + } + repo, err := payloadstore.NewPostgresFollower(rootCtx, postgresDSN, logger) if err != nil { - logger.Error("Failed to initialize MemberNodeApp", "error", err) - return err + return fmt.Errorf("failed to initialize payload repository: %w", err) } + syncBatchSize := c.Uint64(syncBatchSizeFlag.Name) + if syncBatchSize == 0 { + return fmt.Errorf("sync-batch-size is required") + } + ethClientURL := c.String(ethClientURLFlag.Name) + if ethClientURL == "" { + return fmt.Errorf("eth-client-url is required") + } + jwtSecret := c.String(jwtSecretFlag.Name) + if jwtSecret == "" { + return fmt.Errorf("jwt-secret is required") + } + jwtBytes, err := hex.DecodeString(jwtSecret) + if err != nil { + return fmt.Errorf("failed to decode JWT secret: %w", err) + } + engineCL, err := ethclient.NewAuthClient(rootCtx, ethClientURL, jwtBytes) + if err != nil { + return fmt.Errorf("failed to create Ethereum engine client: %w", err) + } + bb := blockbuilder.NewMemberBlockBuilder(engineCL, logger.With("component", "BlockBuilder")) - memberNode.Start() - - <-rootCtx.Done() - - logger.Info("Shutdown signal received, stopping member node...") - memberNode.Stop() + followerNode, err := follower.NewFollower( + logger, + repo, + syncBatchSize, + bb, + ) + if err != nil { + logger.Error("Failed to initialize Follower", "error", err) + return err + } - logger.Info("Member node shutdown completed.") - return nil + done := followerNode.Start(rootCtx) + select { + case <-done: + logger.Info("Follower node shutdown completed.") + return nil + case <-rootCtx.Done(): + logger.Info("Follower node shutdown completed.") + return nil + } } diff --git a/cl/go.mod b/cl/go.mod index 1f8c77bed..3325d7dd5 100644 --- a/cl/go.mod +++ b/cl/go.mod @@ -11,6 +11,7 @@ require ( github.com/lib/pq v1.10.9 github.com/redis/go-redis/v9 v9.6.1 github.com/urfave/cli/v2 v2.27.5 + golang.org/x/sync v0.11.0 golang.org/x/tools v0.29.0 ) @@ -34,7 +35,9 @@ require ( github.com/go-ole/go-ole v1.3.0 // indirect github.com/gofrs/flock v0.8.1 // indirect github.com/holiman/uint256 v1.3.2 // indirect + github.com/klauspost/compress v1.17.9 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/mmcloughlin/addchain v0.4.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/nxadm/tail v1.4.11 // indirect @@ -44,6 +47,7 @@ require ( github.com/pion/transport/v2 v2.2.5 // indirect github.com/pion/transport/v3 v3.0.2 // indirect github.com/rivo/uniseg v0.4.7 // indirect + github.com/rs/cors v1.8.3 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible // indirect github.com/stretchr/objx v0.5.2 // indirect @@ -52,8 +56,8 @@ require ( github.com/tklauser/numcpus v0.7.0 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect + golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect golang.org/x/mod v0.22.0 // indirect - golang.org/x/sync v0.11.0 // indirect rsc.io/tmplfunc v0.0.3 // indirect ) @@ -64,19 +68,15 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect github.com/ethereum/go-ethereum v1.15.11 - github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb // indirect github.com/gorilla/websocket v1.5.3 // indirect github.com/heyvito/go-leader v0.1.0 - github.com/klauspost/compress v1.17.9 // indirect - github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/primev/mev-commit/x v0.0.0-20241029202458-b151c03fa49e + github.com/primev/mev-commit/x v0.0.1 github.com/prometheus/client_golang v1.19.1 github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.55.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect - github.com/rs/cors v1.8.3 // indirect github.com/stretchr/testify v1.10.0 github.com/vmihailenco/msgpack/v5 v5.4.1 golang.org/x/crypto v0.35.0 // indirect @@ -84,3 +84,5 @@ require ( google.golang.org/protobuf v1.34.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/primev/mev-commit/x => ../x diff --git a/cl/go.sum b/cl/go.sum index d406a3e90..f3f5a9ab3 100644 --- a/cl/go.sum +++ b/cl/go.sum @@ -155,8 +155,6 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/primev/mev-commit/x v0.0.0-20241029202458-b151c03fa49e h1:4UCC8PDllbWQNGuliOw3rxcmH79CUn6+xZhVGx3qZnQ= -github.com/primev/mev-commit/x v0.0.0-20241029202458-b151c03fa49e/go.mod h1:EEJMyKLa7ZLqTghwo1PlvHJPDW4MVl5WzBnhE6f5jDM= github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE= github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho= github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= diff --git a/cl/singlenode/follower/export.go b/cl/singlenode/follower/export.go new file mode 100644 index 000000000..e32be7cdb --- /dev/null +++ b/cl/singlenode/follower/export.go @@ -0,0 +1,19 @@ +package follower + +import ( + "context" + + "github.com/primev/mev-commit/cl/types" +) + +func (f *Follower) PayloadCh() <-chan types.PayloadInfo { + return f.payloadCh +} + +func (f *Follower) SyncFromSharedDB(ctx context.Context) error { + return f.syncFromSharedDB(ctx) +} + +func (f *Follower) GetExecutionHead() *types.ExecutionHead { + return f.getExecutionHead() +} diff --git a/cl/singlenode/follower/follower.go b/cl/singlenode/follower/follower.go new file mode 100644 index 000000000..12bd5bbb7 --- /dev/null +++ b/cl/singlenode/follower/follower.go @@ -0,0 +1,190 @@ +package follower + +import ( + "context" + "errors" + "fmt" + "log/slog" + "sync" + "time" + + "github.com/primev/mev-commit/cl/types" + "golang.org/x/sync/errgroup" +) + +type Follower struct { + logger *slog.Logger + sharedDB payloadDB + syncBatchSize uint64 + payloadCh chan types.PayloadInfo + bbMutex sync.RWMutex + bb blockBuilder +} + +const ( + defaultBackoff = 200 * time.Millisecond +) + +type payloadDB interface { + GetPayloadsSince(ctx context.Context, sinceHeight uint64, limit int) ([]types.PayloadInfo, error) + GetLatestHeight(ctx context.Context) (uint64, error) +} + +type blockBuilder interface { + GetExecutionHead() *types.ExecutionHead + FinalizeBlock(ctx context.Context, payloadIDStr, executionPayloadStr, msgID string) error + SetExecutionHeadFromRPC(ctx context.Context) error +} + +func NewFollower( + logger *slog.Logger, + sharedDB payloadDB, + syncBatchSize uint64, + bb blockBuilder, +) (*Follower, error) { + if sharedDB == nil { + return nil, errors.New("payload repository not provided") + } + if syncBatchSize == 0 { + return nil, errors.New("sync batch size must be greater than 0") + } + return &Follower{ + logger: logger, + sharedDB: sharedDB, + syncBatchSize: syncBatchSize, + payloadCh: make(chan types.PayloadInfo), + bb: bb, + }, nil +} + +func (f *Follower) Start(ctx context.Context) <-chan struct{} { + + done := make(chan struct{}) + eg, egCtx := errgroup.WithContext(ctx) + eg.Go(func() error { + return f.handlePayloads(egCtx) + }) + + eg.Go(func() error { + f.logger.Info("Starting sync from shared DB") + return f.syncFromSharedDB(egCtx) + }) + + go func() { + defer close(done) + if err := eg.Wait(); err != nil { + f.logger.Error("follower failed, exiting", "error", err) + } + }() + + return done +} + +func (f *Follower) syncFromSharedDB(ctx context.Context) error { + if f.getExecutionHead() == nil { + if err := f.setExecutionHeadFromRPC(ctx); err != nil { + f.logger.Error("failed to set execution head from rpc", "error", err) + return err + } + f.logger.Debug("set execution head from rpc") + } + + lastSignalledBlock := f.getExecutionHead().BlockHeight + f.logger.Debug("lastSignalledBlock set from execution head", "block height", lastSignalledBlock) + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + cctx, cancel := context.WithTimeout(ctx, 5*time.Second) + targetBlock, err := f.sharedDB.GetLatestHeight(cctx) + cancel() + if err != nil { + f.sleepRespectingContext(ctx, defaultBackoff) + continue + } + + if lastSignalledBlock > targetBlock { + return fmt.Errorf("internal invariant has been broken. Follower EL is ahead of signer") + } + + blocksRemaining := targetBlock - lastSignalledBlock + + if blocksRemaining == 0 { + f.sleepRespectingContext(ctx, time.Millisecond) // New payload will likely be available within milliseconds + continue + } + f.logger.Debug("non-zero blocksRemaining", "blocksRemaining", blocksRemaining) + + limit := min(f.syncBatchSize, blocksRemaining) + + cctx, cancel = context.WithTimeout(ctx, 5*time.Second) + payloads, err := f.sharedDB.GetPayloadsSince(cctx, lastSignalledBlock+1, int(limit)) + cancel() + if err != nil { + f.logger.Error("failed to get payloads since", "error", err) + f.sleepRespectingContext(ctx, defaultBackoff) + continue + } + if len(payloads) == 0 { + f.logger.Error("no payloads returned from valid query") + f.sleepRespectingContext(ctx, defaultBackoff) + continue + } + f.logger.Debug("number of payloads returned", "number of payloads", len(payloads)) + + for i := range payloads { + p := payloads[i] + select { + case <-ctx.Done(): + return ctx.Err() + case f.payloadCh <- p: + lastSignalledBlock = p.BlockHeight + } + } + } +} + +func (f *Follower) sleepRespectingContext(ctx context.Context, duration time.Duration) { + select { + case <-ctx.Done(): + return + case <-time.After(duration): + } +} + +func (f *Follower) handlePayloads(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case p := <-f.payloadCh: + if err := f.finalizeBlock(ctx, p.PayloadID, p.ExecutionPayload, ""); err != nil { + f.logger.Error("Failed to process payload", "height", p.BlockHeight, "error", err) + continue + } + f.logger.Info("Successfully processed payload", "height", p.BlockHeight) + } + } +} + +func (f *Follower) getExecutionHead() *types.ExecutionHead { + f.bbMutex.RLock() + defer f.bbMutex.RUnlock() + return f.bb.GetExecutionHead() +} + +func (f *Follower) setExecutionHeadFromRPC(ctx context.Context) error { + f.bbMutex.Lock() + defer f.bbMutex.Unlock() + return f.bb.SetExecutionHeadFromRPC(ctx) +} + +func (f *Follower) finalizeBlock(ctx context.Context, payloadIDStr, executionPayloadStr, msgID string) error { + f.bbMutex.Lock() + defer f.bbMutex.Unlock() + return f.bb.FinalizeBlock(ctx, payloadIDStr, executionPayloadStr, msgID) +} diff --git a/cl/singlenode/follower/follower_test.go b/cl/singlenode/follower/follower_test.go new file mode 100644 index 000000000..046132365 --- /dev/null +++ b/cl/singlenode/follower/follower_test.go @@ -0,0 +1,498 @@ +package follower_test + +import ( + "context" + "fmt" + "io" + "strconv" + "testing" + "time" + + "github.com/primev/mev-commit/cl/singlenode/follower" + "github.com/primev/mev-commit/cl/types" + "github.com/primev/mev-commit/x/util" +) + +type mockPayloadDB struct { + GetPayloadsSinceFunc func(ctx context.Context, sinceHeight uint64, limit int) ([]types.PayloadInfo, error) + GetLatestHeightFunc func(ctx context.Context) (uint64, error) +} + +func (m *mockPayloadDB) GetPayloadsSince(ctx context.Context, sinceHeight uint64, limit int) ([]types.PayloadInfo, error) { + return m.GetPayloadsSinceFunc(ctx, sinceHeight, limit) +} + +func (m *mockPayloadDB) GetLatestHeight(ctx context.Context) (uint64, error) { + return m.GetLatestHeightFunc(ctx) +} + +type mockBlockBuilder struct { + executionHead *types.ExecutionHead + SetExecutionHeadFromRPCFunc func(ctx context.Context) error + FinalizeBlockFunc func(ctx context.Context, payloadIDStr, executionPayloadStr, msgID string) error +} + +func (m *mockBlockBuilder) GetExecutionHead() *types.ExecutionHead { + return m.executionHead +} + +func (m *mockBlockBuilder) SetExecutionHeadFromRPC(ctx context.Context) error { + return m.SetExecutionHeadFromRPCFunc(ctx) +} + +func (m *mockBlockBuilder) FinalizeBlock(ctx context.Context, payloadIDStr, executionPayloadStr, msgID string) error { + return m.FinalizeBlockFunc(ctx, payloadIDStr, executionPayloadStr, msgID) +} + +func newMockBlockBuilder() *mockBlockBuilder { + return &mockBlockBuilder{ + executionHead: nil, + SetExecutionHeadFromRPCFunc: func(ctx context.Context) error { + return nil + }, + FinalizeBlockFunc: func(ctx context.Context, payloadIDStr, executionPayloadStr, msgID string) error { + return nil + }, + } +} + +func TestFollower_syncFromSharedDB(t *testing.T) { + t.Parallel() + + lastProcessed := uint64(500) + latest := uint64(550) + + logger := util.NewTestLogger(io.Discard) + payloadRepo := &mockPayloadDB{ + GetLatestHeightFunc: func(ctx context.Context) (uint64, error) { + return latest, nil + }, + GetPayloadsSinceFunc: func(ctx context.Context, sinceHeight uint64, limit int) ([]types.PayloadInfo, error) { + if sinceHeight != 501 { + t.Fatal("unexpected sinceHeight", sinceHeight) + } + toReturn := []types.PayloadInfo{} + for i := sinceHeight; i <= latest; i++ { + toReturn = append(toReturn, types.PayloadInfo{BlockHeight: i}) + } + return toReturn, nil + }, + } + syncBatchSize := uint64(100) + + bb := newMockBlockBuilder() + follower, err := follower.NewFollower(logger, payloadRepo, syncBatchSize, bb) + if err != nil { + t.Fatal(err) + } + + bb.SetExecutionHeadFromRPCFunc = func(ctx context.Context) error { + bb.executionHead = &types.ExecutionHead{BlockHeight: lastProcessed} + return nil + } + + errCh := make(chan error) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + err := follower.SyncFromSharedDB(ctx) + if err != nil { + errCh <- err + } + }() + + payloadCh := follower.PayloadCh() + + // expect 50 payloads + received := 0 + expectedBlockHeight := uint64(501) + for received < 50 { + select { + case p := <-payloadCh: + if p == (types.PayloadInfo{}) { + t.Fatalf("received zero payload for expected block height %d", expectedBlockHeight) + } + if p.BlockHeight != expectedBlockHeight { + t.Fatalf("expected payload height %d, got %d", expectedBlockHeight, p.BlockHeight) + } + expectedBlockHeight++ + received++ + case <-time.After(1 * time.Second): + t.Fatalf("timeout waiting for payload for expected block height %d", expectedBlockHeight) + } + } + if received != 50 { + t.Fatalf("expected 50 payloads, got %d", received) + } + + // No more than 50 + select { + case err := <-errCh: + t.Fatal(err) + case <-payloadCh: + t.Fatal("received unexpected payload") + case <-time.After(1 * time.Second): + } +} + +func TestFollower_syncFromSharedDB_NoRows(t *testing.T) { + t.Parallel() + + attempts := 0 + logger := util.NewTestLogger(io.Discard) + payloadRepo := &mockPayloadDB{ + GetLatestHeightFunc: func(ctx context.Context) (uint64, error) { + if attempts < 3 { + attempts++ + return 0, nil + } + return 15, nil + }, + GetPayloadsSinceFunc: func(ctx context.Context, sinceHeight uint64, limit int) ([]types.PayloadInfo, error) { + if sinceHeight != 1 { + return nil, fmt.Errorf("unexpected sinceHeight %d", sinceHeight) + } + if limit != 15 { + return nil, fmt.Errorf("unexpected limit %d", limit) + } + toReturn := []types.PayloadInfo{} + for i := 1; i <= 15; i++ { + toReturn = append(toReturn, types.PayloadInfo{BlockHeight: uint64(i)}) + } + return toReturn, nil + }, + } + syncBatchSize := uint64(100) + + bb := newMockBlockBuilder() + follower, err := follower.NewFollower(logger, payloadRepo, syncBatchSize, bb) + if err != nil { + t.Fatal(err) + } + + bb.SetExecutionHeadFromRPCFunc = func(ctx context.Context) error { + bb.executionHead = &types.ExecutionHead{BlockHeight: 0} // Only genesis block is available + return nil + } + + errCh := make(chan error) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + err := follower.SyncFromSharedDB(ctx) + if err != nil { + errCh <- err + } + }() + + payloadCh := follower.PayloadCh() + + // expect 15 payloads + received := 0 + expectedBlockHeight := uint64(1) + for received < 15 { + select { + case p := <-payloadCh: + if p == (types.PayloadInfo{}) { + t.Fatalf("received zero payload at %d", expectedBlockHeight) + } + if p.BlockHeight != expectedBlockHeight { + t.Fatalf("expected payload height %d, got %d", expectedBlockHeight, p.BlockHeight) + } + expectedBlockHeight++ + received++ + case <-time.After(10 * time.Second): + t.Fatalf("timeout waiting for payload %d", expectedBlockHeight) + } + } + if received != 15 { + t.Fatalf("expected 15 payloads, got %d", received) + } + + // No more than 15 + select { + case err := <-errCh: + t.Fatal(err) + case <-payloadCh: + t.Fatal("received unexpected payload") + case <-time.After(1 * time.Second): + } +} + +func TestFollower_syncFromSharedDB_MultipleIterations(t *testing.T) { + t.Parallel() + + lastProcessed := uint64(200) + latest := uint64(250) + + logger := util.NewTestLogger(io.Discard) + + numGetLatestHeightCalls := 0 + numGetPayloadsCalls := 0 + payloadRepo := &mockPayloadDB{ + GetLatestHeightFunc: func(ctx context.Context) (uint64, error) { + numGetLatestHeightCalls++ + if numGetLatestHeightCalls > 3 { + return 253, nil // Simulate that DB has only been updated up to block 253 + } + return latest + uint64(numGetLatestHeightCalls), nil + }, + GetPayloadsSinceFunc: func(ctx context.Context, sinceHeight uint64, limit int) ([]types.PayloadInfo, error) { + numGetPayloadsCalls++ + switch numGetPayloadsCalls { + case 1: + // First iteration should request payloads from 201 to 220 + if sinceHeight != 201 { + t.Fatal("unexpected sinceHeight", sinceHeight) + } + if limit != 20 { + t.Fatal("unexpected limit", limit) + } + case 2: + // Second iteration should request payloads from 221 to 240 + if sinceHeight != 221 { + t.Fatal("unexpected sinceHeight", sinceHeight) + } + if limit != 20 { + t.Fatal("unexpected limit", limit) + } + case 3: + // Third iteration should request payloads from 241 to 253 + if sinceHeight != 241 { + t.Fatal("unexpected sinceHeight", sinceHeight) + } + if limit != 13 { + t.Fatal("unexpected limit", limit) + } + default: + t.Fatal("unexpected numGetPayloadsCalls", numGetPayloadsCalls) + return nil, nil + } + toReturn := []types.PayloadInfo{} + for i := sinceHeight; i < sinceHeight+uint64(limit); i++ { + toReturn = append(toReturn, types.PayloadInfo{BlockHeight: i}) + } + return toReturn, nil + }, + } + syncBatchSize := uint64(20) + + bb := newMockBlockBuilder() + follower, err := follower.NewFollower(logger, payloadRepo, syncBatchSize, bb) + if err != nil { + t.Fatal(err) + } + + bb.SetExecutionHeadFromRPCFunc = func(ctx context.Context) error { + bb.executionHead = &types.ExecutionHead{BlockHeight: lastProcessed} + return nil + } + + errCh := make(chan error) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + err := follower.SyncFromSharedDB(ctx) + if err != nil { + errCh <- err + } + }() + + payloadCh := follower.PayloadCh() + + // expect payloads up to 253 + received := 0 + expectedBlockHeight := uint64(201) + for received < 53 { + select { + case p := <-payloadCh: + if p == (types.PayloadInfo{}) { + t.Fatalf("received zero payload at %d", expectedBlockHeight) + } + if p.BlockHeight != expectedBlockHeight { + t.Fatalf("expected payload height %d, got %d", expectedBlockHeight, p.BlockHeight) + } + expectedBlockHeight++ + received++ + case <-time.After(10 * time.Second): + t.Fatalf("timeout waiting for payload %d", expectedBlockHeight) + } + } + if received != 53 { + t.Fatalf("expected 53 payloads, got %d", received) + } + + // No more than 53 + select { + case err := <-errCh: + t.Fatal(err) + case <-payloadCh: + t.Fatal("received unexpected payload") + case <-time.After(1 * time.Second): + } +} + +func TestFollower_Start_SimulateNewChain(t *testing.T) { + t.Parallel() + + logger := util.NewTestLogger(io.Discard) + + getLatestCalls := 0 + + payloadRepo := &mockPayloadDB{ + GetLatestHeightFunc: func(ctx context.Context) (uint64, error) { + getLatestCalls++ + if getLatestCalls <= 3 { + return 0, nil + } + return 1, nil + }, + GetPayloadsSinceFunc: func(ctx context.Context, sinceHeight uint64, limit int) ([]types.PayloadInfo, error) { + if sinceHeight != 1 { + t.Fatalf("unexpected sinceHeight %d", sinceHeight) + } + if limit != 1 { + t.Fatalf("unexpected limit %d", limit) + } + return []types.PayloadInfo{{BlockHeight: 1}}, nil + }, + } + + syncBatchSize := uint64(100) + + bb := newMockBlockBuilder() + follower, err := follower.NewFollower(logger, payloadRepo, syncBatchSize, bb) + if err != nil { + t.Fatal(err) + } + + bb.SetExecutionHeadFromRPCFunc = func(ctx context.Context) error { + bb.executionHead = &types.ExecutionHead{BlockHeight: 0} // Only genesis block is available + return nil + } + + bb.FinalizeBlockFunc = func(ctx context.Context, payloadIDStr, executionPayloadStr, msgID string) error { + bb.executionHead = &types.ExecutionHead{BlockHeight: 1} + return nil + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := follower.Start(ctx) + + deadline := time.Now().Add(5 * time.Second) + for { + lp := follower.GetExecutionHead() + if lp == nil { + continue + } + if lp.BlockHeight >= 1 { + break + } + if time.Now().After(deadline) { + t.Fatalf("timeout waiting for first block to be processed") + } + time.Sleep(10 * time.Millisecond) + } + + finalExecutionHead := follower.GetExecutionHead() + if finalExecutionHead == nil { + t.Fatal("execution head is nil") + } + if finalExecutionHead.BlockHeight != 1 { + t.Fatalf("expected execution head block height to be 1, got %d", finalExecutionHead.BlockHeight) + } + + cancel() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for follower to stop") + } +} + +func TestFollower_Start_SyncExistingChain(t *testing.T) { + t.Parallel() + + logger := util.NewTestLogger(io.Discard) + + lastProcessed := uint64(450) + latest := uint64(700) + + payloadRepo := &mockPayloadDB{ + GetLatestHeightFunc: func(ctx context.Context) (uint64, error) { + return latest, nil + }, + GetPayloadsSinceFunc: func(ctx context.Context, sinceHeight uint64, limit int) ([]types.PayloadInfo, error) { + toReturn := make([]types.PayloadInfo, 0, limit) + for i := uint64(0); i < uint64(limit); i++ { + toReturn = append(toReturn, types.PayloadInfo{ + BlockHeight: sinceHeight + i, + // Encode just the block height + ExecutionPayload: fmt.Sprintf("%d", sinceHeight+i), + }) + } + return toReturn, nil + }, + } + + syncBatchSize := uint64(20) + + bb := newMockBlockBuilder() + follower, err := follower.NewFollower(logger, payloadRepo, syncBatchSize, bb) + if err != nil { + t.Fatal(err) + } + + bb.SetExecutionHeadFromRPCFunc = func(ctx context.Context) error { + bb.executionHead = &types.ExecutionHead{BlockHeight: lastProcessed} + return nil + } + + bb.FinalizeBlockFunc = func(ctx context.Context, payloadIDStr, executionPayloadStr, msgID string) error { + // decode block num from executionPayloadStr + blockNum, err := strconv.ParseUint(executionPayloadStr, 10, 64) + if err != nil { + return err + } + bb.executionHead = &types.ExecutionHead{BlockHeight: blockNum} + return nil + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := follower.Start(ctx) + + deadline := time.Now().Add(5 * time.Second) + for { + lp := follower.GetExecutionHead() + if lp == nil { + continue + } + if lp.BlockHeight >= 700 { + break + } + if time.Now().After(deadline) { + t.Fatalf("timeout waiting for sync + steady-state; last processed: %d", lp) + } + time.Sleep(10 * time.Millisecond) + } + + finalExecutionHead := follower.GetExecutionHead() + if finalExecutionHead == nil { + t.Fatal("execution head is nil") + } + if finalExecutionHead.BlockHeight != 700 { + t.Fatalf("expected execution head block height to be 700, got %d", finalExecutionHead.BlockHeight) + } + + cancel() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for follower to stop") + } +} diff --git a/cl/singlenode/payloadstore/postgres.go b/cl/singlenode/payloadstore/postgres.go index 50b50b7ed..e6f50f6b1 100644 --- a/cl/singlenode/payloadstore/postgres.go +++ b/cl/singlenode/payloadstore/postgres.go @@ -71,6 +71,25 @@ func NewPostgresRepository(ctx context.Context, dsn string, logger *slog.Logger) return &PostgresRepository{db: db, logger: l}, nil } +func NewPostgresFollower(ctx context.Context, dsn string, logger *slog.Logger) (*PostgresRepository, error) { + l := logger.With("component", "postgresFollower") + + db, err := sql.Open("postgres", dsn) + if err != nil { + return nil, fmt.Errorf("failed to open postgres connection: %w", err) + } + + pingCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + if err := db.PingContext(pingCtx); err != nil { + _ = db.Close() + return nil, fmt.Errorf("failed to ping postgres: %w", err) + } + + l.Info("Connected to PostgreSQL") + return &PostgresRepository{db: db, logger: l}, nil +} + // SavePayload saves the payload information to the database. func (r *PostgresRepository) SavePayload(ctx context.Context, info *types.PayloadInfo) error { query := ` @@ -259,6 +278,25 @@ func (r *PostgresRepository) GetLatestPayload(ctx context.Context) (*types.Paylo return &payload, nil } +func (r *PostgresRepository) GetLatestHeight(ctx context.Context) (uint64, error) { + query := ` + SELECT MAX(block_height) FROM execution_payloads; + ` + queryCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + var n sql.NullInt64 + if err := r.db.QueryRowContext(queryCtx, query).Scan(&n); err != nil { + // MAX should never return sql.ErrNoRow, always bubble errors + return 0, err + } + if !n.Valid { + // Empty table -> new chain + return 0, nil + } + return uint64(n.Int64), nil +} + // Close closes the database connection. func (r *PostgresRepository) Close() error { if r.db != nil {