diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 48a8329..d090edb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,6 +6,9 @@ on: pull_request: branches: ["*"] +env: + GO_VERSION: "1.24" + jobs: test: runs-on: ubuntu-latest @@ -15,19 +18,10 @@ jobs: uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: - go-version: "1.23" - - - name: Cache Go modules - uses: actions/cache@v3 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go- + go-version: ${{ env.GO_VERSION }} + cache: true - name: Install dependencies run: go mod download @@ -50,9 +44,13 @@ jobs: uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: - go-version: "1.23" + go-version: ${{ env.GO_VERSION }} + cache: true + + - name: Clean Go build cache + run: go clean -cache -modcache - name: Run golangci-lint uses: golangci/golangci-lint-action@v3 @@ -74,19 +72,10 @@ jobs: uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: - go-version: "1.23" - - - name: Cache Go modules - uses: actions/cache@v3 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go- + go-version: ${{ env.GO_VERSION }} + cache: true - name: Install dependencies run: go mod download @@ -153,19 +142,10 @@ jobs: uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: - go-version: "1.23" - - - name: Cache Go modules - uses: actions/cache@v3 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go- + go-version: ${{ env.GO_VERSION }} + cache: true - name: Install dependencies run: go mod download @@ -200,19 +180,10 @@ jobs: uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v4 - with: - go-version: "1.23" - - - name: Cache Go modules - uses: actions/cache@v3 + uses: actions/setup-go@v5 with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go- + go-version: ${{ env.GO_VERSION }} + cache: true - name: Install dependencies run: go mod download @@ -293,19 +264,10 @@ jobs: uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v4 - with: - go-version: "1.23" - - - name: Cache Go modules - uses: actions/cache@v3 + uses: actions/setup-go@v5 with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go- + go-version: ${{ env.GO_VERSION }} + cache: true - name: Build mpcium run: go build -v ./cmd/mpcium diff --git a/.github/workflows/e2e-tests.yml b/.github/workflows/e2e-tests.yml index 8714318..e3443f3 100644 --- a/.github/workflows/e2e-tests.yml +++ b/.github/workflows/e2e-tests.yml @@ -7,7 +7,7 @@ on: branches: [master] env: - GO_VERSION: "1.23" + GO_VERSION: "1.24" CGO_ENABLED: 0 DOCKER_BUILDKIT: 1 GO_BUILD_FLAGS: -trimpath -ldflags="-s -w" diff --git a/.gitignore b/.gitignore index b09913b..5832d41 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,7 @@ e2e/coverage.html e2e/logs/ # Generated config file (template is tracked) e2e/config.test.yaml +node0 +node1 +node2 +config.yaml diff --git a/cmd/mpcium/main.go b/cmd/mpcium/main.go index e47edfc..6d12286 100644 --- a/cmd/mpcium/main.go +++ b/cmd/mpcium/main.go @@ -129,7 +129,6 @@ func runNode(ctx context.Context, c *cli.Command) error { if err != nil { logger.Fatal("Failed to connect to NATS", err) } - defer natsConn.Close() pubsub := messaging.NewNATSPubSub(natsConn) keygenBroker, err := messaging.NewJetStreamBroker(ctx, natsConn, event.KeygenBrokerStream, []string{ @@ -159,10 +158,10 @@ func runNode(ctx context.Context, c *cli.Command) error { reshareResultQueue := mqManager.NewMessageQueue("mpc_reshare_result") defer reshareResultQueue.Close() - logger.Info("Node is running", "peerID", nodeID, "name", nodeName) + logger.Info("Node is running", "ID", nodeID, "name", nodeName) peerNodeIDs := GetPeerIDs(peers) - peerRegistry := mpc.NewRegistry(nodeID, peerNodeIDs, consulClient.KV()) + peerRegistry := mpc.NewRegistry(nodeID, peerNodeIDs, consulClient.KV(), directMessaging, pubsub, identityStore) mpcNode := mpc.NewNode( nodeID, @@ -194,16 +193,18 @@ func runNode(ctx context.Context, c *cli.Command) error { timeoutConsumer.Run() defer timeoutConsumer.Close() - keygenConsumer := eventconsumer.NewKeygenConsumer(natsConn, keygenBroker, pubsub, peerRegistry) - signingConsumer := eventconsumer.NewSigningConsumer(natsConn, signingBroker, pubsub, peerRegistry) + keygenConsumer := eventconsumer.NewKeygenConsumer(natsConn, keygenBroker, pubsub, peerRegistry, genKeyResultQueue) + signingConsumer := eventconsumer.NewSigningConsumer(natsConn, signingBroker, pubsub, peerRegistry, singingResultQueue) // Make the node ready before starting the signing consumer if err := peerRegistry.Ready(); err != nil { logger.Error("Failed to mark peer registry as ready", err) } logger.Info("[READY] Node is ready", "nodeID", nodeID) + + logger.Info("Starting consumers", "nodeID", nodeID) appContext, cancel := context.WithCancel(context.Background()) - // Setup signal handling to cancel context on termination signals. + //Setup signal handling to cancel context on termination signals. go func() { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) @@ -211,6 +212,11 @@ func runNode(ctx context.Context, c *cli.Command) error { logger.Warn("Shutdown signal received, canceling context...") cancel() + // Resign from peer registry first (before closing NATS) + if err := peerRegistry.Resign(); err != nil { + logger.Error("Failed to resign from peer registry", err) + } + // Gracefully close consumers if err := keygenConsumer.Close(); err != nil { logger.Error("Failed to close keygen consumer", err) @@ -218,10 +224,15 @@ func runNode(ctx context.Context, c *cli.Command) error { if err := signingConsumer.Close(); err != nil { logger.Error("Failed to close signing consumer", err) } + + err := natsConn.Drain() + if err != nil { + logger.Error("Failed to drain NATS connection", err) + } }() var wg sync.WaitGroup - errChan := make(chan error, 2) + errChan := make(chan error, 3) wg.Add(1) go func() { @@ -250,7 +261,6 @@ func runNode(ctx context.Context, c *cli.Command) error { logger.Info("All consumers have finished") close(errChan) }() - for err := range errChan { if err != nil { logger.Error("Consumer error received", err) @@ -258,6 +268,7 @@ func runNode(ctx context.Context, c *cli.Command) error { return err } } + return nil } diff --git a/config.yaml.template b/config.yaml.template index 2423261..694c90c 100644 --- a/config.yaml.template +++ b/config.yaml.template @@ -5,7 +5,7 @@ consul: mpc_threshold: 2 environment: development -badger_password: "your_badger_password" +badger_password: "F))ysJp?E]ol&I;^" event_initiator_pubkey: "event_initiator_pubkey" db_path: "." backup_enabled: true diff --git a/e2e/reshare_test.go b/e2e/reshare_test.go index 86cf670..f865e5d 100644 --- a/e2e/reshare_test.go +++ b/e2e/reshare_test.go @@ -92,7 +92,7 @@ func testKeyGenerationForResharing(t *testing.T, suite *E2ETestSuite) { require.NoError(t, err, "Failed to setup keygen result listener") // Add a small delay to ensure the result listener is fully set up - time.Sleep(2 * time.Second) + time.Sleep(10 * time.Second) // Trigger key generation for all wallets for _, walletID := range walletIDs { @@ -179,7 +179,7 @@ func testResharingAllNodes(t *testing.T, suite *E2ETestSuite) { require.NoError(t, err, "Failed to setup resharing result listener") // Wait for listener setup - time.Sleep(2 * time.Second) + time.Sleep(10 * time.Second) // Test resharing for both key types for i, walletID := range suite.walletIDs { @@ -360,7 +360,7 @@ func testSigningAfterResharing(t *testing.T, suite *E2ETestSuite) { require.NoError(t, err, "Failed to setup signing result listener") // Wait for listener setup - time.Sleep(2 * time.Second) + time.Sleep(10 * time.Second) // Test messages to sign testMessages := []string{ diff --git a/e2e/setup_test_identities.sh b/e2e/setup_test_identities.sh index c2a29c6..46b62f6 100755 --- a/e2e/setup_test_identities.sh +++ b/e2e/setup_test_identities.sh @@ -2,7 +2,6 @@ # E2E Test Identity Setup Script # This script sets up identities for testing with separate test database paths - set -e # Number of test nodes diff --git a/e2e/sign_test.go b/e2e/sign_test.go index c5d454f..fe197c9 100644 --- a/e2e/sign_test.go +++ b/e2e/sign_test.go @@ -150,7 +150,7 @@ func testKeyGenerationForSigning(t *testing.T, suite *E2ETestSuite) { require.NoError(t, err, "Failed to setup keygen result listener") // Add a small delay to ensure the result listener is fully set up - time.Sleep(2 * time.Second) + time.Sleep(10 * time.Second) // Trigger key generation for all wallets for _, walletID := range walletIDs { diff --git a/examples/reshare/main.go b/examples/reshare/main.go index 880868e..3de3174 100644 --- a/examples/reshare/main.go +++ b/examples/reshare/main.go @@ -49,10 +49,10 @@ func main() { resharingMsg := &types.ResharingMessage{ SessionID: uuid.NewString(), - WalletID: "bf2cc849-8e55-47e4-ab73-e17fb1eb690c", - NodeIDs: []string{"d926fa75-72c7-4538-9052-4a064a84981d", "7b1090cd-ffe3-46ff-8375-594dd3204169"}, // new peer IDs + WalletID: "506d2d40-483a-49f1-93c8-27dd4fe9740c", + NodeIDs: []string{"c95c340e-5a18-472d-b9b0-5ac68218213a", "ac37e85f-caca-4bee-8a3a-49a0fe35abff"}, // new peer IDs - NewThreshold: 2, // t+1 <= len(NodeIDs) + NewThreshold: 1, // t+1 <= len(NodeIDs) KeyType: types.KeyTypeEd25519, } err = mpcClient.Resharing(resharingMsg) diff --git a/pkg/encryption/aes.go b/pkg/encryption/aes.go index bc9f386..ba31b45 100644 --- a/pkg/encryption/aes.go +++ b/pkg/encryption/aes.go @@ -4,6 +4,8 @@ import ( "crypto/aes" "crypto/cipher" "crypto/rand" + "errors" + "fmt" ) func EncryptAESGCM(plain, key []byte) (ciphertext, nonce []byte, err error) { @@ -34,3 +36,51 @@ func DecryptAESGCM(ciphertext, key, nonce []byte) ([]byte, error) { } return aead.Open(nil, nonce, ciphertext, nil) } + +// EncryptAESGCMWithNonceEmbed encrypts plaintext and embeds the nonce at the start of the returned slice. +func EncryptAESGCMWithNonceEmbed(plaintext, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("failed to create AES cipher: %w", err) + } + + aead, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create GCM: %w", err) + } + + nonce := make([]byte, aead.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return nil, fmt.Errorf("failed to generate nonce: %w", err) + } + + ciphertext := aead.Seal(nil, nonce, plaintext, nil) + return append(nonce, ciphertext...), nil +} + +// DecryptAESGCMWithNonceEmbed decrypts ciphertext where the nonce is embedded at the start of the slice. +func DecryptAESGCMWithNonceEmbed(data, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("failed to create AES cipher: %w", err) + } + aead, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create GCM: %w", err) + } + + nonceSize := aead.NonceSize() + if len(data) < nonceSize { + return nil, errors.New("ciphertext too short") + } + + nonce := data[:nonceSize] + ciphertext := data[nonceSize:] + + plaintext, err := aead.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, fmt.Errorf("decryption failed: %w", err) + } + + return plaintext, nil +} diff --git a/pkg/event/types.go b/pkg/event/types.go index 5f9b41b..ac26ca9 100644 --- a/pkg/event/types.go +++ b/pkg/event/types.go @@ -91,6 +91,8 @@ const ( // Context and cancellation errors ErrorCodeContextCancelled ErrorCode = "ERROR_CONTEXT_CANCELLED" ErrorCodeOperationAborted ErrorCode = "ERROR_OPERATION_ABORTED" + ErrorCodeNotMajority ErrorCode = "ERROR_NOT_MAJORITY" + ErrorCodeClusterNotReady ErrorCode = "ERROR_CLUSTER_NOT_READY" ) // GetErrorCodeFromError attempts to categorize a generic error into a specific error code diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 6fb7b5b..c84e684 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -628,14 +628,16 @@ func (ec *eventConsumer) consumeReshareEvent() error { return } + ctx := context.Background() + var wg sync.WaitGroup + successEvent := &event.ResharingResultEvent{ WalletID: walletID, NewThreshold: msg.NewThreshold, KeyType: msg.KeyType, ResultType: event.ResultTypeSuccess, } - ctx := context.Background() - var wg sync.WaitGroup + if oldSession != nil { err := oldSession.Init() if err != nil { @@ -644,6 +646,7 @@ func (ec *eventConsumer) consumeReshareEvent() error { } oldSession.ListenToIncomingMessageAsync() } + if newSession != nil { err := newSession.Init() if err != nil { @@ -651,12 +654,18 @@ func (ec *eventConsumer) consumeReshareEvent() error { return } newSession.ListenToIncomingMessageAsync() + // In resharing process, we need to ensure that the new session is aware of the old committee peers. + // Then new committee peers can start listening to the old committee peers + // and thus enable receiving direct messages from them. + extraOldCommiteePeers := newSession.GetLegacyCommitteePeers() + newSession.ListenToPeersAsync(extraOldCommiteePeers) } ec.warmUpSession() if oldSession != nil { ctxOld, doneOld := context.WithCancel(ctx) go oldSession.Reshare(doneOld) + wg.Add(1) go func() { defer wg.Done() diff --git a/pkg/eventconsumer/keygen_consumer.go b/pkg/eventconsumer/keygen_consumer.go index 1d85402..98d9f07 100644 --- a/pkg/eventconsumer/keygen_consumer.go +++ b/pkg/eventconsumer/keygen_consumer.go @@ -2,6 +2,8 @@ package eventconsumer import ( "context" + "encoding/json" + "errors" "fmt" "time" @@ -9,6 +11,7 @@ import ( "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" "github.com/fystack/mpcium/pkg/mpc" + "github.com/fystack/mpcium/pkg/types" "github.com/google/uuid" "github.com/nats-io/nats.go" "github.com/nats-io/nats.go/jetstream" @@ -31,22 +34,30 @@ type KeygenConsumer interface { // keygenConsumer implements KeygenConsumer. type keygenConsumer struct { - natsConn *nats.Conn - pubsub messaging.PubSub - jsBroker messaging.MessageBroker - peerRegistry mpc.PeerRegistry + natsConn *nats.Conn + pubsub messaging.PubSub + jsBroker messaging.MessageBroker + peerRegistry mpc.PeerRegistry + keygenResultQueue messaging.MessageQueue // jsSub holds the JetStream subscription, so it can be cleaned up during Close(). jsSub messaging.MessageSubscription } // NewKeygenConsumer returns a new instance of KeygenConsumer. -func NewKeygenConsumer(natsConn *nats.Conn, jsBroker messaging.MessageBroker, pubsub messaging.PubSub, peerRegistry mpc.PeerRegistry) KeygenConsumer { +func NewKeygenConsumer( + natsConn *nats.Conn, + jsBroker messaging.MessageBroker, + pubsub messaging.PubSub, + peerRegistry mpc.PeerRegistry, + keygenResultQueue messaging.MessageQueue, +) KeygenConsumer { return &keygenConsumer{ - natsConn: natsConn, - pubsub: pubsub, - jsBroker: jsBroker, - peerRegistry: peerRegistry, + natsConn: natsConn, + pubsub: pubsub, + jsBroker: jsBroker, + peerRegistry: peerRegistry, + keygenResultQueue: keygenResultQueue, } } @@ -60,6 +71,9 @@ func (sc *keygenConsumer) waitForAllPeersReadyToGenKey(ctx context.Context) erro for { select { case <-ctx.Done(): + if ctx.Err() == context.Canceled { + return nil + } return ctx.Err() case <-ticker.C: allPeersReady := sc.peerRegistry.ArePeersReady() @@ -80,6 +94,9 @@ func (sc *keygenConsumer) waitForAllPeersReadyToGenKey(ctx context.Context) erro func (sc *keygenConsumer) Run(ctx context.Context) error { // Wait for sufficient peers before starting to consume messages if err := sc.waitForAllPeersReadyToGenKey(ctx); err != nil { + if err == context.Canceled { + return nil + } return fmt.Errorf("failed to wait for sufficient peers: %w", err) } @@ -90,10 +107,10 @@ func (sc *keygenConsumer) Run(ctx context.Context) error { sc.handleKeygenEvent, ) if err != nil { - return fmt.Errorf("failed to subscribe to signing events: %w", err) + return fmt.Errorf("failed to subscribe to keygen events: %w", err) } sc.jsSub = sub - logger.Info("SigningConsumer: Subscribed to signing events") + logger.Info("SigningConsumer: Subscribed to keygen events") // Block until context cancellation. <-ctx.Done() @@ -103,26 +120,23 @@ func (sc *keygenConsumer) Run(ctx context.Context) error { return sc.Close() } -// The handleSigningEvent function in sign_consumer.go acts as a bridge between the JetStream-based event queue and the MPC (Multi-Party Computation) signing system -// Creates a reply channel: It generates a unique inbox address using nats.NewInbox() to receive the signing response. -// Sets up response handling: It creates a synchronous subscription to listen for replies on this inbox. -// Forwards the signing request: It publishes the original signing event data to the MPCSigningEventTopic with the reply inbox attached, which triggers the MPC signing process. -// Polls for completion: It enters a polling loop that checks for a reply message, continuing until either: -// A reply is received (successful signing) -// An error occurs (failed signing) -// The timeout is reached (30 seconds) -// Completes the transaction: It either acknowledges (Ack) the message if signing was successful or negatively acknowledges (Nak) it if there was a timeout or error. -// MPC Session Interaction -// The signing consumer doesn't directly interact with MPC sessions. Instead: -// It publishes the signing request to the MPCSigningEventTopic, which is consumed by the eventconsumer.consumeTxSigningEvent handler. -// This handler creates the appropriate signing session (SigningSession for ECDSA or EDDSASigningSession for EdDSA) via the MPC node's creation methods. -// The MPC signing sessions manage the distributed cryptographic operations across multiple nodes, handling message routing, party updates, and signature verification. -// When signing completes, the session publishes the result to a queue and calls the onSuccess callback, which sends a reply to the inbox that the KeygenConsumer is monitoring. -// The reply signals completion, allowing the KeygenConsumer to acknowledge the original message. func (sc *keygenConsumer) handleKeygenEvent(msg jetstream.Msg) { + raw := msg.Data() + var keygenMsg types.GenerateKeyMessage + sessionID := msg.Headers().Get("SessionID") + + err := json.Unmarshal(raw, &keygenMsg) + if err != nil { + logger.Error("SigningConsumer: Failed to unmarshal keygen message", err) + sc.handleKeygenError(keygenMsg, event.ErrorCodeUnmarshalFailure, err, sessionID) + _ = msg.Ack() + return + } if !sc.peerRegistry.ArePeersReady() { - logger.Warn("KeygenConsumer: Not all peers are ready to sign, skipping message processing") + logger.Warn("KeygenConsumer: Not all peers are ready to gen key, skipping message processing") + sc.handleKeygenError(keygenMsg, event.ErrorCodeClusterNotReady, errors.New("not all peers are ready"), sessionID) + _ = msg.Ack() return } @@ -165,7 +179,7 @@ func (sc *keygenConsumer) handleKeygenEvent(msg jetstream.Msg) { break } if replyMsg != nil { - logger.Info("KeygenConsumer: Completed signing event; reply received") + logger.Info("KeygenConsumer: Completed keygen event; reply received") if ackErr := msg.Ack(); ackErr != nil { logger.Error("KeygenConsumer: ACK failed", ackErr) } @@ -173,10 +187,37 @@ func (sc *keygenConsumer) handleKeygenEvent(msg jetstream.Msg) { } } - logger.Warn("KeygenConsumer: Timeout waiting for signing event response") + logger.Warn("KeygenConsumer: Timeout waiting for keygen event response") _ = msg.Nak() } +func (sc *keygenConsumer) handleKeygenError(keygenMsg types.GenerateKeyMessage, errorCode event.ErrorCode, err error, sessionID string) { + keygenResult := event.KeygenResultEvent{ + ResultType: event.ResultTypeError, + ErrorCode: string(errorCode), + WalletID: keygenMsg.WalletID, + ErrorReason: err.Error(), + } + + keygenResultBytes, err := json.Marshal(keygenResult) + if err != nil { + logger.Error("Failed to marshal keygen result event", err, + "walletID", keygenResult.WalletID, + ) + return + } + + topic := fmt.Sprintf(mpc.TypeGenerateWalletResultFmt, keygenResult.WalletID) + err = sc.keygenResultQueue.Enqueue(topic, keygenResultBytes, &messaging.EnqueueOptions{ + IdempotententKey: buildIdempotentKey(keygenMsg.WalletID, sessionID, mpc.TypeGenerateWalletResultFmt), + }) + if err != nil { + logger.Error("Failed to enqueue keygen result event", err, + "walletID", keygenMsg.WalletID, + ) + } +} + // Close unsubscribes from the JetStream subject and cleans up resources. func (sc *keygenConsumer) Close() error { if sc.jsSub != nil { diff --git a/pkg/eventconsumer/sign_consumer.go b/pkg/eventconsumer/sign_consumer.go index f38e779..24f21fe 100644 --- a/pkg/eventconsumer/sign_consumer.go +++ b/pkg/eventconsumer/sign_consumer.go @@ -2,6 +2,7 @@ package eventconsumer import ( "context" + "encoding/json" "fmt" "time" @@ -9,6 +10,7 @@ import ( "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" "github.com/fystack/mpcium/pkg/mpc" + "github.com/fystack/mpcium/pkg/types" "github.com/google/uuid" "github.com/nats-io/nats.go" "github.com/nats-io/nats.go/jetstream" @@ -34,25 +36,27 @@ type SigningConsumer interface { // signingConsumer implements SigningConsumer. type signingConsumer struct { - natsConn *nats.Conn - pubsub messaging.PubSub - jsBroker messaging.MessageBroker - peerRegistry mpc.PeerRegistry - mpcThreshold int + natsConn *nats.Conn + pubsub messaging.PubSub + jsBroker messaging.MessageBroker + peerRegistry mpc.PeerRegistry + mpcThreshold int + signingResultQueue messaging.MessageQueue // jsSub holds the JetStream subscription, so it can be cleaned up during Close(). jsSub messaging.Subscription } // NewSigningConsumer returns a new instance of SigningConsumer. -func NewSigningConsumer(natsConn *nats.Conn, jsBroker messaging.MessageBroker, pubsub messaging.PubSub, peerRegistry mpc.PeerRegistry) SigningConsumer { +func NewSigningConsumer(natsConn *nats.Conn, jsBroker messaging.MessageBroker, pubsub messaging.PubSub, peerRegistry mpc.PeerRegistry, signingResultQueue messaging.MessageQueue) SigningConsumer { mpcThreshold := viper.GetInt("mpc_threshold") return &signingConsumer{ - natsConn: natsConn, - pubsub: pubsub, - jsBroker: jsBroker, - peerRegistry: peerRegistry, - mpcThreshold: mpcThreshold, + natsConn: natsConn, + pubsub: pubsub, + jsBroker: jsBroker, + peerRegistry: peerRegistry, + mpcThreshold: mpcThreshold, + signingResultQueue: signingResultQueue, } } @@ -70,6 +74,9 @@ func (sc *signingConsumer) waitForSufficientPeers(ctx context.Context) error { for { select { case <-ctx.Done(): + if ctx.Err() == context.Canceled { + return nil + } return ctx.Err() case <-ticker.C: readyPeers := sc.peerRegistry.GetReadyPeersCount() @@ -90,6 +97,9 @@ func (sc *signingConsumer) waitForSufficientPeers(ctx context.Context) error { func (sc *signingConsumer) Run(ctx context.Context) error { // Wait for sufficient peers before starting to consume messages if err := sc.waitForSufficientPeers(ctx); err != nil { + if err == context.Canceled { + return nil + } return fmt.Errorf("failed to wait for sufficient peers: %w", err) } @@ -130,18 +140,26 @@ func (sc *signingConsumer) Run(ctx context.Context) error { // When signing completes, the session publishes the result to a queue and calls the onSuccess callback, which sends a reply to the inbox that the SigningConsumer is monitoring. // The reply signals completion, allowing the SigningConsumer to acknowledge the original message. func (sc *signingConsumer) handleSigningEvent(msg jetstream.Msg) { - // Check if we still have enough peers before processing the message - requiredPeers := int64(sc.mpcThreshold + 1) - readyPeers := sc.peerRegistry.GetReadyPeersCount() - - if readyPeers < requiredPeers { - logger.Warn("SigningConsumer: Not enough peers to process signing request, rejecting message", - "ready", readyPeers, - "required", requiredPeers) - // Immediately return and let nats redeliver the message with backoff + // Parse the signing request message to extract transaction details + raw := msg.Data() + var signingMsg types.SignTxMessage + sessionID := msg.Headers().Get("SessionID") + + err := json.Unmarshal(raw, &signingMsg) + if err != nil { + logger.Error("SigningConsumer: Failed to unmarshal signing message", err) + sc.handleSigningError(signingMsg, event.ErrorCodeUnmarshalFailure, err, sessionID) + _ = msg.Ack() return } + if !sc.peerRegistry.AreMajorityReady() { + requiredPeers := int64(sc.mpcThreshold + 1) + err := fmt.Errorf("not enough peers to process signing request: ready=%d, required=%d", sc.peerRegistry.GetReadyPeersCount(), requiredPeers) + sc.handleSigningError(signingMsg, event.ErrorCodeNotMajority, err, sessionID) + _ = msg.Ack() + return + } // Create a reply inbox to receive the signing event response. replyInbox := nats.NewInbox() @@ -193,6 +211,36 @@ func (sc *signingConsumer) handleSigningEvent(msg jetstream.Msg) { _ = msg.Nak() } +func (sc *signingConsumer) handleSigningError(signMsg types.SignTxMessage, errorCode event.ErrorCode, err error, sessionID string) { + signingResult := event.SigningResultEvent{ + ResultType: event.ResultTypeError, + ErrorCode: errorCode, + NetworkInternalCode: signMsg.NetworkInternalCode, + WalletID: signMsg.WalletID, + TxID: signMsg.TxID, + ErrorReason: err.Error(), + } + + signingResultBytes, err := json.Marshal(signingResult) + if err != nil { + logger.Error("Failed to marshal signing result event", err, + "walletID", signMsg.WalletID, + "txID", signMsg.TxID, + ) + return + } + + err = sc.signingResultQueue.Enqueue(event.SigningResultCompleteTopic, signingResultBytes, &messaging.EnqueueOptions{ + IdempotententKey: buildIdempotentKey(signMsg.TxID, sessionID, mpc.TypeSigningResultFmt), + }) + if err != nil { + logger.Error("Failed to enqueue signing result event", err, + "walletID", signMsg.WalletID, + "txID", signMsg.TxID, + ) + } +} + // Close unsubscribes from the JetStream subject and cleans up resources. func (sc *signingConsumer) Close() error { if sc.jsSub != nil { @@ -204,3 +252,13 @@ func (sc *signingConsumer) Close() error { } return nil } + +func buildIdempotentKey(baseID string, sessionID string, formatTemplate string) string { + var uniqueKey string + if sessionID != "" { + uniqueKey = fmt.Sprintf("%s:%s", baseID, sessionID) + } else { + uniqueKey = baseID + } + return fmt.Sprintf(formatTemplate, uniqueKey) +} diff --git a/pkg/identity/identity.go b/pkg/identity/identity.go index 68b9176..0d2329a 100644 --- a/pkg/identity/identity.go +++ b/pkg/identity/identity.go @@ -17,6 +17,7 @@ import ( "golang.org/x/term" "github.com/fystack/mpcium/pkg/common/pathutil" + "github.com/fystack/mpcium/pkg/encryption" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/types" "github.com/spf13/viper" @@ -37,6 +38,18 @@ type Store interface { VerifyInitiatorMessage(msg types.InitiatorMessage) error SignMessage(msg *types.TssMessage) ([]byte, error) VerifyMessage(msg *types.TssMessage) error + + SignEcdhMessage(msg *types.ECDHMessage) ([]byte, error) + VerifySignature(msg *types.ECDHMessage) error + + SetSymmetricKey(peerID string, key []byte) + GetSymmetricKey(peerID string) ([]byte, error) + RemoveSymmetricKey(peerID string) + GetSymetricKeyCount() int + CheckSymmetricKeyComplete(desired int) bool + + EncryptMessage(plaintext []byte, peerID string) ([]byte, error) + DecryptMessage(cipher []byte, peerID string) ([]byte, error) } // fileStore implements the Store interface using the filesystem @@ -48,9 +61,9 @@ type fileStore struct { publicKeys map[string][]byte mu sync.RWMutex - // Cached private key privateKey []byte initiatorPubKey []byte + symmetricKeys map[string][]byte } // NewFileStore creates a new identity store @@ -97,6 +110,7 @@ func NewFileStore(identityDir, nodeName string, decrypt bool) (*fileStore, error publicKeys: make(map[string][]byte), privateKey: privateKey, initiatorPubKey: initiatorPubKey, + symmetricKeys: make(map[string][]byte), } // Check that each node in peers.json has an identity file @@ -207,6 +221,43 @@ func loadPrivateKey(identityDir, nodeName string, decrypt bool) (string, error) } } +// Set SymmetricKey: adds or updates a symmetric key for a given peer ID. +func (s *fileStore) SetSymmetricKey(peerID string, key []byte) { + s.mu.Lock() + defer s.mu.Unlock() + s.symmetricKeys[peerID] = key +} + +// Get SymmetricKey: retrieves a peer node's dh symmetric-key by its ID +func (s *fileStore) GetSymmetricKey(peerID string) ([]byte, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if key, exists := s.symmetricKeys[peerID]; exists { + return key, nil + } + + return nil, fmt.Errorf("SymmetricKey key not found for node ID: %s", peerID) +} + +func (s *fileStore) RemoveSymmetricKey(peerID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.symmetricKeys, peerID) +} + +func (s *fileStore) GetSymetricKeyCount() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.symmetricKeys) +} + +func (s *fileStore) CheckSymmetricKeyComplete(desired int) bool { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.symmetricKeys) == desired +} + // GetPublicKey retrieves a node's public key by its ID func (s *fileStore) GetPublicKey(nodeID string) ([]byte, error) { s.mu.RLock() @@ -259,6 +310,70 @@ func (s *fileStore) VerifyMessage(msg *types.TssMessage) error { return nil } +func (s *fileStore) EncryptMessage(plaintext []byte, peerID string) ([]byte, error) { + key, err := s.GetSymmetricKey(peerID) + if err != nil { + return nil, err + } + + if key == nil { + return nil, fmt.Errorf("no symmetric key for peer %s", peerID) + } + + return encryption.EncryptAESGCMWithNonceEmbed(plaintext, key) +} + +func (s *fileStore) DecryptMessage(cipher []byte, peerID string) ([]byte, error) { + key, err := s.GetSymmetricKey(peerID) + + if err != nil { + return nil, err + } + + if key == nil { + return nil, fmt.Errorf("no symmetric key for peer %s", peerID) + } + return encryption.DecryptAESGCMWithNonceEmbed(cipher, key) +} + +// Sign ECDH key exchange message +func (s *fileStore) SignEcdhMessage(msg *types.ECDHMessage) ([]byte, error) { + // Get deterministic bytes for signing + msgBytes, err := msg.MarshalForSigning() + if err != nil { + return nil, fmt.Errorf("failed to marshal message for signing: %w", err) + } + + signature := ed25519.Sign(s.privateKey, msgBytes) + return signature, nil +} + +// Verify ECDH key exchange message +func (s *fileStore) VerifySignature(msg *types.ECDHMessage) error { + if msg.Signature == nil { + return fmt.Errorf("ECDH message has no signature") + } + + // Get the sender's public key + senderPk, err := s.GetPublicKey(msg.From) + if err != nil { + return fmt.Errorf("failed to get sender's public key: %w", err) + } + + // Get deterministic bytes for verification + msgBytes, err := msg.MarshalForSigning() + if err != nil { + return fmt.Errorf("failed to marshal message for verification: %w", err) + } + + // Verify the signature + if !ed25519.Verify(senderPk, msgBytes, msg.Signature) { + return fmt.Errorf("invalid signature") + } + + return nil +} + // VerifyInitiatorMessage verifies that a message was signed by the known initiator func (s *fileStore) VerifyInitiatorMessage(msg types.InitiatorMessage) error { // Get the raw message that was signed diff --git a/pkg/messaging/point2point.go b/pkg/messaging/point2point.go index 6919c05..225f090 100644 --- a/pkg/messaging/point2point.go +++ b/pkg/messaging/point2point.go @@ -1,6 +1,8 @@ package messaging import ( + "fmt" + "sync" "time" "github.com/avast/retry-go" @@ -9,25 +11,54 @@ import ( ) type DirectMessaging interface { - Listen(target string, handler func(data []byte)) (Subscription, error) - Send(target string, data []byte) error + Listen(topic string, handler func(data []byte)) (Subscription, error) + SendToOther(topic string, data []byte) error + SendToOtherWithRetry(topic string, data []byte, config RetryConfig) error + SendToSelf(topic string, data []byte) error +} + +type RetryConfig struct { + RetryAttempt uint + ExponentialBackoff bool + Delay time.Duration + OnRetry func(n uint, err error) } type natsDirectMessaging struct { natsConn *nats.Conn + handlers map[string][]func([]byte) + mu sync.Mutex } func NewNatsDirectMessaging(natsConn *nats.Conn) DirectMessaging { return &natsDirectMessaging{ natsConn: natsConn, + handlers: make(map[string][]func([]byte)), + } +} + +// SendToSelf locally sends a message to the same node, invoking all handlers for the topic +// avoiding mediating through the message layer. +func (d *natsDirectMessaging) SendToSelf(topic string, message []byte) error { + d.mu.Lock() + handlers, ok := d.handlers[topic] + d.mu.Unlock() + + if !ok || len(handlers) == 0 { + return fmt.Errorf("no handlers found for topic %s", topic) + } + + for _, handler := range handlers { + handler(message) } + + return nil } -func (d *natsDirectMessaging) Send(id string, message []byte) error { - var retryCount = 0 - err := retry.Do( +func (d *natsDirectMessaging) SendToOther(topic string, message []byte) error { + return retry.Do( func() error { - _, err := d.natsConn.Request(id, message, 3*time.Second) + _, err := d.natsConn.Request(topic, message, 3*time.Second) if err != nil { return err } @@ -37,15 +68,43 @@ func (d *natsDirectMessaging) Send(id string, message []byte) error { retry.Delay(50*time.Millisecond), retry.DelayType(retry.FixedDelay), retry.OnRetry(func(n uint, err error) { - logger.Error("Failed to send direct message message", err, "retryCount", retryCount) + logger.Error("Failed to send direct message", err, "attempt", n+1, "topic", topic) }), ) +} + +func (d *natsDirectMessaging) SendToOtherWithRetry(topic string, message []byte, config RetryConfig) error { + opts := []retry.Option{ + retry.MaxJitter(80 * time.Millisecond), + } + + if config.RetryAttempt > 0 { + opts = append(opts, retry.Attempts(config.RetryAttempt)) + } + if config.ExponentialBackoff { + opts = append(opts, retry.DelayType(retry.BackOffDelay)) + } + if config.Delay > 0 { + opts = append(opts, retry.Delay(config.Delay)) + } + if config.OnRetry != nil { + opts = append(opts, retry.OnRetry(config.OnRetry)) + } - return err + return retry.Do( + func() error { + _, err := d.natsConn.Request(topic, message, 3*time.Second) + if err != nil { + return err + } + return nil + }, + opts..., + ) } -func (d *natsDirectMessaging) Listen(id string, handler func(data []byte)) (Subscription, error) { - sub, err := d.natsConn.Subscribe(id, func(m *nats.Msg) { +func (d *natsDirectMessaging) Listen(topic string, handler func(data []byte)) (Subscription, error) { + sub, err := d.natsConn.Subscribe(topic, func(m *nats.Msg) { handler(m.Data) if err := m.Respond([]byte("OK")); err != nil { logger.Error("Failed to respond to message", err) @@ -55,5 +114,17 @@ func (d *natsDirectMessaging) Listen(id string, handler func(data []byte)) (Subs return nil, err } + if err := d.natsConn.Flush(); err != nil { + err := sub.Unsubscribe() + if err != nil { + logger.Error("Failed to unsubscribe", err) + } + return nil, fmt.Errorf("flush after subscribe failed: %w", err) + } + + d.mu.Lock() + d.handlers[topic] = append(d.handlers[topic], handler) + d.mu.Unlock() + return &natsSubscription{subscription: sub}, nil } diff --git a/pkg/messaging/pubsub.go b/pkg/messaging/pubsub.go index 57547ea..8e4fd0e 100644 --- a/pkg/messaging/pubsub.go +++ b/pkg/messaging/pubsub.go @@ -51,8 +51,7 @@ func (n *natsPubSub) PublishWithReply(topic, reply string, data []byte, headers } func (n *natsPubSub) Subscribe(topic string, handler func(msg *nats.Msg)) (Subscription, error) { - // TODO: Handle subscription - // handle more fields in msg + //Handle subscription: handle more fields in msg sub, err := n.natsConn.Subscribe(topic, func(msg *nats.Msg) { handler(msg) }) diff --git a/pkg/mpc/ecdsa_keygen_session.go b/pkg/mpc/ecdsa_keygen_session.go index e9ea00a..acbc497 100644 --- a/pkg/mpc/ecdsa_keygen_session.go +++ b/pkg/mpc/ecdsa_keygen_session.go @@ -61,8 +61,8 @@ func newECDSAKeygenSession( ComposeBroadcastTopic: func() string { return fmt.Sprintf("keygen:broadcast:ecdsa:%s", walletID) }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("keygen:direct:ecdsa:%s:%s", nodeID, walletID) + ComposeDirectTopic: func(fromID string, toID string) string { + return fmt.Sprintf("keygen:direct:ecdsa:%s:%s:%s", fromID, toID, walletID) }, }, composeKey: func(walletID string) string { diff --git a/pkg/mpc/ecdsa_resharing_session.go b/pkg/mpc/ecdsa_resharing_session.go index cd44fab..4cf3845 100644 --- a/pkg/mpc/ecdsa_resharing_session.go +++ b/pkg/mpc/ecdsa_resharing_session.go @@ -21,11 +21,13 @@ type ReshareSession interface { Init() error Reshare(done func()) GetPubKeyResult() []byte + GetLegacyCommitteePeers() []string } type ecdsaReshareSession struct { *session isNewParty bool + oldPeerIDs []string newPeerIDs []string reshareParams *tss.ReSharingParameters endCh chan *keygen.LocalPartySaveData @@ -50,6 +52,12 @@ func NewECDSAReshareSession( isNewParty bool, version int, ) *ecdsaReshareSession { + + realPartyIDs := oldPartyIDs + if isNewParty { + realPartyIDs = newPartyIDs + } + session := session{ walletID: walletID, pubSub: pubSub, @@ -57,7 +65,7 @@ func NewECDSAReshareSession( threshold: threshold, participantPeerIDs: participantPeerIDs, selfPartyID: selfID, - partyIDs: newPartyIDs, + partyIDs: realPartyIDs, outCh: make(chan tss.Message), ErrCh: make(chan error), preParams: preParams, @@ -68,8 +76,8 @@ func NewECDSAReshareSession( ComposeBroadcastTopic: func() string { return fmt.Sprintf("resharing:broadcast:ecdsa:%s", walletID) }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("resharing:direct:ecdsa:%s:%s", nodeID, walletID) + ComposeDirectTopic: func(fromID string, toID string) string { + return fmt.Sprintf("resharing:direct:ecdsa:%s:%s:%s", fromID, toID, walletID) }, }, composeKey: func(walletID string) string { @@ -90,17 +98,46 @@ func NewECDSAReshareSession( len(newPartyIDs), newThreshold, ) + + var oldPeerIDs []string + for _, partyId := range oldPartyIDs { + oldPeerIDs = append(oldPeerIDs, partyIDToNodeID(partyId)) + } + return &ecdsaReshareSession{ session: &session, reshareParams: reshareParams, isNewParty: isNewParty, + oldPeerIDs: oldPeerIDs, newPeerIDs: newPeerIDs, endCh: make(chan *keygen.LocalPartySaveData), } } +// GetLegacyCommitteePeers returns peer IDs that were part of the old committee +// but are NOT part of the new committee after resharing. +// These peers are still relevant during resharing because +// they must send final share data to the new committee. +func (s *ecdsaReshareSession) GetLegacyCommitteePeers() []string { + difference := func(A, B []string) []string { + seen := make(map[string]bool) + for _, b := range B { + seen[b] = true + } + var result []string + for _, a := range A { + if !seen[a] { + result = append(result, a) + } + } + return result + } + + return difference(s.oldPeerIDs, s.newPeerIDs) +} + func (s *ecdsaReshareSession) Init() error { - logger.Infof("Initializing resharing session with partyID: %s, newPartyIDs %s", s.selfPartyID, s.partyIDs) + logger.Infof("Initializing ecdsa resharing session with partyID: %s, newPartyIDs %s", s.selfPartyID, s.partyIDs) var share keygen.LocalPartySaveData if s.isNewParty { diff --git a/pkg/mpc/ecdsa_signing_session.go b/pkg/mpc/ecdsa_signing_session.go index ba45ce3..d8fc3b5 100644 --- a/pkg/mpc/ecdsa_signing_session.go +++ b/pkg/mpc/ecdsa_signing_session.go @@ -72,8 +72,8 @@ func newECDSASigningSession( ComposeBroadcastTopic: func() string { return fmt.Sprintf("sign:ecdsa:broadcast:%s:%s", walletID, txID) }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("sign:ecdsa:direct:%s:%s", nodeID, txID) + ComposeDirectTopic: func(fromID string, toID string) string { + return fmt.Sprintf("sign:ecdsa:direct:%s:%s:%s", fromID, toID, txID) }, }, composeKey: func(waleltID string) string { diff --git a/pkg/mpc/eddsa_keygen_session.go b/pkg/mpc/eddsa_keygen_session.go index a4fe030..fd832e3 100644 --- a/pkg/mpc/eddsa_keygen_session.go +++ b/pkg/mpc/eddsa_keygen_session.go @@ -49,8 +49,8 @@ func newEDDSAKeygenSession( ComposeBroadcastTopic: func() string { return fmt.Sprintf("keygen:broadcast:eddsa:%s", walletID) }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("keygen:direct:eddsa:%s:%s", nodeID, walletID) + ComposeDirectTopic: func(fromID string, toID string) string { + return fmt.Sprintf("keygen:direct:eddsa:%s:%s:%s", fromID, toID, walletID) }, }, composeKey: func(waleltID string) string { diff --git a/pkg/mpc/eddsa_resharing_session.go b/pkg/mpc/eddsa_resharing_session.go index 5694624..70a59c2 100644 --- a/pkg/mpc/eddsa_resharing_session.go +++ b/pkg/mpc/eddsa_resharing_session.go @@ -18,6 +18,7 @@ import ( type eddsaReshareSession struct { *session isNewParty bool + oldPeerIDs []string newPeerIDs []string reshareParams *tss.ReSharingParameters endCh chan *keygen.LocalPartySaveData @@ -41,6 +42,12 @@ func NewEDDSAReshareSession( isNewParty bool, version int, ) *eddsaReshareSession { + + realPartyIDs := oldPartyIDs + if isNewParty { + realPartyIDs = newPartyIDs + } + session := session{ walletID: walletID, pubSub: pubSub, @@ -49,7 +56,7 @@ func NewEDDSAReshareSession( version: version, participantPeerIDs: participantPeerIDs, selfPartyID: selfID, - partyIDs: newPartyIDs, + partyIDs: realPartyIDs, outCh: make(chan tss.Message), ErrCh: make(chan error), kvstore: kvstore, @@ -58,8 +65,8 @@ func NewEDDSAReshareSession( ComposeBroadcastTopic: func() string { return fmt.Sprintf("reshare:broadcast:eddsa:%s", walletID) }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("reshare:direct:eddsa:%s:%s", nodeID, walletID) + ComposeDirectTopic: func(fromID string, toID string) string { + return fmt.Sprintf("reshare:direct:eddsa:%s:%s:%s", fromID, toID, walletID) }, }, composeKey: func(walletID string) string { @@ -82,17 +89,45 @@ func NewEDDSAReshareSession( newThreshold, ) + var oldPeerIDs []string + for _, partyId := range oldPartyIDs { + oldPeerIDs = append(oldPeerIDs, partyIDToNodeID(partyId)) + } + return &eddsaReshareSession{ session: &session, reshareParams: reshareParams, isNewParty: isNewParty, + oldPeerIDs: oldPeerIDs, newPeerIDs: newPeerIDs, endCh: make(chan *keygen.LocalPartySaveData), } } +// GetLegacyCommitteePeers returns peer IDs that were part of the old committee +// but are NOT part of the new committee after resharing. +// These peers are still relevant during resharing because +// they must send final share data to the new committee. +func (s *eddsaReshareSession) GetLegacyCommitteePeers() []string { + difference := func(A, B []string) []string { + seen := make(map[string]bool) + for _, b := range B { + seen[b] = true + } + var result []string + for _, a := range A { + if !seen[a] { + result = append(result, a) + } + } + return result + } + + return difference(s.oldPeerIDs, s.newPeerIDs) +} + func (s *eddsaReshareSession) Init() error { - logger.Infof("Initializing resharing session with partyID: %s, peerIDs %s", s.selfPartyID, s.partyIDs) + logger.Infof("Initializing eddsa resharing session with partyID: %s, peerIDs %s", s.selfPartyID, s.partyIDs) var share keygen.LocalPartySaveData if s.isNewParty { // Initialize empty share data for new party @@ -104,7 +139,7 @@ func (s *eddsaReshareSession) Init() error { } } s.party = resharing.NewLocalParty(s.reshareParams, share, s.outCh, s.endCh) - logger.Infof("[INITIALIZED] Initialized resharing session successfully partyID: %s, peerIDs %s, walletID %s, oldThreshold = %d, newThreshold = %d", + logger.Infof("[INITIALIZED] Initialized eddsa resharing session successfully partyID: %s, peerIDs %s, walletID %s, oldThreshold = %d, newThreshold = %d", s.selfPartyID, s.partyIDs, s.walletID, s.threshold, s.reshareParams.NewThreshold()) return nil diff --git a/pkg/mpc/eddsa_signing_session.go b/pkg/mpc/eddsa_signing_session.go index fe037b8..d70b242 100644 --- a/pkg/mpc/eddsa_signing_session.go +++ b/pkg/mpc/eddsa_signing_session.go @@ -63,8 +63,8 @@ func newEDDSASigningSession( ComposeBroadcastTopic: func() string { return fmt.Sprintf("sign:eddsa:broadcast:%s:%s", walletID, txID) }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("sign:eddsa:direct:%s:%s", nodeID, txID) + ComposeDirectTopic: func(fromID string, toID string) string { + return fmt.Sprintf("sign:eddsa:direct:%s:%s:%s", fromID, toID, txID) }, }, composeKey: func(waleltID string) string { @@ -73,7 +73,7 @@ func newEDDSASigningSession( getRoundFunc: GetEddsaMsgRound, resultQueue: resultQueue, identityStore: identityStore, - idempotentKey: idempotentKey, + idempotentKey: idempotentKey, }, endCh: make(chan *common.SignatureData), txID: txID, diff --git a/pkg/mpc/key_exchange_session.go b/pkg/mpc/key_exchange_session.go new file mode 100644 index 0000000..2065f03 --- /dev/null +++ b/pkg/mpc/key_exchange_session.go @@ -0,0 +1,184 @@ +package mpc + +import ( + "crypto/ecdh" + "crypto/rand" + "crypto/sha256" + + "golang.org/x/crypto/hkdf" + + "fmt" + "time" + + "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/types" + + "encoding/json" + + "github.com/nats-io/nats.go" +) + +const ( + ECDHExchangeTopic = "ecdh:exchange" + ECDHExchangeTimeout = 2 * time.Minute +) + +type ECDHSession interface { + ListenKeyExchange() error + BroadcastPublicKey() error + RemovePeer(peerID string) + GetReadyPeersCount() int + ErrChan() <-chan error + Close() error +} + +type ecdhSession struct { + nodeID string + peerIDs []string + pubSub messaging.PubSub + ecdhSub messaging.Subscription + identityStore identity.Store + privateKey *ecdh.PrivateKey + publicKey *ecdh.PublicKey + errCh chan error +} + +func NewECDHSession( + nodeID string, + peerIDs []string, + pubSub messaging.PubSub, + identityStore identity.Store, +) *ecdhSession { + return &ecdhSession{ + nodeID: nodeID, + peerIDs: peerIDs, + pubSub: pubSub, + identityStore: identityStore, + errCh: make(chan error, 1), + } +} + +func (e *ecdhSession) RemovePeer(peerID string) { + e.identityStore.RemoveSymmetricKey(peerID) +} + +func (e *ecdhSession) GetReadyPeersCount() int { + return e.identityStore.GetSymetricKeyCount() +} + +func (e *ecdhSession) ErrChan() <-chan error { + return e.errCh +} + +func (e *ecdhSession) ListenKeyExchange() error { + // Generate an ephemeral ECDH key pair + privateKey, err := ecdh.X25519().GenerateKey(rand.Reader) + if err != nil { + return fmt.Errorf("failed to generate ECDH key pair: %w", err) + } + + e.privateKey = privateKey + e.publicKey = privateKey.PublicKey() + + // Subscribe to ECDH broadcast + sub, err := e.pubSub.Subscribe(ECDHExchangeTopic, func(natMsg *nats.Msg) { + var ecdhMsg types.ECDHMessage + if err := json.Unmarshal(natMsg.Data, &ecdhMsg); err != nil { + return + } + + if ecdhMsg.From == e.nodeID { + return + } + + //TODO: consider how to avoid replay attack + if err := e.identityStore.VerifySignature(&ecdhMsg); err != nil { + e.errCh <- err + return + } + + peerPublicKey, err := ecdh.X25519().NewPublicKey(ecdhMsg.PublicKey) + if err != nil { + e.errCh <- err + return + } + sharedSecret, err := e.privateKey.ECDH(peerPublicKey) + if err != nil { + e.errCh <- err + return + } + + // Derive symmetric key using HKDF + symmetricKey := e.deriveSymmetricKey(sharedSecret, ecdhMsg.From) + e.identityStore.SetSymmetricKey(ecdhMsg.From, symmetricKey) + logger.Debug("ECDH progress", "peer", ecdhMsg.From, "current", e.identityStore.GetSymetricKeyCount()) + }) + + e.ecdhSub = sub + if err != nil { + return fmt.Errorf("failed to subscribe to ECDH topic: %w", err) + } + return nil +} + +func (s *ecdhSession) Close() error { + err := s.ecdhSub.Unsubscribe() + if err != nil { + return err + } + + return nil +} + +func (e *ecdhSession) BroadcastPublicKey() error { + publicKeyBytes := e.publicKey.Bytes() + msg := types.ECDHMessage{ + From: e.nodeID, + PublicKey: publicKeyBytes, + Timestamp: time.Now(), + } + //Sign the message using existing identity store + signature, err := e.identityStore.SignEcdhMessage(&msg) + if err != nil { + return fmt.Errorf("failed to sign ECDH message: %w", err) + } + msg.Signature = signature + signedMsgBytes, _ := json.Marshal(msg) + + logger.Info("Starting to broadcast DH key", "nodeID", e.nodeID) + if err := e.pubSub.Publish(ECDHExchangeTopic, signedMsgBytes); err != nil { + return fmt.Errorf("%s failed to publish DH message because %w", e.nodeID, err) + } + return nil +} + +func deriveConsistentInfo(a, b string) []byte { + if a < b { + return []byte(a + b) + } + return []byte(b + a) +} + +// derives a symmetric key from the shared secret and peer ID using HKDF. +func (e *ecdhSession) deriveSymmetricKey(sharedSecret []byte, peerID string) []byte { + hash := sha256.New + + // Info parameter can include context-specific data; here we use a pair of party IDs + info := deriveConsistentInfo(e.nodeID, peerID) + + // Salt can be nil or a random value; here we use nil + var salt []byte + + hkdf := hkdf.New(hash, sharedSecret, salt, info) + + // Derive a 32-byte symmetric key (suitable for AES-256) + symmetricKey := make([]byte, 32) + _, err := hkdf.Read(symmetricKey) + if err != nil { + e.errCh <- err + return nil + } + return symmetricKey +} diff --git a/pkg/mpc/node.go b/pkg/mpc/node.go index b60b5b9..d615444 100644 --- a/pkg/mpc/node.go +++ b/pkg/mpc/node.go @@ -1,12 +1,9 @@ package mpc import ( - "bytes" "encoding/json" "fmt" - "math/big" "slices" - "strconv" "time" "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" @@ -17,7 +14,6 @@ import ( "github.com/fystack/mpcium/pkg/kvstore" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" - "github.com/google/uuid" ) const ( @@ -45,18 +41,6 @@ type Node struct { peerRegistry PeerRegistry } -func PartyIDToRoutingDest(partyID *tss.PartyID) string { - return string(partyID.KeyInt().Bytes()) -} - -func ComparePartyIDs(x, y *tss.PartyID) bool { - return bytes.Equal(x.KeyInt().Bytes(), y.KeyInt().Bytes()) -} - -func ComposeReadyKey(nodeID string) string { - return fmt.Sprintf("ready/%s", nodeID) -} - func NewNode( nodeID string, peerIDs []string, @@ -83,6 +67,7 @@ func NewNode( } node.ecdsaPreParams = node.generatePreParams() + // Start watching peers - ECDH is now handled by the registry go peerRegistry.WatchPeersReady() return node } @@ -98,11 +83,7 @@ func (p *Node) CreateKeyGenSession( resultQueue messaging.MessageQueue, ) (KeyGenSession, error) { if !p.peerRegistry.ArePeersReady() { - return nil, fmt.Errorf( - "Not enough peers to create gen session! Expected %d, got %d", - p.peerRegistry.GetTotalPeersCount(), - p.peerRegistry.GetReadyPeersCount(), - ) + return nil, errors.New("All nodes are not ready!") } keyInfo, _ := p.getKeyInfo(sessionType, walletID) @@ -417,44 +398,8 @@ func (p *Node) CreateReshareSession( } } -// generatePartyIDs generates the party IDs for the given purpose and version -// It returns the self party ID and all party IDs -// It also sorts the party IDs in place -func (n *Node) generatePartyIDs( - label string, - readyPeerIDs []string, - version int, -) (self *tss.PartyID, all []*tss.PartyID) { - // Pre-allocate slice with exact size needed - partyIDs := make([]*tss.PartyID, 0, len(readyPeerIDs)) - - // Create all party IDs in one pass - for _, peerID := range readyPeerIDs { - partyID := createPartyID(peerID, label, version) - if peerID == n.nodeID { - self = partyID - } - partyIDs = append(partyIDs, partyID) - } - - // Sort party IDs in place - all = tss.SortPartyIDs(partyIDs, 0) - return -} - -// createPartyID creates a new party ID for the given node ID, label and version -// It returns the party ID: random string -// Moniker: for routing messages -// Key: for mpc internal use (need persistent storage) -func createPartyID(nodeID string, label string, version int) *tss.PartyID { - partyID := uuid.NewString() - var key *big.Int - if version == BackwardCompatibleVersion { - key = big.NewInt(0).SetBytes([]byte(nodeID)) - } else { - key = big.NewInt(0).SetBytes([]byte(nodeID + ":" + strconv.Itoa(version))) - } - return tss.NewPartyID(partyID, label, key) +func ComposeReadyKey(nodeID string) string { + return fmt.Sprintf("ready/%s", nodeID) } func (p *Node) Close() { diff --git a/pkg/mpc/node_test.go b/pkg/mpc/node_test.go index cfd7294..37792e9 100644 --- a/pkg/mpc/node_test.go +++ b/pkg/mpc/node_test.go @@ -6,12 +6,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestPartyIDToNodeID(t *testing.T) { - partyID := createPartyID("4d8cb873-dc86-4776-b6f6-cf5c668f6468", "keygen", 1) - nodeID := PartyIDToRoutingDest(partyID) - assert.Equal(t, nodeID, "4d8cb873-dc86-4776-b6f6-cf5c668f6468:1", "NodeID should be equal") -} - func TestCreatePartyID_Structure(t *testing.T) { sessionID := "test-session-123" keyType := "keygen" @@ -46,29 +40,6 @@ func TestCreatePartyID_DifferentVersions(t *testing.T) { assert.NotEqual(t, partyID0.Key, partyID1.Key) } -func TestPartyIDToRoutingDest_BackwardCompatible(t *testing.T) { - sessionID := "test-session-789" - keyType := "signing" - - partyID := createPartyID(sessionID, keyType, BackwardCompatibleVersion) - nodeID := PartyIDToRoutingDest(partyID) - - // For backward compatible version, should just be the sessionID - assert.Equal(t, sessionID, nodeID) -} - -func TestPartyIDToRoutingDest_DefaultVersion(t *testing.T) { - sessionID := "test-session-999" - keyType := "signing" - - partyID := createPartyID(sessionID, keyType, DefaultVersion) - nodeID := PartyIDToRoutingDest(partyID) - - // For default version, should include the version number - expected := sessionID + ":1" - assert.Equal(t, expected, nodeID) -} - func TestCreatePartyID_EmptyValues(t *testing.T) { // Test with empty session ID partyID := createPartyID("", "keygen", 0) @@ -81,22 +52,6 @@ func TestCreatePartyID_EmptyValues(t *testing.T) { assert.Equal(t, "", partyID.Moniker) } -func TestPartyIDToRoutingDest_Consistency(t *testing.T) { - sessionID := "consistent-session" - keyType := "keygen" - version := 3 - - // Create the same party ID multiple times - partyID1 := createPartyID(sessionID, keyType, version) - partyID2 := createPartyID(sessionID, keyType, version) - - nodeID1 := PartyIDToRoutingDest(partyID1) - nodeID2 := PartyIDToRoutingDest(partyID2) - - // Should produce consistent results based on sessionID and version - assert.Equal(t, nodeID1, nodeID2, "Same parameters should produce same routing destinations") -} - func TestCreatePartyID_UniqueIDs(t *testing.T) { sessionID := "test-session" keyType := "keygen" diff --git a/pkg/mpc/party_id.go b/pkg/mpc/party_id.go new file mode 100644 index 0000000..e0c6b19 --- /dev/null +++ b/pkg/mpc/party_id.go @@ -0,0 +1,71 @@ +package mpc + +import ( + "bytes" + "fmt" + "math/big" + "strings" + + "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/google/uuid" +) + +// generatePartyIDs generates the party IDs for the given purpose and version +// It returns the self party ID and all party IDs +// It also sorts the party IDs in place +func (n *Node) generatePartyIDs( + label string, + readyPeerIDs []string, + version int, +) (self *tss.PartyID, all []*tss.PartyID) { + // Pre-allocate slice with exact size needed + partyIDs := make([]*tss.PartyID, 0, len(readyPeerIDs)) + + // Create all party IDs in one pass + for _, peerID := range readyPeerIDs { + partyID := createPartyID(peerID, label, version) + if peerID == n.nodeID { + self = partyID + } + partyIDs = append(partyIDs, partyID) + } + + // Sort party IDs in place + all = tss.SortPartyIDs(partyIDs, 0) + return +} + +// createPartyID creates a new party ID for the given node ID, label and version +// It returns the party ID: random string +// Moniker: for routing messages +// Key: for mpc internal use (need persistent storage) +func createPartyID(nodeID string, label string, version int) *tss.PartyID { + partyID := uuid.NewString() + var key *big.Int + if version == BackwardCompatibleVersion { + key = new(big.Int).SetBytes([]byte(nodeID)) + } else { + key = new(big.Int).SetBytes([]byte(fmt.Sprintf("%s:%d", nodeID, version))) + } + return tss.NewPartyID(partyID, label, key) +} + +func partyIDToNodeID(partyID *tss.PartyID) string { + if partyID == nil { + return "" + } + nodeID, _, _ := strings.Cut(string(partyID.KeyInt().Bytes()), ":") + return strings.TrimSpace(nodeID) +} + +func partyIDsToNodeIDs(pids []*tss.PartyID) []string { + out := make([]string, 0, len(pids)) + for _, p := range pids { + out = append(out, partyIDToNodeID(p)) + } + return out +} + +func comparePartyIDs(x, y *tss.PartyID) bool { + return bytes.Equal(x.KeyInt().Bytes(), y.KeyInt().Bytes()) +} diff --git a/pkg/mpc/registry.go b/pkg/mpc/registry.go index 688a306..92c2cf6 100644 --- a/pkg/mpc/registry.go +++ b/pkg/mpc/registry.go @@ -2,14 +2,19 @@ package mpc import ( "fmt" + "strconv" + "strings" "sync" "sync/atomic" "time" + "github.com/fystack/mpcium/pkg/identity" "github.com/fystack/mpcium/pkg/infra" "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/messaging" "github.com/hashicorp/consul/api" "github.com/samber/lo" + "github.com/spf13/viper" ) const ( @@ -19,12 +24,18 @@ const ( type PeerRegistry interface { Ready() error ArePeersReady() bool + AreMajorityReady() bool WatchPeersReady() // Resign is called by the node when it is going to shutdown Resign() error GetReadyPeersCount() int64 + GetReadyPeersCountExcludeSelf() int64 GetReadyPeersIncludeSelf() []string // get ready peers include self GetTotalPeersCount() int64 + + OnPeerConnected(callback func(peerID string)) + OnPeerDisconnected(callback func(peerID string)) + OnPeerReConnected(callback func(peerID string)) } type registry struct { @@ -35,20 +46,43 @@ type registry struct { mu sync.RWMutex ready bool // ready is true when all peers are ready - consulKV infra.ConsulKV + consulKV infra.ConsulKV + healthCheck messaging.DirectMessaging + pubSub messaging.PubSub + identityStore identity.Store + ecdhSession ECDHSession + mpcThreshold int + + onPeerConnected func(peerID string) + onPeerDisconnected func(peerID string) + onPeerReConnected func(peerID string) } func NewRegistry( nodeID string, peerNodeIDs []string, consulKV infra.ConsulKV, + directMessaging messaging.DirectMessaging, + pubSub messaging.PubSub, + identityStore identity.Store, ) *registry { + ecdhSession := NewECDHSession(nodeID, peerNodeIDs, pubSub, identityStore) + mpcThreshold := viper.GetInt("mpc_threshold") + if mpcThreshold < 1 { + logger.Fatal("mpc_threshold must be greater than 0", nil) + } + return ®istry{ - consulKV: consulKV, - nodeID: nodeID, - peerNodeIDs: getPeerIDsExceptSelf(nodeID, peerNodeIDs), - readyMap: make(map[string]bool), - readyCount: 1, // self + consulKV: consulKV, + nodeID: nodeID, + peerNodeIDs: getPeerIDsExceptSelf(nodeID, peerNodeIDs), + readyMap: make(map[string]bool), + readyCount: 1, // self + healthCheck: directMessaging, + pubSub: pubSub, + identityStore: identityStore, + ecdhSession: ecdhSession, + mpcThreshold: mpcThreshold, } } @@ -72,9 +106,17 @@ func (r *registry) registerReadyPairs(peerIDs []string) { if !exist { atomic.AddInt64(&r.readyCount, 1) logger.Info("Register", "peerID", peerID) + if r.onPeerConnected != nil { + r.onPeerConnected(peerID) + } + go r.triggerECDHExchange() } else if !ready { atomic.AddInt64(&r.readyCount, 1) logger.Info("Reconnecting...", "peerID", peerID) + if r.onPeerReConnected != nil { + r.onPeerReConnected(peerID) + } + go r.triggerECDHExchange() } r.readyMap[peerID] = true @@ -84,14 +126,26 @@ func (r *registry) registerReadyPairs(peerIDs []string) { r.mu.Lock() r.ready = true r.mu.Unlock() - logger.Info("ALL PEERS ARE READY! Starting to accept MPC requests") + logger.Info("All peers are ready including ECDH exchange completion") } +} +// triggerECDHExchange safely triggers ECDH key exchange +func (r *registry) triggerECDHExchange() { + logger.Info("Triggering ECDH key exchange") + if err := r.ecdhSession.BroadcastPublicKey(); err != nil { + logger.Error("Failed to trigger ECDH exchange", err) + } } // Ready is called by the node when it complete generate preparams and starting to accept // incoming requests func (r *registry) Ready() error { + // Start ECDH exchange first + if err := r.startECDHExchange(); err != nil { + return fmt.Errorf("failed to start ECDH exchange: %w", err) + } + k := r.readyKey(r.nodeID) kv := &api.KVPair{ @@ -104,12 +158,25 @@ func (r *registry) Ready() error { return fmt.Errorf("Put ready key failed: %w", err) } + _, err = r.healthCheck.Listen(r.composeHealthCheckTopic(r.nodeID), func(data []byte) { + peerID, ecdhReadyPeersCount, _ := parseHealthDataSplit(string(data)) + logger.Debug("Health check ok", "peerID", peerID) + if ecdhReadyPeersCount < int(r.GetReadyPeersCountExcludeSelf()) { + logger.Info("[ECDH exchange retriggerd] not all peers are ready", "peerID", peerID) + go r.triggerECDHExchange() + + } + }) + if err != nil { + return fmt.Errorf("Listen health check failed: %w", err) + } return nil } func (r *registry) WatchPeersReady() { + go r.checkPeersHealth() + ticker := time.NewTicker(ReadinessCheckPeriod) - go r.logReadyStatus() // first tick is executed immediately for ; true; <-ticker.C { pairs, _, err := r.consulKV.List("ready/", nil) @@ -136,6 +203,13 @@ func (r *registry) WatchPeersReady() { logger.Warn("Peer disconnected!", "peerID", peerID) r.readyMap[peerID] = false atomic.AddInt64(&r.readyCount, -1) + + // Remove ECDH key for disconnected peer + r.ecdhSession.RemovePeer(peerID) + + if r.onPeerDisconnected != nil { + r.onPeerDisconnected(peerID) + } } } @@ -146,15 +220,36 @@ func (r *registry) WatchPeersReady() { } -func (r *registry) logReadyStatus() { +func (r *registry) checkPeersHealth() { for { time.Sleep(5 * time.Second) if !r.ArePeersReady() { logger.Info("Peers are not ready yet", "ready", r.GetReadyPeersCount(), "expected", len(r.peerNodeIDs)+1) } + + pairs, _, err := r.consulKV.List("ready/", nil) + if err != nil { + logger.Error("List ready keys failed", err) + continue + } + readyPeerIDs := r.getReadyPeersFromKVStore(pairs) + for _, peerID := range readyPeerIDs { + err := r.healthCheck.SendToOtherWithRetry(r.composeHealthCheckTopic(peerID), []byte(r.composeHealthData()), messaging.RetryConfig{ + RetryAttempt: 2, + }) + if err != nil && strings.Contains(err.Error(), "no responders") { + logger.Info("No response from peer", "peerID", peerID) + _, err := r.consulKV.Delete(r.readyKey(peerID), nil) + if err != nil { + logger.Error("Delete ready key failed", err) + } + } + } } } +// GetReadyPeersCount returns the number of ready peers including self +// should -1 if want to exclude self func (r *registry) GetReadyPeersCount() int64 { return atomic.LoadInt64(&r.readyCount) } @@ -193,7 +288,17 @@ func (r *registry) ArePeersReady() bool { r.mu.RLock() defer r.mu.RUnlock() - return r.ready + // Check both peer connectivity and ECDH completion + return r.ready && r.isECDHReady() +} + +// AreMajorityReady checks if a majority of peers are ready. +// Returns true only if: +// 1. The number of ready peers (including self) is greater than mpcThreshold+1 +// 2. Symmetric keys are fully established among all ready peers (excluding self). +func (r *registry) AreMajorityReady() bool { + readyCount := r.GetReadyPeersCount() + return int(readyCount) >= r.mpcThreshold+1 && r.isECDHReady() } func (r *registry) GetTotalPeersCount() int64 { @@ -211,3 +316,60 @@ func (r *registry) Resign() error { return nil } + +func (r *registry) OnPeerConnected(callback func(peerID string)) { + r.onPeerConnected = callback +} + +func (r *registry) OnPeerDisconnected(callback func(peerID string)) { + r.onPeerDisconnected = callback +} + +func (r *registry) OnPeerReConnected(callback func(peerID string)) { + r.onPeerReConnected = callback +} + +// StartECDHExchange starts the ECDH key exchange process +func (r *registry) startECDHExchange() error { + if err := r.ecdhSession.ListenKeyExchange(); err != nil { + return fmt.Errorf("failed to start ECDH listener: %w", err) + } + + if err := r.ecdhSession.BroadcastPublicKey(); err != nil { + return fmt.Errorf("failed to broadcast ECDH public key: %w", err) + } + + return nil +} + +func (r *registry) GetReadyPeersCountExcludeSelf() int64 { + return r.GetReadyPeersCount() - 1 +} + +func (r *registry) isECDHReady() bool { + requiredKeyCount := r.GetReadyPeersCountExcludeSelf() + return r.identityStore.CheckSymmetricKeyComplete(int(requiredKeyCount)) +} + +func (r *registry) composeHealthCheckTopic(nodeID string) string { + return fmt.Sprintf("healthcheck:%s", nodeID) +} + +func (r *registry) composeHealthData() string { + return fmt.Sprintf("%s,%d", r.nodeID, r.ecdhSession.GetReadyPeersCount()) +} + +func parseHealthDataSplit(s string) (peerID string, readyCount int, err error) { + parts := strings.SplitN(s, ",", 2) + if len(parts) != 2 { + return "", 0, fmt.Errorf("invalid format: %q", s) + } + + peerID = parts[0] + readyCount, err = strconv.Atoi(parts[1]) + if err != nil { + return "", 0, err + } + return peerID, readyCount, nil + +} diff --git a/pkg/mpc/session.go b/pkg/mpc/session.go index 66a6f8e..b1a76b5 100644 --- a/pkg/mpc/session.go +++ b/pkg/mpc/session.go @@ -3,6 +3,7 @@ package mpc import ( "encoding/json" "fmt" + "strings" "sync" "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" @@ -35,20 +36,22 @@ var ( type TopicComposer struct { ComposeBroadcastTopic func() string - ComposeDirectTopic func(nodeID string) string + ComposeDirectTopic func(fromID string, toID string) string } type KeyComposerFn func(id string) string type Session interface { ListenToIncomingMessageAsync() + ListenToPeersAsync(peerIDs []string) ErrChan() <-chan error } type session struct { - walletID string - pubSub messaging.PubSub - direct messaging.DirectMessaging + walletID string + pubSub messaging.PubSub + direct messaging.DirectMessaging + threshold int participantPeerIDs []string selfPartyID *tss.PartyID @@ -60,11 +63,12 @@ type session struct { version int // preParams is nil for EDDSA session - preParams *keygen.LocalPreParams - kvstore kvstore.KVStore - keyinfoStore keyinfo.Store - broadcastSub messaging.Subscription - directSub messaging.Subscription + preParams *keygen.LocalPreParams + kvstore kvstore.KVStore + keyinfoStore keyinfo.Store + broadcastSub messaging.Subscription + directSubs []messaging.Subscription + resultQueue messaging.MessageQueue identityStore identity.Store @@ -90,6 +94,7 @@ func (s *session) PartyCount() int { return len(s.partyIDs) } +// update: use AEAD encryption for each message so NATs server learns nothing func (s *session) handleTssMessage(keyshare tss.Message) { data, routing, err := keyshare.WireBytes() if err != nil { @@ -98,55 +103,121 @@ func (s *session) handleTssMessage(keyshare tss.Message) { } tssMsg := types.NewTssMessage(s.walletID, data, routing.IsBroadcast, routing.From, routing.To) - signature, err := s.identityStore.SignMessage(&tssMsg) - if err != nil { - s.ErrCh <- fmt.Errorf("failed to sign message: %w", err) - return - } - tssMsg.Signature = signature - msg, err := types.MarshalTssMessage(&tssMsg) - if err != nil { - s.ErrCh <- fmt.Errorf("failed to marshal tss message: %w", err) - return - } + toIDs := make([]string, len(routing.To)) for i, id := range routing.To { toIDs[i] = id.String() } - logger.Debug(fmt.Sprintf("%s Sending message", s.sessionType), "from", s.selfPartyID.String(), "to", toIDs, "isBroadcast", routing.IsBroadcast) + logger.Debug( + fmt.Sprintf("%s Sending message", s.sessionType), + "from", + s.selfPartyID.String(), + "to", + toIDs, + "isBroadcast", + routing.IsBroadcast, + ) + + // Broadcast message if routing.IsBroadcast && len(routing.To) == 0 { - err := s.pubSub.Publish(s.topicComposer.ComposeBroadcastTopic(), msg) + signature, err := s.identityStore.SignMessage(&tssMsg) // attach signature + if err != nil { + s.ErrCh <- fmt.Errorf("failed to sign message: %w", err) + return + } + tssMsg.Signature = signature + msg, err := types.MarshalTssMessage(&tssMsg) + if err != nil { + s.ErrCh <- fmt.Errorf("failed to marshal tss message: %w", err) + return + } + + err = s.pubSub.Publish(s.topicComposer.ComposeBroadcastTopic(), msg) if err != nil { s.ErrCh <- err return } } else { + // p2p message + msg, err := types.MarshalTssMessage(&tssMsg) // without signature + if err != nil { + s.ErrCh <- fmt.Errorf("failed to marshal tss message: %w", err) + return + } + + selfID := partyIDToNodeID(s.selfPartyID) for _, to := range routing.To { - nodeID := PartyIDToRoutingDest(to) - topic := s.topicComposer.ComposeDirectTopic(nodeID) - err := s.direct.Send(topic, msg) - if err != nil { - logger.Error("Failed to send direct message to", err, "topic", topic) - s.ErrCh <- fmt.Errorf("Failed to send direct message to %s", topic) + toNodeID := partyIDToNodeID(to) + topic := s.topicComposer.ComposeDirectTopic(selfID, toNodeID) + if selfID == toNodeID { + err := s.direct.SendToSelf(topic, msg) + if err != nil { + logger.Error("Failed in SendToSelf direct message", err, "topic", topic) + s.ErrCh <- fmt.Errorf("failed to send direct message to %s", topic) + } + } else { + cipher, err := s.identityStore.EncryptMessage(msg, toNodeID) + if err != nil { + s.ErrCh <- fmt.Errorf("encrypt tss message error %w", err) + logger.Error("Encrypt tss message error", err, "topic", topic) + } + err = s.direct.SendToOther(topic, cipher) + if err != nil { + logger.Error("Failed in SendToOther direct message", err, "topic", topic) + s.ErrCh <- fmt.Errorf("failed to send direct message to %w", err) + } } - } + } +} + +func (s *session) receiveP2PTssMessage(topic string, cipher []byte) { + senderID := extractSenderIDFromDirectTopic(topic) + if senderID == "" { + s.ErrCh <- fmt.Errorf("failed to extract senderID from direct topic: the direct topic format is wrong") + return + } + var plaintext []byte + var err error + + if senderID == partyIDToNodeID(s.selfPartyID) { + plaintext = cipher // to self, no decryption needed + } else { + plaintext, err = s.identityStore.DecryptMessage(cipher, senderID) + if err != nil { + s.ErrCh <- fmt.Errorf("failed to decrypt message: %w, tampered message", err) + return + } } + msg, err := types.UnmarshalTssMessage(plaintext) + if err != nil { + s.ErrCh <- fmt.Errorf("failed to unmarshal message: %w", err) + return + } + + s.receiveTssMessage(msg) } -func (s *session) receiveTssMessage(rawMsg []byte) { +func (s *session) receiveBroadcastTssMessage(rawMsg []byte) { + msg, err := types.UnmarshalTssMessage(rawMsg) if err != nil { - s.ErrCh <- fmt.Errorf("Failed to unmarshal message: %w", err) + s.ErrCh <- fmt.Errorf("failed to unmarshal message: %w", err) return } + err = s.identityStore.VerifyMessage(msg) if err != nil { s.ErrCh <- fmt.Errorf("Failed to verify message: %w, tampered message", err) return } + s.receiveTssMessage(msg) +} + +// update: the logic of receiving message should be modified +func (s *session) receiveTssMessage(msg *types.TssMessage) { toIDs := make([]string, len(msg.To)) for i, id := range msg.To { toIDs[i] = id.String() @@ -157,11 +228,23 @@ func (s *session) receiveTssMessage(rawMsg []byte) { s.ErrCh <- errors.Wrap(err, "Broken TSS Share") return } - logger.Debug("Received message", "round", round.RoundMsg, "isBroadcast", msg.IsBroadcast, "to", toIDs, "from", msg.From.String(), "self", s.selfPartyID.String()) + logger.Debug( + "Received message", + "round", + round.RoundMsg, + "isBroadcast", + msg.IsBroadcast, + "to", + toIDs, + "from", + msg.From.String(), + "self", + s.selfPartyID.String(), + ) isBroadcast := msg.IsBroadcast && len(msg.To) == 0 var isToSelf bool for _, to := range msg.To { - if ComparePartyIDs(to, s.selfPartyID) { + if comparePartyIDs(to, s.selfPartyID) { isToSelf = true break } @@ -175,35 +258,56 @@ func (s *session) receiveTssMessage(rawMsg []byte) { logger.Error("Failed to update party", err, "walletID", s.walletID) return } + } +} + +func (s *session) subscribeDirectTopicAsync(topic string) error { + t := topic // avoid capturing the changing loop variable + sub, err := s.direct.Listen(t, func(cipher []byte) { + // async to avoid timeouts in handlers + go s.receiveP2PTssMessage(t, cipher) + }) + if err != nil { + return fmt.Errorf("Failed to subscribe to direct topic %s: %w", t, err) + } + s.directSubs = append(s.directSubs, sub) + return nil +} +func (s *session) subscribeFromPeersAsync(fromIDs []string) { + toID := partyIDToNodeID(s.selfPartyID) + for _, fromID := range fromIDs { + topic := s.topicComposer.ComposeDirectTopic(fromID, toID) + if err := s.subscribeDirectTopicAsync(topic); err != nil { + s.ErrCh <- err + } } } -func (s *session) ListenToIncomingMessageAsync() { +func (s *session) subscribeBroadcastAsync() { go func() { - sub, err := s.pubSub.Subscribe(s.topicComposer.ComposeBroadcastTopic(), func(natMsg *nats.Msg) { - msg := natMsg.Data - s.receiveTssMessage(msg) + topic := s.topicComposer.ComposeBroadcastTopic() + sub, err := s.pubSub.Subscribe(topic, func(natMsg *nats.Msg) { + s.receiveBroadcastTssMessage(natMsg.Data) }) - if err != nil { - s.ErrCh <- fmt.Errorf("Failed to subscribe to broadcast topic %s: %w", s.topicComposer.ComposeBroadcastTopic(), err) + s.ErrCh <- fmt.Errorf("Failed to subscribe to broadcast topic %s: %w", topic, err) return } - s.broadcastSub = sub }() +} - nodeID := PartyIDToRoutingDest(s.selfPartyID) - targetID := s.topicComposer.ComposeDirectTopic(nodeID) - sub, err := s.direct.Listen(targetID, func(msg []byte) { - go s.receiveTssMessage(msg) // async for avoid timeout - }) - if err != nil { - s.ErrCh <- fmt.Errorf("Failed to subscribe to direct topic %s: %w", targetID, err) - } - s.directSub = sub +func (s *session) ListenToIncomingMessageAsync() { + // 1) broadcast + s.subscribeBroadcastAsync() + // 2) direct from peers in this session's partyIDs (includes self) + s.subscribeFromPeersAsync(partyIDsToNodeIDs(s.partyIDs)) +} + +func (s *session) ListenToPeersAsync(peerIDs []string) { + s.subscribeFromPeersAsync(peerIDs) } func (s *session) Close() error { @@ -211,10 +315,14 @@ func (s *session) Close() error { if err != nil { return err } - err = s.directSub.Unsubscribe() - if err != nil { - return err + + for _, sub := range s.directSubs { + err = sub.Unsubscribe() + if err != nil { + return err + } } + return nil } @@ -273,3 +381,13 @@ func walletIDWithVersion(walletID string, version int) string { } return walletID } + +func extractSenderIDFromDirectTopic(topic string) string { + // E.g: keygen:direct:ecdsa::: + parts := strings.SplitN(topic, ":", 5) + if len(parts) >= 4 { + return parts[3] + } + + return "" +} diff --git a/pkg/types/ecdh.go b/pkg/types/ecdh.go new file mode 100644 index 0000000..cfc4295 --- /dev/null +++ b/pkg/types/ecdh.go @@ -0,0 +1,26 @@ +package types + +import ( + "encoding/json" + "time" +) + +type ECDHMessage struct { + From string `json:"from"` + PublicKey []byte `json:"public_key"` + Timestamp time.Time `json:"timestamp"` + Signature []byte `json:"signature"` +} + +// MarshalForSigning returns the deterministic JSON bytes for signing +func (msg *ECDHMessage) MarshalForSigning() ([]byte, error) { + // Create a map with ordered keys + signingData := map[string]interface{}{ + "from": msg.From, + "publicKey": msg.PublicKey, + "timestamp": msg.Timestamp, + } + + // Use json.Marshal with sorted keys + return json.Marshal(signingData) +} diff --git a/pkg/types/tss.go b/pkg/types/tss.go index c2d54df..ed574d7 100644 --- a/pkg/types/tss.go +++ b/pkg/types/tss.go @@ -1,10 +1,7 @@ -// The Licensed Work is (c) 2022 Sygma -// SPDX-License-Identifier: LGPL-3.0-only package types import ( "encoding/json" - "sort" "github.com/bnb-chain/tss-lib/v2/tss" ) @@ -126,13 +123,3 @@ func (msg *TssMessage) MarshalForSigning() ([]byte, error) { // Use json.Marshal with sorted keys return json.Marshal(signingData) } - -// Helper function to get sorted party IDs -func getPartyIDs(parties []*tss.PartyID) []string { - ids := make([]string, len(parties)) - for i, party := range parties { - ids[i] = party.Id - } - sort.Strings(ids) // Ensure deterministic order - return ids -} diff --git a/pkg/types/tss_test.go b/pkg/types/tss_test.go index 702ac07..9fa7fc6 100644 --- a/pkg/types/tss_test.go +++ b/pkg/types/tss_test.go @@ -170,27 +170,3 @@ func TestUnmarshalStartMessage_InvalidJSON(t *testing.T) { _, err := UnmarshalStartMessage(invalidJSON) assert.Error(t, err) } - -func TestGetPartyIDs(t *testing.T) { - parties := []*tss.PartyID{ - { - MessageWrapper_PartyID: &tss.MessageWrapper_PartyID{ - Id: "party3", - }, - }, - { - MessageWrapper_PartyID: &tss.MessageWrapper_PartyID{ - Id: "party1", - }, - }, - { - MessageWrapper_PartyID: &tss.MessageWrapper_PartyID{ - Id: "party2", - }, - }, - } - - ids := getPartyIDs(parties) - expected := []string{"party1", "party2", "party3"} - assert.Equal(t, expected, ids) -}