diff --git a/cmd/mpcium/main.go b/cmd/mpcium/main.go index cff7b70..11808e2 100644 --- a/cmd/mpcium/main.go +++ b/cmd/mpcium/main.go @@ -7,6 +7,7 @@ import ( "os/signal" "path/filepath" "syscall" + "time" "github.com/fystack/mpcium/pkg/config" "github.com/fystack/mpcium/pkg/constant" @@ -130,14 +131,17 @@ func runNode(ctx context.Context, c *cli.Command) error { directMessaging := messaging.NewNatsDirectMessaging(natsConn) mqManager := messaging.NewNATsMessageQueueManager("mpc", []string{ - "mpc.mpc_keygen_success.*", + "mpc.mpc_keygen_result.*", event.SigningResultTopic, + "mpc.mpc_reshare_result.*", }, natsConn) - genKeySuccessQueue := mqManager.NewMessageQueue("mpc_keygen_success") - defer genKeySuccessQueue.Close() - singingResultQueue := mqManager.NewMessageQueue("signing_result") + genKeyResultQueue := mqManager.NewMessageQueue("mpc_keygen_result") + defer genKeyResultQueue.Close() + singingResultQueue := mqManager.NewMessageQueue("mpc_signing_result") defer singingResultQueue.Close() + reshareResultQueue := mqManager.NewMessageQueue("mpc_reshare_result") + defer reshareResultQueue.Close() logger.Info("Node is running", "peerID", nodeID, "name", nodeName) @@ -159,8 +163,9 @@ func runNode(ctx context.Context, c *cli.Command) error { eventConsumer := eventconsumer.NewEventConsumer( mpcNode, pubsub, - genKeySuccessQueue, + genKeyResultQueue, singingResultQueue, + reshareResultQueue, identityStore, ) eventConsumer.Run() @@ -173,7 +178,7 @@ func runNode(ctx context.Context, c *cli.Command) error { timeoutConsumer.Run() defer timeoutConsumer.Close() - signingConsumer := eventconsumer.NewSigningConsumer(natsConn, signingStream, pubsub) + signingConsumer := eventconsumer.NewSigningConsumer(natsConn, signingStream, pubsub, peerRegistry) // Make the node ready before starting the signing consumer peerRegistry.Ready() @@ -340,16 +345,32 @@ func NewBadgerKV(nodeName string) *kvstore.BadgerKVStore { } func GetNATSConnection(environment string) (*nats.Conn, error) { - if environment != constant.EnvProduction { - return nats.Connect(viper.GetString("nats.url")) + url := viper.GetString("nats.url") + opts := []nats.Option{ + nats.MaxReconnects(-1), // retry forever + nats.ReconnectWait(2 * time.Second), + nats.DisconnectHandler(func(nc *nats.Conn) { + logger.Warn("Disconnected from NATS") + }), + nats.ReconnectHandler(func(nc *nats.Conn) { + logger.Info("Reconnected to NATS", "url", nc.ConnectedUrl()) + }), + nats.ClosedHandler(func(nc *nats.Conn) { + logger.Info("NATS connection closed!") + }), } - clientCert := filepath.Join(".", "certs", "client-cert.pem") - clientKey := filepath.Join(".", "certs", "client-key.pem") - caCert := filepath.Join(".", "certs", "rootCA.pem") - - return nats.Connect(viper.GetString("nats.url"), - nats.ClientCert(clientCert, clientKey), - nats.RootCAs(caCert), - nats.UserInfo(viper.GetString("nats.username"), viper.GetString("nats.password")), - ) + + if environment == constant.EnvProduction { + clientCert := filepath.Join(".", "certs", "client-cert.pem") + clientKey := filepath.Join(".", "certs", "client-key.pem") + caCert := filepath.Join(".", "certs", "rootCA.pem") + + opts = append(opts, + nats.ClientCert(clientCert, clientKey), + nats.RootCAs(caCert), + nats.UserInfo(viper.GetString("nats.username"), viper.GetString("nats.password")), + ) + } + + return nats.Connect(url, opts...) } diff --git a/examples/generate/main.go b/examples/generate/main.go index 9511418..343a418 100644 --- a/examples/generate/main.go +++ b/examples/generate/main.go @@ -60,7 +60,7 @@ func main() { } // STEP 2: Register the result handler AFTER all walletIDs are stored - err = mpcClient.OnWalletCreationResult(func(event event.KeygenSuccessEvent) { + err = mpcClient.OnWalletCreationResult(func(event event.KeygenResultEvent) { now := time.Now() startTimeAny, ok := walletStartTimes.Load(event.WalletID) if ok { diff --git a/examples/reshare/main.go b/examples/reshare/main.go new file mode 100644 index 0000000..880868e --- /dev/null +++ b/examples/reshare/main.go @@ -0,0 +1,69 @@ +package main + +import ( + "fmt" + "os" + "os/signal" + "syscall" + + "github.com/fystack/mpcium/pkg/client" + "github.com/fystack/mpcium/pkg/config" + "github.com/fystack/mpcium/pkg/event" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/types" + "github.com/google/uuid" + "github.com/nats-io/nats.go" + "github.com/spf13/viper" +) + +func main() { + const environment = "dev" + config.InitViperConfig() + logger.Init(environment, true) + + natsURL := viper.GetString("nats.url") + natsConn, err := nats.Connect(natsURL) + if err != nil { + logger.Fatal("Failed to connect to NATS", err) + } + defer natsConn.Drain() + defer natsConn.Close() + + mpcClient := client.NewMPCClient(client.Options{ + NatsConn: natsConn, + KeyPath: "./event_initiator.key", + }) + + // 3) Listen for signing results + err = mpcClient.OnResharingResult(func(evt event.ResharingResultEvent) { + logger.Info("Resharing result received", + "walletID", evt.WalletID, + "pubKey", fmt.Sprintf("%x", evt.PubKey), + "newThreshold", evt.NewThreshold, + "keyType", evt.KeyType, + ) + }) + if err != nil { + logger.Fatal("Failed to subscribe to OnResharingResult", err) + } + + resharingMsg := &types.ResharingMessage{ + SessionID: uuid.NewString(), + WalletID: "bf2cc849-8e55-47e4-ab73-e17fb1eb690c", + NodeIDs: []string{"d926fa75-72c7-4538-9052-4a064a84981d", "7b1090cd-ffe3-46ff-8375-594dd3204169"}, // new peer IDs + + NewThreshold: 2, // t+1 <= len(NodeIDs) + KeyType: types.KeyTypeEd25519, + } + err = mpcClient.Resharing(resharingMsg) + if err != nil { + logger.Fatal("Resharing failed", err) + } + fmt.Printf("Resharing(%q) sent, awaiting result...\n", resharingMsg.WalletID) + + stop := make(chan os.Signal, 1) + signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) + <-stop + + fmt.Println("Shutting down.") +} diff --git a/go.mod b/go.mod index fb45e42..2d96a3e 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,6 @@ require ( github.com/spf13/viper v1.18.0 github.com/stretchr/testify v1.10.0 github.com/urfave/cli/v3 v3.3.2 - go.uber.org/mock v0.3.0 golang.org/x/term v0.31.0 ) @@ -68,7 +67,7 @@ require ( github.com/pelletier/go-toml/v2 v2.1.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/rogpeppe/go-internal v1.13.1 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect diff --git a/go.sum b/go.sum index 3c448e3..df1895b 100644 --- a/go.sum +++ b/go.sum @@ -310,6 +310,8 @@ github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+Gx github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A= github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= @@ -373,8 +375,6 @@ go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo= -go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= diff --git a/pkg/client/client.go b/pkg/client/client.go index bccd748..fe61bb4 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -20,23 +20,28 @@ import ( ) const ( - GenerateWalletSuccessTopic = "mpc.mpc_keygen_success.*" // wildcard to listen to all success events + GenerateWalletSuccessTopic = "mpc.mpc_keygen_result.*" // wildcard to listen to all success events + ResharingSuccessTopic = "mpc.mpc_reshare_result.*" // wildcard to listen to all success events ) type MPCClient interface { CreateWallet(walletID string) error - OnWalletCreationResult(callback func(event event.KeygenSuccessEvent)) error + OnWalletCreationResult(callback func(event event.KeygenResultEvent)) error SignTransaction(msg *types.SignTxMessage) error OnSignResult(callback func(event event.SigningResultEvent)) error + + Resharing(msg *types.ResharingMessage) error + OnResharingResult(callback func(event event.ResharingResultEvent)) error } type mpcClient struct { - signingStream messaging.StreamPubsub - pubsub messaging.PubSub - genKeySuccessQueue messaging.MessageQueue - signResultQueue messaging.MessageQueue - privKey ed25519.PrivateKey + signingStream messaging.StreamPubsub + pubsub messaging.PubSub + genKeySuccessQueue messaging.MessageQueue + signResultQueue messaging.MessageQueue + reshareSuccessQueue messaging.MessageQueue + privKey ed25519.PrivateKey } // Options defines configuration options for creating a new MPCClient @@ -120,19 +125,22 @@ func NewMPCClient(opts Options) MPCClient { pubsub := messaging.NewNATSPubSub(opts.NatsConn) manager := messaging.NewNATsMessageQueueManager("mpc", []string{ - "mpc.mpc_keygen_success.*", - "mpc.signing_result.*", + "mpc.mpc_keygen_result.*", + "mpc.mpc_signing_result.*", + "mpc.mpc_reshare_result.*", }, opts.NatsConn) - genKeySuccessQueue := manager.NewMessageQueue("mpc_keygen_success") - signResultQueue := manager.NewMessageQueue("signing_result") + genKeySuccessQueue := manager.NewMessageQueue("mpc_keygen_result") + signResultQueue := manager.NewMessageQueue("mpc_signing_result") + reshareSuccessQueue := manager.NewMessageQueue("mpc_reshare_result") return &mpcClient{ - signingStream: signingStream, - pubsub: pubsub, - genKeySuccessQueue: genKeySuccessQueue, - signResultQueue: signResultQueue, - privKey: priv, + signingStream: signingStream, + pubsub: pubsub, + genKeySuccessQueue: genKeySuccessQueue, + signResultQueue: signResultQueue, + reshareSuccessQueue: reshareSuccessQueue, + privKey: priv, } } @@ -185,9 +193,9 @@ func (c *mpcClient) CreateWallet(walletID string) error { } // The callback will be invoked whenever a wallet creation result is received. -func (c *mpcClient) OnWalletCreationResult(callback func(event event.KeygenSuccessEvent)) error { +func (c *mpcClient) OnWalletCreationResult(callback func(event event.KeygenResultEvent)) error { err := c.genKeySuccessQueue.Dequeue(GenerateWalletSuccessTopic, func(msg []byte) error { - var event event.KeygenSuccessEvent + var event event.KeygenResultEvent err := json.Unmarshal(msg, &event) if err != nil { return err @@ -241,3 +249,45 @@ func (c *mpcClient) OnSignResult(callback func(event event.SigningResultEvent)) return nil } + +func (c *mpcClient) Resharing(msg *types.ResharingMessage) error { + // compute the canonical raw bytes + raw, err := msg.Raw() + if err != nil { + return fmt.Errorf("Resharing: raw payload error: %w", err) + } + // sign + msg.Signature = ed25519.Sign(c.privKey, raw) + + bytes, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("Resharing: marshal error: %w", err) + } + + if err := c.pubsub.Publish(eventconsumer.MPCReshareEvent, bytes); err != nil { + return fmt.Errorf("Resharing: publish error: %w", err) + } + return nil +} + +func (c *mpcClient) OnResharingResult(callback func(event event.ResharingResultEvent)) error { + + err := c.reshareSuccessQueue.Dequeue(ResharingSuccessTopic, func(msg []byte) error { + logger.Info("Received reshare success message", "raw", string(msg)) + var event event.ResharingResultEvent + err := json.Unmarshal(msg, &event) + if err != nil { + logger.Error("Failed to unmarshal reshare success event", err, "raw", string(msg)) + return err + } + logger.Info("Deserialized reshare success event", "event", event) + callback(event) + return nil + }) + + if err != nil { + return fmt.Errorf("OnResharingResult: subscribe error: %w", err) + } + + return nil +} diff --git a/pkg/event/generate.go b/pkg/event/generate.go deleted file mode 100644 index eba2bcc..0000000 --- a/pkg/event/generate.go +++ /dev/null @@ -1,7 +0,0 @@ -package event - -type KeygenSuccessEvent struct { - WalletID string `json:"wallet_id"` - ECDSAPubKey []byte `json:"ecdsa_pub_key"` - EDDSAPubKey []byte `json:"eddsa_pub_key"` -} diff --git a/pkg/event/keygen.go b/pkg/event/keygen.go new file mode 100644 index 0000000..9aa86da --- /dev/null +++ b/pkg/event/keygen.go @@ -0,0 +1,11 @@ +package event + +type KeygenResultEvent struct { + WalletID string `json:"wallet_id"` + ECDSAPubKey []byte `json:"ecdsa_pub_key"` + EDDSAPubKey []byte `json:"eddsa_pub_key"` + + ResultType ResultType `json:"result_type"` + ErrorReason string `json:"error_reason"` + ErrorCode string `json:"error_code"` +} diff --git a/pkg/event/reshare.go b/pkg/event/reshare.go new file mode 100644 index 0000000..3615388 --- /dev/null +++ b/pkg/event/reshare.go @@ -0,0 +1,14 @@ +package event + +import "github.com/fystack/mpcium/pkg/types" + +type ResharingResultEvent struct { + WalletID string `json:"wallet_id"` + NewThreshold int `json:"new_threshold"` + KeyType types.KeyType `json:"key_type"` + PubKey []byte `json:"pub_key"` + + ResultType ResultType `json:"result_type"` + ErrorReason string `json:"error_reason"` + ErrorCode string `json:"error_code"` +} diff --git a/pkg/event/sign.go b/pkg/event/sign.go index cb8d53d..4a376c4 100644 --- a/pkg/event/sign.go +++ b/pkg/event/sign.go @@ -4,30 +4,23 @@ const ( SigningPublisherStream = "mpc-signing" SigningConsumerStream = "mpc-signing-consumer" SigningRequestTopic = "mpc.signing_request.*" - SigningResultTopic = "mpc.signing_result.*" - SigningResultCompleteTopic = "mpc.signing_result.complete" + SigningResultTopic = "mpc.mpc_signing_result.*" + SigningResultCompleteTopic = "mpc.mpc_signing_result.complete" MPCSigningEventTopic = "mpc:sign" SigningRequestEventTopic = "mpc.signing_request.event" ) -type SigningResultType int - -const ( - SigningResultTypeUnknown SigningResultType = iota - SigningResultTypeSuccess - SigningResultTypeError -) - type SigningResultEvent struct { - ResultType SigningResultType `json:"result_type"` - ErrorReason string `json:"error_reason"` - IsTimeout bool `json:"is_timeout"` - NetworkInternalCode string `json:"network_internal_code"` - WalletID string `json:"wallet_id"` - TxID string `json:"tx_id"` - R []byte `json:"r"` - S []byte `json:"s"` - SignatureRecovery []byte `json:"signature_recovery"` + ResultType ResultType `json:"result_type"` + ErrorCode ErrorCode `json:"error_code"` + ErrorReason string `json:"error_reason"` + IsTimeout bool `json:"is_timeout"` + NetworkInternalCode string `json:"network_internal_code"` + WalletID string `json:"wallet_id"` + TxID string `json:"tx_id"` + R []byte `json:"r"` + S []byte `json:"s"` + SignatureRecovery []byte `json:"signature_recovery"` // TODO: define two separate events for eddsa and ecdsa Signature []byte `json:"signature"` @@ -46,9 +39,10 @@ type SigningResultSuccessEvent struct { } type SigningResultErrorEvent struct { - NetworkInternalCode string `json:"network_internal_code"` - WalletID string `json:"wallet_id"` - TxID string `json:"tx_id"` - ErrorReason string `json:"error_reason"` - IsTimeout bool `json:"is_timeout"` + NetworkInternalCode string `json:"network_internal_code"` + WalletID string `json:"wallet_id"` + TxID string `json:"tx_id"` + ErrorCode ErrorCode `json:"error_code"` + ErrorReason string `json:"error_reason"` + IsTimeout bool `json:"is_timeout"` } diff --git a/pkg/event/types.go b/pkg/event/types.go new file mode 100644 index 0000000..c24c28a --- /dev/null +++ b/pkg/event/types.go @@ -0,0 +1,151 @@ +package event + +import "strings" + +type ResultType string + +const ( + ResultTypeSuccess ResultType = "success" + ResultTypeError ResultType = "error" +) + +// ErrorCode defines specific error types that can occur in MPC operations +type ErrorCode string + +const ( + // Generic/Unknown errors + ErrorCodeUnknown ErrorCode = "ERROR_UNKNOWN" + + // Network and connectivity errors + ErrorCodeNetworkTimeout ErrorCode = "ERROR_NETWORK_TIMEOUT" + ErrorCodeNetworkConnection ErrorCode = "ERROR_NETWORK_CONNECTION" + ErrorCodeNetworkSubscription ErrorCode = "ERROR_NETWORK_SUBSCRIPTION" + ErrorCodeMessageRouting ErrorCode = "ERROR_MESSAGE_ROUTING" + ErrorCodeDirectMessaging ErrorCode = "ERROR_DIRECT_MESSAGING" + + // Session errors + ErrorCodeSessionTimeout ErrorCode = "ERROR_SESSION_TIMEOUT" + ErrorCodeSessionCreation ErrorCode = "ERROR_SESSION_CREATION" + ErrorCodeSessionInitialization ErrorCode = "ERROR_SESSION_INITIALIZATION" + ErrorCodeSessionCleanup ErrorCode = "ERROR_SESSION_CLEANUP" + ErrorCodeSessionDuplicate ErrorCode = "ERROR_SESSION_DUPLICATE" + ErrorCodeSessionStale ErrorCode = "ERROR_SESSION_STALE" + + // Participant and peer errors + ErrorCodeInsufficientParticipants ErrorCode = "ERROR_INSUFFICIENT_PARTICIPANTS" + ErrorCodeIncompatiblePeerIDs ErrorCode = "ERROR_INCOMPATIBLE_PEER_IDS" + ErrorCodePeerNotReady ErrorCode = "ERROR_PEER_NOT_READY" + ErrorCodePeerUnavailable ErrorCode = "ERROR_PEER_UNAVAILABLE" + ErrorCodeParticipantNotFound ErrorCode = "ERROR_PARTICIPANT_NOT_FOUND" + + // Key management errors + ErrorCodeKeyNotFound ErrorCode = "ERROR_KEY_NOT_FOUND" + ErrorCodeKeyAlreadyExists ErrorCode = "ERROR_KEY_ALREADY_EXISTS" + ErrorCodeKeyGeneration ErrorCode = "ERROR_KEY_GENERATION" + ErrorCodeKeySave ErrorCode = "ERROR_KEY_SAVE" + ErrorCodeKeyLoad ErrorCode = "ERROR_KEY_LOAD" + ErrorCodeKeyInfoSave ErrorCode = "ERROR_KEY_INFO_SAVE" + ErrorCodeKeyInfoLoad ErrorCode = "ERROR_KEY_INFO_LOAD" + ErrorCodeKeyEncoding ErrorCode = "ERROR_KEY_ENCODING" + ErrorCodeKeyDecoding ErrorCode = "ERROR_KEY_DECODING" + ErrorCodeMsgValidation ErrorCode = "ERROR_MSG_VALIDATION" + + // Cryptographic operation errors + ErrorCodeSignatureGeneration ErrorCode = "ERROR_SIGNATURE_GENERATION" + ErrorCodeSignatureVerification ErrorCode = "ERROR_SIGNATURE_VERIFICATION" + ErrorCodePreParamsGeneration ErrorCode = "ERROR_PRE_PARAMS_GENERATION" + ErrorCodeTSSPartyCreation ErrorCode = "ERROR_TSS_PARTY_CREATION" + + // Data serialization errors + ErrorCodeMarshalFailure ErrorCode = "ERROR_MARSHAL_FAILURE" + ErrorCodeUnmarshalFailure ErrorCode = "ERROR_UNMARSHAL_FAILURE" + ErrorCodeDataCorruption ErrorCode = "ERROR_DATA_CORRUPTION" + + // Storage errors + ErrorCodeStorageRead ErrorCode = "ERROR_STORAGE_READ" + ErrorCodeStorageWrite ErrorCode = "ERROR_STORAGE_WRITE" + ErrorCodeStorageInit ErrorCode = "ERROR_STORAGE_INIT" + + // Message and verification errors + ErrorCodeMessageVerification ErrorCode = "ERROR_MESSAGE_VERIFICATION" + ErrorCodeMessageFormat ErrorCode = "ERROR_MESSAGE_FORMAT" + ErrorCodeMessageDelivery ErrorCode = "ERROR_MESSAGE_DELIVERY" + ErrorCodeMaxDeliveryAttempts ErrorCode = "ERROR_MAX_DELIVERY_ATTEMPTS" + + // Configuration errors + ErrorCodeInvalidConfiguration ErrorCode = "ERROR_INVALID_CONFIGURATION" + ErrorCodeInvalidThreshold ErrorCode = "ERROR_INVALID_THRESHOLD" + ErrorCodeInvalidSessionType ErrorCode = "ERROR_INVALID_SESSION_TYPE" + + // Resource errors + ErrorCodeResourceExhausted ErrorCode = "ERROR_RESOURCE_EXHAUSTED" + ErrorCodeMemoryAllocation ErrorCode = "ERROR_MEMORY_ALLOCATION" + ErrorCodeConcurrencyLimit ErrorCode = "ERROR_CONCURRENCY_LIMIT" + + // Operation-specific errors + ErrorCodeKeygenFailure ErrorCode = "ERROR_KEYGEN_FAILURE" + ErrorCodeSigningFailure ErrorCode = "ERROR_SIGNING_FAILURE" + ErrorCodeReshareFailure ErrorCode = "ERROR_RESHARE_FAILURE" + + // Context and cancellation errors + ErrorCodeContextCancelled ErrorCode = "ERROR_CONTEXT_CANCELLED" + ErrorCodeOperationAborted ErrorCode = "ERROR_OPERATION_ABORTED" +) + +// GetErrorCodeFromError attempts to categorize a generic error into a specific error code +func GetErrorCodeFromError(err error) ErrorCode { + if err == nil { + return "" + } + + errStr := err.Error() + + // Check for specific error patterns + switch { + case contains(errStr, "validation"): + return ErrorCodeMsgValidation + case contains(errStr, "timeout", "timed out"): + return ErrorCodeNetworkTimeout + case contains(errStr, "connection", "connect"): + return ErrorCodeNetworkConnection + case contains(errStr, "send"): + return ErrorCodePeerUnavailable + case contains(errStr, "not enough", "insufficient"): + return ErrorCodeInsufficientParticipants + case contains(errStr, "incompatible"): + return ErrorCodeIncompatiblePeerIDs + case contains(errStr, "key not found", "no such key"): + return ErrorCodeKeyNotFound + case contains(errStr, "exists"): + return ErrorCodeKeyAlreadyExists + case contains(errStr, "marshal"): + return ErrorCodeMarshalFailure + case contains(errStr, "unmarshal"): + return ErrorCodeUnmarshalFailure + case contains(errStr, "storage", "kvstore"): + return ErrorCodeStorageRead + case contains(errStr, "save", "put"): + return ErrorCodeStorageWrite + case contains(errStr, "session"): + return ErrorCodeSessionCreation + case contains(errStr, "verify", "verification"): + return ErrorCodeMessageVerification + case contains(errStr, "delivery", "deliver"): + return ErrorCodeMessageDelivery + case contains(errStr, "context", "cancelled"): + return ErrorCodeContextCancelled + default: + return ErrorCodeUnknown + } +} + +// Helper function for case-insensitive string matching +func contains(str string, patterns ...string) bool { + str = strings.ToLower(str) + for _, pattern := range patterns { + if strings.Contains(str, strings.ToLower(pattern)) { + return true + } + } + return false +} diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index c42be22..085699e 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -23,9 +23,12 @@ import ( const ( MPCGenerateEvent = "mpc:generate" MPCSignEvent = "mpc:sign" + MPCReshareEvent = "mpc:reshare" - DefaultConcurrentKeygen = 2 - DefaultKeyGenStartupDelayMs = 500 + DefaultConcurrentKeygen = 2 + DefaultSessionStartupDelay = 500 + + KeyGenTimeOut = 30 * time.Second ) type EventConsumer interface { @@ -38,11 +41,13 @@ type eventConsumer struct { pubsub messaging.PubSub mpcThreshold int - genKeySucecssQueue messaging.MessageQueue + genKeyResultQueue messaging.MessageQueue signingResultQueue messaging.MessageQueue + reshareResultQueue messaging.MessageQueue keyGenerationSub messaging.Subscription signingSub messaging.Subscription + reshareSub messaging.Subscription identityStore identity.Store msgBuffer chan *nats.Msg @@ -59,8 +64,9 @@ type eventConsumer struct { func NewEventConsumer( node *mpc.Node, pubsub messaging.PubSub, - genKeySucecssQueue messaging.MessageQueue, + genKeyResultQueue messaging.MessageQueue, signingResultQueue messaging.MessageQueue, + reshareResultQueue messaging.MessageQueue, identityStore identity.Store, ) EventConsumer { maxConcurrentKeygen := viper.GetInt("max_concurrent_keygen") @@ -71,8 +77,9 @@ func NewEventConsumer( ec := &eventConsumer{ node: node, pubsub: pubsub, - genKeySucecssQueue: genKeySucecssQueue, + genKeyResultQueue: genKeyResultQueue, signingResultQueue: signingResultQueue, + reshareResultQueue: reshareResultQueue, activeSessions: make(map[string]time.Time), cleanupInterval: 5 * time.Minute, // Run cleanup every 5 minutes sessionTimeout: 30 * time.Minute, // Consider sessions older than 30 minutes stale @@ -101,117 +108,180 @@ func (ec *eventConsumer) Run() { log.Fatal("Failed to consume tx signing event", err) } + err = ec.consumeReshareEvent() + if err != nil { + log.Fatal("Failed to consume reshare event", err) + } + logger.Info("MPC Event consumer started...!") } func (ec *eventConsumer) handleKeyGenEvent(natMsg *nats.Msg) { + baseCtx, baseCancel := context.WithTimeout(context.Background(), KeyGenTimeOut) + defer baseCancel() + raw := natMsg.Data var msg types.GenerateKeyMessage - err := json.Unmarshal(raw, &msg) - if err != nil { - logger.Error("Failed to unmarshal signing message", err) + if err := json.Unmarshal(raw, &msg); err != nil { + logger.Error("Failed to unmarshal keygen message", err) + ec.handleKeygenSessionError(msg.WalletID, err, "Failed to unmarshal keygen message") return } - err = ec.identityStore.VerifyInitiatorMessage(&msg) - if err != nil { + if err := ec.identityStore.VerifyInitiatorMessage(&msg); err != nil { logger.Error("Failed to verify initiator message", err) + ec.handleKeygenSessionError(msg.WalletID, err, "Failed to verify initiator message") return } walletID := msg.WalletID - ecdsaSession, err := ec.node.CreateKeyGenSession(mpc.SessionTypeECDSA, walletID, ec.mpcThreshold, ec.genKeySucecssQueue) + ecdsaSession, err := ec.node.CreateKeyGenSession(mpc.SessionTypeECDSA, walletID, ec.mpcThreshold, ec.genKeyResultQueue) if err != nil { - logger.Error("Failed to create key generation session", err, "walletID", walletID) + logger.Error("Failed to create ECDSA key generation session", err, "walletID", walletID) + ec.handleKeygenSessionError(walletID, err, "Failed to create ECDSA key generation session") return } - eddsaSession, err := ec.node.CreateKeyGenSession(mpc.SessionTypeEDDSA, walletID, ec.mpcThreshold, ec.genKeySucecssQueue) + eddsaSession, err := ec.node.CreateKeyGenSession(mpc.SessionTypeEDDSA, walletID, ec.mpcThreshold, ec.genKeyResultQueue) if err != nil { - logger.Error("Failed to create key generation session", err, "walletID", walletID) + logger.Error("Failed to create EdDSA key generation session", err, "walletID", walletID) + ec.handleKeygenSessionError(walletID, err, "Failed to create EdDSA key generation session") return } - ecdsaSession.Init() eddsaSession.Init() - ctx := context.Background() - ctxEcdsa, doneEcdsa := context.WithCancel(ctx) - ctxEddsa, doneEddsa := context.WithCancel(ctx) - - successEvent := &event.KeygenSuccessEvent{ - WalletID: walletID, - } + ctxEcdsa, doneEcdsa := context.WithCancel(baseCtx) + ctxEddsa, doneEddsa := context.WithCancel(baseCtx) + successEvent := &event.KeygenResultEvent{WalletID: walletID, ResultType: event.ResultTypeSuccess} var wg sync.WaitGroup wg.Add(2) + + // Channel to communicate errors from goroutines to main function + errorChan := make(chan error, 2) + go func() { - for { - select { - case <-ctxEcdsa.Done(): - successEvent.ECDSAPubKey = ecdsaSession.GetPubKeyResult() - wg.Done() - return - case err := <-ecdsaSession.ErrChan(): - logger.Error("Keygen session error", err) - } + defer wg.Done() + select { + case <-ctxEcdsa.Done(): + successEvent.ECDSAPubKey = ecdsaSession.GetPubKeyResult() + case err := <-ecdsaSession.ErrChan(): + logger.Error("ECDSA keygen session error", err) + ec.handleKeygenSessionError(walletID, err, "ECDSA keygen session error") + errorChan <- err + doneEcdsa() } }() - go func() { - for { - select { - case <-ctxEddsa.Done(): - successEvent.EDDSAPubKey = eddsaSession.GetPubKeyResult() - wg.Done() - return - case err := <-eddsaSession.ErrChan(): - logger.Error("Keygen session error", err) - } + defer wg.Done() + select { + case <-ctxEddsa.Done(): + successEvent.EDDSAPubKey = eddsaSession.GetPubKeyResult() + case err := <-eddsaSession.ErrChan(): + logger.Error("EdDSA keygen session error", err) + ec.handleKeygenSessionError(walletID, err, "EdDSA keygen session error") + errorChan <- err + doneEddsa() } }() ecdsaSession.ListenToIncomingMessageAsync() eddsaSession.ListenToIncomingMessageAsync() - // Temporary delay to allow peer nodes to subscribe and prepare before starting key generation. - // This should be replaced with a proper distributed coordination mechanism later (e.g., Consul lock). - time.Sleep(DefaultKeyGenStartupDelayMs * time.Millisecond) - + // Temporary delay for peer setup + time.Sleep(DefaultSessionStartupDelay * time.Millisecond) go ecdsaSession.GenerateKey(doneEcdsa) go eddsaSession.GenerateKey(doneEddsa) - wg.Wait() - logger.Info("Closing session successfully!", "event", successEvent) + // Wait for completion or timeout + doneAll := make(chan struct{}) + go func() { + wg.Wait() + close(doneAll) + }() + + select { + case <-doneAll: + // Check if any errors occurred during execution + select { + case <-errorChan: + // Error already handled by the goroutine, just return early + return + default: + // No errors, continue with success + } + case <-baseCtx.Done(): + // timeout occurred + logger.Warn("Key generation timed out", "walletID", walletID, "timeout", KeyGenTimeOut) + ec.handleKeygenSessionError(walletID, fmt.Errorf("keygen session timed out after %v", KeyGenTimeOut), "Key generation timed out") + return + } - successEventBytes, err := json.Marshal(successEvent) + logger.Info("Closing session successfully!", "event", successEvent) + payload, err := json.Marshal(successEvent) if err != nil { logger.Error("Failed to marshal keygen success event", err) + ec.handleKeygenSessionError(walletID, err, "Failed to marshal keygen success event") return } - err = ec.genKeySucecssQueue.Enqueue(fmt.Sprintf(mpc.TypeGenerateWalletSuccess, walletID), successEventBytes, &messaging.EnqueueOptions{ - IdempotententKey: fmt.Sprintf(mpc.TypeGenerateWalletSuccess, walletID), - }) - if err != nil { + key := fmt.Sprintf(mpc.TypeGenerateWalletResultFmt, walletID) + if err := ec.genKeyResultQueue.Enqueue(key, payload, &messaging.EnqueueOptions{IdempotententKey: key}); err != nil { logger.Error("Failed to publish key generation success message", err) + ec.handleKeygenSessionError(walletID, err, "Failed to publish key generation success message") return } - logger.Info("[COMPLETED KEY GEN] Key generation completed successfully", "walletID", walletID) } +// handleKeygenSessionError handles errors that occur during key generation +func (ec *eventConsumer) handleKeygenSessionError(walletID string, err error, contextMsg string) { + fullErrMsg := fmt.Sprintf("%s: %v", contextMsg, err) + errorCode := event.GetErrorCodeFromError(err) + + logger.Warn("Keygen session error", + "walletID", walletID, + "error", err.Error(), + "errorCode", errorCode, + "context", contextMsg, + ) + + keygenResult := event.KeygenResultEvent{ + ResultType: event.ResultTypeError, + ErrorCode: string(errorCode), + WalletID: walletID, + ErrorReason: fullErrMsg, + } + + keygenResultBytes, err := json.Marshal(keygenResult) + if err != nil { + logger.Error("Failed to marshal keygen result event", err, + "walletID", walletID, + ) + return + } + + key := fmt.Sprintf(mpc.TypeGenerateWalletResultFmt, walletID) + err = ec.genKeyResultQueue.Enqueue(key, keygenResultBytes, &messaging.EnqueueOptions{ + IdempotententKey: key, + }) + if err != nil { + logger.Error("Failed to enqueue keygen result event", err, + "walletID", walletID, + "payload", string(keygenResultBytes), + ) + } +} + func (ec *eventConsumer) startKeyGenEventWorker() { // semaphore to limit concurrency semaphore := make(chan struct{}, ec.maxConcurrentKeygen) - for { - select { - case natMsg := <-ec.msgBuffer: - semaphore <- struct{}{} // acquire a slot - go func(msg *nats.Msg) { - defer func() { <-semaphore }() // release the slot when done - ec.handleKeyGenEvent(msg) - }(natMsg) - } + for natMsg := range ec.msgBuffer { + semaphore <- struct{}{} // acquire a slot + go func(msg *nats.Msg) { + defer func() { <-semaphore }() // release the slot when done + ec.handleKeyGenEvent(msg) + }(natMsg) } } @@ -257,7 +327,6 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { // Check for duplicate session and track if new if ec.checkDuplicateSession(msg.WalletID, msg.TxID) { - natMsg.Term() return } @@ -269,29 +338,26 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { msg.WalletID, msg.TxID, msg.NetworkInternalCode, - ec.mpcThreshold, ec.signingResultQueue, ) case types.KeyTypeEd25519: session, err = ec.node.CreateSigningSession( - mpc.SessionTypeECDSA, + mpc.SessionTypeEDDSA, msg.WalletID, msg.TxID, msg.NetworkInternalCode, - ec.mpcThreshold, ec.signingResultQueue, ) } - if err != nil { + logger.Error("Failed to create signing session", err) ec.handleSigningSessionError( msg.WalletID, msg.TxID, msg.NetworkInternalCode, err, "Failed to create signing session", - natMsg, ) return } @@ -310,7 +376,6 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { msg.NetworkInternalCode, err, "Failed to init signing session", - natMsg, ) return } @@ -332,7 +397,6 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { msg.NetworkInternalCode, err, "Failed to sign tx", - natMsg, ) return } @@ -348,7 +412,7 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { // One solution: // The messaging includes mechanisms for direct point-to-point communication (in point2point.go). // The nodes could explicitly coordinate through request-response patterns before starting signing - time.Sleep(1 * time.Second) + time.Sleep(DefaultSessionStartupDelay * time.Millisecond) onSuccess := func(data []byte) { done() @@ -371,31 +435,254 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { return nil } +func (ec *eventConsumer) handleSigningSessionError(walletID, txID, networkInternalCode string, err error, contextMsg string) { + fullErrMsg := fmt.Sprintf("%s: %v", contextMsg, err) + errorCode := event.GetErrorCodeFromError(err) + + logger.Warn("Signing session error", + "walletID", walletID, + "txID", txID, + "networkInternalCode", networkInternalCode, + "error", err.Error(), + "errorCode", errorCode, + "context", contextMsg, + ) -func (ec *eventConsumer) handleSigningSessionError(walletID, txID, NetworkInternalCode string, err error, errMsg string, natMsg *nats.Msg) { - logger.Error("Signing session error", err, "walletID", walletID, "txID", txID, "error", errMsg) signingResult := event.SigningResultEvent{ - ResultType: event.SigningResultTypeError, - NetworkInternalCode: NetworkInternalCode, + ResultType: event.ResultTypeError, + ErrorCode: errorCode, + NetworkInternalCode: networkInternalCode, WalletID: walletID, TxID: txID, - ErrorReason: errMsg, + ErrorReason: fullErrMsg, } signingResultBytes, err := json.Marshal(signingResult) if err != nil { - logger.Error("Failed to marshal signing result event", err) + logger.Error("Failed to marshal signing result event", err, + "walletID", walletID, + "txID", txID, + ) return } - natMsg.Ack() err = ec.signingResultQueue.Enqueue(event.SigningResultCompleteTopic, signingResultBytes, &messaging.EnqueueOptions{ IdempotententKey: txID, }) if err != nil { - logger.Error("Failed to publish signing result event", err) + logger.Error("Failed to enqueue signing result event", err, + "walletID", walletID, + "txID", txID, + "payload", string(signingResultBytes), + ) + } +} +func (ec *eventConsumer) consumeReshareEvent() error { + sub, err := ec.pubsub.Subscribe(MPCReshareEvent, func(natMsg *nats.Msg) { + var msg types.ResharingMessage + if err := json.Unmarshal(natMsg.Data, &msg); err != nil { + logger.Error("Failed to unmarshal resharing message", err) + ec.handleReshareSessionError(msg.WalletID, msg.KeyType, msg.NewThreshold, err, "Failed to unmarshal resharing message") + return + } + + if msg.SessionID == "" { + ec.handleReshareSessionError( + msg.WalletID, + msg.KeyType, + msg.NewThreshold, + errors.New("validation: session ID is empty"), + "Session ID is empty", + ) + return + } + + if err := ec.identityStore.VerifyInitiatorMessage(&msg); err != nil { + logger.Error("Failed to verify initiator message", err) + ec.handleReshareSessionError(msg.WalletID, msg.KeyType, msg.NewThreshold, err, "Failed to verify initiator message") + return + } + + walletID := msg.WalletID + keyType := msg.KeyType + + sessionType, err := sessionTypeFromKeyType(keyType) + if err != nil { + logger.Error("Failed to get session type", err) + ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to get session type") + return + } + + createSession := func(isNewPeer bool) (mpc.ReshareSession, error) { + return ec.node.CreateReshareSession( + sessionType, + walletID, + ec.mpcThreshold, + msg.NewThreshold, + msg.NodeIDs, + isNewPeer, + ec.reshareResultQueue, + ) + } + + oldSession, err := createSession(false) + if err != nil { + logger.Error("Failed to create old reshare session", err, "walletID", walletID) + ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to create old reshare session") + return + } + newSession, err := createSession(true) + if err != nil { + logger.Error("Failed to create new reshare session", err, "walletID", walletID) + ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to create new reshare session") + return + } + + if oldSession == nil && newSession == nil { + logger.Info("Node is not participating in this reshare (neither old nor new)", "walletID", walletID) + return + } + + successEvent := &event.ResharingResultEvent{ + WalletID: walletID, + NewThreshold: msg.NewThreshold, + KeyType: msg.KeyType, + ResultType: event.ResultTypeSuccess, + } + + var wg sync.WaitGroup + ctx := context.Background() + + time.Sleep(DefaultSessionStartupDelay * time.Millisecond) + + if oldSession != nil { + ctxOld, doneOld := context.WithCancel(ctx) + oldSession.Init() + oldSession.ListenToIncomingMessageAsync() + go oldSession.Reshare(doneOld) + + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-ctxOld.Done(): + return + case err := <-oldSession.ErrChan(): + logger.Error("Old reshare session error", err) + ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Old reshare session error") + doneOld() // Cancel the context to stop this session + return + } + } + }() + } + + if newSession != nil { + ctxNew, doneNew := context.WithCancel(ctx) + newSession.Init() + newSession.ListenToIncomingMessageAsync() + go newSession.Reshare(doneNew) + + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-ctxNew.Done(): + successEvent.PubKey = newSession.GetPubKeyResult() + return + case err := <-newSession.ErrChan(): + logger.Error("New reshare session error", err) + ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "New reshare session error") + doneNew() // Cancel the context to stop this session + return + } + } + }() + } + + wg.Wait() + + logger.Info("Reshare session finished", "walletID", walletID, "pubKey", fmt.Sprintf("%x", successEvent.PubKey)) + + if newSession != nil { + successBytes, err := json.Marshal(successEvent) + if err != nil { + logger.Error("Failed to marshal reshare success event", err) + ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to marshal reshare success event") + return + } + + key := fmt.Sprintf(mpc.TypeReshareWalletResultFmt, msg.SessionID) + err = ec.reshareResultQueue.Enqueue( + key, + successBytes, + &messaging.EnqueueOptions{ + IdempotententKey: key, + }) + if err != nil { + logger.Error("Failed to publish reshare success message", err) + ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to publish reshare success message") + return + } + logger.Info("[COMPLETED RESHARE] Successfully published", "walletID", walletID) + } else { + logger.Info("[COMPLETED RESHARE] Done (not a new party)", "walletID", walletID) + } + }) + + ec.reshareSub = sub + return err +} + +// handleReshareSessionError handles errors that occur during reshare operations +func (ec *eventConsumer) handleReshareSessionError( + walletID string, + keyType types.KeyType, + newThreshold int, + err error, + contextMsg string, +) { + fullErrMsg := fmt.Sprintf("%s: %v", contextMsg, err) + errorCode := event.GetErrorCodeFromError(err) + + logger.Warn("Reshare session error", + "walletID", walletID, + "keyType", keyType, + "newThreshold", newThreshold, + "error", err.Error(), + "errorCode", errorCode, + "context", contextMsg, + ) + + reshareResult := event.ResharingResultEvent{ + ResultType: event.ResultTypeError, + ErrorCode: string(errorCode), + WalletID: walletID, + KeyType: keyType, + NewThreshold: newThreshold, + ErrorReason: fullErrMsg, + } + + reshareResultBytes, err := json.Marshal(reshareResult) + if err != nil { + logger.Error("Failed to marshal reshare result event", err, + "walletID", walletID, + ) return } + + key := fmt.Sprintf(mpc.TypeReshareWalletResultFmt, walletID) + err = ec.reshareResultQueue.Enqueue(key, reshareResultBytes, &messaging.EnqueueOptions{ + IdempotententKey: key, + }) + if err != nil { + logger.Error("Failed to enqueue reshare result event", err, + "walletID", walletID, + "payload", string(reshareResultBytes), + ) + } } // Add a cleanup routine that runs periodically @@ -421,7 +708,6 @@ func (ec *eventConsumer) cleanupStaleSessions() { for sessionID, creationTime := range ec.activeSessions { if now.Sub(creationTime) > ec.sessionTimeout { - logger.Info("Cleaning up stale session", "sessionID", sessionID, "age", now.Sub(creationTime)) delete(ec.activeSessions, sessionID) } } @@ -474,6 +760,22 @@ func (ec *eventConsumer) Close() error { if err != nil { return err } + err = ec.reshareSub.Unsubscribe() + if err != nil { + return err + } return nil } + +func sessionTypeFromKeyType(keyType types.KeyType) (mpc.SessionType, error) { + switch keyType { + case types.KeyTypeSecp256k1: + return mpc.SessionTypeECDSA, nil + case types.KeyTypeEd25519: + return mpc.SessionTypeEDDSA, nil + default: + logger.Warn("Unsupported key type", "keyType", keyType) + return "", fmt.Errorf("unsupported key type: %v", keyType) + } +} diff --git a/pkg/eventconsumer/events.go b/pkg/eventconsumer/events.go index 5b9ca06..4d71714 100644 --- a/pkg/eventconsumer/events.go +++ b/pkg/eventconsumer/events.go @@ -6,7 +6,7 @@ type KeyType string const ( KeyTypeSecp256k1 KeyType = "secp256k1" - KeyTypeEd25519 = "ed25519" + KeyTypeEd25519 KeyType = "ed25519" ) // InitiatorMessage is anything that carries a payload to verify and its signature. diff --git a/pkg/eventconsumer/sign_consumer.go b/pkg/eventconsumer/sign_consumer.go index 83387a6..deff686 100644 --- a/pkg/eventconsumer/sign_consumer.go +++ b/pkg/eventconsumer/sign_consumer.go @@ -8,8 +8,10 @@ import ( "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/mpc" "github.com/nats-io/nats.go" "github.com/nats-io/nats.go/jetstream" + "github.com/spf13/viper" ) const ( @@ -17,6 +19,8 @@ const ( signingResponseTimeout = 30 * time.Second // How often to poll for the reply message. signingPollingInterval = 500 * time.Millisecond + // How often to check if enough peers are ready + readinessCheckInterval = 2 * time.Second ) // SigningConsumer represents a consumer that processes signing events. @@ -29,25 +33,65 @@ type SigningConsumer interface { // signingConsumer implements SigningConsumer. type signingConsumer struct { - natsConn *nats.Conn - pubsub messaging.PubSub - jsPubsub messaging.StreamPubsub + natsConn *nats.Conn + pubsub messaging.PubSub + jsPubsub messaging.StreamPubsub + peerRegistry mpc.PeerRegistry + mpcThreshold int // jsSub holds the JetStream subscription, so it can be cleaned up during Close(). jsSub messaging.Subscription } // NewSigningConsumer returns a new instance of SigningConsumer. -func NewSigningConsumer(natsConn *nats.Conn, jsPubsub messaging.StreamPubsub, pubsub messaging.PubSub) SigningConsumer { +func NewSigningConsumer(natsConn *nats.Conn, jsPubsub messaging.StreamPubsub, pubsub messaging.PubSub, peerRegistry mpc.PeerRegistry) SigningConsumer { + mpcThreshold := viper.GetInt("mpc_threshold") return &signingConsumer{ - natsConn: natsConn, - pubsub: pubsub, - jsPubsub: jsPubsub, + natsConn: natsConn, + pubsub: pubsub, + jsPubsub: jsPubsub, + peerRegistry: peerRegistry, + mpcThreshold: mpcThreshold, + } +} + +// waitForSufficientPeers waits until enough peers are ready to handle signing requests +func (sc *signingConsumer) waitForSufficientPeers(ctx context.Context) error { + requiredPeers := int64(sc.mpcThreshold + 1) // t+1 peers needed for signing + + logger.Info("SigningConsumer: Waiting for sufficient peers before consuming messages", + "required", requiredPeers, + "threshold", sc.mpcThreshold) + + ticker := time.NewTicker(readinessCheckInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + readyPeers := sc.peerRegistry.GetReadyPeersCount() + if readyPeers >= requiredPeers { + logger.Info("SigningConsumer: Sufficient peers ready, starting message consumption", + "ready", readyPeers, + "t+1", requiredPeers) + return nil + } + logger.Info("SigningConsumer: Waiting for more peers to be ready", + "ready", readyPeers, + "t+1", requiredPeers) + } } } // Run subscribes to signing events and processes them until the context is canceled. func (sc *signingConsumer) Run(ctx context.Context) error { + // Wait for sufficient peers before starting to consume messages + if err := sc.waitForSufficientPeers(ctx); err != nil { + return fmt.Errorf("failed to wait for sufficient peers: %w", err) + } + sub, err := sc.jsPubsub.Subscribe( event.SigningConsumerStream, event.SigningRequestEventTopic, @@ -83,6 +127,18 @@ func (sc *signingConsumer) Run(ctx context.Context) error { // When signing completes, the session publishes the result to a queue and calls the onSuccess callback, which sends a reply to the inbox that the SigningConsumer is monitoring. // The reply signals completion, allowing the SigningConsumer to acknowledge the original message. func (sc *signingConsumer) handleSigningEvent(msg jetstream.Msg) { + // Check if we still have enough peers before processing the message + requiredPeers := int64(sc.mpcThreshold + 1) + readyPeers := sc.peerRegistry.GetReadyPeersCount() + + if readyPeers < requiredPeers { + logger.Warn("SigningConsumer: Not enough peers to process signing request, rejecting message", + "ready", readyPeers, + "required", requiredPeers) + // Immediately return and let nats redeliver the message with backoff + return + } + // Create a reply inbox to receive the signing event response. replyInbox := nats.NewInbox() @@ -95,7 +151,7 @@ func (sc *signingConsumer) handleSigningEvent(msg jetstream.Msg) { } defer func() { if err := replySub.Unsubscribe(); err != nil { - logger.Warn("SigningConsumer: Failed to unsubscribe from reply inbox", err) + logger.Warn("SigningConsumer: Failed to unsubscribe from reply inbox", "error", err) } }() diff --git a/pkg/eventconsumer/timeout_consumer.go b/pkg/eventconsumer/timeout_consumer.go index 4c5c31c..bd91170 100644 --- a/pkg/eventconsumer/timeout_consumer.go +++ b/pkg/eventconsumer/timeout_consumer.go @@ -2,7 +2,6 @@ package eventconsumer import ( "encoding/json" - "fmt" "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/logger" @@ -61,9 +60,10 @@ func (tc *timeOutConsumer) Run() { return } - signErrorResult.ResultType = event.SigningResultTypeError + signErrorResult.ResultType = event.ResultTypeError + signErrorResult.ErrorCode = event.ErrorCodeMaxDeliveryAttempts signErrorResult.IsTimeout = true - signErrorResult.ErrorReason = fmt.Sprintf("Message delivery exceeded for stream %s", advisory.Stream) + signErrorResult.ErrorReason = "Signing failed: maximum delivery attempts exceeded" signErrorResultBytes, err := json.Marshal(signErrorResult) if err != nil { diff --git a/pkg/identity/identity.go b/pkg/identity/identity.go index 863a696..4d281c2 100644 --- a/pkg/identity/identity.go +++ b/pkg/identity/identity.go @@ -8,6 +8,7 @@ import ( "io" "os" "path/filepath" + "strings" "sync" "syscall" @@ -264,5 +265,5 @@ func (s *fileStore) VerifyInitiatorMessage(msg types.InitiatorMessage) error { } func partyIDToNodeID(partyID *tss.PartyID) string { - return string(partyID.KeyInt().Bytes()) + return strings.Split(string(partyID.KeyInt().Bytes()), ":")[0] } diff --git a/pkg/keyinfo/keyinfo.go b/pkg/keyinfo/keyinfo.go index 6952e7d..49d4c7f 100644 --- a/pkg/keyinfo/keyinfo.go +++ b/pkg/keyinfo/keyinfo.go @@ -11,6 +11,7 @@ import ( type KeyInfo struct { ParticipantPeerIDs []string `json:"participant_peer_ids"` Threshold int `json:"threshold"` + Version int `json:"version"` } type store struct { diff --git a/pkg/messaging/message_queue.go b/pkg/messaging/message_queue.go index d5a1c7f..c2aeef9 100644 --- a/pkg/messaging/message_queue.go +++ b/pkg/messaging/message_queue.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "time" "github.com/fystack/mpcium/pkg/logger" "github.com/nats-io/nats.go" @@ -57,14 +58,14 @@ func NewNATsMessageQueueManager(queueName string, subjectWildCards []string, nc Name: queueName, Description: "Stream for " + queueName, Subjects: subjectWildCards, - MaxBytes: 1024, + MaxBytes: 10_485_760, // Light Production (Low Traffic) (10 MB) Storage: jetstream.FileStorage, Retention: jetstream.WorkQueuePolicy, }) if err != nil { logger.Fatal("Error creating JetStream stream: ", err) } - logger.Info("Creating apex NATs Jetstream context successfully!") + logger.Info("Creating apex NATs Jetstream context successfully!", "streamName", queueName, "subjects", subjectWildCards) return &NATsMessageQueueManager{ queueName: queueName, @@ -81,13 +82,16 @@ func (m *NATsMessageQueueManager) NewMessageQueue(consumerName string) MessageQu cfg := jetstream.ConsumerConfig{ Name: consumerName, Durable: consumerName, - MaxAckPending: 4, + MaxAckPending: 1000, + // If a message isn't acked within AckWait, it will be redelivered up to MaxDelive + AckWait: 30 * time.Second, + AckPolicy: jetstream.AckExplicitPolicy, FilterSubjects: []string{ consumerWildCard, }, MaxDeliver: 3, } - logger.Info("Creating consumer for subject", "config", cfg) + logger.Info("Creating consumer for subject", "consumerName", consumerName, "queueName", m.queueName, "filterSubject", consumerWildCard, "config", cfg) consumer, err := m.js.CreateOrUpdateConsumer(context.Background(), m.queueName, cfg) if err != nil { logger.Fatal("Error creating JetStream consumer: ", err) @@ -103,7 +107,7 @@ func (mq *msgQueue) Enqueue(topic string, message []byte, options *EnqueueOption header.Add("Nats-Msg-Id", options.IdempotententKey) } - logger.Info("Publishing message", "topic", topic) + logger.Info("Publishing message", "topic", topic, "consumerName", mq.consumerName) _, err := mq.js.PublishMsg(context.Background(), &nats.Msg{ Subject: topic, Data: message, @@ -111,7 +115,8 @@ func (mq *msgQueue) Enqueue(topic string, message []byte, options *EnqueueOption }) if err != nil { - return fmt.Errorf("Error enqueueing message: %w", err) + logger.Error("Failed to publish message to JetStream", err, "topic", topic, "consumerName", mq.consumerName) + return fmt.Errorf("error enqueueing message: %w", err) } return nil diff --git a/pkg/messaging/pubsub.go b/pkg/messaging/pubsub.go index 9860e02..27750fa 100644 --- a/pkg/messaging/pubsub.go +++ b/pkg/messaging/pubsub.go @@ -223,10 +223,11 @@ func (j *jetStreamPubSub) Subscribe(name string, topic string, handler func(msg logger.Info("Subscribing to topic", sanitizeConsumerName(name), topic) consumerConfig := jetstream.ConsumerConfig{ - Name: sanitizeConsumerName(name), - Durable: sanitizeConsumerName(name), - AckPolicy: jetstream.AckExplicitPolicy, - MaxDeliver: 4, + Name: sanitizeConsumerName(name), + Durable: sanitizeConsumerName(name), + AckPolicy: jetstream.AckExplicitPolicy, + MaxDeliver: 4, + // backoff is NOT applied to naked messages. BackOff: []time.Duration{30 * time.Second, 30 * time.Second, 30 * time.Second}, DeliverPolicy: jetstream.DeliverAllPolicy, // Deliver all messages FilterSubject: topic, @@ -245,7 +246,7 @@ func (j *jetStreamPubSub) Subscribe(name string, topic string, handler func(msg } _, err = consumer.Consume(func(msg jetstream.Msg) { - logger.Info("Received jetStreamPubSub message", "subject", msg.Data()) + logger.Info("Received jetStreamPubSub message") handler(msg) }) diff --git a/pkg/mpc/ecdsa_keygen_session.go b/pkg/mpc/ecdsa_keygen_session.go index c90cd79..e9ea00a 100644 --- a/pkg/mpc/ecdsa_keygen_session.go +++ b/pkg/mpc/ecdsa_keygen_session.go @@ -48,6 +48,7 @@ func newECDSAKeygenSession( pubSub: pubSub, direct: direct, threshold: threshold, + version: DefaultVersion, participantPeerIDs: participantPeerIDs, selfPartyID: selfID, partyIDs: partyIDs, @@ -103,7 +104,7 @@ func (s *ecdsaKeygenSession) GenerateKey(done func()) { return } - err = s.kvstore.Put(s.composeKey(s.walletID), keyBytes) + err = s.kvstore.Put(s.composeKey(walletIDWithVersion(s.walletID, s.GetVersion())), keyBytes) if err != nil { logger.Error("Failed to save key", err, "walletID", s.walletID) s.ErrCh <- err @@ -113,6 +114,7 @@ func (s *ecdsaKeygenSession) GenerateKey(done func()) { keyInfo := keyinfo.KeyInfo{ ParticipantPeerIDs: s.participantPeerIDs, Threshold: s.threshold, + Version: s.GetVersion(), } err = s.keyinfoStore.Save(s.composeKey(s.walletID), &keyInfo) diff --git a/pkg/mpc/ecdsa_resharing_session.go b/pkg/mpc/ecdsa_resharing_session.go index 1e3d1d8..0b19be7 100644 --- a/pkg/mpc/ecdsa_resharing_session.go +++ b/pkg/mpc/ecdsa_resharing_session.go @@ -1 +1,195 @@ package mpc + +import ( + "crypto/ecdsa" + "encoding/json" + "fmt" + + "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" + "github.com/bnb-chain/tss-lib/v2/ecdsa/resharing" + "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/fystack/mpcium/pkg/encoding" + "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/messaging" +) + +type ReshareSession interface { + Session + Init() + Reshare(done func()) + GetPubKeyResult() []byte +} + +type ecdsaReshareSession struct { + *session + isNewParty bool + newPeerIDs []string + reshareParams *tss.ReSharingParameters + endCh chan *keygen.LocalPartySaveData +} + +func NewECDSAReshareSession( + walletID string, + pubSub messaging.PubSub, + direct messaging.DirectMessaging, + participantPeerIDs []string, + selfID *tss.PartyID, + oldPartyIDs []*tss.PartyID, + newPartyIDs []*tss.PartyID, + threshold int, + newThreshold int, + preParams *keygen.LocalPreParams, + kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, + resultQueue messaging.MessageQueue, + identityStore identity.Store, + newPeerIDs []string, + isNewParty bool, + version int, +) *ecdsaReshareSession { + session := session{ + walletID: walletID, + pubSub: pubSub, + direct: direct, + threshold: threshold, + participantPeerIDs: participantPeerIDs, + selfPartyID: selfID, + partyIDs: newPartyIDs, + outCh: make(chan tss.Message), + ErrCh: make(chan error), + preParams: preParams, + kvstore: kvstore, + keyinfoStore: keyinfoStore, + version: version, + topicComposer: &TopicComposer{ + ComposeBroadcastTopic: func() string { + return fmt.Sprintf("resharing:broadcast:ecdsa:%s", walletID) + }, + ComposeDirectTopic: func(nodeID string) string { + return fmt.Sprintf("resharing:direct:ecdsa:%s:%s", nodeID, walletID) + }, + }, + composeKey: func(walletID string) string { + return fmt.Sprintf("ecdsa:%s", walletID) + }, + getRoundFunc: GetEcdsaMsgRound, + resultQueue: resultQueue, + sessionType: SessionTypeECDSA, + identityStore: identityStore, + } + reshareParams := tss.NewReSharingParameters( + tss.S256(), + tss.NewPeerContext(oldPartyIDs), + tss.NewPeerContext(newPartyIDs), + selfID, + len(oldPartyIDs), + threshold, + len(newPartyIDs), + newThreshold, + ) + return &ecdsaReshareSession{ + session: &session, + reshareParams: reshareParams, + isNewParty: isNewParty, + newPeerIDs: newPeerIDs, + endCh: make(chan *keygen.LocalPartySaveData), + } +} + +func (s *ecdsaReshareSession) Init() { + logger.Infof("Initializing resharing session with partyID: %s, newPartyIDs %s", s.selfPartyID, s.partyIDs) + var share keygen.LocalPartySaveData + + if s.isNewParty { + // New party → generate empty share + share = keygen.NewLocalPartySaveData(len(s.partyIDs)) + share.LocalPreParams = *s.preParams + } else { + err := s.loadOldShareDataGeneric(s.walletID, s.GetVersion(), &share) + if err != nil { + s.ErrCh <- err + return + } + } + + s.party = resharing.NewLocalParty(s.reshareParams, share, s.outCh, s.endCh) + + logger.Infof("[INITIALIZED] Initialized resharing session successfully partyID: %s, peerIDs %s, walletID %s, oldThreshold = %d, newThreshold = %d", + s.selfPartyID, s.partyIDs, s.walletID, s.threshold, s.reshareParams.NewThreshold()) +} + +func (s *ecdsaReshareSession) Reshare(done func()) { + logger.Info("Starting resharing", "walletID", s.walletID, "partyID", s.selfPartyID) + go func() { + if err := s.party.Start(); err != nil { + s.ErrCh <- err + } + }() + + for { + select { + case saveData := <-s.endCh: + // skip for old committee + if saveData.ECDSAPub != nil { + + keyBytes, err := json.Marshal(saveData) + if err != nil { + s.ErrCh <- err + return + } + + newVersion := s.GetVersion() + 1 + key := s.composeKey(walletIDWithVersion(s.walletID, newVersion)) + if err := s.kvstore.Put(key, keyBytes); err != nil { + s.ErrCh <- err + return + } + + keyInfo := keyinfo.KeyInfo{ + ParticipantPeerIDs: s.newPeerIDs, + Threshold: s.reshareParams.NewThreshold(), + Version: newVersion, + } + + // Save key info with resharing flag + if err := s.keyinfoStore.Save(s.composeKey(s.walletID), &keyInfo); err != nil { + s.ErrCh <- err + return + } + // Get public key + publicKey := saveData.ECDSAPub + pubKey := &ecdsa.PublicKey{ + Curve: publicKey.Curve(), + X: publicKey.X(), + Y: publicKey.Y(), + } + + pubKeyBytes, err := encoding.EncodeS256PubKey(pubKey) + if err != nil { + logger.Error("failed to encode public key", err) + s.ErrCh <- fmt.Errorf("failed to encode public key: %w", err) + return + } + + // Set the public key bytes + s.pubkeyBytes = pubKeyBytes + logger.Info("Generated public key bytes", + "walletID", s.walletID, + "pubKeyBytes", pubKeyBytes) + } + + done() + err := s.Close() + if err != nil { + logger.Error("Failed to close session", err) + } + return + case msg := <-s.outCh: + // Handle the message + s.handleTssMessage(msg) + } + } +} diff --git a/pkg/mpc/ecdsa_rounds.go b/pkg/mpc/ecdsa_rounds.go index 8e70f7a..b7ddd2e 100644 --- a/pkg/mpc/ecdsa_rounds.go +++ b/pkg/mpc/ecdsa_rounds.go @@ -2,26 +2,35 @@ package mpc import ( "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" + "github.com/bnb-chain/tss-lib/v2/ecdsa/resharing" "github.com/bnb-chain/tss-lib/v2/ecdsa/signing" "github.com/bnb-chain/tss-lib/v2/tss" "github.com/fystack/mpcium/pkg/common/errors" ) const ( - KEYGEN1 = "KGRound1Message" - KEYGEN2aUnicast = "KGRound2Message1" - KEYGEN2b = "KGRound2Message2" - KEYGEN3 = "KGRound3Message" - KEYSIGN1aUnicast = "SignRound1Message1" - KEYSIGN1b = "SignRound1Message2" - KEYSIGN2Unicast = "SignRound2Message" - KEYSIGN3 = "SignRound3Message" - KEYSIGN4 = "SignRound4Message" - KEYSIGN5 = "SignRound5Message" - KEYSIGN6 = "SignRound6Message" - KEYSIGN7 = "SignRound7Message" - KEYSIGN8 = "SignRound8Message" - KEYSIGN9 = "SignRound9Message" + KEYGEN1 = "KGRound1Message" + KEYGEN2aUnicast = "KGRound2Message1" + KEYGEN2b = "KGRound2Message2" + KEYGEN3 = "KGRound3Message" + KEYSIGN1aUnicast = "SignRound1Message1" + KEYSIGN1b = "SignRound1Message2" + KEYSIGN2Unicast = "SignRound2Message" + KEYSIGN3 = "SignRound3Message" + KEYSIGN4 = "SignRound4Message" + KEYSIGN5 = "SignRound5Message" + KEYSIGN6 = "SignRound6Message" + KEYSIGN7 = "SignRound7Message" + KEYSIGN8 = "SignRound8Message" + KEYSIGN9 = "SignRound9Message" + KEYRESHARING1Unicast = "DGRound1Message" + KEYRESHARING2aUnicast = "DGRound2Message1" + KEYRESHARING2bUnicast = "DGRound2Message2" + KEYRESHARING3aUnicast = "DGRound3Message1" + KEYRESHARING3b = "DGRound3Message2" + KEYRESHARING4a = "DGRound4Message1" + KEYRESHARING4bUnicast = "DGRound4Message2" + TSSKEYGENROUNDS = 4 TSSKEYSIGNROUNDS = 10 ) @@ -113,8 +122,46 @@ func GetEcdsaMsgRound(msg []byte, partyID *tss.PartyID, isBroadcast bool) (Round Index: 9, RoundMsg: KEYSIGN9, }, nil - + case *resharing.DGRound1Message: + return RoundInfo{ + Index: 0, + RoundMsg: KEYRESHARING1Unicast, + }, nil + case *resharing.DGRound2Message1: + return RoundInfo{ + Index: 1, + RoundMsg: KEYRESHARING2aUnicast, + }, nil + case *resharing.DGRound2Message2: + return RoundInfo{ + Index: 2, + RoundMsg: KEYRESHARING2bUnicast, + }, nil + case *resharing.DGRound3Message1: + return RoundInfo{ + Index: 3, + RoundMsg: KEYRESHARING3aUnicast, + }, nil + case *resharing.DGRound3Message2: + return RoundInfo{ + Index: 4, + RoundMsg: KEYRESHARING3b, + }, nil + case *resharing.DGRound4Message1: + return RoundInfo{ + Index: 5, + RoundMsg: KEYRESHARING4a, + }, nil + case *resharing.DGRound4Message2: + return RoundInfo{ + Index: 6, + RoundMsg: KEYRESHARING4bUnicast, + }, nil default: return RoundInfo{}, errors.New("unknown round") } } + +func IsReshareRound(roundMsg string) bool { + return roundMsg == KEYRESHARING1Unicast || roundMsg == KEYRESHARING2aUnicast || roundMsg == KEYRESHARING2bUnicast || roundMsg == KEYRESHARING3aUnicast || roundMsg == KEYRESHARING3b || roundMsg == KEYRESHARING4a || roundMsg == KEYRESHARING4bUnicast +} diff --git a/pkg/mpc/ecdsa_signing_session.go b/pkg/mpc/ecdsa_signing_session.go index 9c8af7d..76cba2c 100644 --- a/pkg/mpc/ecdsa_signing_session.go +++ b/pkg/mpc/ecdsa_signing_session.go @@ -93,11 +93,6 @@ func (s *ecdsaSigningSession) Init(tx *big.Int) error { ctx := tss.NewPeerContext(s.partyIDs) params := tss.NewParameters(tss.S256(), ctx, s.selfPartyID, len(s.partyIDs), s.threshold) - keyData, err := s.kvstore.Get(s.composeKey(s.walletID)) - if err != nil { - return errors.Wrap(err, "Failed to get wallet data from KVStore") - } - keyInfo, err := s.keyinfoStore.Get(s.composeKey(s.walletID)) if err != nil { return errors.Wrap(err, "Failed to get key info data") @@ -119,6 +114,11 @@ func (s *ecdsaSigningSession) Init(tx *big.Int) error { } logger.Info("Have enough participants to sign", "participants", s.participantPeerIDs) + + keyData, err := s.kvstore.Get(s.composeKey(walletIDWithVersion(s.walletID, keyInfo.Version))) + if err != nil { + return errors.Wrap(err, "Failed to get wallet data from KVStore") + } // Check if all the participants of the key are present var data keygen.LocalPartySaveData err = json.Unmarshal(keyData, &data) @@ -128,6 +128,7 @@ func (s *ecdsaSigningSession) Init(tx *big.Int) error { s.party = signing.NewLocalParty(tx, params, data, s.outCh, s.endCh) s.data = &data + s.version = keyInfo.Version s.tx = tx logger.Info("Initialized sigining session successfully!") return nil @@ -161,7 +162,7 @@ func (s *ecdsaSigningSession) Sign(onSuccess func(data []byte)) { } r := event.SigningResultEvent{ - ResultType: event.SigningResultTypeSuccess, + ResultType: event.ResultTypeSuccess, NetworkInternalCode: s.networkInternalCode, WalletID: s.walletID, TxID: s.txID, diff --git a/pkg/mpc/eddsa_keygen_session.go b/pkg/mpc/eddsa_keygen_session.go index d3489ac..a4fe030 100644 --- a/pkg/mpc/eddsa_keygen_session.go +++ b/pkg/mpc/eddsa_keygen_session.go @@ -37,6 +37,7 @@ func newEDDSAKeygenSession( pubSub: pubSub, direct: direct, threshold: threshold, + version: DefaultVersion, participantPeerIDs: participantPeerIDs, selfPartyID: selfID, partyIDs: partyIDs, @@ -91,7 +92,7 @@ func (s *eddsaKeygenSession) GenerateKey(done func()) { return } - err = s.kvstore.Put(s.composeKey(s.walletID), keyBytes) + err = s.kvstore.Put(s.composeKey(walletIDWithVersion(s.walletID, s.GetVersion())), keyBytes) if err != nil { logger.Error("Failed to save key", err, "walletID", s.walletID) s.ErrCh <- err @@ -101,6 +102,7 @@ func (s *eddsaKeygenSession) GenerateKey(done func()) { keyInfo := keyinfo.KeyInfo{ ParticipantPeerIDs: s.participantPeerIDs, Threshold: s.threshold, + Version: s.GetVersion(), } err = s.keyinfoStore.Save(s.composeKey(s.walletID), &keyInfo) diff --git a/pkg/mpc/eddsa_resharing_session.go b/pkg/mpc/eddsa_resharing_session.go index 1e3d1d8..9135fbc 100644 --- a/pkg/mpc/eddsa_resharing_session.go +++ b/pkg/mpc/eddsa_resharing_session.go @@ -1 +1,180 @@ package mpc + +import ( + "encoding/json" + "fmt" + + "github.com/bnb-chain/tss-lib/v2/eddsa/keygen" + "github.com/bnb-chain/tss-lib/v2/eddsa/resharing" + "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/decred/dcrd/dcrec/edwards/v2" + "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/messaging" +) + +type eddsaReshareSession struct { + *session + isNewParty bool + newPeerIDs []string + reshareParams *tss.ReSharingParameters + endCh chan *keygen.LocalPartySaveData +} + +func NewEDDSAReshareSession( + walletID string, + pubSub messaging.PubSub, + direct messaging.DirectMessaging, + participantPeerIDs []string, + selfID *tss.PartyID, + oldPartyIDs []*tss.PartyID, + newPartyIDs []*tss.PartyID, + threshold int, + newThreshold int, + kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, + resultQueue messaging.MessageQueue, + identityStore identity.Store, + newPeerIDs []string, + isNewParty bool, + version int, +) *eddsaReshareSession { + session := session{ + walletID: walletID, + pubSub: pubSub, + direct: direct, + threshold: threshold, + version: version, + participantPeerIDs: participantPeerIDs, + selfPartyID: selfID, + partyIDs: newPartyIDs, + outCh: make(chan tss.Message), + ErrCh: make(chan error), + kvstore: kvstore, + keyinfoStore: keyinfoStore, + topicComposer: &TopicComposer{ + ComposeBroadcastTopic: func() string { + return fmt.Sprintf("reshare:broadcast:eddsa:%s", walletID) + }, + ComposeDirectTopic: func(nodeID string) string { + return fmt.Sprintf("reshare:direct:eddsa:%s:%s", nodeID, walletID) + }, + }, + composeKey: func(walletID string) string { + return fmt.Sprintf("eddsa:%s", walletID) + }, + getRoundFunc: GetEddsaMsgRound, + resultQueue: resultQueue, + sessionType: SessionTypeEDDSA, + identityStore: identityStore, + } + + reshareParams := tss.NewReSharingParameters( + tss.Edwards(), + tss.NewPeerContext(oldPartyIDs), + tss.NewPeerContext(newPartyIDs), + selfID, + len(oldPartyIDs), + threshold, + len(newPartyIDs), + newThreshold, + ) + + return &eddsaReshareSession{ + session: &session, + reshareParams: reshareParams, + isNewParty: isNewParty, + newPeerIDs: newPeerIDs, + endCh: make(chan *keygen.LocalPartySaveData), + } +} + +func (s *eddsaReshareSession) Init() { + logger.Infof("Initializing resharing session with partyID: %s, peerIDs %s", s.selfPartyID, s.partyIDs) + var share keygen.LocalPartySaveData + if s.isNewParty { + // Initialize empty share data for new party + share = keygen.NewLocalPartySaveData(len(s.partyIDs)) + } else { + err := s.loadOldShareDataGeneric(s.walletID, s.GetVersion(), &share) + if err != nil { + s.ErrCh <- err + return + } + } + s.party = resharing.NewLocalParty(s.reshareParams, share, s.outCh, s.endCh) + logger.Infof("[INITIALIZED] Initialized resharing session successfully partyID: %s, peerIDs %s, walletID %s, oldThreshold = %d, newThreshold = %d", + s.selfPartyID, s.partyIDs, s.walletID, s.threshold, s.reshareParams.NewThreshold()) +} + +func (s *eddsaReshareSession) Reshare(done func()) { + logger.Info("Starting resharing", "walletID", s.walletID, "partyID", s.selfPartyID) + go func() { + if err := s.party.Start(); err != nil { + s.ErrCh <- err + } + }() + + for { + select { + case saveData := <-s.endCh: + if saveData.EDDSAPub != nil { + keyBytes, err := json.Marshal(saveData) + if err != nil { + s.ErrCh <- err + return + } + + newVersion := s.GetVersion() + 1 + key := s.composeKey(walletIDWithVersion(s.walletID, newVersion)) + if err := s.kvstore.Put(key, keyBytes); err != nil { + s.ErrCh <- err + return + } + + keyInfo := keyinfo.KeyInfo{ + ParticipantPeerIDs: s.newPeerIDs, + Threshold: s.reshareParams.NewThreshold(), + Version: newVersion, + } + + // Save key info with resharing flag + if err := s.keyinfoStore.Save(s.composeKey(s.walletID), &keyInfo); err != nil { + s.ErrCh <- err + return + } + + // skip for old committee + if saveData.EDDSAPub != nil { + + // Get public key + publicKey := saveData.EDDSAPub + pkX, pkY := publicKey.X(), publicKey.Y() + pk := edwards.PublicKey{ + Curve: tss.Edwards(), + X: pkX, + Y: pkY, + } + + pubKeyBytes := pk.SerializeCompressed() + s.pubkeyBytes = pubKeyBytes + + logger.Info("Generated public key bytes", + "walletID", s.walletID, + "pubKeyBytes", pubKeyBytes) + } + } + done() + err := s.Close() + if err != nil { + logger.Error("Failed to close session", err) + } + return + case msg := <-s.outCh: + // Handle the message + s.handleTssMessage(msg) + } + } +} diff --git a/pkg/mpc/eddsa_rounds.go b/pkg/mpc/eddsa_rounds.go index 01519d0..88864f3 100644 --- a/pkg/mpc/eddsa_rounds.go +++ b/pkg/mpc/eddsa_rounds.go @@ -2,6 +2,7 @@ package mpc import ( "github.com/bnb-chain/tss-lib/v2/eddsa/keygen" + "github.com/bnb-chain/tss-lib/v2/eddsa/resharing" "github.com/bnb-chain/tss-lib/v2/eddsa/signing" "github.com/bnb-chain/tss-lib/v2/tss" "github.com/fystack/mpcium/pkg/common/errors" @@ -16,14 +17,21 @@ type RoundInfo struct { } const ( - EDDSA_KEYGEN1 = "KGRound1Message" - EDDSA_KEYGEN2aUnicast = "KGRound2Message1" - EDDSA_KEYGEN2b = "KGRound2Message2" - EDDSA_KEYSIGN1 = "SignRound1Message" - EDDSA_KEYSIGN2 = "SignRound2Message" - EDDSA_KEYSIGN3 = "SignRound3Message" + EDDSA_KEYGEN1 = "KGRound1Message" + EDDSA_KEYGEN2aUnicast = "KGRound2Message1" + EDDSA_KEYGEN2b = "KGRound2Message2" + EDDSA_KEYSIGN1 = "SignRound1Message" + EDDSA_KEYSIGN2 = "SignRound2Message" + EDDSA_KEYSIGN3 = "SignRound3Message" + EDDSA_RESHARING1 = "DGRound1Message" + EDDSA_RESHARING2 = "DGRound2Message" + EDDSA_RESHARING3aUnicast = "DGRound3Message1" + EDDSA_RESHARING3bUnicast = "DGRound3Message2" + EDDSA_RESHARING4 = "DGRound4Message" + EDDSA_TSSKEYGENROUNDS = 3 EDDSA_TSSKEYSIGNROUNDS = 3 + EDDSA_RESHARINGROUNDS = 4 ) func GetEddsaMsgRound(msg []byte, partyID *tss.PartyID, isBroadcast bool) (RoundInfo, error) { @@ -68,6 +76,36 @@ func GetEddsaMsgRound(msg []byte, partyID *tss.PartyID, isBroadcast bool) (Round RoundMsg: EDDSA_KEYSIGN3, }, nil + case *resharing.DGRound1Message: + return RoundInfo{ + Index: 0, + RoundMsg: EDDSA_RESHARING1, + }, nil + + case *resharing.DGRound2Message: + return RoundInfo{ + Index: 1, + RoundMsg: EDDSA_RESHARING2, + }, nil + + case *resharing.DGRound3Message1: + return RoundInfo{ + Index: 2, + RoundMsg: EDDSA_RESHARING3aUnicast, + }, nil + + case *resharing.DGRound3Message2: + return RoundInfo{ + Index: 3, + RoundMsg: EDDSA_RESHARING3bUnicast, + }, nil + + case *resharing.DGRound4Message: + return RoundInfo{ + Index: 4, + RoundMsg: EDDSA_RESHARING4, + }, nil + default: return RoundInfo{}, errors.New("unknown round") } diff --git a/pkg/mpc/eddsa_signing_session.go b/pkg/mpc/eddsa_signing_session.go index 1c586da..4538730 100644 --- a/pkg/mpc/eddsa_signing_session.go +++ b/pkg/mpc/eddsa_signing_session.go @@ -29,7 +29,7 @@ type eddsaSigningSession struct { networkInternalCode string } -func NewEDDSASigningSession( +func newEDDSASigningSession( walletID string, txID string, networkInternalCode string, @@ -84,11 +84,6 @@ func (s *eddsaSigningSession) Init(tx *big.Int) error { ctx := tss.NewPeerContext(s.partyIDs) params := tss.NewParameters(tss.Edwards(), ctx, s.selfPartyID, len(s.partyIDs), s.threshold) - keyData, err := s.kvstore.Get(s.composeKey(s.walletID)) - if err != nil { - return errors.Wrap(err, "Failed to get wallet data from KVStore") - } - keyInfo, err := s.keyinfoStore.Get(s.composeKey(s.walletID)) if err != nil { return errors.Wrap(err, "Failed to get key info data") @@ -110,6 +105,11 @@ func (s *eddsaSigningSession) Init(tx *big.Int) error { } logger.Info("Have enough participants to sign", "participants", s.participantPeerIDs) + key := s.composeKey(walletIDWithVersion(s.walletID, keyInfo.Version)) + keyData, err := s.kvstore.Get(key) + if err != nil { + return errors.Wrap(err, "Failed to get wallet data from KVStore") + } // Check if all the participants of the key are present var data keygen.LocalPartySaveData err = json.Unmarshal(keyData, &data) @@ -119,6 +119,7 @@ func (s *eddsaSigningSession) Init(tx *big.Int) error { s.party = signing.NewLocalParty(tx, params, data, s.outCh, s.endCh) s.data = &data + s.version = keyInfo.Version s.tx = tx logger.Info("Initialized sigining session successfully!") return nil @@ -152,7 +153,7 @@ func (s *eddsaSigningSession) Sign(onSuccess func(data []byte)) { } r := event.SigningResultEvent{ - ResultType: event.SigningResultTypeSuccess, + ResultType: event.ResultTypeSuccess, NetworkInternalCode: s.networkInternalCode, WalletID: s.walletID, TxID: s.txID, diff --git a/pkg/mpc/key_type.go b/pkg/mpc/key_type.go index 756efa8..96add07 100644 --- a/pkg/mpc/key_type.go +++ b/pkg/mpc/key_type.go @@ -4,5 +4,5 @@ type KeyType string const ( KeyTypeSecp256k1 KeyType = "secp256k1" - KeyTypeEd25519 = "ed25519" + KeyTypeEd25519 KeyType = "ed25519" ) diff --git a/pkg/mpc/node.go b/pkg/mpc/node.go index cbdb5bc..720741d 100644 --- a/pkg/mpc/node.go +++ b/pkg/mpc/node.go @@ -2,8 +2,11 @@ package mpc import ( "bytes" + "encoding/json" "fmt" "math/big" + "slices" + "strconv" "time" "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" @@ -18,8 +21,12 @@ import ( ) const ( - PurposeKeygen string = "keygen" - PurposeSign string = "sign" + PurposeKeygen string = "keygen" + PurposeSign string = "sign" + PurposeReshare string = "reshare" + + BackwardCompatibleVersion int = 0 + DefaultVersion int = 1 ) type ID string @@ -32,19 +39,13 @@ type Node struct { direct messaging.DirectMessaging kvstore kvstore.KVStore keyinfoStore keyinfo.Store - ecdsaPreParams *keygen.LocalPreParams + ecdsaPreParams []*keygen.LocalPreParams identityStore identity.Store peerRegistry PeerRegistry } -func CreatePartyID(nodeID string, label string) *tss.PartyID { - partyID := uuid.NewString() - key := big.NewInt(0).SetBytes([]byte(nodeID)) - return tss.NewPartyID(partyID, label, key) -} - -func PartyIDToNodeID(partyID *tss.PartyID) string { +func PartyIDToRoutingDest(partyID *tss.PartyID) string { return string(partyID.KeyInt().Bytes()) } @@ -67,26 +68,23 @@ func NewNode( identityStore identity.Store, ) *Node { start := time.Now() - preParams, err := keygen.GeneratePreParams(5 * time.Minute) - if err != nil { - logger.Fatal("Generate pre params failed", err) - } elapsed := time.Since(start) logger.Info("Starting new node, preparams is generated successfully!", "elapsed", elapsed.Milliseconds()) - go peerRegistry.WatchPeersReady() - - return &Node{ - nodeID: nodeID, - peerIDs: peerIDs, - pubSub: pubSub, - direct: direct, - kvstore: kvstore, - keyinfoStore: keyinfoStore, - ecdsaPreParams: preParams, - peerRegistry: peerRegistry, - identityStore: identityStore, + node := &Node{ + nodeID: nodeID, + peerIDs: peerIDs, + pubSub: pubSub, + direct: direct, + kvstore: kvstore, + keyinfoStore: keyinfoStore, + peerRegistry: peerRegistry, + identityStore: identityStore, } + node.ecdsaPreParams = node.generatePreParams() + + go peerRegistry.WatchPeersReady() + return node } func (p *Node) ID() string { @@ -97,25 +95,34 @@ func (p *Node) CreateKeyGenSession( sessionType SessionType, walletID string, threshold int, - successQueue messaging.MessageQueue, + resultQueue messaging.MessageQueue, ) (KeyGenSession, error) { if !p.peerRegistry.ArePeersReady() { - return nil, fmt.Errorf("Not enough peers to create gen session! Expected %d, got %d", threshold+1, p.peerRegistry.GetReadyPeersCount()) + return nil, fmt.Errorf( + "Not enough peers to create gen session! Expected %d, got %d", + p.peerRegistry.GetTotalPeersCount(), + p.peerRegistry.GetReadyPeersCount(), + ) + } + + keyInfo, _ := p.getKeyInfo(sessionType, walletID) + if keyInfo != nil { + return nil, fmt.Errorf("Key already exists: %s", walletID) } switch sessionType { case SessionTypeECDSA: - return p.createECDSAKeyGenSession(walletID, threshold, successQueue) + return p.createECDSAKeyGenSession(walletID, threshold, DefaultVersion, resultQueue) case SessionTypeEDDSA: - return p.createEDDSAKeyGenSession(walletID, threshold, successQueue) + return p.createEDDSAKeyGenSession(walletID, threshold, DefaultVersion, resultQueue) default: return nil, fmt.Errorf("Unknown session type: %s", sessionType) } } -func (p *Node) createECDSAKeyGenSession(walletID string, threshold int, successQueue messaging.MessageQueue) (KeyGenSession, error) { +func (p *Node) createECDSAKeyGenSession(walletID string, threshold int, version int, resultQueue messaging.MessageQueue) (KeyGenSession, error) { readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() - selfPartyID, allPartyIDs := p.generatePartyIDs(PurposeKeygen, readyPeerIDs) + selfPartyID, allPartyIDs := p.generatePartyIDs(PurposeKeygen, readyPeerIDs, version) session := newECDSAKeygenSession( walletID, p.pubSub, @@ -124,18 +131,18 @@ func (p *Node) createECDSAKeyGenSession(walletID string, threshold int, successQ selfPartyID, allPartyIDs, threshold, - p.ecdsaPreParams, + p.ecdsaPreParams[0], p.kvstore, p.keyinfoStore, - successQueue, + resultQueue, p.identityStore, ) return session, nil } -func (p *Node) createEDDSAKeyGenSession(walletID string, threshold int, successQueue messaging.MessageQueue) (KeyGenSession, error) { +func (p *Node) createEDDSAKeyGenSession(walletID string, threshold int, version int, resultQueue messaging.MessageQueue) (KeyGenSession, error) { readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() - selfPartyID, allPartyIDs := p.generatePartyIDs(PurposeKeygen, readyPeerIDs) + selfPartyID, allPartyIDs := p.generatePartyIDs(PurposeKeygen, readyPeerIDs, version) session := newEDDSAKeygenSession( walletID, p.pubSub, @@ -146,7 +153,7 @@ func (p *Node) createEDDSAKeyGenSession(walletID string, threshold int, successQ threshold, p.kvstore, p.keyinfoStore, - successQueue, + resultQueue, p.identityStore, ) return session, nil @@ -157,11 +164,36 @@ func (p *Node) CreateSigningSession( walletID string, txID string, networkInternalCode string, - threshold int, resultQueue messaging.MessageQueue, ) (SigningSession, error) { - readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() - selfPartyID, allPartyIDs := p.generatePartyIDs(PurposeKeygen, readyPeerIDs) + version := p.getVersion(sessionType, walletID) + keyInfo, err := p.getKeyInfo(sessionType, walletID) + if err != nil { + return nil, err + } + + readyPeers := p.peerRegistry.GetReadyPeersIncludeSelf() + readyParticipantIDs := p.getReadyPeersForSession(keyInfo, readyPeers) + + logger.Info("Creating signing session", + "type", sessionType, + "readyPeers", readyPeers, + "participantPeerIDs", keyInfo.ParticipantPeerIDs, + "ready count", len(readyParticipantIDs), + "min ready", keyInfo.Threshold+1, + "version", version, + ) + + if len(readyParticipantIDs) < keyInfo.Threshold+1 { + return nil, fmt.Errorf("not enough peers to create signing session! expected %d, got %d", keyInfo.Threshold+1, len(readyParticipantIDs)) + } + + if err := p.ensureNodeIsParticipant(keyInfo); err != nil { + return nil, err + } + + selfPartyID, allPartyIDs := p.generatePartyIDs(PurposeKeygen, readyParticipantIDs, version) + switch sessionType { case SessionTypeECDSA: return newECDSASigningSession( @@ -170,27 +202,28 @@ func (p *Node) CreateSigningSession( networkInternalCode, p.pubSub, p.direct, - readyPeerIDs, + readyParticipantIDs, selfPartyID, allPartyIDs, - threshold, - p.ecdsaPreParams, + keyInfo.Threshold, + p.ecdsaPreParams[0], p.kvstore, p.keyinfoStore, resultQueue, p.identityStore, ), nil + case SessionTypeEDDSA: - return NewEDDSASigningSession( + return newEDDSASigningSession( walletID, txID, networkInternalCode, p.pubSub, p.direct, - readyPeerIDs, + readyParticipantIDs, selfPartyID, allPartyIDs, - threshold, + keyInfo.Threshold, p.kvstore, p.keyinfoStore, resultQueue, @@ -198,22 +231,225 @@ func (p *Node) CreateSigningSession( ), nil } - return nil, errors.New("Unknown session type") + return nil, errors.New("unknown session type") +} + +func (p *Node) getKeyInfo(sessionType SessionType, walletID string) (*keyinfo.KeyInfo, error) { + var keyID string + switch sessionType { + case SessionTypeECDSA: + keyID = fmt.Sprintf("ecdsa:%s", walletID) + case SessionTypeEDDSA: + keyID = fmt.Sprintf("eddsa:%s", walletID) + default: + return nil, errors.New("unsupported session type") + } + return p.keyinfoStore.Get(keyID) +} + +func (p *Node) getReadyPeersForSession(keyInfo *keyinfo.KeyInfo, readyPeers []string) []string { + // Ensure all participants are ready + readyParticipantIDs := make([]string, 0, len(keyInfo.ParticipantPeerIDs)) + for _, peerID := range keyInfo.ParticipantPeerIDs { + if slices.Contains(readyPeers, peerID) { + readyParticipantIDs = append(readyParticipantIDs, peerID) + } + } + + return readyParticipantIDs } -func (p *Node) generatePartyIDs(purpose string, readyPeerIDs []string) (self *tss.PartyID, all []*tss.PartyID) { +func (p *Node) ensureNodeIsParticipant(keyInfo *keyinfo.KeyInfo) error { + if !slices.Contains(keyInfo.ParticipantPeerIDs, p.nodeID) { + return fmt.Errorf("this node %s is not in the participant list", p.nodeID) + } + return nil +} + +func (p *Node) CreateReshareSession( + sessionType SessionType, + walletID string, + oldThreshold int, + newThreshold int, + newPeerIDs []string, + isNewPeer bool, + resultQueue messaging.MessageQueue, +) (ReshareSession, error) { + // 1. Check peer readiness + count := p.peerRegistry.GetReadyPeersCount() + if count < int64(newThreshold)+1 { + return nil, fmt.Errorf( + "not enough peers to create reshare session! Expected at least %d, got %d", + newThreshold+1, + count, + ) + } + + if len(newPeerIDs) < newThreshold+1 { + return nil, fmt.Errorf("new peer list is smaller than required t+1") + } + + // 2. Make sure all new peers are ready + readyNewPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() + for _, peerID := range newPeerIDs { + if !slices.Contains(readyNewPeerIDs, peerID) { + return nil, fmt.Errorf("new peer %s is not ready", peerID) + } + } + + // 3. Load old key info + keyPrefix, err := sessionKeyPrefix(sessionType) + if err != nil { + return nil, fmt.Errorf("failed to get session key prefix: %w", err) + } + keyInfoKey := fmt.Sprintf("%s:%s", keyPrefix, walletID) + oldKeyInfo, err := p.keyinfoStore.Get(keyInfoKey) + if err != nil { + return nil, fmt.Errorf("failed to get old key info: %w", err) + } + + readyPeers := p.peerRegistry.GetReadyPeersIncludeSelf() + readyOldParticipantIDs := p.getReadyPeersForSession(oldKeyInfo, readyPeers) + + isInOldCommittee := slices.Contains(oldKeyInfo.ParticipantPeerIDs, p.nodeID) + isInNewCommittee := slices.Contains(newPeerIDs, p.nodeID) + + // 4. Skip if not relevant + if isNewPeer && !isInNewCommittee { + logger.Info("Skipping new session: node is not in new committee", "walletID", walletID, "nodeID", p.nodeID) + return nil, nil + } + if !isNewPeer && !isInOldCommittee { + logger.Info("Skipping old session: node is not in old committee", "walletID", walletID, "nodeID", p.nodeID) + return nil, nil + } + + logger.Info("Creating resharing session", + "type", sessionType, + "readyPeers", readyPeers, + "participantPeerIDs", oldKeyInfo.ParticipantPeerIDs, + "ready count", len(readyOldParticipantIDs), + "min ready", oldKeyInfo.Threshold+1, + "version", oldKeyInfo.Version, + ) + + if len(readyOldParticipantIDs) < oldKeyInfo.Threshold+1 { + return nil, fmt.Errorf("not enough peers to create resharing session! expected %d, got %d", oldKeyInfo.Threshold+1, len(readyOldParticipantIDs)) + } + + if err := p.ensureNodeIsParticipant(oldKeyInfo); err != nil { + return nil, err + } + + // 5. Generate party IDs + version := p.getVersion(sessionType, walletID) + oldSelf, oldAllPartyIDs := p.generatePartyIDs(PurposeKeygen, readyOldParticipantIDs, version) + newSelf, newAllPartyIDs := p.generatePartyIDs(PurposeReshare, newPeerIDs, version+1) + + // 6. Pick identity and call session constructor var selfPartyID *tss.PartyID - partyIDs := make([]*tss.PartyID, len(readyPeerIDs)) - for i, peerID := range readyPeerIDs { - if peerID == p.nodeID { - selfPartyID = CreatePartyID(peerID, purpose) - partyIDs[i] = selfPartyID + var participantPeerIDs []string + if isNewPeer { + selfPartyID = newSelf + participantPeerIDs = newPeerIDs + } else { + selfPartyID = oldSelf + participantPeerIDs = readyOldParticipantIDs + } + + switch sessionType { + case SessionTypeECDSA: + preParams := p.ecdsaPreParams[0] + if isNewPeer { + preParams = p.ecdsaPreParams[1] + participantPeerIDs = newPeerIDs } else { - partyIDs[i] = CreatePartyID(peerID, purpose) + participantPeerIDs = oldKeyInfo.ParticipantPeerIDs + } + + return NewECDSAReshareSession( + walletID, + p.pubSub, + p.direct, + participantPeerIDs, + selfPartyID, + oldAllPartyIDs, + newAllPartyIDs, + oldThreshold, + newThreshold, + preParams, + p.kvstore, + p.keyinfoStore, + resultQueue, + p.identityStore, + newPeerIDs, + isNewPeer, + oldKeyInfo.Version, + ), nil + + case SessionTypeEDDSA: + return NewEDDSAReshareSession( + walletID, + p.pubSub, + p.direct, + participantPeerIDs, + selfPartyID, + oldAllPartyIDs, + newAllPartyIDs, + oldThreshold, + newThreshold, + p.kvstore, + p.keyinfoStore, + resultQueue, + p.identityStore, + newPeerIDs, + isNewPeer, + oldKeyInfo.Version, + ), nil + + default: + return nil, fmt.Errorf("unsupported session type: %v", sessionType) + } +} + +// generatePartyIDs generates the party IDs for the given purpose and version +// It returns the self party ID and all party IDs +// It also sorts the party IDs in place +func (n *Node) generatePartyIDs( + label string, + readyPeerIDs []string, + version int, +) (self *tss.PartyID, all []*tss.PartyID) { + // Pre-allocate slice with exact size needed + partyIDs := make([]*tss.PartyID, 0, len(readyPeerIDs)) + + // Create all party IDs in one pass + for _, peerID := range readyPeerIDs { + partyID := createPartyID(peerID, label, version) + if peerID == n.nodeID { + self = partyID } + partyIDs = append(partyIDs, partyID) } - allPartyIDs := tss.SortPartyIDs(partyIDs, 0) - return selfPartyID, allPartyIDs + + // Sort party IDs in place + all = tss.SortPartyIDs(partyIDs, 0) + return +} + +// createPartyID creates a new party ID for the given node ID, label and version +// It returns the party ID: random string +// Moniker: for routing messages +// Key: for mpc internal use (need persistent storage) +func createPartyID(nodeID string, label string, version int) *tss.PartyID { + partyID := uuid.NewString() + var key *big.Int + if version == BackwardCompatibleVersion { + key = big.NewInt(0).SetBytes([]byte(nodeID)) + } else { + key = big.NewInt(0).SetBytes([]byte(nodeID + ":" + strconv.Itoa(version))) + } + return tss.NewPartyID(partyID, label, key) } func (p *Node) Close() { @@ -222,3 +458,66 @@ func (p *Node) Close() { logger.Error("Resign failed", err) } } + +func (p *Node) generatePreParams() []*keygen.LocalPreParams { + start := time.Now() + // Try to load from kvstore + preParams := make([]*keygen.LocalPreParams, 2) + for i := 0; i < 2; i++ { + key := fmt.Sprintf("pre_params_%d", i) + val, err := p.kvstore.Get(key) + if err == nil && val != nil { + preParams[i] = &keygen.LocalPreParams{} + err = json.Unmarshal(val, preParams[i]) + if err != nil { + logger.Fatal("Unmarshal pre params failed", err) + } + continue + } + // Not found, generate and save + params, err := keygen.GeneratePreParams(5 * time.Minute) + if err != nil { + logger.Fatal("Generate pre params failed", err) + } + bytes, err := json.Marshal(params) + if err != nil { + logger.Fatal("Marshal pre params failed", err) + } + err = p.kvstore.Put(key, bytes) + if err != nil { + logger.Fatal("Save pre params failed", err) + } + preParams[i] = params + } + logger.Info("Generate pre params successfully!", "elapsed", time.Since(start).Milliseconds()) + return preParams +} + +func (p *Node) getVersion(sessionType SessionType, walletID string) int { + var composeKey string + switch sessionType { + case SessionTypeECDSA: + composeKey = fmt.Sprintf("ecdsa:%s", walletID) + case SessionTypeEDDSA: + composeKey = fmt.Sprintf("eddsa:%s", walletID) + default: + logger.Fatal("Unknown session type", errors.New("Unknown session type")) + } + keyinfo, err := p.keyinfoStore.Get(composeKey) + if err != nil { + logger.Error("Get keyinfo failed", err, "walletID", walletID) + return DefaultVersion + } + return keyinfo.Version +} + +func sessionKeyPrefix(sessionType SessionType) (string, error) { + switch sessionType { + case SessionTypeECDSA: + return "ecdsa", nil + case SessionTypeEDDSA: + return "eddsa", nil + default: + return "", fmt.Errorf("unsupported session type: %v", sessionType) + } +} diff --git a/pkg/mpc/node_test.go b/pkg/mpc/node_test.go index dafacda..3e9f889 100644 --- a/pkg/mpc/node_test.go +++ b/pkg/mpc/node_test.go @@ -35,7 +35,7 @@ import ( // } func TestPartyIDToNodeID(t *testing.T) { - partyID := CreatePartyID("4d8cb873-dc86-4776-b6f6-cf5c668f6468", "keygen") - nodeID := PartyIDToNodeID(partyID) - assert.Equal(t, nodeID, "4d8cb873-dc86-4776-b6f6-cf5c668f6468", "NodeID should be equal") + partyID := createPartyID("4d8cb873-dc86-4776-b6f6-cf5c668f6468", "keygen", 1) + nodeID := PartyIDToRoutingDest(partyID) + assert.Equal(t, nodeID, "4d8cb873-dc86-4776-b6f6-cf5c668f6468:1", "NodeID should be equal") } diff --git a/pkg/mpc/session.go b/pkg/mpc/session.go index e307c3d..9b73e71 100644 --- a/pkg/mpc/session.go +++ b/pkg/mpc/session.go @@ -1,8 +1,8 @@ package mpc import ( + "encoding/json" "fmt" - "strings" "sync" "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" @@ -20,9 +20,11 @@ import ( type SessionType string const ( - TypeGenerateWalletSuccess = "mpc.mpc_keygen_success.%s" - SessionTypeECDSA SessionType = "session_ecdsa" - SessionTypeEDDSA SessionType = "session_eddsa" + TypeGenerateWalletResultFmt = "mpc.mpc_keygen_result.%s" + TypeReshareWalletResultFmt = "mpc.mpc_reshare_result.%s" + + SessionTypeECDSA SessionType = "session_ecdsa" + SessionTypeEDDSA SessionType = "session_eddsa" ) var ( @@ -53,6 +55,7 @@ type session struct { outCh chan tss.Message ErrCh chan error party tss.Party + version int // preParams is nil for EDDSA session preParams *keygen.LocalPreParams @@ -103,7 +106,11 @@ func (s *session) handleTssMessage(keyshare tss.Message) { s.ErrCh <- fmt.Errorf("failed to marshal tss message: %w", err) return } - + toIDs := make([]string, len(routing.To)) + for i, id := range routing.To { + toIDs[i] = id.String() + } + logger.Debug(fmt.Sprintf("%s Sending message", s.sessionType), "from", s.selfPartyID.String(), "to", toIDs, "isBroadcast", routing.IsBroadcast) if routing.IsBroadcast && len(routing.To) == 0 { err := s.pubSub.Publish(s.topicComposer.ComposeBroadcastTopic(), msg) if err != nil { @@ -112,11 +119,12 @@ func (s *session) handleTssMessage(keyshare tss.Message) { } } else { for _, to := range routing.To { - nodeID := PartyIDToNodeID(to) + nodeID := PartyIDToRoutingDest(to) topic := s.topicComposer.ComposeDirectTopic(nodeID) err := s.direct.Send(topic, msg) if err != nil { - s.ErrCh <- fmt.Errorf("Failed to send direct message to %s: %w", topic, err) + logger.Error("Failed to send direct message to", err, "topic", topic) + s.ErrCh <- fmt.Errorf("Failed to send direct message to %s", topic) } } @@ -146,10 +154,15 @@ func (s *session) receiveTssMessage(rawMsg []byte) { s.ErrCh <- errors.Wrap(err, "Broken TSS Share") return } - - logger.Debug(fmt.Sprintf("%s Received message", s.sessionType), "from", msg.From.String(), "to", strings.Join(toIDs, ","), "isBroadcast", msg.IsBroadcast, "round", round.RoundMsg) + logger.Debug("Received message", "round", round.RoundMsg, "isBroadcast", msg.IsBroadcast, "to", toIDs, "from", msg.From.String(), "self", s.selfPartyID.String()) isBroadcast := msg.IsBroadcast && len(msg.To) == 0 - isToSelf := len(msg.To) == 1 && ComparePartyIDs(msg.To[0], s.selfPartyID) + var isToSelf bool + for _, to := range msg.To { + if ComparePartyIDs(to, s.selfPartyID) { + isToSelf = true + break + } + } if isBroadcast || isToSelf { s.mu.Lock() @@ -191,7 +204,7 @@ func (s *session) ListenToIncomingMessageAsync() { s.broadcastSub = sub }() - nodeID := PartyIDToNodeID(s.selfPartyID) + nodeID := PartyIDToRoutingDest(s.selfPartyID) targetID := s.topicComposer.ComposeDirectTopic(nodeID) sub, err := s.direct.Listen(targetID, func(msg []byte) { go s.receiveTssMessage(msg) // async for avoid timeout @@ -222,3 +235,45 @@ func (s *session) GetPubKeyResult() []byte { func (s *session) ErrChan() <-chan error { return s.ErrCh } + +func (s *session) GetVersion() int { + return s.version +} + +// loadOldShareDataGeneric loads the old share data from kvstore with backward compatibility (versioned and unversioned keys) +func (s *session) loadOldShareDataGeneric(walletID string, version int, dest interface{}) error { + var ( + key string + keyData []byte + err error + ) + + // Try versioned key first if version > 0 + if version > 0 { + key = s.composeKey(walletIDWithVersion(walletID, version)) + keyData, err = s.kvstore.Get(key) + } + + // If version == 0 or previous key not found, fall back to unversioned key + if err != nil || version == 0 { + key = s.composeKey(walletID) + keyData, err = s.kvstore.Get(key) + } + + if err != nil { + return fmt.Errorf("failed to get wallet data from KVStore (key=%s): %w", key, err) + } + + if err := json.Unmarshal(keyData, dest); err != nil { + return fmt.Errorf("failed to unmarshal wallet data: %w", err) + } + return nil +} + +// walletIDWithVersion is used to compose the key for the kvstore +func walletIDWithVersion(walletID string, version int) string { + if version > 0 { + return fmt.Sprintf("%s_v%d", walletID, version) + } + return walletID +} diff --git a/pkg/types/initiator_msg.go b/pkg/types/initiator_msg.go index edd0bf4..6684058 100644 --- a/pkg/types/initiator_msg.go +++ b/pkg/types/initiator_msg.go @@ -33,6 +33,15 @@ type SignTxMessage struct { Signature []byte `json:"signature"` } +type ResharingMessage struct { + SessionID string `json:"session_id"` + NodeIDs []string `json:"node_ids"` // new peer IDs + NewThreshold int `json:"new_threshold"` + KeyType KeyType `json:"key_type"` + WalletID string `json:"wallet_id"` + Signature []byte `json:"signature,omitempty"` +} + func (m *SignTxMessage) Raw() ([]byte, error) { // omit the Signature field itself when computing the signed‐over data payload := struct { @@ -70,3 +79,17 @@ func (m *GenerateKeyMessage) Sig() []byte { func (m *GenerateKeyMessage) InitiatorID() string { return m.WalletID } + +func (m *ResharingMessage) Raw() ([]byte, error) { + copy := *m // create a shallow copy + copy.Signature = nil // modify only the copy + return json.Marshal(©) +} + +func (m *ResharingMessage) Sig() []byte { + return m.Signature +} + +func (m *ResharingMessage) InitiatorID() string { + return m.WalletID +}