diff --git a/p2p/pkg/node/node.go b/p2p/pkg/node/node.go index 6f10d9ef6..9fd054ec7 100644 --- a/p2p/pkg/node/node.go +++ b/p2p/pkg/node/node.go @@ -56,6 +56,7 @@ import ( providerapi "github.com/primev/mev-commit/p2p/pkg/rpc/provider" validatorapi "github.com/primev/mev-commit/p2p/pkg/rpc/validator" "github.com/primev/mev-commit/p2p/pkg/signer" + "github.com/primev/mev-commit/p2p/pkg/stakemanager" "github.com/primev/mev-commit/p2p/pkg/storage" inmem "github.com/primev/mev-commit/p2p/pkg/storage/inmem" pebblestorage "github.com/primev/mev-commit/p2p/pkg/storage/pebble" @@ -316,14 +317,31 @@ func NewNode(opts *Options) (*Node, error) { keysStore := keysstore.New(store) - p2pSvc, err := libp2p.New(&libp2p.Options{ - KeySigner: opts.KeySigner, - Secret: opts.Secret, - PeerType: peerType, - Register: &providerStakeChecker{ - providerRegistry: providerRegistry, - from: opts.KeySigner.GetAddress(), + stakeMgr, err := stakemanager.NewStakeManager( + opts.Logger.With("component", "stakemanager"), + opts.KeySigner.GetAddress(), + evtMgr, + providerRegistry, + notificationsSvc, + ) + if err != nil { + opts.Logger.Error("failed to create stake manager", "error", err) + return nil, errors.Join(err, nd.Close()) + } + + startables = append( + startables, + StartableObjWithDesc{ + Desc: "stakemanager", + Startable: stakeMgr, }, + ) + + p2pSvc, err := libp2p.New(&libp2p.Options{ + KeySigner: opts.KeySigner, + Secret: opts.Secret, + PeerType: peerType, + Register: stakeMgr, Store: keysStore, Logger: opts.Logger.With("component", "p2p"), ListenPort: opts.P2PPort, @@ -957,30 +975,6 @@ func (f StartableFunc) Start(ctx context.Context) <-chan struct{} { return f(ctx) } -type providerStakeChecker struct { - providerRegistry *providerregistry.Providerregistry - from common.Address -} - -func (p *providerStakeChecker) CheckProviderRegistered(ctx context.Context, provider common.Address) bool { - callOpts := &bind.CallOpts{ - From: p.from, - Context: ctx, - } - - minStake, err := p.providerRegistry.MinStake(callOpts) - if err != nil { - return false - } - - stake, err := p.providerRegistry.GetProviderStake(callOpts, provider) - if err != nil { - return false - } - - return stake.Cmp(minStake) >= 0 -} - type progressStore struct { contractRPC *ethclient.Client lastBlock atomic.Uint64 diff --git a/p2p/pkg/notifications/notifications.go b/p2p/pkg/notifications/notifications.go index f319b845a..15d89c7b5 100644 --- a/p2p/pkg/notifications/notifications.go +++ b/p2p/pkg/notifications/notifications.go @@ -11,6 +11,10 @@ const ( TopicPeerDisconnected Topic = "peer_disconnected" TopicValidatorOptedIn Topic = "validator_opted_in" TopicEpochValidatorsOptedIn Topic = "epoch_validators_opted_in" + TopicProviderRegistered Topic = "provider_registered" + TopicProviderDeposit Topic = "provider_deposit" + TopicProviderSlashed Topic = "provider_slashed" + TopicProviderDeregistered Topic = "provider_deregistered" ) func IsTopicValid(topic Topic) bool { diff --git a/p2p/pkg/stakemanager/stakemanager.go b/p2p/pkg/stakemanager/stakemanager.go new file mode 100644 index 000000000..60ec36202 --- /dev/null +++ b/p2p/pkg/stakemanager/stakemanager.go @@ -0,0 +1,313 @@ +package stakemanager + +import ( + "context" + "fmt" + "log/slog" + "math/big" + "sync" + "sync/atomic" + + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/common" + lru "github.com/hashicorp/golang-lru/v2" + providerregistry "github.com/primev/mev-commit/contracts-abi/clients/ProviderRegistry" + "github.com/primev/mev-commit/p2p/pkg/notifications" + "github.com/primev/mev-commit/x/contracts/events" + "golang.org/x/sync/errgroup" +) + +type ProviderRegistryContract interface { + GetProviderStake(*bind.CallOpts, common.Address) (*big.Int, error) + MinStake(*bind.CallOpts) (*big.Int, error) +} + +type StakeManager struct { + owner common.Address + evtMgr events.EventManager + providerRegistry ProviderRegistryContract + notifier notifications.Notifier + stakeMu sync.RWMutex + stakes *lru.Cache[common.Address, *big.Int] + unstakeReqs *lru.Cache[common.Address, struct{}] + minStake atomic.Pointer[big.Int] + logger *slog.Logger +} + +func NewStakeManager( + logger *slog.Logger, + owner common.Address, + evtMgr events.EventManager, + providerRegistry ProviderRegistryContract, + notifier notifications.Notifier, +) (*StakeManager, error) { + minStake, err := providerRegistry.MinStake(&bind.CallOpts{ + From: owner, + }) + if err != nil { + return nil, fmt.Errorf("failed to get min stake: %w", err) + } + stakes, err := lru.New[common.Address, *big.Int](1000) + if err != nil { + return nil, fmt.Errorf("failed to create stakes cache: %w", err) + } + unstakeReqs, err := lru.New[common.Address, struct{}](1000) + if err != nil { + return nil, fmt.Errorf("failed to create unstake requests cache: %w", err) + } + sm := &StakeManager{ + providerRegistry: providerRegistry, + evtMgr: evtMgr, + logger: logger, + owner: owner, + stakes: stakes, + unstakeReqs: unstakeReqs, + notifier: notifier, + } + sm.minStake.Store(minStake) + return sm, nil +} + +func (sm *StakeManager) Start(ctx context.Context) <-chan struct{} { + doneChan := make(chan struct{}) + + eg, egCtx := errgroup.WithContext(ctx) + + ch1 := make(chan *providerregistry.ProviderregistryProviderRegistered, 10) + ev1 := events.NewChannelEventHandler(egCtx, "ProviderRegistered", ch1) + + ch2 := make(chan *providerregistry.ProviderregistryFundsSlashed, 10) + ev2 := events.NewChannelEventHandler(egCtx, "FundsSlashed", ch2) + + ch3 := make(chan *providerregistry.ProviderregistryFundsDeposited, 10) + ev3 := events.NewChannelEventHandler(egCtx, "FundsDeposited", ch3) + + ch4 := make(chan *providerregistry.ProviderregistryUnstake, 10) + ev4 := events.NewChannelEventHandler(egCtx, "Unstake", ch4) + + ch5 := make(chan *providerregistry.ProviderregistryMinStakeUpdated, 10) + ev5 := events.NewChannelEventHandler(egCtx, "MinStakeUpdated", ch5) + + sub, err := sm.evtMgr.Subscribe(ev1, ev2, ev3, ev4, ev5) + if err != nil { + close(doneChan) + return doneChan + } + + eg.Go(func() error { + defer sub.Unsubscribe() + + select { + case <-egCtx.Done(): + sm.logger.Info("event subscription context done") + return nil + case err := <-sub.Err(): + return fmt.Errorf("error in event subscription: %w", err) + } + }) + + eg.Go(func() error { + for { + select { + case <-egCtx.Done(): + sm.logger.Info("clear balances set balances context done") + return nil + case evt := <-ch1: + // handle ProviderRegistered event + sm.stakeMu.RLock() + _, found := sm.stakes.Get(evt.Provider) + sm.stakeMu.RUnlock() + + if !found { + sm.stakeMu.Lock() + _ = sm.stakes.Add(evt.Provider, evt.StakedAmount) + _ = sm.unstakeReqs.Remove(evt.Provider) + sm.stakeMu.Unlock() + + sm.notifier.Notify( + notifications.NewNotification( + notifications.TopicProviderRegistered, + map[string]any{ + "provider": evt.Provider.Hex(), + }, + ), + ) + } + case evt := <-ch2: + // handle FundsSlashed event + sm.stakeMu.RLock() + stake, found := sm.stakes.Get(evt.Provider) + sm.stakeMu.RUnlock() + + hasEnoughStake := false + + if !found { + // if not tracked locally, get the latest value from on-chain + s, err := sm.providerRegistry.GetProviderStake(&bind.CallOpts{ + From: sm.owner, + Context: egCtx, + }, evt.Provider) + if err != nil { + sm.logger.Error("failed to get provider stake", "error", err) + continue + } + + sm.stakeMu.Lock() + _ = sm.stakes.Add(evt.Provider, s) + sm.stakeMu.Unlock() + hasEnoughStake = s.Cmp(sm.minStake.Load()) >= 0 + } else { + // update the local value + newStake := new(big.Int).Sub(stake, evt.Amount) + sm.stakeMu.Lock() + _ = sm.stakes.Add(evt.Provider, newStake) + sm.stakeMu.Unlock() + hasEnoughStake = newStake.Cmp(sm.minStake.Load()) >= 0 + } + + sm.notifier.Notify( + notifications.NewNotification( + notifications.TopicProviderSlashed, + map[string]any{ + "provider": evt.Provider.Hex(), + "amount": evt.Amount, + "hasEnoughStake": hasEnoughStake, + }, + ), + ) + case evt := <-ch3: + // handle FundsDeposited event + sm.stakeMu.RLock() + stake, found := sm.stakes.Get(evt.Provider) + sm.stakeMu.RUnlock() + + hasEnoughStake := false + + if !found { + // if not tracked locally, get the latest value from on-chain + s, err := sm.providerRegistry.GetProviderStake(&bind.CallOpts{ + From: sm.owner, + Context: egCtx, + }, evt.Provider) + if err != nil { + sm.logger.Error("failed to get provider stake", "error", err) + continue + } + + sm.stakeMu.Lock() + _ = sm.stakes.Add(evt.Provider, s) + sm.stakeMu.Unlock() + hasEnoughStake = s.Cmp(sm.minStake.Load()) >= 0 + } else { + // update the local value + newStake := new(big.Int).Add(stake, evt.Amount) + sm.stakeMu.Lock() + _ = sm.stakes.Add(evt.Provider, newStake) + sm.stakeMu.Unlock() + hasEnoughStake = newStake.Cmp(sm.minStake.Load()) >= 0 + } + + sm.notifier.Notify( + notifications.NewNotification( + notifications.TopicProviderDeposit, + map[string]any{ + "provider": evt.Provider.Hex(), + "amount": evt.Amount, + "hasEnoughStake": hasEnoughStake, + }, + ), + ) + case evt := <-ch4: + // handle Unstake event + // even after unstaking, the provider stake is still non-zero till + // the withdraw request is processed. So, we keep track of the unstake + // requests in order to avoid querying the on-chain value for the provider + // stake. + sm.stakeMu.RLock() + _, found := sm.unstakeReqs.Get(evt.Provider) + sm.stakeMu.RUnlock() + + if !found { + sm.stakeMu.Lock() + _ = sm.unstakeReqs.Add(evt.Provider, struct{}{}) + _ = sm.stakes.Remove(evt.Provider) + sm.stakeMu.Unlock() + + sm.notifier.Notify( + notifications.NewNotification( + notifications.TopicProviderDeregistered, + map[string]any{ + "provider": evt.Provider.Hex(), + }, + ), + ) + } + case evt := <-ch5: + // handle MinStakeUpdated event + sm.minStake.Store(evt.NewMinStake) + } + } + }) + + go func() { + defer close(doneChan) + if err := eg.Wait(); err != nil { + sm.logger.Error("error in StakeManager", "error", err) + } + }() + + return doneChan +} + +func (sm *StakeManager) GetStake(ctx context.Context, provider common.Address) (*big.Int, error) { + sm.stakeMu.RLock() + _, found := sm.unstakeReqs.Get(provider) + sm.stakeMu.RUnlock() + if found { + sm.logger.Debug("provider is in unstake requests, returning zero stake") + return big.NewInt(0), nil + } + + sm.stakeMu.RLock() + stake, found := sm.stakes.Get(provider) + sm.stakeMu.RUnlock() + if !found { + stake, err := sm.providerRegistry.GetProviderStake(&bind.CallOpts{ + From: sm.owner, + Context: ctx, + }, provider) + if err != nil { + return nil, fmt.Errorf("failed to get provider stake: %w", err) + } + + // if stake is zero, it means the provider is not registered + if stake.Cmp(big.NewInt(0)) > 0 { + sm.stakeMu.Lock() + _ = sm.stakes.Add(provider, stake) + sm.stakeMu.Unlock() + sm.logger.Debug( + "fetched provider stake from on-chain", + "stake", stake.String(), + "minStake", sm.minStake.Load().String(), + ) + } + + return stake, nil + } + + return stake, nil +} + +func (sm *StakeManager) MinStake() *big.Int { + return new(big.Int).Set(sm.minStake.Load()) +} + +func (sm *StakeManager) CheckProviderRegistered(ctx context.Context, provider common.Address) bool { + stake, err := sm.GetStake(ctx, provider) + if err != nil { + sm.logger.Error("failed to get provider stake", "error", err) + return false + } + + return stake.Cmp(sm.minStake.Load()) >= 0 +} diff --git a/p2p/pkg/stakemanager/stakemanager_test.go b/p2p/pkg/stakemanager/stakemanager_test.go new file mode 100644 index 000000000..d82224e4e --- /dev/null +++ b/p2p/pkg/stakemanager/stakemanager_test.go @@ -0,0 +1,378 @@ +package stakemanager_test + +import ( + "context" + "math/big" + "os" + "strings" + "testing" + "time" + + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + providerregistry "github.com/primev/mev-commit/contracts-abi/clients/ProviderRegistry" + "github.com/primev/mev-commit/p2p/pkg/notifications" + "github.com/primev/mev-commit/p2p/pkg/stakemanager" + "github.com/primev/mev-commit/x/contracts/events" + "github.com/primev/mev-commit/x/util" +) + +type mockProviderRegistry struct { + providerStakes map[common.Address]*big.Int + minStake *big.Int +} + +func (m *mockProviderRegistry) GetProviderStake(_ *bind.CallOpts, addr common.Address) (*big.Int, error) { + stake, found := m.providerStakes[addr] + if !found { + return big.NewInt(0), nil + } + + return stake, nil +} + +func (m *mockProviderRegistry) MinStake(_ *bind.CallOpts) (*big.Int, error) { + return new(big.Int).Set(m.minStake), nil +} + +type mockNotifier struct { + evt chan *notifications.Notification +} + +func (m *mockNotifier) Notify(n *notifications.Notification) { + m.evt <- n +} + +func TestStakeManager(t *testing.T) { + t.Parallel() + + providerStakes := map[common.Address]*big.Int{ + common.HexToAddress("0x456"): big.NewInt(100), + common.HexToAddress("0x789"): big.NewInt(200), + common.HexToAddress("0xabc"): big.NewInt(500), + } + + owner := common.HexToAddress("0x123") + providerRegistry := &mockProviderRegistry{ + providerStakes: providerStakes, + minStake: big.NewInt(10), + } + + notifier := &mockNotifier{ + evt: make(chan *notifications.Notification, 10), + } + + prABI, err := abi.JSON(strings.NewReader(providerregistry.ProviderregistryABI)) + if err != nil { + t.Fatalf("failed to parse provider registry ABI: %v", err) + } + evtMgr := events.NewListener( + util.NewTestLogger(os.Stdout), + &prABI, + ) + + stakeMgr, err := stakemanager.NewStakeManager( + util.NewTestLogger(os.Stdout), + owner, + evtMgr, + providerRegistry, + notifier, + ) + if err != nil { + t.Fatalf("failed to create stake manager: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + + done := stakeMgr.Start(ctx) + + minStake := stakeMgr.MinStake() + if minStake.Cmp(big.NewInt(10)) != 0 { + t.Errorf("unexpected min stake: %v", minStake) + } + + // Simulates getting the value from the contract + stake, err := stakeMgr.GetStake(context.Background(), common.HexToAddress("0x456")) + if err != nil { + t.Fatalf("failed to get stake: %v", err) + } + + if stake.Cmp(big.NewInt(100)) != 0 { + t.Errorf("unexpected stake: %v", stake) + } + + if !stakeMgr.CheckProviderRegistered(context.Background(), common.HexToAddress("0x456")) { + t.Errorf("provider should be registered") + } + + err = publishMinStakeUpdated( + evtMgr, + &prABI, + providerregistry.ProviderregistryMinStakeUpdated{ + NewMinStake: big.NewInt(20), + }, + ) + if err != nil { + t.Fatalf("failed to publish min stake updated: %v", err) + } + + start := time.Now() + for stakeMgr.MinStake().Cmp(big.NewInt(20)) != 0 { + if time.Since(start) > 5*time.Second { + t.Fatalf("timed out waiting for min stake to update") + } + time.Sleep(100 * time.Millisecond) + } + + for addr, stake := range providerStakes { + err = publishProviderRegistered( + evtMgr, + &prABI, + providerregistry.ProviderregistryProviderRegistered{ + Provider: addr, + StakedAmount: stake, + }, + ) + if err != nil { + t.Fatalf("failed to publish provider registered: %v", err) + } + } + + count := 0 + for e := range notifier.evt { + if e.Topic() == notifications.TopicProviderRegistered { + count++ + } + // We will get the event only for unknown providers + if count == len(providerStakes)-1 { + break + } + } + + for addr, stake := range providerStakes { + providerStakes[addr] = new(big.Int).Sub(stake, big.NewInt(10)) + err = publishFundsSlashed( + evtMgr, + &prABI, + providerregistry.ProviderregistryFundsSlashed{ + Provider: addr, + Amount: big.NewInt(10), + }, + ) + if err != nil { + t.Fatalf("failed to publish funds slashed: %v", err) + } + } + + count = 0 + for e := range notifier.evt { + if e.Topic() == notifications.TopicProviderSlashed { + count++ + } + addr := common.HexToAddress(e.Value()["provider"].(string)) + stake, err := stakeMgr.GetStake(context.Background(), addr) + if err != nil { + t.Fatalf("failed to get stake: %v", err) + } + if stake.Cmp(providerStakes[addr]) != 0 { + t.Errorf("unexpected stake: %v", stake) + } + if count == len(providerStakes) { + break + } + } + + for addr, stake := range providerStakes { + providerStakes[addr] = new(big.Int).Add(stake, big.NewInt(10)) + err = publishFundsDeposited( + evtMgr, + &prABI, + providerregistry.ProviderregistryFundsDeposited{ + Provider: addr, + Amount: big.NewInt(10), + }, + ) + if err != nil { + t.Fatalf("failed to publish funds deposited: %v", err) + } + } + + count = 0 + for e := range notifier.evt { + if e.Topic() == notifications.TopicProviderDeposit { + count++ + } + addr := common.HexToAddress(e.Value()["provider"].(string)) + stake, err := stakeMgr.GetStake(context.Background(), addr) + if err != nil { + t.Fatalf("failed to get stake: %v", err) + } + if stake.Cmp(providerStakes[addr]) != 0 { + t.Errorf("unexpected stake: %v", stake) + } + if count == len(providerStakes) { + break + } + } + + for addr, _ := range providerStakes { + if !stakeMgr.CheckProviderRegistered(context.Background(), addr) { + t.Errorf("provider should be registered") + } + } + + err = publishUnstake( + evtMgr, + &prABI, + providerregistry.ProviderregistryUnstake{ + Provider: common.HexToAddress("0x456"), + Timestamp: big.NewInt(0), + }, + ) + if err != nil { + t.Fatalf("failed to publish unstake: %v", err) + } + + e := <-notifier.evt + if e.Topic() != notifications.TopicProviderDeregistered { + t.Errorf("unexpected notification: %v", e) + } + + if stakeMgr.CheckProviderRegistered(context.Background(), common.HexToAddress("0x456")) { + t.Errorf("provider should not be registered") + } + + cancel() + <-done +} + +func publishProviderRegistered( + evtMgr events.EventManager, + prABI *abi.ABI, + ev providerregistry.ProviderregistryProviderRegistered, +) error { + event := prABI.Events["ProviderRegistered"] + buf, err := event.Inputs.NonIndexed().Pack( + ev.StakedAmount, + ) + if err != nil { + return err + } + + // Creating a Log object + testLog := types.Log{ + Topics: []common.Hash{ + event.ID, // The first topic is the hash of the event signature + common.HexToHash(ev.Provider.Hex()), // The next topics are the indexed event parameters + }, + // Non-indexed parameters are stored in the Data field + Data: buf, + } + + evtMgr.PublishLogEvent(context.Background(), testLog) + return nil +} + +func publishFundsSlashed( + evMgr events.EventManager, + prABI *abi.ABI, + ev providerregistry.ProviderregistryFundsSlashed, +) error { + event := prABI.Events["FundsSlashed"] + buf, err := event.Inputs.NonIndexed().Pack( + ev.Amount, + ) + if err != nil { + return err + } + + // Creating a Log object + testLog := types.Log{ + Topics: []common.Hash{ + event.ID, // The first topic is the hash of the event signature + common.HexToHash(ev.Provider.Hex()), // The next topics are the indexed event parameters + }, + // Non-indexed parameters are stored in the Data field + Data: buf, + } + + evMgr.PublishLogEvent(context.Background(), testLog) + return nil +} + +func publishFundsDeposited( + evMgr events.EventManager, + prABI *abi.ABI, + ev providerregistry.ProviderregistryFundsDeposited, +) error { + event := prABI.Events["FundsDeposited"] + buf, err := event.Inputs.NonIndexed().Pack( + ev.Amount, + ) + if err != nil { + return err + } + + // Creating a Log object + testLog := types.Log{ + Topics: []common.Hash{ + event.ID, // The first topic is the hash of the event signature + common.HexToHash(ev.Provider.Hex()), // The next topics are the indexed event parameters + }, + // Non-indexed parameters are stored in the Data field + Data: buf, + } + + evMgr.PublishLogEvent(context.Background(), testLog) + return nil +} + +func publishUnstake( + evMgr events.EventManager, + prABI *abi.ABI, + ev providerregistry.ProviderregistryUnstake, +) error { + event := prABI.Events["Unstake"] + buf, err := event.Inputs.NonIndexed().Pack( + ev.Timestamp, + ) + if err != nil { + return err + } + + // Creating a Log object + testLog := types.Log{ + Topics: []common.Hash{ + event.ID, // The first topic is the hash of the event signature + common.HexToHash(ev.Provider.Hex()), // The next topics are the indexed event parameters + }, + // Non-indexed parameters are stored in the Data field + Data: buf, + } + + evMgr.PublishLogEvent(context.Background(), testLog) + return nil +} + +func publishMinStakeUpdated( + evMgr events.EventManager, + prABI *abi.ABI, + ev providerregistry.ProviderregistryMinStakeUpdated, +) error { + event := prABI.Events["MinStakeUpdated"] + + // Creating a Log object + testLog := types.Log{ + Topics: []common.Hash{ + event.ID, // The first topic is the hash of the event signature + common.BigToHash(ev.NewMinStake), + }, + // Non-indexed parameters are stored in the Data field + Data: nil, + } + + evMgr.PublishLogEvent(context.Background(), testLog) + return nil +} diff --git a/x/contracts/events/events.go b/x/contracts/events/events.go index a950a891e..d29296a8e 100644 --- a/x/contracts/events/events.go +++ b/x/contracts/events/events.go @@ -103,6 +103,15 @@ func (h *eventHandler[T]) topic() common.Hash { return h.topicID } +func NewChannelEventHandler[T any](ctx context.Context, name string, ch chan<- *T) EventHandler { + return NewEventHandler(name, func(obj *T) { + select { + case <-ctx.Done(): + case ch <- obj: + } + }) +} + // EventManager is an interface for subscribing to contract events. The EventHandler callback // is called when an event is received. The Subscription returned by the Subscribe // method can be used to unsubscribe from the event and also to receive any errors