diff --git a/cmd/mpcium/main.go b/cmd/mpcium/main.go index 27cb704..6d12286 100644 --- a/cmd/mpcium/main.go +++ b/cmd/mpcium/main.go @@ -129,7 +129,6 @@ func runNode(ctx context.Context, c *cli.Command) error { if err != nil { logger.Fatal("Failed to connect to NATS", err) } - defer natsConn.Close() pubsub := messaging.NewNATSPubSub(natsConn) keygenBroker, err := messaging.NewJetStreamBroker(ctx, natsConn, event.KeygenBrokerStream, []string{ @@ -162,7 +161,7 @@ func runNode(ctx context.Context, c *cli.Command) error { logger.Info("Node is running", "ID", nodeID, "name", nodeName) peerNodeIDs := GetPeerIDs(peers) - peerRegistry := mpc.NewRegistry(nodeID, peerNodeIDs, consulClient.KV(), directMessaging) + peerRegistry := mpc.NewRegistry(nodeID, peerNodeIDs, consulClient.KV(), directMessaging, pubsub, identityStore) mpcNode := mpc.NewNode( nodeID, @@ -176,9 +175,6 @@ func runNode(ctx context.Context, c *cli.Command) error { ) defer mpcNode.Close() - // ECDH session for DH key exchange - ecdhSession := mpcNode.GetECDHSession() - eventConsumer := eventconsumer.NewEventConsumer( mpcNode, pubsub, @@ -197,8 +193,8 @@ func runNode(ctx context.Context, c *cli.Command) error { timeoutConsumer.Run() defer timeoutConsumer.Close() - keygenConsumer := eventconsumer.NewKeygenConsumer(natsConn, keygenBroker, pubsub, peerRegistry) - signingConsumer := eventconsumer.NewSigningConsumer(natsConn, signingBroker, pubsub, peerRegistry) + keygenConsumer := eventconsumer.NewKeygenConsumer(natsConn, keygenBroker, pubsub, peerRegistry, genKeyResultQueue) + signingConsumer := eventconsumer.NewSigningConsumer(natsConn, signingBroker, pubsub, peerRegistry, singingResultQueue) // Make the node ready before starting the signing consumer if err := peerRegistry.Ready(); err != nil { @@ -206,12 +202,7 @@ func runNode(ctx context.Context, c *cli.Command) error { } logger.Info("[READY] Node is ready", "nodeID", nodeID) - logger.Info("Waiting for ECDH key exchange to complete...", "nodeID", nodeID) - if err := ecdhSession.WaitForExchangeComplete(); err != nil { - logger.Fatal("ECDH exchange failed", err) - } - - logger.Info("ECDH key exchange completed successfully, starting consumers...", "nodeID", nodeID) + logger.Info("Starting consumers", "nodeID", nodeID) appContext, cancel := context.WithCancel(context.Background()) //Setup signal handling to cancel context on termination signals. go func() { @@ -221,6 +212,11 @@ func runNode(ctx context.Context, c *cli.Command) error { logger.Warn("Shutdown signal received, canceling context...") cancel() + // Resign from peer registry first (before closing NATS) + if err := peerRegistry.Resign(); err != nil { + logger.Error("Failed to resign from peer registry", err) + } + // Gracefully close consumers if err := keygenConsumer.Close(); err != nil { logger.Error("Failed to close keygen consumer", err) @@ -229,10 +225,6 @@ func runNode(ctx context.Context, c *cli.Command) error { logger.Error("Failed to close signing consumer", err) } - if err := ecdhSession.Close(); err != nil { - logger.Error("Failed to close ECDH session", err) - } - err := natsConn.Drain() if err != nil { logger.Error("Failed to drain NATS connection", err) @@ -264,21 +256,6 @@ func runNode(ctx context.Context, c *cli.Command) error { logger.Info("Signing consumer finished successfully") }() - go func() { - for { - select { - case <-appContext.Done(): - return - case err := <-ecdhSession.ErrChan(): - if err != nil { - logger.Error("ECDH session error", err) - errChan <- fmt.Errorf("ecdh session error: %w", err) - return - } - } - } - }() - go func() { wg.Wait() logger.Info("All consumers have finished") diff --git a/pkg/event/types.go b/pkg/event/types.go index 5f9b41b..ac26ca9 100644 --- a/pkg/event/types.go +++ b/pkg/event/types.go @@ -91,6 +91,8 @@ const ( // Context and cancellation errors ErrorCodeContextCancelled ErrorCode = "ERROR_CONTEXT_CANCELLED" ErrorCodeOperationAborted ErrorCode = "ERROR_OPERATION_ABORTED" + ErrorCodeNotMajority ErrorCode = "ERROR_NOT_MAJORITY" + ErrorCodeClusterNotReady ErrorCode = "ERROR_CLUSTER_NOT_READY" ) // GetErrorCodeFromError attempts to categorize a generic error into a specific error code diff --git a/pkg/eventconsumer/keygen_consumer.go b/pkg/eventconsumer/keygen_consumer.go index 65d88d8..98d9f07 100644 --- a/pkg/eventconsumer/keygen_consumer.go +++ b/pkg/eventconsumer/keygen_consumer.go @@ -2,6 +2,8 @@ package eventconsumer import ( "context" + "encoding/json" + "errors" "fmt" "time" @@ -9,6 +11,7 @@ import ( "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" "github.com/fystack/mpcium/pkg/mpc" + "github.com/fystack/mpcium/pkg/types" "github.com/google/uuid" "github.com/nats-io/nats.go" "github.com/nats-io/nats.go/jetstream" @@ -31,22 +34,30 @@ type KeygenConsumer interface { // keygenConsumer implements KeygenConsumer. type keygenConsumer struct { - natsConn *nats.Conn - pubsub messaging.PubSub - jsBroker messaging.MessageBroker - peerRegistry mpc.PeerRegistry + natsConn *nats.Conn + pubsub messaging.PubSub + jsBroker messaging.MessageBroker + peerRegistry mpc.PeerRegistry + keygenResultQueue messaging.MessageQueue // jsSub holds the JetStream subscription, so it can be cleaned up during Close(). jsSub messaging.MessageSubscription } // NewKeygenConsumer returns a new instance of KeygenConsumer. -func NewKeygenConsumer(natsConn *nats.Conn, jsBroker messaging.MessageBroker, pubsub messaging.PubSub, peerRegistry mpc.PeerRegistry) KeygenConsumer { +func NewKeygenConsumer( + natsConn *nats.Conn, + jsBroker messaging.MessageBroker, + pubsub messaging.PubSub, + peerRegistry mpc.PeerRegistry, + keygenResultQueue messaging.MessageQueue, +) KeygenConsumer { return &keygenConsumer{ - natsConn: natsConn, - pubsub: pubsub, - jsBroker: jsBroker, - peerRegistry: peerRegistry, + natsConn: natsConn, + pubsub: pubsub, + jsBroker: jsBroker, + peerRegistry: peerRegistry, + keygenResultQueue: keygenResultQueue, } } @@ -60,6 +71,9 @@ func (sc *keygenConsumer) waitForAllPeersReadyToGenKey(ctx context.Context) erro for { select { case <-ctx.Done(): + if ctx.Err() == context.Canceled { + return nil + } return ctx.Err() case <-ticker.C: allPeersReady := sc.peerRegistry.ArePeersReady() @@ -80,6 +94,9 @@ func (sc *keygenConsumer) waitForAllPeersReadyToGenKey(ctx context.Context) erro func (sc *keygenConsumer) Run(ctx context.Context) error { // Wait for sufficient peers before starting to consume messages if err := sc.waitForAllPeersReadyToGenKey(ctx); err != nil { + if err == context.Canceled { + return nil + } return fmt.Errorf("failed to wait for sufficient peers: %w", err) } @@ -104,9 +121,22 @@ func (sc *keygenConsumer) Run(ctx context.Context) error { } func (sc *keygenConsumer) handleKeygenEvent(msg jetstream.Msg) { + raw := msg.Data() + var keygenMsg types.GenerateKeyMessage + sessionID := msg.Headers().Get("SessionID") + + err := json.Unmarshal(raw, &keygenMsg) + if err != nil { + logger.Error("SigningConsumer: Failed to unmarshal keygen message", err) + sc.handleKeygenError(keygenMsg, event.ErrorCodeUnmarshalFailure, err, sessionID) + _ = msg.Ack() + return + } if !sc.peerRegistry.ArePeersReady() { - logger.Warn("KeygenConsumer: Not all peers are ready to sign, skipping message processing") + logger.Warn("KeygenConsumer: Not all peers are ready to gen key, skipping message processing") + sc.handleKeygenError(keygenMsg, event.ErrorCodeClusterNotReady, errors.New("not all peers are ready"), sessionID) + _ = msg.Ack() return } @@ -161,6 +191,33 @@ func (sc *keygenConsumer) handleKeygenEvent(msg jetstream.Msg) { _ = msg.Nak() } +func (sc *keygenConsumer) handleKeygenError(keygenMsg types.GenerateKeyMessage, errorCode event.ErrorCode, err error, sessionID string) { + keygenResult := event.KeygenResultEvent{ + ResultType: event.ResultTypeError, + ErrorCode: string(errorCode), + WalletID: keygenMsg.WalletID, + ErrorReason: err.Error(), + } + + keygenResultBytes, err := json.Marshal(keygenResult) + if err != nil { + logger.Error("Failed to marshal keygen result event", err, + "walletID", keygenResult.WalletID, + ) + return + } + + topic := fmt.Sprintf(mpc.TypeGenerateWalletResultFmt, keygenResult.WalletID) + err = sc.keygenResultQueue.Enqueue(topic, keygenResultBytes, &messaging.EnqueueOptions{ + IdempotententKey: buildIdempotentKey(keygenMsg.WalletID, sessionID, mpc.TypeGenerateWalletResultFmt), + }) + if err != nil { + logger.Error("Failed to enqueue keygen result event", err, + "walletID", keygenMsg.WalletID, + ) + } +} + // Close unsubscribes from the JetStream subject and cleans up resources. func (sc *keygenConsumer) Close() error { if sc.jsSub != nil { diff --git a/pkg/eventconsumer/sign_consumer.go b/pkg/eventconsumer/sign_consumer.go index f38e779..24f21fe 100644 --- a/pkg/eventconsumer/sign_consumer.go +++ b/pkg/eventconsumer/sign_consumer.go @@ -2,6 +2,7 @@ package eventconsumer import ( "context" + "encoding/json" "fmt" "time" @@ -9,6 +10,7 @@ import ( "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" "github.com/fystack/mpcium/pkg/mpc" + "github.com/fystack/mpcium/pkg/types" "github.com/google/uuid" "github.com/nats-io/nats.go" "github.com/nats-io/nats.go/jetstream" @@ -34,25 +36,27 @@ type SigningConsumer interface { // signingConsumer implements SigningConsumer. type signingConsumer struct { - natsConn *nats.Conn - pubsub messaging.PubSub - jsBroker messaging.MessageBroker - peerRegistry mpc.PeerRegistry - mpcThreshold int + natsConn *nats.Conn + pubsub messaging.PubSub + jsBroker messaging.MessageBroker + peerRegistry mpc.PeerRegistry + mpcThreshold int + signingResultQueue messaging.MessageQueue // jsSub holds the JetStream subscription, so it can be cleaned up during Close(). jsSub messaging.Subscription } // NewSigningConsumer returns a new instance of SigningConsumer. -func NewSigningConsumer(natsConn *nats.Conn, jsBroker messaging.MessageBroker, pubsub messaging.PubSub, peerRegistry mpc.PeerRegistry) SigningConsumer { +func NewSigningConsumer(natsConn *nats.Conn, jsBroker messaging.MessageBroker, pubsub messaging.PubSub, peerRegistry mpc.PeerRegistry, signingResultQueue messaging.MessageQueue) SigningConsumer { mpcThreshold := viper.GetInt("mpc_threshold") return &signingConsumer{ - natsConn: natsConn, - pubsub: pubsub, - jsBroker: jsBroker, - peerRegistry: peerRegistry, - mpcThreshold: mpcThreshold, + natsConn: natsConn, + pubsub: pubsub, + jsBroker: jsBroker, + peerRegistry: peerRegistry, + mpcThreshold: mpcThreshold, + signingResultQueue: signingResultQueue, } } @@ -70,6 +74,9 @@ func (sc *signingConsumer) waitForSufficientPeers(ctx context.Context) error { for { select { case <-ctx.Done(): + if ctx.Err() == context.Canceled { + return nil + } return ctx.Err() case <-ticker.C: readyPeers := sc.peerRegistry.GetReadyPeersCount() @@ -90,6 +97,9 @@ func (sc *signingConsumer) waitForSufficientPeers(ctx context.Context) error { func (sc *signingConsumer) Run(ctx context.Context) error { // Wait for sufficient peers before starting to consume messages if err := sc.waitForSufficientPeers(ctx); err != nil { + if err == context.Canceled { + return nil + } return fmt.Errorf("failed to wait for sufficient peers: %w", err) } @@ -130,18 +140,26 @@ func (sc *signingConsumer) Run(ctx context.Context) error { // When signing completes, the session publishes the result to a queue and calls the onSuccess callback, which sends a reply to the inbox that the SigningConsumer is monitoring. // The reply signals completion, allowing the SigningConsumer to acknowledge the original message. func (sc *signingConsumer) handleSigningEvent(msg jetstream.Msg) { - // Check if we still have enough peers before processing the message - requiredPeers := int64(sc.mpcThreshold + 1) - readyPeers := sc.peerRegistry.GetReadyPeersCount() - - if readyPeers < requiredPeers { - logger.Warn("SigningConsumer: Not enough peers to process signing request, rejecting message", - "ready", readyPeers, - "required", requiredPeers) - // Immediately return and let nats redeliver the message with backoff + // Parse the signing request message to extract transaction details + raw := msg.Data() + var signingMsg types.SignTxMessage + sessionID := msg.Headers().Get("SessionID") + + err := json.Unmarshal(raw, &signingMsg) + if err != nil { + logger.Error("SigningConsumer: Failed to unmarshal signing message", err) + sc.handleSigningError(signingMsg, event.ErrorCodeUnmarshalFailure, err, sessionID) + _ = msg.Ack() return } + if !sc.peerRegistry.AreMajorityReady() { + requiredPeers := int64(sc.mpcThreshold + 1) + err := fmt.Errorf("not enough peers to process signing request: ready=%d, required=%d", sc.peerRegistry.GetReadyPeersCount(), requiredPeers) + sc.handleSigningError(signingMsg, event.ErrorCodeNotMajority, err, sessionID) + _ = msg.Ack() + return + } // Create a reply inbox to receive the signing event response. replyInbox := nats.NewInbox() @@ -193,6 +211,36 @@ func (sc *signingConsumer) handleSigningEvent(msg jetstream.Msg) { _ = msg.Nak() } +func (sc *signingConsumer) handleSigningError(signMsg types.SignTxMessage, errorCode event.ErrorCode, err error, sessionID string) { + signingResult := event.SigningResultEvent{ + ResultType: event.ResultTypeError, + ErrorCode: errorCode, + NetworkInternalCode: signMsg.NetworkInternalCode, + WalletID: signMsg.WalletID, + TxID: signMsg.TxID, + ErrorReason: err.Error(), + } + + signingResultBytes, err := json.Marshal(signingResult) + if err != nil { + logger.Error("Failed to marshal signing result event", err, + "walletID", signMsg.WalletID, + "txID", signMsg.TxID, + ) + return + } + + err = sc.signingResultQueue.Enqueue(event.SigningResultCompleteTopic, signingResultBytes, &messaging.EnqueueOptions{ + IdempotententKey: buildIdempotentKey(signMsg.TxID, sessionID, mpc.TypeSigningResultFmt), + }) + if err != nil { + logger.Error("Failed to enqueue signing result event", err, + "walletID", signMsg.WalletID, + "txID", signMsg.TxID, + ) + } +} + // Close unsubscribes from the JetStream subject and cleans up resources. func (sc *signingConsumer) Close() error { if sc.jsSub != nil { @@ -204,3 +252,13 @@ func (sc *signingConsumer) Close() error { } return nil } + +func buildIdempotentKey(baseID string, sessionID string, formatTemplate string) string { + var uniqueKey string + if sessionID != "" { + uniqueKey = fmt.Sprintf("%s:%s", baseID, sessionID) + } else { + uniqueKey = baseID + } + return fmt.Sprintf(formatTemplate, uniqueKey) +} diff --git a/pkg/identity/identity.go b/pkg/identity/identity.go index bb83600..0d2329a 100644 --- a/pkg/identity/identity.go +++ b/pkg/identity/identity.go @@ -44,6 +44,8 @@ type Store interface { SetSymmetricKey(peerID string, key []byte) GetSymmetricKey(peerID string) ([]byte, error) + RemoveSymmetricKey(peerID string) + GetSymetricKeyCount() int CheckSymmetricKeyComplete(desired int) bool EncryptMessage(plaintext []byte, peerID string) ([]byte, error) @@ -238,7 +240,21 @@ func (s *fileStore) GetSymmetricKey(peerID string) ([]byte, error) { return nil, fmt.Errorf("SymmetricKey key not found for node ID: %s", peerID) } +func (s *fileStore) RemoveSymmetricKey(peerID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.symmetricKeys, peerID) +} + +func (s *fileStore) GetSymetricKeyCount() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.symmetricKeys) +} + func (s *fileStore) CheckSymmetricKeyComplete(desired int) bool { + s.mu.RLock() + defer s.mu.RUnlock() return len(s.symmetricKeys) == desired } diff --git a/pkg/mpc/key_exchange_session.go b/pkg/mpc/key_exchange_session.go index 496dc6c..2065f03 100644 --- a/pkg/mpc/key_exchange_session.go +++ b/pkg/mpc/key_exchange_session.go @@ -17,8 +17,6 @@ import ( "encoding/json" - "sync" - "github.com/nats-io/nats.go" ) @@ -28,25 +26,23 @@ const ( ) type ECDHSession interface { - StartKeyExchange() error + ListenKeyExchange() error BroadcastPublicKey() error - WaitForExchangeComplete() error + RemovePeer(peerID string) + GetReadyPeersCount() int ErrChan() <-chan error Close() error } type ecdhSession struct { - nodeID string - peerIDs []string - pubSub messaging.PubSub - ecdhSub messaging.Subscription - identityStore identity.Store - privateKey *ecdh.PrivateKey - publicKey *ecdh.PublicKey - exchangeComplete chan struct{} - errCh chan error - exchangeDone bool - mu sync.RWMutex + nodeID string + peerIDs []string + pubSub messaging.PubSub + ecdhSub messaging.Subscription + identityStore identity.Store + privateKey *ecdh.PrivateKey + publicKey *ecdh.PublicKey + errCh chan error } func NewECDHSession( @@ -56,16 +52,27 @@ func NewECDHSession( identityStore identity.Store, ) *ecdhSession { return &ecdhSession{ - nodeID: nodeID, - peerIDs: peerIDs, - pubSub: pubSub, - identityStore: identityStore, - exchangeComplete: make(chan struct{}, 1), - errCh: make(chan error, 1), + nodeID: nodeID, + peerIDs: peerIDs, + pubSub: pubSub, + identityStore: identityStore, + errCh: make(chan error, 1), } } -func (e *ecdhSession) StartKeyExchange() error { +func (e *ecdhSession) RemovePeer(peerID string) { + e.identityStore.RemoveSymmetricKey(peerID) +} + +func (e *ecdhSession) GetReadyPeersCount() int { + return e.identityStore.GetSymetricKeyCount() +} + +func (e *ecdhSession) ErrChan() <-chan error { + return e.errCh +} + +func (e *ecdhSession) ListenKeyExchange() error { // Generate an ephemeral ECDH key pair privateKey, err := ecdh.X25519().GenerateKey(rand.Reader) if err != nil { @@ -85,7 +92,7 @@ func (e *ecdhSession) StartKeyExchange() error { if ecdhMsg.From == e.nodeID { return } - logger.Info("Received ECDH message from", "node", ecdhMsg.From) + //TODO: consider how to avoid replay attack if err := e.identityStore.VerifySignature(&ecdhMsg); err != nil { e.errCh <- err @@ -106,20 +113,7 @@ func (e *ecdhSession) StartKeyExchange() error { // Derive symmetric key using HKDF symmetricKey := e.deriveSymmetricKey(sharedSecret, ecdhMsg.From) e.identityStore.SetSymmetricKey(ecdhMsg.From, symmetricKey) - - requiredKeyCount := len(e.peerIDs) - 1 - logger.Info("ECDH progress", "peer", ecdhMsg.From, "required", requiredKeyCount) - - if e.identityStore.CheckSymmetricKeyComplete(requiredKeyCount) { - logger.Info("Completed ECDH!", "symmetric key counts of peers", requiredKeyCount) - logger.Info("ALL PEERS ARE READY! Starting to accept MPC requests") - - e.mu.Lock() - e.exchangeDone = true - e.mu.Unlock() - - e.exchangeComplete <- struct{}{} - } + logger.Debug("ECDH progress", "peer", ecdhMsg.From, "current", e.identityStore.GetSymetricKeyCount()) }) e.ecdhSub = sub @@ -129,10 +123,6 @@ func (e *ecdhSession) StartKeyExchange() error { return nil } -func (s *ecdhSession) ErrChan() <-chan error { - return s.errCh -} - func (s *ecdhSession) Close() error { err := s.ecdhSub.Unsubscribe() if err != nil { @@ -164,25 +154,6 @@ func (e *ecdhSession) BroadcastPublicKey() error { return nil } -func (e *ecdhSession) WaitForExchangeComplete() error { - e.mu.RLock() - if e.exchangeDone { - e.mu.RUnlock() - return nil - } - e.mu.RUnlock() - timeout := time.After(ECDHExchangeTimeout) // 2 minutes timeout - - select { - case <-e.exchangeComplete: - return nil - case err := <-e.errCh: - return err - case <-timeout: - return fmt.Errorf("ECDH exchange timeout!") - } -} - func deriveConsistentInfo(a, b string) []byte { if a < b { return []byte(a + b) diff --git a/pkg/mpc/node.go b/pkg/mpc/node.go index a881554..d615444 100644 --- a/pkg/mpc/node.go +++ b/pkg/mpc/node.go @@ -39,7 +39,6 @@ type Node struct { identityStore identity.Store peerRegistry PeerRegistry - ecdhSession ECDHSession } func NewNode( @@ -55,11 +54,6 @@ func NewNode( start := time.Now() elapsed := time.Since(start) logger.Info("Starting new node, preparams is generated successfully!", "elapsed", elapsed.Milliseconds()) - // Each node initiates the DH key exchange listener at the beginning and invoke message sending when all peers are ready - dhSession := NewECDHSession(nodeID, peerIDs, pubSub, identityStore) - if err := dhSession.StartKeyExchange(); err != nil { - logger.Fatal("Failed to start DH key exchange", err) - } node := &Node{ nodeID: nodeID, @@ -70,17 +64,11 @@ func NewNode( keyinfoStore: keyinfoStore, peerRegistry: peerRegistry, identityStore: identityStore, - ecdhSession: dhSession, } node.ecdsaPreParams = node.generatePreParams() - ecdhTask := func() { - if err := dhSession.BroadcastPublicKey(); err != nil { - logger.Fatal("DH key broadcast failed", err) - } - } - - go peerRegistry.WatchPeersReady(ecdhTask) + // Start watching peers - ECDH is now handled by the registry + go peerRegistry.WatchPeersReady() return node } @@ -95,11 +83,7 @@ func (p *Node) CreateKeyGenSession( resultQueue messaging.MessageQueue, ) (KeyGenSession, error) { if !p.peerRegistry.ArePeersReady() { - return nil, fmt.Errorf( - "Not enough peers to create gen session! Expected %d, got %d", - p.peerRegistry.GetTotalPeersCount(), - p.peerRegistry.GetReadyPeersCount(), - ) + return nil, errors.New("All nodes are not ready!") } keyInfo, _ := p.getKeyInfo(sessionType, walletID) @@ -425,14 +409,6 @@ func (p *Node) Close() { } } -func (p *Node) GetECDHSession() ECDHSession { - return p.ecdhSession -} - -func (p *Node) GetDHSession() ECDHSession { - return p.ecdhSession -} - func (p *Node) generatePreParams() []*keygen.LocalPreParams { start := time.Now() // Try to load from kvstore diff --git a/pkg/mpc/registry.go b/pkg/mpc/registry.go index c73f447..92c2cf6 100644 --- a/pkg/mpc/registry.go +++ b/pkg/mpc/registry.go @@ -2,16 +2,19 @@ package mpc import ( "fmt" + "strconv" "strings" "sync" "sync/atomic" "time" + "github.com/fystack/mpcium/pkg/identity" "github.com/fystack/mpcium/pkg/infra" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" "github.com/hashicorp/consul/api" "github.com/samber/lo" + "github.com/spf13/viper" ) const ( @@ -21,12 +24,18 @@ const ( type PeerRegistry interface { Ready() error ArePeersReady() bool - WatchPeersReady(callback func()) + AreMajorityReady() bool + WatchPeersReady() // Resign is called by the node when it is going to shutdown Resign() error GetReadyPeersCount() int64 + GetReadyPeersCountExcludeSelf() int64 GetReadyPeersIncludeSelf() []string // get ready peers include self GetTotalPeersCount() int64 + + OnPeerConnected(callback func(peerID string)) + OnPeerDisconnected(callback func(peerID string)) + OnPeerReConnected(callback func(peerID string)) } type registry struct { @@ -37,8 +46,16 @@ type registry struct { mu sync.RWMutex ready bool // ready is true when all peers are ready - consulKV infra.ConsulKV - healthCheck messaging.DirectMessaging + consulKV infra.ConsulKV + healthCheck messaging.DirectMessaging + pubSub messaging.PubSub + identityStore identity.Store + ecdhSession ECDHSession + mpcThreshold int + + onPeerConnected func(peerID string) + onPeerDisconnected func(peerID string) + onPeerReConnected func(peerID string) } func NewRegistry( @@ -46,14 +63,26 @@ func NewRegistry( peerNodeIDs []string, consulKV infra.ConsulKV, directMessaging messaging.DirectMessaging, + pubSub messaging.PubSub, + identityStore identity.Store, ) *registry { + ecdhSession := NewECDHSession(nodeID, peerNodeIDs, pubSub, identityStore) + mpcThreshold := viper.GetInt("mpc_threshold") + if mpcThreshold < 1 { + logger.Fatal("mpc_threshold must be greater than 0", nil) + } + return ®istry{ - consulKV: consulKV, - nodeID: nodeID, - peerNodeIDs: getPeerIDsExceptSelf(nodeID, peerNodeIDs), - readyMap: make(map[string]bool), - readyCount: 1, // self - healthCheck: directMessaging, + consulKV: consulKV, + nodeID: nodeID, + peerNodeIDs: getPeerIDsExceptSelf(nodeID, peerNodeIDs), + readyMap: make(map[string]bool), + readyCount: 1, // self + healthCheck: directMessaging, + pubSub: pubSub, + identityStore: identityStore, + ecdhSession: ecdhSession, + mpcThreshold: mpcThreshold, } } @@ -71,15 +100,23 @@ func (r *registry) readyKey(nodeID string) string { return fmt.Sprintf("ready/%s", nodeID) } -func (r *registry) registerReadyPairs(peerIDs []string, callback func()) { +func (r *registry) registerReadyPairs(peerIDs []string) { for _, peerID := range peerIDs { ready, exist := r.readyMap[peerID] if !exist { atomic.AddInt64(&r.readyCount, 1) logger.Info("Register", "peerID", peerID) + if r.onPeerConnected != nil { + r.onPeerConnected(peerID) + } + go r.triggerECDHExchange() } else if !ready { atomic.AddInt64(&r.readyCount, 1) logger.Info("Reconnecting...", "peerID", peerID) + if r.onPeerReConnected != nil { + r.onPeerReConnected(peerID) + } + go r.triggerECDHExchange() } r.readyMap[peerID] = true @@ -89,14 +126,26 @@ func (r *registry) registerReadyPairs(peerIDs []string, callback func()) { r.mu.Lock() r.ready = true r.mu.Unlock() - time.AfterFunc(5*time.Second, callback) + logger.Info("All peers are ready including ECDH exchange completion") } +} +// triggerECDHExchange safely triggers ECDH key exchange +func (r *registry) triggerECDHExchange() { + logger.Info("Triggering ECDH key exchange") + if err := r.ecdhSession.BroadcastPublicKey(); err != nil { + logger.Error("Failed to trigger ECDH exchange", err) + } } // Ready is called by the node when it complete generate preparams and starting to accept // incoming requests func (r *registry) Ready() error { + // Start ECDH exchange first + if err := r.startECDHExchange(); err != nil { + return fmt.Errorf("failed to start ECDH exchange: %w", err) + } + k := r.readyKey(r.nodeID) kv := &api.KVPair{ @@ -110,7 +159,13 @@ func (r *registry) Ready() error { } _, err = r.healthCheck.Listen(r.composeHealthCheckTopic(r.nodeID), func(data []byte) { - logger.Debug("Health check", "peerID", string(data)) + peerID, ecdhReadyPeersCount, _ := parseHealthDataSplit(string(data)) + logger.Debug("Health check ok", "peerID", peerID) + if ecdhReadyPeersCount < int(r.GetReadyPeersCountExcludeSelf()) { + logger.Info("[ECDH exchange retriggerd] not all peers are ready", "peerID", peerID) + go r.triggerECDHExchange() + + } }) if err != nil { return fmt.Errorf("Listen health check failed: %w", err) @@ -118,13 +173,10 @@ func (r *registry) Ready() error { return nil } -func (r *registry) composeHealthCheckTopic(nodeID string) string { - return fmt.Sprintf("healthcheck:%s", nodeID) -} +func (r *registry) WatchPeersReady() { + go r.checkPeersHealth() -func (r *registry) WatchPeersReady(callback func()) { ticker := time.NewTicker(ReadinessCheckPeriod) - go r.checkPeersHeath() // first tick is executed immediately for ; true; <-ticker.C { pairs, _, err := r.consulKV.List("ready/", nil) @@ -151,17 +203,24 @@ func (r *registry) WatchPeersReady(callback func()) { logger.Warn("Peer disconnected!", "peerID", peerID) r.readyMap[peerID] = false atomic.AddInt64(&r.readyCount, -1) + + // Remove ECDH key for disconnected peer + r.ecdhSession.RemovePeer(peerID) + + if r.onPeerDisconnected != nil { + r.onPeerDisconnected(peerID) + } } } } - r.registerReadyPairs(newReadyPeerIDs, callback) + r.registerReadyPairs(newReadyPeerIDs) } } -func (r *registry) checkPeersHeath() { +func (r *registry) checkPeersHealth() { for { time.Sleep(5 * time.Second) if !r.ArePeersReady() { @@ -175,7 +234,7 @@ func (r *registry) checkPeersHeath() { } readyPeerIDs := r.getReadyPeersFromKVStore(pairs) for _, peerID := range readyPeerIDs { - err := r.healthCheck.SendToOtherWithRetry(r.composeHealthCheckTopic(peerID), []byte(peerID), messaging.RetryConfig{ + err := r.healthCheck.SendToOtherWithRetry(r.composeHealthCheckTopic(peerID), []byte(r.composeHealthData()), messaging.RetryConfig{ RetryAttempt: 2, }) if err != nil && strings.Contains(err.Error(), "no responders") { @@ -189,6 +248,8 @@ func (r *registry) checkPeersHeath() { } } +// GetReadyPeersCount returns the number of ready peers including self +// should -1 if want to exclude self func (r *registry) GetReadyPeersCount() int64 { return atomic.LoadInt64(&r.readyCount) } @@ -227,7 +288,17 @@ func (r *registry) ArePeersReady() bool { r.mu.RLock() defer r.mu.RUnlock() - return r.ready + // Check both peer connectivity and ECDH completion + return r.ready && r.isECDHReady() +} + +// AreMajorityReady checks if a majority of peers are ready. +// Returns true only if: +// 1. The number of ready peers (including self) is greater than mpcThreshold+1 +// 2. Symmetric keys are fully established among all ready peers (excluding self). +func (r *registry) AreMajorityReady() bool { + readyCount := r.GetReadyPeersCount() + return int(readyCount) >= r.mpcThreshold+1 && r.isECDHReady() } func (r *registry) GetTotalPeersCount() int64 { @@ -245,3 +316,60 @@ func (r *registry) Resign() error { return nil } + +func (r *registry) OnPeerConnected(callback func(peerID string)) { + r.onPeerConnected = callback +} + +func (r *registry) OnPeerDisconnected(callback func(peerID string)) { + r.onPeerDisconnected = callback +} + +func (r *registry) OnPeerReConnected(callback func(peerID string)) { + r.onPeerReConnected = callback +} + +// StartECDHExchange starts the ECDH key exchange process +func (r *registry) startECDHExchange() error { + if err := r.ecdhSession.ListenKeyExchange(); err != nil { + return fmt.Errorf("failed to start ECDH listener: %w", err) + } + + if err := r.ecdhSession.BroadcastPublicKey(); err != nil { + return fmt.Errorf("failed to broadcast ECDH public key: %w", err) + } + + return nil +} + +func (r *registry) GetReadyPeersCountExcludeSelf() int64 { + return r.GetReadyPeersCount() - 1 +} + +func (r *registry) isECDHReady() bool { + requiredKeyCount := r.GetReadyPeersCountExcludeSelf() + return r.identityStore.CheckSymmetricKeyComplete(int(requiredKeyCount)) +} + +func (r *registry) composeHealthCheckTopic(nodeID string) string { + return fmt.Sprintf("healthcheck:%s", nodeID) +} + +func (r *registry) composeHealthData() string { + return fmt.Sprintf("%s,%d", r.nodeID, r.ecdhSession.GetReadyPeersCount()) +} + +func parseHealthDataSplit(s string) (peerID string, readyCount int, err error) { + parts := strings.SplitN(s, ",", 2) + if len(parts) != 2 { + return "", 0, fmt.Errorf("invalid format: %q", s) + } + + peerID = parts[0] + readyCount, err = strconv.Atoi(parts[1]) + if err != nil { + return "", 0, err + } + return peerID, readyCount, nil + +}