diff --git a/.gitignore b/.gitignore index 5832d41..14216ed 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ node0 node1 node2 config.yaml +.vscode \ No newline at end of file diff --git a/README.md b/README.md index fd3c791..fcff7c6 100644 --- a/README.md +++ b/README.md @@ -153,41 +153,261 @@ $ mpcium start -n node2 - **Go**: Available in the `pkg/client` directory. Check the `examples` folder for usage samples. - **TypeScript**: Available at [github.com/fystack/mpcium-client-ts](https://github.com/fystack/mpcium-client-ts) -### Client +### Client Usage -```go +Mpcium supports flexible client authentication through a signer interface, allowing you to use either local keys or AWS KMS for signing operations. + +#### Local Signer (Ed25519) +```go import ( - "github.com/fystack/mpcium/client" + "github.com/fystack/mpcium/pkg/client" + "github.com/fystack/mpcium/pkg/event" + "github.com/fystack/mpcium/pkg/types" + "github.com/google/uuid" "github.com/nats-io/nats.go" ) +func main() { + // Connect to NATS + natsConn, err := nats.Connect(natsURL) + if err != nil { + logger.Fatal("Failed to connect to NATS", err) + } + defer natsConn.Close() + + // Create local signer with Ed25519 key + localSigner, err := client.NewLocalSigner(types.EventInitiatorKeyTypeEd25519, client.LocalSignerOptions{ + KeyPath: "./event_initiator.key", + }) + if err != nil { + logger.Fatal("Failed to create local signer", err) + } + + // Create MPC client with signer + mpcClient := client.NewMPCClient(client.Options{ + NatsConn: natsConn, + Signer: localSigner, + }) + + // Handle wallet creation results + err = mpcClient.OnWalletCreationResult(func(event event.KeygenResultEvent) { + logger.Info("Received wallet creation result", "event", event) + }) + if err != nil { + logger.Fatal("Failed to subscribe to wallet-creation results", err) + } + + // Create a wallet + walletID := uuid.New().String() + if err := mpcClient.CreateWallet(walletID); err != nil { + logger.Fatal("CreateWallet failed", err) + } + logger.Info("CreateWallet sent, awaiting result...", "walletID", walletID) +} +``` + +#### Local Signer (P256 with encrypted key) + +```go +// Create local signer with P256 key (encrypted with age) +localSigner, err := client.NewLocalSigner(types.EventInitiatorKeyTypeP256, client.LocalSignerOptions{ + KeyPath: "./event_initiator_p256.key.age", + Encrypted: true, + Password: "your-encryption-password", +}) +``` + +#### AWS KMS Signer + +##### Production (IAM Role-based Authentication) + +For production environments using IAM roles (recommended): + +```go +import ( + "github.com/fystack/mpcium/pkg/client" + "github.com/fystack/mpcium/pkg/types" +) + +func main() { + // KMS signer with role-based authentication (no static credentials) + kmsSigner, err := client.NewKMSSigner(types.EventInitiatorKeyTypeP256, client.KMSSignerOptions{ + Region: "us-east-1", + KeyID: "arn:aws:kms:us-east-1:123456789012:key/12345678-1234-1234-1234-123456789012", + // No AccessKeyID/SecretAccessKey - uses IAM role + }) + if err != nil { + logger.Fatal("Failed to create KMS signer", err) + } + + mpcClient := client.NewMPCClient(client.Options{ + NatsConn: natsConn, + Signer: kmsSigner, + }) + // ... rest of the client code +} +``` + +##### Development with Static Credentials + +```go +// KMS signer with static credentials (development only) +kmsSigner, err := client.NewKMSSigner(types.EventInitiatorKeyTypeP256, client.KMSSignerOptions{ + Region: "us-west-2", + KeyID: "12345678-1234-1234-1234-123456789012", + AccessKeyID: "AKIA...", + SecretAccessKey: "...", +}) +``` + +##### LocalStack Development + +```go +// KMS signer with LocalStack for local development +kmsSigner, err := client.NewKMSSigner(types.EventInitiatorKeyTypeP256, client.KMSSignerOptions{ + Region: "us-east-1", + KeyID: "48e76117-fd08-4dc0-bd10-b1c7d01de748", + EndpointURL: "http://localhost:4566", // LocalStack endpoint + AccessKeyID: "test", // LocalStack dummy credentials + SecretAccessKey: "test", +}) +``` + +##### AWS Cloud Config Variations -func main () { - natsConn, err := nats.Connect(natsURL) - if err != nil { - logger.Fatal("Failed to connect to NATS", err) - } - defer natsConn.Close() - mpcClient := client.NewMPCClient(client.Options{ - NatsConn: natsConn, - KeyPath: "./event_initiator.key", - }) - err = mpcClient.OnWalletCreationResult(func(event event.KeygenSuccessEvent) { - logger.Info("Received wallet creation result", "event", event) - }) - if err != nil { - logger.Fatal("Failed to subscribe to wallet-creation results", err) - } - - walletID := uuid.New().String() - if err := mpcClient.CreateWallet(walletID); err != nil { - logger.Fatal("CreateWallet failed", err) - } - logger.Info("CreateWallet sent, awaiting result...", "walletID", walletID) +```go +// Different regions and key formats +configs := []client.KMSSignerOptions{ + // Key ID only + { + Region: "eu-west-1", + KeyID: "12345678-1234-1234-1234-123456789012", + }, + // Full ARN + { + Region: "ap-southeast-1", + KeyID: "arn:aws:kms:ap-southeast-1:123456789012:key/12345678-1234-1234-1234-123456789012", + }, + // Key alias + { + Region: "us-east-2", + KeyID: "alias/mpcium-signing-key", + }, } ``` +**Note**: AWS KMS only supports P256 (ECDSA) keys, not Ed25519. If you need Ed25519, use the local signer. + +## Test with AWS KMS (LocalStack) + +For local development and testing with AWS KMS functionality, you can use LocalStack to simulate AWS KMS services. + +### Setup LocalStack + +1. **Install and start LocalStack:** + ```bash + # Using Docker + docker run -d \ + -p 4566:4566 \ + -p 4510-4559:4510-4559 \ + localstack/localstack + + # Or using LocalStack CLI + pip install localstack + localstack start + ``` + +2. **Configure AWS CLI for LocalStack:** + ```bash + aws configure set aws_access_key_id test + aws configure set aws_secret_access_key test + aws configure set region us-east-1 + ``` + +### Create P256 Key in LocalStack + +1. **Create a P256 keypair in AWS KMS:** + ```bash + aws kms create-key \ + --endpoint-url=http://localhost:4566 \ + --description "Test P-256 keypair for Mpcium" \ + --key-usage SIGN_VERIFY \ + --customer-master-key-spec ECC_NIST_P256 + ``` + + Expected response: + ```json + { + "KeyMetadata": { + "AWSAccountId": "000000000000", + "KeyId": "330a9df7-4fd9-4e86-bfc5-f360b4c4be39", + "Arn": "arn:aws:kms:us-east-1:000000000000:key/330a9df7-4fd9-4e86-bfc5-f360b4c4be39", + "CreationDate": "2025-08-28T16:42:18.487655+07:00", + "Enabled": true, + "Description": "Test P-256 keypair for Mpcium", + "KeyUsage": "SIGN_VERIFY", + "KeyState": "Enabled", + "Origin": "AWS_KMS", + "KeyManager": "CUSTOMER", + "CustomerMasterKeySpec": "ECC_NIST_P256", + "KeySpec": "ECC_NIST_P256", + "SigningAlgorithms": [ + "ECDSA_SHA_256" + ], + "MultiRegion": false + } + } + ``` + +2. **Get the public key (save the KeyId from step 1):** + ```bash + export KMS_KEY_ID="330a9df7-4fd9-4e86-bfc5-f360b4c4be39" # Replace with your KeyId + + aws kms get-public-key \ + --endpoint-url=http://localhost:4566 \ + --key-id $KMS_KEY_ID \ + --query PublicKey \ + --output text | base64 -d | xxd -p -c 256 + ``` + + Expected response (hex-encoded public key): + ``` + 3059301306072a8648ce3d020106082a8648ce3d030107034200042b7539fc51123c3ba53c71e244be71d2d3138cbed4909fa259b924b56c92148cadd410cf98b789269d7f672c3ba978e99fc1f01c87daee97292d3666357738fd + ``` + +### Configure Mpcium for LocalStack KMS + +Update your `config.yaml` file with the KMS public key and algorithm: + +```yaml +# MPC Configuration +mpc_threshold: 2 +event_initiator_pubkey: "3059301306072a8648ce3d020106082a8648ce3d030107034200042b7539fc51123c3ba53c71e244be71d2d3138cbed4909fa259b924b56c92148cadd410cf98b789269d7f672c3ba978e99fc1f01c87daee97292d3666357738fd" +event_initiator_algorithm: "p256" + +# Other configuration... +nats: + url: "nats://localhost:4222" +consul: + address: "localhost:8500" +``` + +### Test KMS Integration + +Run the KMS example: + +```bash +# Run the KMS example directly +go run examples/generate/kms/main.go -n 1 +``` + +The example will: +1. Connect to LocalStack KMS endpoint +2. Load the P256 public key from KMS +3. Use KMS for signing wallet creation events +4. Generate wallets using the MPC cluster + ### Testing ## 1. Unit tests diff --git a/cmd/mpcium-cli/generate-initiator.go b/cmd/mpcium-cli/generate-initiator.go index de87e57..4e8af7e 100644 --- a/cmd/mpcium-cli/generate-initiator.go +++ b/cmd/mpcium-cli/generate-initiator.go @@ -2,25 +2,26 @@ package main import ( "context" - "crypto/ed25519" - "crypto/rand" - "encoding/hex" "encoding/json" "fmt" "os" "os/user" "path/filepath" "runtime" + "slices" "time" "filippo.io/age" "github.com/fystack/mpcium/pkg/common/pathutil" + "github.com/fystack/mpcium/pkg/encryption" + "github.com/fystack/mpcium/pkg/types" "github.com/urfave/cli/v3" ) // Identity struct to store node metadata type InitiatorIdentity struct { NodeName string `json:"node_name"` + Algorithm string `json:"algorithm,omitempty"` PublicKey string `json:"public_key"` CreatedAt string `json:"created_at"` CreatedBy string `json:"created_by"` @@ -33,6 +34,22 @@ func generateInitiatorIdentity(ctx context.Context, c *cli.Command) error { outputDir := c.String("output-dir") encrypt := c.Bool("encrypt") overwrite := c.Bool("overwrite") + algorithm := c.String("algorithm") + + if algorithm == "" { + algorithm = string(types.EventInitiatorKeyTypeEd25519) + } + + if !slices.Contains( + []string{string(types.EventInitiatorKeyTypeEd25519), string(types.EventInitiatorKeyTypeP256)}, + algorithm, + ) { + return fmt.Errorf("invalid algorithm: %s. Must be %s or %s", + algorithm, + types.EventInitiatorKeyTypeEd25519, + types.EventInitiatorKeyTypeP256, + ) + } // Create output directory if it doesn't exist if err := os.MkdirAll(outputDir, 0750); err != nil { @@ -46,7 +63,10 @@ func generateInitiatorIdentity(ctx context.Context, c *cli.Command) error { // Check for existing identity file if _, err := os.Stat(identityPath); err == nil && !overwrite { - return fmt.Errorf("identity file already exists: %s (use --overwrite to force)", identityPath) + return fmt.Errorf( + "identity file already exists: %s (use --overwrite to force)", + identityPath, + ) } // Check for existing key files @@ -56,19 +76,26 @@ func generateInitiatorIdentity(ctx context.Context, c *cli.Command) error { if encrypt { if _, err := os.Stat(encKeyPath); err == nil && !overwrite { - return fmt.Errorf("encrypted key file already exists: %s (use --overwrite to force)", encKeyPath) + return fmt.Errorf( + "encrypted key file already exists: %s (use --overwrite to force)", + encKeyPath, + ) } } - // Generate Ed25519 keypair - pubKey, privKeyFull, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - return fmt.Errorf("failed to generate Ed25519 keypair: %w", err) + // Generate keys based on algorithm + var keyData encryption.KeyData + var err error + + if algorithm == string(types.EventInitiatorKeyTypeEd25519) { + keyData, err = encryption.GenerateEd25519Keys() + } else if algorithm == string(types.EventInitiatorKeyTypeP256) { + keyData, err = encryption.GenerateP256Keys() } - // Extract 32-byte seed - privKeySeed := privKeyFull.Seed() - privHex := hex.EncodeToString(privKeySeed) + if err != nil { + return fmt.Errorf("failed to generate %s keys: %w", algorithm, err) + } // Get current user currentUser, err := user.Current() @@ -85,7 +112,8 @@ func generateInitiatorIdentity(ctx context.Context, c *cli.Command) error { // Create Identity object identity := InitiatorIdentity{ NodeName: nodeName, - PublicKey: hex.EncodeToString(pubKey), + Algorithm: algorithm, + PublicKey: keyData.PublicKeyHex, CreatedAt: time.Now().UTC().Format(time.RFC3339), CreatedBy: currentUser.Username, MachineOS: runtime.GOOS, @@ -136,7 +164,7 @@ func generateInitiatorIdentity(ctx context.Context, c *cli.Command) error { } // Write the encrypted private key - if _, err := identityWriter.Write([]byte(privHex)); err != nil { + if _, err := identityWriter.Write([]byte(keyData.PrivateKeyHex)); err != nil { return fmt.Errorf("failed to write encrypted private key: %w", err) } @@ -152,7 +180,7 @@ func generateInitiatorIdentity(ctx context.Context, c *cli.Command) error { fmt.Println("WARNING: You are generating the private key without encryption.") fmt.Println("This is less secure. Consider using --encrypt flag for better security.") - if err := os.WriteFile(keyPath, []byte(privHex), 0600); err != nil { + if err := os.WriteFile(keyPath, []byte(keyData.PrivateKeyHex), 0600); err != nil { return fmt.Errorf("failed to save private key: %w", err) } } diff --git a/cmd/mpcium-cli/main.go b/cmd/mpcium-cli/main.go index 910d261..8c9d0ad 100644 --- a/cmd/mpcium-cli/main.go +++ b/cmd/mpcium-cli/main.go @@ -122,6 +122,12 @@ func main() { Value: false, Usage: "Overwrite identity files if they already exist", }, + &cli.StringFlag{ + Name: "algorithm", + Aliases: []string{"a"}, + Value: "ed25519", + Usage: "Algorithm to use for key generation (ed25519,p256)", + }, }, Action: generateInitiatorIdentity, }, diff --git a/config.yaml.template b/config.yaml.template index 694c90c..35be692 100644 --- a/config.yaml.template +++ b/config.yaml.template @@ -6,6 +6,7 @@ consul: mpc_threshold: 2 environment: development badger_password: "F))ysJp?E]ol&I;^" +event_initiator_algorithm: "ed25519" # or "ed25519", default: ed25519 event_initiator_pubkey: "event_initiator_pubkey" db_path: "." backup_enabled: true diff --git a/e2e/base_test.go b/e2e/base_test.go index 6d38588..5c209c4 100644 --- a/e2e/base_test.go +++ b/e2e/base_test.go @@ -18,6 +18,7 @@ import ( "github.com/fystack/mpcium/pkg/client" "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/types" "github.com/hashicorp/consul/api" "github.com/nats-io/nats.go" "github.com/stretchr/testify/require" @@ -184,9 +185,17 @@ func (s *E2ETestSuite) SetupMPCClient(t *testing.T) { t.Fatalf("Key file does not exist: %s. Make sure setupTestNodes ran successfully.", keyPath) } + // Create local signer for Ed25519 (default for E2E tests) + localSigner, err := client.NewLocalSigner(types.EventInitiatorKeyTypeEd25519, client.LocalSignerOptions{ + KeyPath: keyPath, + }) + if err != nil { + t.Fatalf("Failed to create local signer: %v", err) + } + mpcClient := client.NewMPCClient(client.Options{ NatsConn: s.natsConn, - KeyPath: keyPath, + Signer: localSigner, }) s.mpcClient = mpcClient t.Log("MPC client created") diff --git a/e2e/go.mod b/e2e/go.mod index 03304de..4112bb3 100644 --- a/e2e/go.mod +++ b/e2e/go.mod @@ -17,6 +17,20 @@ require ( github.com/agl/ed25519 v0.0.0-20200225211852-fd4d107ace12 // indirect github.com/armon/go-metrics v0.4.1 // indirect github.com/avast/retry-go v3.0.0+incompatible // indirect + github.com/aws/aws-sdk-go-v2 v1.38.2 // indirect + github.com/aws/aws-sdk-go-v2/config v1.31.4 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.8 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.5 // indirect + github.com/aws/aws-sdk-go-v2/service/kms v1.45.0 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.28.3 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.38.1 // indirect + github.com/aws/smithy-go v1.23.0 // indirect github.com/bnb-chain/tss-lib/v2 v2.0.2 // indirect github.com/btcsuite/btcd v0.24.2 // indirect github.com/btcsuite/btcd/btcec/v2 v2.3.2 // indirect diff --git a/e2e/go.sum b/e2e/go.sum index 5acaf08..804a2c4 100644 --- a/e2e/go.sum +++ b/e2e/go.sum @@ -17,6 +17,34 @@ github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/avast/retry-go v3.0.0+incompatible h1:4SOWQ7Qs+oroOTQOYnAHqelpCO0biHSxpiH9JdtuBj0= github.com/avast/retry-go v3.0.0+incompatible/go.mod h1:XtSnn+n/sHqQIpZ10K1qAevBhOOCWBLXXy3hyiqqBrY= +github.com/aws/aws-sdk-go-v2 v1.38.2 h1:QUkLO1aTW0yqW95pVzZS0LGFanL71hJ0a49w4TJLMyM= +github.com/aws/aws-sdk-go-v2 v1.38.2/go.mod h1:sDioUELIUO9Znk23YVmIk86/9DOpkbyyVb1i/gUNFXY= +github.com/aws/aws-sdk-go-v2/config v1.31.4 h1:aY2IstXOfjdLtr1lDvxFBk5DpBnHgS5GS3jgR/0BmPw= +github.com/aws/aws-sdk-go-v2/config v1.31.4/go.mod h1:1IAykiegrTp6n+CbZoCpW6kks1I74fEDgl2BPQSkLSU= +github.com/aws/aws-sdk-go-v2/credentials v1.18.8 h1:0FfdP0I9gs/f1rwtEdkcEdsclTEkPB8o6zWUG2Z8+IM= +github.com/aws/aws-sdk-go-v2/credentials v1.18.8/go.mod h1:9UReQ1UmGooX93JKzHyr7PRF3F+p3r+PmRwR7+qHJYA= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.5 h1:ul7hICbZ5Z/Pp9VnLVGUVe7rqYLXCyIiPU7hQ0sRkow= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.5/go.mod h1:5cIWJ0N6Gjj+72Q6l46DeaNtcxXHV42w/Uq3fIfeUl4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.5 h1:d45S2DqHZOkHu0uLUW92VdBoT5v0hh3EyR+DzMEh3ag= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.5/go.mod h1:G6e/dR2c2huh6JmIo9SXysjuLuDDGWMeYGibfW2ZrXg= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.5 h1:ENhnQOV3SxWHplOqNN1f+uuCNf9n4Y/PKpl6b1WRP0Q= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.5/go.mod h1:csQLMI+odbC0/J+UecSTztG70Dc4aTCOu4GyPNDNpVo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1 h1:oegbebPEMA/1Jny7kvwejowCaHz1FWZAQ94WXFNCyTM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1/go.mod h1:kemo5Myr9ac0U9JfSjMo9yHLtw+pECEHsFtJ9tqCEI8= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.5 h1:Cx1M/UUgYu9UCQnIMKaOhkVaFvLy1HneD6T4sS/DlKg= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.5/go.mod h1:fTRNLgrTvPpEzGqc9QkeO4hu/3ng+mdtUbL8shUwXz4= +github.com/aws/aws-sdk-go-v2/service/kms v1.45.0 h1:WYQcp4o0/X+Xd50dSFluzKk3Lee2mP+tP39uMI60s1M= +github.com/aws/aws-sdk-go-v2/service/kms v1.45.0/go.mod h1:le5DfWrncVIxOWL2Q0NnDqvhH8ULiGYgC9iS8BtwcZE= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.3 h1:z6lajFT/qGlLRB/I8V5CCklqSuWZKUkdwRAn9leIkiQ= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.3/go.mod h1:BnyjuIX0l+KXJVl2o9Ki3Zf0M4pA2hQYopFCRUj9ADU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.1 h1:8yI3jK5JZ310S8RpgdZdzwvlvBu3QbG8DP7Be/xJ6yo= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.1/go.mod h1:HPzXfFgrLd02lYpcFYdDz5xZs94LOb+lWlvbAGaeMsk= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.1 h1:3kWmIg5iiWPMBJyq/I55Fki5fyfoMtrn/SkUIpxPwHQ= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.1/go.mod h1:yi0b3Qez6YamRVJ+Rbi19IgvjfjPODgVRhkWA6RTMUM= +github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE= +github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= diff --git a/examples/generate/kms/main.go b/examples/generate/kms/main.go new file mode 100644 index 0000000..2ccb177 --- /dev/null +++ b/examples/generate/kms/main.go @@ -0,0 +1,167 @@ +package main + +import ( + "encoding/json" + "flag" + "fmt" + "os" + "os/signal" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/fystack/mpcium/pkg/client" + "github.com/fystack/mpcium/pkg/config" + "github.com/fystack/mpcium/pkg/event" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/types" + "github.com/google/uuid" + "github.com/nats-io/nats.go" + "github.com/spf13/viper" +) + +func main() { + const environment = "development" + const awsRegion = "ap-southeast-1" + const kmsKeyID = "48e76117-fd08-4dc0-bd10-b1c7d01de748" + + numWallets := flag.Int("n", 1, "Number of wallets to generate") + + flag.Parse() + + config.InitViperConfig() + logger.Init(environment, false) + + // KMS signer only supports P256 + + natsURL := viper.GetString("nats.url") + natsConn, err := nats.Connect(natsURL) + if err != nil { + logger.Fatal("Failed to connect to NATS", err) + } + defer natsConn.Drain() + defer natsConn.Close() + + // For AWS production, use: + kmsSigner, err := client.NewKMSSigner(types.EventInitiatorKeyTypeP256, client.KMSSignerOptions{ + Region: awsRegion, + KeyID: kmsKeyID, + EndpointURL: "http://localhost:4566", // LocalStack endpoint + AccessKeyID: "test", // LocalStack dummy credentials + SecretAccessKey: "test", // LocalStack dummy credentials + }) + if err != nil { + logger.Fatal("Failed to create KMS signer", err) + } + + // Log the public key for verification + pubKey, err := kmsSigner.PublicKey() + if err != nil { + logger.Fatal("Failed to get public key from KMS signer", err) + } + logger.Info("Public key", "key", pubKey) + + mpcClient := client.NewMPCClient(client.Options{ + NatsConn: natsConn, + Signer: kmsSigner, + }) + + var walletStartTimes sync.Map + var walletIDs []string + var walletIDsMu sync.Mutex + var wg sync.WaitGroup + var completedCount int32 + + startAll := time.Now() + + // STEP 1: Pre-generate wallet IDs and store start times + for i := 0; i < *numWallets; i++ { + walletID := uuid.New().String() + walletStartTimes.Store(walletID, time.Now()) + + walletIDsMu.Lock() + walletIDs = append(walletIDs, walletID) + walletIDsMu.Unlock() + } + + // STEP 2: Register the result handler AFTER all walletIDs are stored + err = mpcClient.OnWalletCreationResult(func(event event.KeygenResultEvent) { + logger.Info("Received wallet creation result", "event", event) + now := time.Now() + startTimeAny, ok := walletStartTimes.Load(event.WalletID) + if ok { + startTime := startTimeAny.(time.Time) + duration := now.Sub(startTime).Seconds() + accumulated := now.Sub(startAll).Seconds() + countSoFar := atomic.AddInt32(&completedCount, 1) + + logger.Info("Wallet created", + "walletID", event.WalletID, + "duration_seconds", fmt.Sprintf("%.3f", duration), + "accumulated_time_seconds", fmt.Sprintf("%.3f", accumulated), + "count_so_far", countSoFar, + ) + + walletStartTimes.Delete(event.WalletID) + } else { + logger.Warn("Received wallet result but no start time found", "walletID", event.WalletID) + } + wg.Done() + }) + if err != nil { + logger.Fatal("Failed to subscribe to wallet-creation results", err) + } + + // STEP 3: Create wallets + for _, walletID := range walletIDs { + wg.Add(1) // Add to WaitGroup BEFORE attempting to create wallet + + if err := mpcClient.CreateWallet(walletID); err != nil { + logger.Error("CreateWallet failed", err) + walletStartTimes.Delete(walletID) + wg.Done() // Now this is safe since we added 1 above + continue + } + + logger.Info("CreateWallet sent, awaiting result...", "walletID", walletID) + } + + // Wait until all wallet creations complete + go func() { + wg.Wait() + totalDuration := time.Since(startAll).Seconds() + logger.Info( + "All wallets generated using KMS signer", + "count", + completedCount, + "total_duration_seconds", + fmt.Sprintf("%.3f", totalDuration), + "kms_key_id", + kmsKeyID, + ) + + // Save wallet IDs to wallets.json + walletIDsMu.Lock() + data, err := json.MarshalIndent(walletIDs, "", " ") + walletIDsMu.Unlock() + if err != nil { + logger.Error("Failed to marshal wallet IDs", err) + } else { + err = os.WriteFile("wallets.json", data, 0600) + if err != nil { + logger.Error("Failed to write wallets.json", err) + } else { + logger.Info("wallets.json written", "count", len(walletIDs)) + } + } + os.Exit(0) + }() + + // Block on SIGINT/SIGTERM (Ctrl+C etc.) + stop := make(chan os.Signal, 1) + signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) + <-stop + + fmt.Println("Shutting down.") +} diff --git a/examples/generate/main.go b/examples/generate/main.go index 952935d..6f54c31 100644 --- a/examples/generate/main.go +++ b/examples/generate/main.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "os/signal" + "slices" "sync" "sync/atomic" "syscall" @@ -15,6 +16,7 @@ import ( "github.com/fystack/mpcium/pkg/config" "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/types" "github.com/google/uuid" "github.com/nats-io/nats.go" "github.com/spf13/viper" @@ -23,11 +25,35 @@ import ( func main() { const environment = "development" numWallets := flag.Int("n", 1, "Number of wallets to generate") + flag.Parse() config.InitViperConfig() logger.Init(environment, false) + algorithm := viper.GetString("event_initiator_algorithm") + if algorithm == "" { + algorithm = string(types.EventInitiatorKeyTypeEd25519) + } + + if !slices.Contains( + []string{ + string(types.EventInitiatorKeyTypeEd25519), + string(types.EventInitiatorKeyTypeP256), + }, + algorithm, + ) { + logger.Fatal( + fmt.Sprintf( + "invalid algorithm: %s. Must be %s or %s", + algorithm, + types.EventInitiatorKeyTypeEd25519, + types.EventInitiatorKeyTypeP256, + ), + nil, + ) + } + natsURL := viper.GetString("nats.url") natsConn, err := nats.Connect(natsURL) if err != nil { @@ -36,9 +62,16 @@ func main() { defer natsConn.Drain() defer natsConn.Close() + localSigner, err := client.NewLocalSigner(types.EventInitiatorKeyType(algorithm), client.LocalSignerOptions{ + KeyPath: "./event_initiator.key", + }) + if err != nil { + logger.Fatal("Failed to create local signer", err) + } + mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, - KeyPath: "./event_initiator.key", + Signer: localSigner, }) var walletStartTimes sync.Map @@ -89,13 +122,15 @@ func main() { // STEP 3: Create wallets for _, walletID := range walletIDs { - wg.Add(1) + wg.Add(1) // Add to WaitGroup BEFORE attempting to create wallet + if err := mpcClient.CreateWallet(walletID); err != nil { logger.Error("CreateWallet failed", err) walletStartTimes.Delete(walletID) - wg.Done() + wg.Done() // Now this is safe since we added 1 above continue } + logger.Info("CreateWallet sent, awaiting result...", "walletID", walletID) } @@ -103,7 +138,13 @@ func main() { go func() { wg.Wait() totalDuration := time.Since(startAll).Seconds() - logger.Info("All wallets generated", "count", completedCount, "total_duration_seconds", fmt.Sprintf("%.3f", totalDuration)) + logger.Info( + "All wallets generated", + "count", + completedCount, + "total_duration_seconds", + fmt.Sprintf("%.3f", totalDuration), + ) // Save wallet IDs to wallets.json walletIDsMu.Lock() diff --git a/examples/reshare/main.go b/examples/reshare/main.go index 3de3174..68ea786 100644 --- a/examples/reshare/main.go +++ b/examples/reshare/main.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "os/signal" + "slices" "syscall" "github.com/fystack/mpcium/pkg/client" @@ -21,6 +22,29 @@ func main() { config.InitViperConfig() logger.Init(environment, true) + algorithm := viper.GetString("event_initiator_algorithm") + if algorithm == "" { + algorithm = string(types.EventInitiatorKeyTypeEd25519) + } + + // Validate algorithm + if !slices.Contains( + []string{ + string(types.EventInitiatorKeyTypeEd25519), + string(types.EventInitiatorKeyTypeP256), + }, + algorithm, + ) { + logger.Fatal( + fmt.Sprintf( + "invalid algorithm: %s. Must be %s or %s", + algorithm, + types.EventInitiatorKeyTypeEd25519, + types.EventInitiatorKeyTypeP256, + ), + nil, + ) + } natsURL := viper.GetString("nats.url") natsConn, err := nats.Connect(natsURL) if err != nil { @@ -29,9 +53,16 @@ func main() { defer natsConn.Drain() defer natsConn.Close() + localSigner, err := client.NewLocalSigner(types.EventInitiatorKeyType(algorithm), client.LocalSignerOptions{ + KeyPath: "./event_initiator.key", + }) + if err != nil { + logger.Fatal("Failed to create local signer", err) + } + mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, - KeyPath: "./event_initiator.key", + Signer: localSigner, }) // 3) Listen for signing results @@ -50,7 +81,10 @@ func main() { resharingMsg := &types.ResharingMessage{ SessionID: uuid.NewString(), WalletID: "506d2d40-483a-49f1-93c8-27dd4fe9740c", - NodeIDs: []string{"c95c340e-5a18-472d-b9b0-5ac68218213a", "ac37e85f-caca-4bee-8a3a-49a0fe35abff"}, // new peer IDs + NodeIDs: []string{ + "c95c340e-5a18-472d-b9b0-5ac68218213a", + "ac37e85f-caca-4bee-8a3a-49a0fe35abff", + }, // new peer IDs NewThreshold: 1, // t+1 <= len(NodeIDs) KeyType: types.KeyTypeEd25519, diff --git a/examples/sign/main.go b/examples/sign/main.go index 62031c5..4cc4aa1 100644 --- a/examples/sign/main.go +++ b/examples/sign/main.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "os/signal" + "slices" "syscall" "github.com/fystack/mpcium/pkg/client" @@ -21,6 +22,29 @@ func main() { config.InitViperConfig() logger.Init(environment, true) + algorithm := viper.GetString("event_initiator_algorithm") + if algorithm == "" { + algorithm = string(types.EventInitiatorKeyTypeEd25519) + } + + // Validate algorithm + if !slices.Contains( + []string{ + string(types.EventInitiatorKeyTypeEd25519), + string(types.EventInitiatorKeyTypeP256), + }, + algorithm, + ) { + logger.Fatal( + fmt.Sprintf( + "invalid algorithm: %s. Must be %s or %s", + algorithm, + types.EventInitiatorKeyTypeEd25519, + types.EventInitiatorKeyTypeP256, + ), + nil, + ) + } natsURL := viper.GetString("nats.url") natsConn, err := nats.Connect(natsURL) if err != nil { @@ -29,9 +53,16 @@ func main() { defer natsConn.Drain() defer natsConn.Close() + localSigner, err := client.NewLocalSigner(types.EventInitiatorKeyType(algorithm), client.LocalSignerOptions{ + KeyPath: "./event_initiator.key", + }) + if err != nil { + logger.Fatal("Failed to create local signer", err) + } + mpcClient := client.NewMPCClient(client.Options{ NatsConn: natsConn, - KeyPath: "./event_initiator.key", + Signer: localSigner, }) // 2) Once wallet exists, immediately fire a SignTransaction @@ -40,7 +71,7 @@ func main() { txMsg := &types.SignTxMessage{ KeyType: types.KeyTypeEd25519, - WalletID: "c47cd6f4-8ef4-4d77-9d2b-37f9d062e615", + WalletID: "ad24f678-b04b-4149-bcf6-bf9c90df8e63", // Use the generated wallet ID NetworkInternalCode: "solana-devnet", TxID: txID, Tx: dummyTx, diff --git a/go.mod b/go.mod index 15b45be..4b75cb9 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,8 @@ toolchain go1.23.5 require ( filippo.io/age v1.2.1 github.com/avast/retry-go v3.0.0+incompatible + github.com/aws/aws-sdk-go-v2/config v1.31.4 + github.com/aws/aws-sdk-go-v2/service/kms v1.45.0 github.com/bnb-chain/tss-lib/v2 v2.0.2 github.com/decred/dcrd/dcrec/edwards/v2 v2.0.3 github.com/dgraph-io/badger/v4 v4.7.0 @@ -19,12 +21,25 @@ require ( github.com/spf13/viper v1.18.0 github.com/stretchr/testify v1.10.0 github.com/urfave/cli/v3 v3.3.2 + golang.org/x/crypto v0.37.0 golang.org/x/term v0.31.0 ) require ( github.com/agl/ed25519 v0.0.0-20200225211852-fd4d107ace12 // indirect github.com/armon/go-metrics v0.4.1 // indirect + github.com/aws/aws-sdk-go-v2 v1.38.2 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.8 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.28.3 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.38.1 // indirect + github.com/aws/smithy-go v1.23.0 // indirect github.com/btcsuite/btcd v0.24.2 // indirect github.com/btcsuite/btcd/btcec/v2 v2.3.2 // indirect github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0 // indirect @@ -71,6 +86,7 @@ require ( github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.6 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/otel v1.35.0 // indirect @@ -80,7 +96,6 @@ require ( go.uber.org/goleak v1.3.0 // indirect go.uber.org/multierr v1.9.0 // indirect go.uber.org/zap v1.21.0 // indirect - golang.org/x/crypto v0.37.0 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect golang.org/x/net v0.39.0 // indirect golang.org/x/sys v0.33.0 // indirect diff --git a/go.sum b/go.sum index 2b58b3e..125692d 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,34 @@ github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/avast/retry-go v3.0.0+incompatible h1:4SOWQ7Qs+oroOTQOYnAHqelpCO0biHSxpiH9JdtuBj0= github.com/avast/retry-go v3.0.0+incompatible/go.mod h1:XtSnn+n/sHqQIpZ10K1qAevBhOOCWBLXXy3hyiqqBrY= +github.com/aws/aws-sdk-go-v2 v1.38.2 h1:QUkLO1aTW0yqW95pVzZS0LGFanL71hJ0a49w4TJLMyM= +github.com/aws/aws-sdk-go-v2 v1.38.2/go.mod h1:sDioUELIUO9Znk23YVmIk86/9DOpkbyyVb1i/gUNFXY= +github.com/aws/aws-sdk-go-v2/config v1.31.4 h1:aY2IstXOfjdLtr1lDvxFBk5DpBnHgS5GS3jgR/0BmPw= +github.com/aws/aws-sdk-go-v2/config v1.31.4/go.mod h1:1IAykiegrTp6n+CbZoCpW6kks1I74fEDgl2BPQSkLSU= +github.com/aws/aws-sdk-go-v2/credentials v1.18.8 h1:0FfdP0I9gs/f1rwtEdkcEdsclTEkPB8o6zWUG2Z8+IM= +github.com/aws/aws-sdk-go-v2/credentials v1.18.8/go.mod h1:9UReQ1UmGooX93JKzHyr7PRF3F+p3r+PmRwR7+qHJYA= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.5 h1:ul7hICbZ5Z/Pp9VnLVGUVe7rqYLXCyIiPU7hQ0sRkow= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.5/go.mod h1:5cIWJ0N6Gjj+72Q6l46DeaNtcxXHV42w/Uq3fIfeUl4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.5 h1:d45S2DqHZOkHu0uLUW92VdBoT5v0hh3EyR+DzMEh3ag= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.5/go.mod h1:G6e/dR2c2huh6JmIo9SXysjuLuDDGWMeYGibfW2ZrXg= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.5 h1:ENhnQOV3SxWHplOqNN1f+uuCNf9n4Y/PKpl6b1WRP0Q= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.5/go.mod h1:csQLMI+odbC0/J+UecSTztG70Dc4aTCOu4GyPNDNpVo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1 h1:oegbebPEMA/1Jny7kvwejowCaHz1FWZAQ94WXFNCyTM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1/go.mod h1:kemo5Myr9ac0U9JfSjMo9yHLtw+pECEHsFtJ9tqCEI8= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.5 h1:Cx1M/UUgYu9UCQnIMKaOhkVaFvLy1HneD6T4sS/DlKg= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.5/go.mod h1:fTRNLgrTvPpEzGqc9QkeO4hu/3ng+mdtUbL8shUwXz4= +github.com/aws/aws-sdk-go-v2/service/kms v1.45.0 h1:WYQcp4o0/X+Xd50dSFluzKk3Lee2mP+tP39uMI60s1M= +github.com/aws/aws-sdk-go-v2/service/kms v1.45.0/go.mod h1:le5DfWrncVIxOWL2Q0NnDqvhH8ULiGYgC9iS8BtwcZE= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.3 h1:z6lajFT/qGlLRB/I8V5CCklqSuWZKUkdwRAn9leIkiQ= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.3/go.mod h1:BnyjuIX0l+KXJVl2o9Ki3Zf0M4pA2hQYopFCRUj9ADU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.1 h1:8yI3jK5JZ310S8RpgdZdzwvlvBu3QbG8DP7Be/xJ6yo= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.1/go.mod h1:HPzXfFgrLd02lYpcFYdDz5xZs94LOb+lWlvbAGaeMsk= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.1 h1:3kWmIg5iiWPMBJyq/I55Fki5fyfoMtrn/SkUIpxPwHQ= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.1/go.mod h1:yi0b3Qez6YamRVJ+Rbi19IgvjfjPODgVRhkWA6RTMUM= +github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE= +github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= diff --git a/pkg/client/client.go b/pkg/client/client.go index 94bc1df..3121bdb 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -2,16 +2,9 @@ package client import ( "context" - "crypto/ed25519" - "encoding/hex" "encoding/json" "fmt" - "io" - "os" - "path/filepath" - "strings" - "filippo.io/age" "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/eventconsumer" "github.com/fystack/mpcium/pkg/logger" @@ -43,7 +36,7 @@ type mpcClient struct { genKeySuccessQueue messaging.MessageQueue signResultQueue messaging.MessageQueue reshareSuccessQueue messaging.MessageQueue - privKey ed25519.PrivateKey + signer Signer } // Options defines configuration options for creating a new MPCClient @@ -51,81 +44,37 @@ type Options struct { // NATS connection NatsConn *nats.Conn - // Key path options - KeyPath string // Path to unencrypted key (default: "./event_initiator.key") - - // Encryption options - Encrypted bool // Whether the key is encrypted - Password string // Password for encrypted key + // Signer for signing messages + Signer Signer } // NewMPCClient creates a new MPC client using the provided options. -// It reads the Ed25519 private key from disk and sets up messaging connections. -// If the key is encrypted (.age file), decryption options must be provided in the config. +// The signer must be provided to handle message signing. func NewMPCClient(opts Options) MPCClient { - // Set default paths if not provided - if opts.KeyPath == "" { - opts.KeyPath = filepath.Join(".", "event_initiator.key") - } - - if strings.HasSuffix(opts.KeyPath, ".age") { - opts.Encrypted = true - } - - var privHexBytes []byte - var err error - - // Check if key file exists - if _, err := os.Stat(opts.KeyPath); err == nil { - if opts.Encrypted { - // Encrypted key exists, try to decrypt it - if opts.Password == "" { - logger.Fatal("Encrypted key found but no decryption option provided", nil) - } - - // Read encrypted file - encryptedBytes, err := os.ReadFile(opts.KeyPath) - if err != nil { - logger.Fatal("Failed to read encrypted private key file", err) - } - - // Decrypt the key using the provided password - privHexBytes, err = decryptPrivateKey(encryptedBytes, opts.Password) - if err != nil { - logger.Fatal("Failed to decrypt private key", err) - } - } else { - // Unencrypted key exists, read it normally - privHexBytes, err = os.ReadFile(opts.KeyPath) - if err != nil { - logger.Fatal("Failed to read private key file", err) - } - } - } else { - logger.Fatal("No private key file found", nil) + if opts.Signer == nil { + logger.Fatal("Signer is required", nil) } - privHex := string(privHexBytes) - // Decode private key from hex - privSeed, err := hex.DecodeString(privHex) - if err != nil { - fmt.Println("Failed to decode private key hex:", err) - os.Exit(1) - } - - // Reconstruct full Ed25519 private key from seed - priv := ed25519.NewKeyFromSeed(privSeed) - // 2) Create the PubSub for both publish & subscribe - signingBroker, err := messaging.NewJetStreamBroker(context.Background(), opts.NatsConn, "mpc-signing", []string{ - "mpc.signing_request.*", - }) + signingBroker, err := messaging.NewJetStreamBroker( + context.Background(), + opts.NatsConn, + "mpc-signing", + []string{ + "mpc.signing_request.*", + }, + ) if err != nil { logger.Fatal("Failed to create signing jetstream broker", err) } - keygenBroker, err := messaging.NewJetStreamBroker(context.Background(), opts.NatsConn, "mpc-keygen", []string{ - "mpc.keygen_request.*", - }) + keygenBroker, err := messaging.NewJetStreamBroker( + context.Background(), + opts.NatsConn, + "mpc-keygen", + []string{ + "mpc.keygen_request.*", + }, + ) if err != nil { logger.Fatal("Failed to create keygen jetstream broker", err) } @@ -149,33 +98,10 @@ func NewMPCClient(opts Options) MPCClient { genKeySuccessQueue: genKeySuccessQueue, signResultQueue: signResultQueue, reshareSuccessQueue: reshareSuccessQueue, - privKey: priv, + signer: opts.Signer, } } -// decryptPrivateKey decrypts the encrypted private key using the provided password -func decryptPrivateKey(encryptedData []byte, password string) ([]byte, error) { - // Create an age identity (decryption key) from the password - identity, err := age.NewScryptIdentity(password) - if err != nil { - return nil, fmt.Errorf("failed to create identity from password: %w", err) - } - - // Create a reader from the encrypted data - decrypter, err := age.Decrypt(strings.NewReader(string(encryptedData)), identity) - if err != nil { - return nil, fmt.Errorf("failed to create decrypter: %w", err) - } - - // Read the decrypted data - decryptedData, err := io.ReadAll(decrypter) - if err != nil { - return nil, fmt.Errorf("failed to read decrypted data: %w", err) - } - - return decryptedData, nil -} - // CreateWallet generates a GenerateKeyMessage, signs it, and publishes it. func (c *mpcClient) CreateWallet(walletID string) error { // build the message @@ -187,8 +113,11 @@ func (c *mpcClient) CreateWallet(walletID string) error { if err != nil { return fmt.Errorf("CreateWallet: raw payload error: %w", err) } - // sign - msg.Signature = ed25519.Sign(c.privKey, raw) + signature, err := c.signer.Sign(raw) + if err != nil { + return fmt.Errorf("CreateWallet: failed to sign message: %w", err) + } + msg.Signature = signature bytes, err := json.Marshal(msg) if err != nil { @@ -227,8 +156,11 @@ func (c *mpcClient) SignTransaction(msg *types.SignTxMessage) error { if err != nil { return fmt.Errorf("SignTransaction: raw payload error: %w", err) } - // sign - msg.Signature = ed25519.Sign(c.privKey, raw) + signature, err := c.signer.Sign(raw) + if err != nil { + return fmt.Errorf("SignTransaction: failed to sign message: %w", err) + } + msg.Signature = signature bytes, err := json.Marshal(msg) if err != nil { @@ -265,8 +197,11 @@ func (c *mpcClient) Resharing(msg *types.ResharingMessage) error { if err != nil { return fmt.Errorf("Resharing: raw payload error: %w", err) } - // sign - msg.Signature = ed25519.Sign(c.privKey, raw) + signature, err := c.signer.Sign(raw) + if err != nil { + return fmt.Errorf("Resharing: failed to sign message: %w", err) + } + msg.Signature = signature bytes, err := json.Marshal(msg) if err != nil { diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go new file mode 100644 index 0000000..36ac514 --- /dev/null +++ b/pkg/client/client_test.go @@ -0,0 +1,294 @@ +package client + +import ( + "errors" + "testing" + + "github.com/fystack/mpcium/pkg/types" + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// MockSigner is a mock implementation of the Signer interface +type MockSigner struct { + mock.Mock +} + +func (m *MockSigner) Sign(data []byte) ([]byte, error) { + args := m.Called(data) + return args.Get(0).([]byte), args.Error(1) +} + +func (m *MockSigner) Algorithm() types.EventInitiatorKeyType { + args := m.Called() + return args.Get(0).(types.EventInitiatorKeyType) +} + +func (m *MockSigner) PublicKey() (string, error) { + args := m.Called() + return args.String(0), args.Error(1) +} + +// MockNATSConn creates a mock NATS connection for testing +func MockNATSConn() *nats.Conn { + // For unit tests, we can return nil and handle it appropriately in tests + // In a real test environment, you would use nats-server for testing + return nil +} + +func TestNewMPCClient_Success(t *testing.T) { + mockSigner := &MockSigner{} + mockSigner.On("Algorithm").Return(types.EventInitiatorKeyTypeEd25519) + + // Since we can't easily create a real NATS connection in unit tests, + // we'll test the Options validation logic + opts := Options{ + NatsConn: MockNATSConn(), // This would normally be a real connection + Signer: mockSigner, + } + + // Test that signer is required + assert.NotNil(t, opts.Signer) +} + +func TestNewMPCClient_NoSigner(t *testing.T) { + // Test that client creation fails without signer + // This test would require mocking the logger.Fatal call or refactoring to return error + opts := Options{ + NatsConn: MockNATSConn(), + Signer: nil, + } + + assert.Nil(t, opts.Signer, "Signer should be nil to test error case") +} + +func TestMPCClient_CreateWallet(t *testing.T) { + mockSigner := &MockSigner{} + + // Set up expectations + testSignature := []byte("test-signature") + mockSigner.On("Sign", mock.AnythingOfType("[]uint8")).Return(testSignature, nil) + + // Create a client instance directly for testing (bypassing NATS setup) + client := &mpcClient{ + signer: mockSigner, + } + + // Test CreateWallet - this will test the signing logic + // Note: This test would require mocking the messaging broker as well + // For now, we test that the signer is called correctly + + walletID := "test-wallet-123" + + // We can't fully test CreateWallet without mocking the broker, + // but we can test the signing part by calling it directly + + // Simulate what CreateWallet does with signing + msg := &types.GenerateKeyMessage{ + WalletID: walletID, + } + + raw, err := msg.Raw() + require.NoError(t, err) + + signature, err := client.signer.Sign(raw) + require.NoError(t, err) + assert.Equal(t, testSignature, signature) + + // Verify mock expectations + mockSigner.AssertExpectations(t) +} + +func TestMPCClient_CreateWallet_SigningError(t *testing.T) { + mockSigner := &MockSigner{} + + // Set up signer to return error + mockSigner.On("Sign", mock.AnythingOfType("[]uint8")).Return([]byte(nil), errors.New("signing failed")) + + client := &mpcClient{ + signer: mockSigner, + } + + // Simulate the signing part that would happen in CreateWallet + msg := &types.GenerateKeyMessage{ + WalletID: "test-wallet", + } + + raw, err := msg.Raw() + require.NoError(t, err) + + signature, err := client.signer.Sign(raw) + assert.Error(t, err) + assert.Nil(t, signature) + assert.Contains(t, err.Error(), "signing failed") + + mockSigner.AssertExpectations(t) +} + +func TestMPCClient_SignTransaction(t *testing.T) { + mockSigner := &MockSigner{} + + // Set up expectations + testSignature := []byte("test-transaction-signature") + mockSigner.On("Sign", mock.AnythingOfType("[]uint8")).Return(testSignature, nil) + + client := &mpcClient{ + signer: mockSigner, + } + + // Test signing part of SignTransaction + msg := &types.SignTxMessage{ + KeyType: types.KeyTypeSecp256k1, + WalletID: "test-wallet", + NetworkInternalCode: "btc-mainnet", + TxID: "test-tx-123", + Tx: []byte("test transaction data"), + } + + raw, err := msg.Raw() + require.NoError(t, err) + + signature, err := client.signer.Sign(raw) + require.NoError(t, err) + assert.Equal(t, testSignature, signature) + + mockSigner.AssertExpectations(t) +} + +func TestMPCClient_Resharing(t *testing.T) { + mockSigner := &MockSigner{} + + // Set up expectations + testSignature := []byte("test-resharing-signature") + mockSigner.On("Sign", mock.AnythingOfType("[]uint8")).Return(testSignature, nil) + + client := &mpcClient{ + signer: mockSigner, + } + + // Test signing part of Resharing + msg := &types.ResharingMessage{ + SessionID: "reshare-session-123", + NodeIDs: []string{"node1", "node2", "node3"}, + NewThreshold: 2, + KeyType: types.KeyTypeSecp256k1, + WalletID: "test-wallet", + } + + raw, err := msg.Raw() + require.NoError(t, err) + + signature, err := client.signer.Sign(raw) + require.NoError(t, err) + assert.Equal(t, testSignature, signature) + + mockSigner.AssertExpectations(t) +} + +func TestSignerInterface_Compliance(t *testing.T) { + // Test that our mock signer implements the interface correctly + mockSigner := &MockSigner{} + + // Set up mock expectations + mockSigner.On("Algorithm").Return(types.EventInitiatorKeyTypeP256) + mockSigner.On("PublicKey").Return("mock-public-key-hex", nil) + mockSigner.On("Sign", []byte("test")).Return([]byte("mock-signature"), nil) + + // Test interface compliance + var signer Signer = mockSigner + + algorithm := signer.Algorithm() + assert.Equal(t, types.EventInitiatorKeyTypeP256, algorithm) + + pubKey, err := signer.PublicKey() + require.NoError(t, err) + assert.Equal(t, "mock-public-key-hex", pubKey) + + signature, err := signer.Sign([]byte("test")) + require.NoError(t, err) + assert.Equal(t, []byte("mock-signature"), signature) + + mockSigner.AssertExpectations(t) +} + +func TestSignerInterface_ErrorHandling(t *testing.T) { + mockSigner := &MockSigner{} + + // Set up error cases + mockSigner.On("PublicKey").Return("", errors.New("public key error")) + mockSigner.On("Sign", mock.Anything).Return([]byte(nil), errors.New("signing error")) + + var signer Signer = mockSigner + + // Test public key error + pubKey, err := signer.PublicKey() + assert.Error(t, err) + assert.Empty(t, pubKey) + assert.Contains(t, err.Error(), "public key error") + + // Test signing error + signature, err := signer.Sign([]byte("test")) + assert.Error(t, err) + assert.Nil(t, signature) + assert.Contains(t, err.Error(), "signing error") + + mockSigner.AssertExpectations(t) +} + +// Integration test helpers +func TestOptionsValidation(t *testing.T) { + t.Run("valid options", func(t *testing.T) { + mockSigner := &MockSigner{} + opts := Options{ + NatsConn: MockNATSConn(), + Signer: mockSigner, + } + + assert.NotNil(t, opts.Signer) + // In real implementation, would also check NatsConn is not nil + }) + + t.Run("missing signer", func(t *testing.T) { + opts := Options{ + NatsConn: MockNATSConn(), + Signer: nil, + } + + assert.Nil(t, opts.Signer) + // This would trigger a fatal error in NewMPCClient + }) +} + +// Benchmark tests for signing operations +func BenchmarkMockSigner_Sign(b *testing.B) { + mockSigner := &MockSigner{} + testSignature := []byte("benchmark-signature") + mockSigner.On("Sign", mock.AnythingOfType("[]uint8")).Return(testSignature, nil) + + data := []byte("benchmark test data") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := mockSigner.Sign(data) + if err != nil { + b.Fatalf("Sign failed: %v", err) + } + } +} + +// Test helper functions +func createTestMPCClient(signer Signer) *mpcClient { + return &mpcClient{ + signer: signer, + } +} + +func TestCreateTestMPCClient(t *testing.T) { + mockSigner := &MockSigner{} + client := createTestMPCClient(mockSigner) + + assert.NotNil(t, client) + assert.Equal(t, mockSigner, client.signer) +} \ No newline at end of file diff --git a/pkg/client/kms_signer.go b/pkg/client/kms_signer.go new file mode 100644 index 0000000..0693c29 --- /dev/null +++ b/pkg/client/kms_signer.go @@ -0,0 +1,162 @@ +package client + +import ( + "context" + "crypto/ecdsa" + "crypto/x509" + "encoding/hex" + "fmt" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/kms" + kmstypes "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/fystack/mpcium/pkg/encryption" + "github.com/fystack/mpcium/pkg/types" +) + +// KMSSigner implements the Signer interface for AWS KMS-based signing +type KMSSigner struct { + keyType types.EventInitiatorKeyType + client *kms.Client + keyID string + publicKey *ecdsa.PublicKey +} + +// KMSSignerOptions defines options for creating a KMSSigner +type KMSSignerOptions struct { + Region string // AWS region (e.g., "us-east-1", "us-west-2") - Required + KeyID string // AWS KMS key ID or ARN - Required + EndpointURL string // Custom endpoint URL (optional, for LocalStack/custom services) + AccessKeyID string // AWS access key ID (optional, uses default credential chain if not provided) + SecretAccessKey string // AWS secret access key (optional, uses default credential chain if not provided) +} + +// NewKMSSigner creates a new KMSSigner using AWS KMS +// Note: AWS KMS supports P256, not Ed25519 +func NewKMSSigner(keyType types.EventInitiatorKeyType, opts KMSSignerOptions) (Signer, error) { + // AWS KMS only supports P256 for ECDSA + if keyType != types.EventInitiatorKeyTypeP256 { + return nil, fmt.Errorf("AWS KMS only supports P256 keys, not %s", keyType) + } + + // Validate required options + if opts.KeyID == "" { + return nil, fmt.Errorf("KeyID is required for KMS signer") + } + if opts.Region == "" { + return nil, fmt.Errorf("Region is required for KMS signer") + } + + // Create AWS config + ctx := context.Background() + var configOptions []func(*config.LoadOptions) error + + // Set region + configOptions = append(configOptions, config.WithRegion(opts.Region)) + + // Set custom credentials if provided + if opts.AccessKeyID != "" && opts.SecretAccessKey != "" { + credProvider := credentials.NewStaticCredentialsProvider(opts.AccessKeyID, opts.SecretAccessKey, "") + configOptions = append(configOptions, config.WithCredentialsProvider(credProvider)) + } + + cfg, err := config.LoadDefaultConfig(ctx, configOptions...) + if err != nil { + return nil, fmt.Errorf("failed to load AWS config: %w", err) + } + + // Create KMS client with optional custom endpoint + var clientOptions []func(*kms.Options) + if opts.EndpointURL != "" { + clientOptions = append(clientOptions, func(o *kms.Options) { + o.BaseEndpoint = &opts.EndpointURL + }) + } + + client := kms.NewFromConfig(cfg, clientOptions...) + + signer := &KMSSigner{ + keyType: keyType, + client: client, + keyID: opts.KeyID, + } + + // Retrieve and cache the public key + if err := signer.loadPublicKey(ctx); err != nil { + return nil, fmt.Errorf("failed to load public key from KMS: %w", err) + } + + return signer, nil +} + +// loadPublicKey retrieves the public key from AWS KMS and caches it +func (k *KMSSigner) loadPublicKey(ctx context.Context) error { + input := &kms.GetPublicKeyInput{ + KeyId: &k.keyID, + } + + resp, err := k.client.GetPublicKey(ctx, input) + if err != nil { + return fmt.Errorf("failed to get public key from AWS KMS: %w", err) + } + + // Parse DER encoded public key + publicKeyInterface, err := x509.ParsePKIXPublicKey(resp.PublicKey) + if err != nil { + return fmt.Errorf("failed to parse public key from KMS response: %w", err) + } + + publicKey, ok := publicKeyInterface.(*ecdsa.PublicKey) + if !ok { + return fmt.Errorf("KMS public key is not an ECDSA key") + } + + // Validate it's P256 + if err := encryption.ValidateP256PublicKey(publicKey); err != nil { + return fmt.Errorf("KMS public key is not a valid P256 key: %w", err) + } + + k.publicKey = publicKey + return nil +} + +// Sign implements the Signer interface for KMSSigner +func (k *KMSSigner) Sign(data []byte) ([]byte, error) { + ctx := context.Background() + + // Create the signing request + input := &kms.SignInput{ + KeyId: &k.keyID, + Message: data, + MessageType: kmstypes.MessageTypeRaw, + SigningAlgorithm: kmstypes.SigningAlgorithmSpecEcdsaSha256, + } + + // Call AWS KMS to sign the data + resp, err := k.client.Sign(ctx, input) + if err != nil { + return nil, fmt.Errorf("failed to sign with AWS KMS: %w", err) + } + + return resp.Signature, nil +} + +// Algorithm implements the Signer interface for KMSSigner +func (k *KMSSigner) Algorithm() types.EventInitiatorKeyType { + return k.keyType +} + +// PublicKey implements the Signer interface for KMSSigner +func (k *KMSSigner) PublicKey() (string, error) { + if k.publicKey == nil { + return "", fmt.Errorf("public key not loaded") + } + + pubKeyBytes, err := encryption.MarshalP256PublicKey(k.publicKey) + if err != nil { + return "", fmt.Errorf("failed to marshal P256 public key: %w", err) + } + + return hex.EncodeToString(pubKeyBytes), nil +} diff --git a/pkg/client/kms_signer_test.go b/pkg/client/kms_signer_test.go new file mode 100644 index 0000000..bf6bc3b --- /dev/null +++ b/pkg/client/kms_signer_test.go @@ -0,0 +1,229 @@ +package client + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "encoding/hex" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/fystack/mpcium/pkg/encryption" + "github.com/fystack/mpcium/pkg/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// MockKMSClient is a mock implementation of the AWS KMS client +type MockKMSClient struct { + mock.Mock +} + +func (m *MockKMSClient) GetPublicKey(ctx context.Context, params *kms.GetPublicKeyInput, optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { + args := m.Called(ctx, params) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*kms.GetPublicKeyOutput), args.Error(1) +} + +func (m *MockKMSClient) Sign(ctx context.Context, params *kms.SignInput, optFns ...func(*kms.Options)) (*kms.SignOutput, error) { + args := m.Called(ctx, params) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*kms.SignOutput), args.Error(1) +} + +func TestNewKMSSigner_Success(t *testing.T) { + // Generate a test P256 key for mock response + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) + require.NoError(t, err) + + // Create mock client + mockClient := &MockKMSClient{} + mockClient.On("GetPublicKey", mock.Anything, mock.MatchedBy(func(input *kms.GetPublicKeyInput) bool { + return *input.KeyId == "test-key-id" + })).Return(&kms.GetPublicKeyOutput{ + PublicKey: publicKeyBytes, + }, nil) + + // Test creating KMS signer - we'll need to inject the mock somehow + // For now, we'll test the validation logic + opts := KMSSignerOptions{ + Region: "us-east-1", + KeyID: "test-key-id", + } + + // Test validation + assert.NotEmpty(t, opts.KeyID) + assert.NotEmpty(t, opts.Region) +} + +func TestNewKMSSigner_ValidationErrors(t *testing.T) { + t.Run("unsupported key type", func(t *testing.T) { + _, err := NewKMSSigner(types.EventInitiatorKeyTypeEd25519, KMSSignerOptions{ + Region: "us-east-1", + KeyID: "test-key", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "AWS KMS only supports P256 keys") + }) + + t.Run("missing key ID", func(t *testing.T) { + _, err := NewKMSSigner(types.EventInitiatorKeyTypeP256, KMSSignerOptions{ + Region: "us-east-1", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "KeyID is required") + }) +} + +func TestKMSSigner_Algorithm(t *testing.T) { + // Create a KMS signer instance directly (bypassing AWS client creation for unit test) + signer := &KMSSigner{ + keyType: types.EventInitiatorKeyTypeP256, + } + + assert.Equal(t, types.EventInitiatorKeyTypeP256, signer.Algorithm()) +} + +func TestKMSSigner_PublicKey(t *testing.T) { + // Generate a test P256 key + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + // Create a KMS signer instance with a test public key + signer := &KMSSigner{ + keyType: types.EventInitiatorKeyTypeP256, + publicKey: &privateKey.PublicKey, + } + + pubKeyHex, err := signer.PublicKey() + require.NoError(t, err) + assert.NotEmpty(t, pubKeyHex) + + // Verify it's valid hex + pubKeyBytes, err := hex.DecodeString(pubKeyHex) + require.NoError(t, err) + assert.NotEmpty(t, pubKeyBytes) + + // Verify we can parse the public key back + parsedPubKey, err := encryption.ParseP256PublicKeyFromBytes(pubKeyBytes) + require.NoError(t, err) + assert.Equal(t, privateKey.PublicKey.X, parsedPubKey.X) + assert.Equal(t, privateKey.PublicKey.Y, parsedPubKey.Y) +} + +func TestKMSSigner_PublicKey_NotLoaded(t *testing.T) { + signer := &KMSSigner{ + keyType: types.EventInitiatorKeyTypeP256, + // publicKey is nil + } + + pubKeyHex, err := signer.PublicKey() + assert.Error(t, err) + assert.Empty(t, pubKeyHex) + assert.Contains(t, err.Error(), "public key not loaded") +} + +// TestKMSSignerIntegration tests the KMS signer with mocked AWS responses +func TestKMSSignerIntegration(t *testing.T) { + // This test would require more complex mocking to fully test the KMS integration + // For now, we test the structure and validation logic + + t.Run("options validation", func(t *testing.T) { + validOpts := KMSSignerOptions{ + Region: "us-west-2", + KeyID: "arn:aws:kms:us-west-2:123456789012:key/12345678-1234-1234-1234-123456789012", + } + + // Test that options are properly structured + assert.NotEmpty(t, validOpts.Region) + assert.NotEmpty(t, validOpts.KeyID) + assert.Contains(t, validOpts.KeyID, "arn:aws:kms") // Example ARN format + }) + + t.Run("key ID formats", func(t *testing.T) { + validKeyIDs := []string{ + "12345678-1234-1234-1234-123456789012", // Key ID + "alias/my-key", // Alias + "arn:aws:kms:us-west-2:123456789012:key/12345678-1234-1234-1234-123456789012", // Full ARN + "arn:aws:kms:us-west-2:123456789012:alias/my-key", // Alias ARN + } + + for _, keyID := range validKeyIDs { + opts := KMSSignerOptions{ + Region: "us-west-2", + KeyID: keyID, + } + assert.NotEmpty(t, opts.KeyID, "Key ID should not be empty: %s", keyID) + } + }) +} + +// TestKMSSignerMockIntegration demonstrates how to test KMS signer with proper mocking +func TestKMSSignerMockIntegration(t *testing.T) { + // Generate test key pair + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + // Note: For actual testing, we would use these variables + _ = privateKey // Avoid unused variable error + + // Create a KMS signer with test data (simulating successful initialization) + signer := &KMSSigner{ + keyType: types.EventInitiatorKeyTypeP256, + keyID: "test-key-id", + publicKey: &privateKey.PublicKey, + } + + // Test public key retrieval + pubKeyHex, err := signer.PublicKey() + require.NoError(t, err) + assert.NotEmpty(t, pubKeyHex) + + // Verify the hex can be decoded back to the correct public key + decodedBytes, err := hex.DecodeString(pubKeyHex) + require.NoError(t, err) + + parsedPubKey, err := encryption.ParseP256PublicKeyFromBytes(decodedBytes) + require.NoError(t, err) + assert.Equal(t, signer.publicKey.X, parsedPubKey.X) + assert.Equal(t, signer.publicKey.Y, parsedPubKey.Y) + + // Test algorithm + assert.Equal(t, types.EventInitiatorKeyTypeP256, signer.Algorithm()) + + // Note: Actual signing would require mocking the AWS client's Sign method + // This demonstrates the structure for such tests + t.Log("KMS signer structure validated successfully") +} + +// Helper function to create a test KMS signer (for integration tests) +func createTestKMSSigner(t *testing.T) *KMSSigner { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + return &KMSSigner{ + keyType: types.EventInitiatorKeyTypeP256, + keyID: "test-key-id", + publicKey: &privateKey.PublicKey, + } +} + +func TestKMSSignerHelpers(t *testing.T) { + signer := createTestKMSSigner(t) + + assert.Equal(t, types.EventInitiatorKeyTypeP256, signer.Algorithm()) + + pubKey, err := signer.PublicKey() + require.NoError(t, err) + assert.NotEmpty(t, pubKey) +} diff --git a/pkg/client/local_signer.go b/pkg/client/local_signer.go new file mode 100644 index 0000000..88c8537 --- /dev/null +++ b/pkg/client/local_signer.go @@ -0,0 +1,188 @@ +package client + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "encoding/hex" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "filippo.io/age" + "github.com/fystack/mpcium/pkg/encryption" + "github.com/fystack/mpcium/pkg/types" +) + +// LocalSigner implements the Signer interface for local key management +type LocalSigner struct { + keyType types.EventInitiatorKeyType + ed25519Key ed25519.PrivateKey + p256Key *ecdsa.PrivateKey +} + +// LocalSignerOptions defines options for creating a LocalSigner +type LocalSignerOptions struct { + KeyPath string // Path to the key file + Encrypted bool // Whether the key is encrypted + Password string // Password for decryption (required if encrypted) +} + +// NewLocalSigner creates a new LocalSigner for the specified key type +func NewLocalSigner(keyType types.EventInitiatorKeyType, opts LocalSignerOptions) (Signer, error) { + signer := &LocalSigner{ + keyType: keyType, + } + + // Set default path if not provided + if opts.KeyPath == "" { + opts.KeyPath = filepath.Join(".", "event_initiator.key") + } + + // Auto-detect encryption if .age extension + if strings.HasSuffix(opts.KeyPath, ".age") { + opts.Encrypted = true + } + + // Read the key file + keyData, err := readKeyFile(opts.KeyPath, opts.Encrypted, opts.Password) + if err != nil { + return nil, fmt.Errorf("failed to read key file: %w", err) + } + + // Parse the key based on type + switch keyType { + case types.EventInitiatorKeyTypeEd25519: + if err := signer.loadEd25519Key(keyData); err != nil { + return nil, fmt.Errorf("failed to load Ed25519 key: %w", err) + } + case types.EventInitiatorKeyTypeP256: + if err := signer.loadP256Key(keyData); err != nil { + return nil, fmt.Errorf("failed to load P256 key: %w", err) + } + default: + return nil, fmt.Errorf("unsupported key type: %s", keyType) + } + + return signer, nil +} + +// readKeyFile reads a key file, handling both encrypted and unencrypted files +func readKeyFile(keyPath string, encrypted bool, password string) ([]byte, error) { + // Check if key file exists + if _, err := os.Stat(keyPath); err != nil { + return nil, fmt.Errorf("key file not found: %s", keyPath) + } + + if encrypted { + if password == "" { + return nil, fmt.Errorf("encrypted key found but no password provided") + } + + // Read encrypted file + encryptedBytes, err := os.ReadFile(keyPath) + if err != nil { + return nil, fmt.Errorf("failed to read encrypted key file: %w", err) + } + + // Decrypt the key + return decryptPrivateKey(encryptedBytes, password) + } else { + // Read unencrypted key + return os.ReadFile(keyPath) + } +} + +// loadEd25519Key loads an Ed25519 private key from hex data +func (s *LocalSigner) loadEd25519Key(keyData []byte) error { + privHex := string(keyData) + privSeed, err := hex.DecodeString(strings.TrimSpace(privHex)) + if err != nil { + return fmt.Errorf("failed to decode Ed25519 private key hex: %w", err) + } + + s.ed25519Key = ed25519.NewKeyFromSeed(privSeed) + return nil +} + +// loadP256Key loads a P256 private key from various formats +func (s *LocalSigner) loadP256Key(keyData []byte) error { + privKey, err := encryption.ParseP256PrivateKey(keyData) + if err != nil { + return fmt.Errorf("failed to parse P256 private key: %w", err) + } + + s.p256Key = privKey + return nil +} + +// Sign implements the Signer interface for LocalSigner +func (s *LocalSigner) Sign(data []byte) ([]byte, error) { + switch s.keyType { + case types.EventInitiatorKeyTypeEd25519: + if s.ed25519Key == nil { + return nil, fmt.Errorf("Ed25519 private key not initialized") + } + return ed25519.Sign(s.ed25519Key, data), nil + + case types.EventInitiatorKeyTypeP256: + if s.p256Key == nil { + return nil, fmt.Errorf("P256 private key not initialized") + } + return encryption.SignWithP256(s.p256Key, data) + + default: + return nil, fmt.Errorf("unsupported key type: %s", s.keyType) + } +} + +// Algorithm implements the Signer interface for LocalSigner +func (s *LocalSigner) Algorithm() types.EventInitiatorKeyType { + return s.keyType +} + +// PublicKey implements the Signer interface for LocalSigner +func (s *LocalSigner) PublicKey() (string, error) { + switch s.keyType { + case types.EventInitiatorKeyTypeEd25519: + if s.ed25519Key == nil { + return "", fmt.Errorf("Ed25519 private key not initialized") + } + pubKey := s.ed25519Key.Public().(ed25519.PublicKey) + return hex.EncodeToString(pubKey), nil + + case types.EventInitiatorKeyTypeP256: + if s.p256Key == nil { + return "", fmt.Errorf("P256 private key not initialized") + } + pubKeyBytes, err := encryption.MarshalP256PublicKey(&s.p256Key.PublicKey) + if err != nil { + return "", fmt.Errorf("failed to marshal P256 public key: %w", err) + } + return hex.EncodeToString(pubKeyBytes), nil + + default: + return "", fmt.Errorf("unsupported key type: %s", s.keyType) + } +} + +// decryptPrivateKey decrypts an encrypted private key using age +func decryptPrivateKey(encryptedData []byte, password string) ([]byte, error) { + identity, err := age.NewScryptIdentity(password) + if err != nil { + return nil, fmt.Errorf("failed to create identity from password: %w", err) + } + + decrypter, err := age.Decrypt(strings.NewReader(string(encryptedData)), identity) + if err != nil { + return nil, fmt.Errorf("failed to create decrypter: %w", err) + } + + decryptedData, err := io.ReadAll(decrypter) + if err != nil { + return nil, fmt.Errorf("failed to read decrypted data: %w", err) + } + + return decryptedData, nil +} \ No newline at end of file diff --git a/pkg/client/local_signer_test.go b/pkg/client/local_signer_test.go new file mode 100644 index 0000000..0d5768b --- /dev/null +++ b/pkg/client/local_signer_test.go @@ -0,0 +1,307 @@ +package client + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/hex" + "os" + "path/filepath" + "testing" + + "filippo.io/age" + "github.com/fystack/mpcium/pkg/encryption" + "github.com/fystack/mpcium/pkg/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewLocalSigner_Ed25519(t *testing.T) { + // Generate a test Ed25519 key + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + seed := privateKey.Seed() + privKeyHex := hex.EncodeToString(seed) + + // Create temporary key file + tmpDir := t.TempDir() + keyPath := filepath.Join(tmpDir, "test_ed25519.key") + + err = os.WriteFile(keyPath, []byte(privKeyHex), 0600) + require.NoError(t, err) + + // Test creating signer + signer, err := NewLocalSigner(types.EventInitiatorKeyTypeEd25519, LocalSignerOptions{ + KeyPath: keyPath, + }) + require.NoError(t, err) + require.NotNil(t, signer) + + localSigner, ok := signer.(*LocalSigner) + require.True(t, ok) + assert.Equal(t, types.EventInitiatorKeyTypeEd25519, localSigner.keyType) + assert.NotNil(t, localSigner.ed25519Key) + assert.Nil(t, localSigner.p256Key) +} + +func TestNewLocalSigner_P256(t *testing.T) { + // Generate a test P256 key + keyData, err := encryption.GenerateP256Keys() + require.NoError(t, err) + + // Create temporary key file + tmpDir := t.TempDir() + keyPath := filepath.Join(tmpDir, "test_p256.key") + + err = os.WriteFile(keyPath, []byte(keyData.PrivateKeyHex), 0600) + require.NoError(t, err) + + // Test creating signer + signer, err := NewLocalSigner(types.EventInitiatorKeyTypeP256, LocalSignerOptions{ + KeyPath: keyPath, + }) + require.NoError(t, err) + require.NotNil(t, signer) + + localSigner, ok := signer.(*LocalSigner) + require.True(t, ok) + assert.Equal(t, types.EventInitiatorKeyTypeP256, localSigner.keyType) + assert.Nil(t, localSigner.ed25519Key) + assert.NotNil(t, localSigner.p256Key) +} + +func TestNewLocalSigner_EncryptedKey(t *testing.T) { + // Generate a test Ed25519 key + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + seed := privateKey.Seed() + privKeyHex := hex.EncodeToString(seed) + + // Create encrypted key file + tmpDir := t.TempDir() + keyPath := filepath.Join(tmpDir, "test_encrypted.key.age") + password := "test-password" + + // Encrypt the key using age + recipient, err := age.NewScryptRecipient(password) + require.NoError(t, err) + + tmpFile, err := os.Create(keyPath) + require.NoError(t, err) + defer tmpFile.Close() + + writer, err := age.Encrypt(tmpFile, recipient) + require.NoError(t, err) + + _, err = writer.Write([]byte(privKeyHex)) + require.NoError(t, err) + + err = writer.Close() + require.NoError(t, err) + + // Test creating signer with encrypted key + signer, err := NewLocalSigner(types.EventInitiatorKeyTypeEd25519, LocalSignerOptions{ + KeyPath: keyPath, + Encrypted: true, + Password: password, + }) + require.NoError(t, err) + require.NotNil(t, signer) + + localSigner, ok := signer.(*LocalSigner) + require.True(t, ok) + assert.Equal(t, types.EventInitiatorKeyTypeEd25519, localSigner.keyType) + assert.NotNil(t, localSigner.ed25519Key) +} + +func TestNewLocalSigner_Errors(t *testing.T) { + tmpDir := t.TempDir() + + t.Run("nonexistent key file", func(t *testing.T) { + signer, err := NewLocalSigner(types.EventInitiatorKeyTypeEd25519, LocalSignerOptions{ + KeyPath: filepath.Join(tmpDir, "nonexistent.key"), + }) + assert.Error(t, err) + assert.Nil(t, signer) + assert.Contains(t, err.Error(), "key file not found") + }) + + t.Run("encrypted key without password", func(t *testing.T) { + keyPath := filepath.Join(tmpDir, "test.key.age") + err := os.WriteFile(keyPath, []byte("dummy"), 0600) + require.NoError(t, err) + + signer, err := NewLocalSigner(types.EventInitiatorKeyTypeEd25519, LocalSignerOptions{ + KeyPath: keyPath, + Encrypted: true, + }) + assert.Error(t, err) + assert.Nil(t, signer) + assert.Contains(t, err.Error(), "no password provided") + }) + + t.Run("unsupported key type", func(t *testing.T) { + keyPath := filepath.Join(tmpDir, "test.key") + err := os.WriteFile(keyPath, []byte("dummy"), 0600) + require.NoError(t, err) + + signer, err := NewLocalSigner("unsupported", LocalSignerOptions{ + KeyPath: keyPath, + }) + assert.Error(t, err) + assert.Nil(t, signer) + assert.Contains(t, err.Error(), "unsupported key type") + }) +} + +func TestLocalSigner_Sign_Ed25519(t *testing.T) { + // Generate a test Ed25519 key + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + seed := privateKey.Seed() + privKeyHex := hex.EncodeToString(seed) + + // Create temporary key file + tmpDir := t.TempDir() + keyPath := filepath.Join(tmpDir, "test_ed25519.key") + + err = os.WriteFile(keyPath, []byte(privKeyHex), 0600) + require.NoError(t, err) + + // Create signer + signer, err := NewLocalSigner(types.EventInitiatorKeyTypeEd25519, LocalSignerOptions{ + KeyPath: keyPath, + }) + require.NoError(t, err) + + // Test signing + data := []byte("test message to sign") + signature, err := signer.Sign(data) + require.NoError(t, err) + assert.NotEmpty(t, signature) + assert.Equal(t, ed25519.SignatureSize, len(signature)) + + // Verify signature + publicKey := privateKey.Public().(ed25519.PublicKey) + valid := ed25519.Verify(publicKey, data, signature) + assert.True(t, valid) +} + +func TestLocalSigner_Sign_P256(t *testing.T) { + // Generate a test P256 key + keyData, err := encryption.GenerateP256Keys() + require.NoError(t, err) + + // Create temporary key file + tmpDir := t.TempDir() + keyPath := filepath.Join(tmpDir, "test_p256.key") + + err = os.WriteFile(keyPath, []byte(keyData.PrivateKeyHex), 0600) + require.NoError(t, err) + + // Create signer + signer, err := NewLocalSigner(types.EventInitiatorKeyTypeP256, LocalSignerOptions{ + KeyPath: keyPath, + }) + require.NoError(t, err) + + // Test signing + data := []byte("test message to sign") + signature, err := signer.Sign(data) + require.NoError(t, err) + assert.NotEmpty(t, signature) + + // Verify signature using the encryption package + localSigner := signer.(*LocalSigner) + err = encryption.VerifyP256Signature(&localSigner.p256Key.PublicKey, data, signature) + assert.NoError(t, err) +} + +func TestLocalSigner_Algorithm(t *testing.T) { + tmpDir := t.TempDir() + + // Test Ed25519 + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + seed := privateKey.Seed() + privKeyHex := hex.EncodeToString(seed) + keyPath := filepath.Join(tmpDir, "test_ed25519.key") + err = os.WriteFile(keyPath, []byte(privKeyHex), 0600) + require.NoError(t, err) + + signer, err := NewLocalSigner(types.EventInitiatorKeyTypeEd25519, LocalSignerOptions{ + KeyPath: keyPath, + }) + require.NoError(t, err) + assert.Equal(t, types.EventInitiatorKeyTypeEd25519, signer.Algorithm()) + + // Test P256 + keyData, err := encryption.GenerateP256Keys() + require.NoError(t, err) + keyPathP256 := filepath.Join(tmpDir, "test_p256.key") + err = os.WriteFile(keyPathP256, []byte(keyData.PrivateKeyHex), 0600) + require.NoError(t, err) + + signerP256, err := NewLocalSigner(types.EventInitiatorKeyTypeP256, LocalSignerOptions{ + KeyPath: keyPathP256, + }) + require.NoError(t, err) + assert.Equal(t, types.EventInitiatorKeyTypeP256, signerP256.Algorithm()) +} + +func TestLocalSigner_PublicKey(t *testing.T) { + tmpDir := t.TempDir() + + t.Run("Ed25519", func(t *testing.T) { + // Generate key + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + seed := privateKey.Seed() + privKeyHex := hex.EncodeToString(seed) + keyPath := filepath.Join(tmpDir, "test_ed25519.key") + err = os.WriteFile(keyPath, []byte(privKeyHex), 0600) + require.NoError(t, err) + + // Create signer + signer, err := NewLocalSigner(types.EventInitiatorKeyTypeEd25519, LocalSignerOptions{ + KeyPath: keyPath, + }) + require.NoError(t, err) + + // Get public key + pubKeyHex, err := signer.PublicKey() + require.NoError(t, err) + assert.NotEmpty(t, pubKeyHex) + + // Verify it matches the expected public key + expectedPubKey := privateKey.Public().(ed25519.PublicKey) + expectedHex := hex.EncodeToString(expectedPubKey) + assert.Equal(t, expectedHex, pubKeyHex) + }) + + t.Run("P256", func(t *testing.T) { + // Generate key + keyData, err := encryption.GenerateP256Keys() + require.NoError(t, err) + keyPath := filepath.Join(tmpDir, "test_p256.key") + err = os.WriteFile(keyPath, []byte(keyData.PrivateKeyHex), 0600) + require.NoError(t, err) + + // Create signer + signer, err := NewLocalSigner(types.EventInitiatorKeyTypeP256, LocalSignerOptions{ + KeyPath: keyPath, + }) + require.NoError(t, err) + + // Get public key + pubKeyHex, err := signer.PublicKey() + require.NoError(t, err) + assert.NotEmpty(t, pubKeyHex) + + // Verify it's valid hex + _, err = hex.DecodeString(pubKeyHex) + assert.NoError(t, err) + }) +} diff --git a/pkg/client/signer.go b/pkg/client/signer.go new file mode 100644 index 0000000..3c30994 --- /dev/null +++ b/pkg/client/signer.go @@ -0,0 +1,13 @@ +package client + +import "github.com/fystack/mpcium/pkg/types" + +// Signer defines the interface for signing messages with different key types +type Signer interface { + // Sign signs the given data and returns the signature + Sign(data []byte) ([]byte, error) + // Algorithm returns the key algorithm used by this signer + Algorithm() types.EventInitiatorKeyType + // PublicKey returns the public key in hex format + PublicKey() (string, error) +} diff --git a/pkg/encryption/ed25519.go b/pkg/encryption/ed25519.go new file mode 100644 index 0000000..f0f6e40 --- /dev/null +++ b/pkg/encryption/ed25519.go @@ -0,0 +1,62 @@ +package encryption + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/hex" + "fmt" +) + +// generateEd25519Keys generates Ed25519 keypair +func GenerateEd25519Keys() (KeyData, error) { + pubKey, privKeyFull, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return KeyData{}, err + } + + privKeySeed := privKeyFull.Seed() + return KeyData{ + PublicKeyHex: hex.EncodeToString(pubKey), + PrivateKeyHex: hex.EncodeToString(privKeySeed), + }, nil +} + +// ParseEd25519PublicKeyFromHex parses a hex-encoded Ed25519 public key and validates it. +// Returns the public key as []byte and an error if invalid. +func ParseEd25519PublicKeyFromHex(hexKey string) ([]byte, error) { + if hexKey == "" { + return nil, fmt.Errorf("public key hex string is empty") + } + + // Decode hex string to bytes + keyBytes, err := hex.DecodeString(hexKey) + if err != nil { + return nil, fmt.Errorf("invalid hex format: %w", err) + } + + // Validate the key + if err := ValidateEd25519PublicKey(keyBytes); err != nil { + return nil, err + } + + return keyBytes, nil +} + +// ValidateEd25519PublicKey validates an existing byte slice as a valid Ed25519 public key +func ValidateEd25519PublicKey(keyBytes []byte) error { + if len(keyBytes) != ed25519.PublicKeySize { + return fmt.Errorf("invalid Ed25519 public key length: expected %d bytes, got %d", + ed25519.PublicKeySize, len(keyBytes)) + } + + // Create and validate Ed25519 public key + pubKey := ed25519.PublicKey(keyBytes) + + // Basic validation - attempt to use the key + // Invalid curve points will cause verification to behave predictably + dummyMsg := []byte("validation_test") + dummySig := make([]byte, ed25519.SignatureSize) + ed25519.Verify(pubKey, dummyMsg, dummySig) // This won't panic on invalid keys + + return nil +} diff --git a/pkg/encryption/ed25519_test.go b/pkg/encryption/ed25519_test.go new file mode 100644 index 0000000..ca226a9 --- /dev/null +++ b/pkg/encryption/ed25519_test.go @@ -0,0 +1,159 @@ +package encryption + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/hex" + "strings" + "testing" +) + +var ( + // Test data shared across tests + testValidKey ed25519.PublicKey + testValidHex string + testAllZeros = make([]byte, 32) + testAllMax = func() []byte { + b := make([]byte, 32) + for i := range b { + b[i] = 0xFF + } + return b + }() +) + +func init() { + // Generate a single valid key for all tests + testValidKey, _, _ = ed25519.GenerateKey(rand.Reader) + testValidHex = hex.EncodeToString(testValidKey) +} + +// Helper function to check error expectations +func checkError(t *testing.T, err error, wantError bool, errorMsg string) { + t.Helper() + if wantError { + if err == nil { + t.Errorf("expected error but got none") + } else if errorMsg != "" && !strings.Contains(err.Error(), errorMsg) { + t.Errorf("error = %v, want error containing %v", err, errorMsg) + } + } else if err != nil { + t.Errorf("unexpected error = %v", err) + } +} + +func TestParseEd25519PublicKeyFromHex(t *testing.T) { + tests := []struct { + name string + hexKey string + wantError bool + errorMsg string + }{ + {"valid hex key", testValidHex, false, ""}, + {"empty hex string", "", true, "public key hex string is empty"}, + {"invalid hex characters", strings.Repeat("g", 64), true, "invalid hex format"}, + {"too short hex string", "abcdef1234567890", true, "invalid Ed25519 public key length: expected 32 bytes, got 8"}, + {"too long hex string", strings.Repeat("ab", 40), true, "invalid Ed25519 public key length: expected 32 bytes, got 40"}, + {"odd length hex string", "abc", true, "invalid hex format"}, + {"all zeros", hex.EncodeToString(testAllZeros), false, ""}, + {"all max bytes", hex.EncodeToString(testAllMax), false, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ParseEd25519PublicKeyFromHex(tt.hexKey) + + checkError(t, err, tt.wantError, tt.errorMsg) + + if !tt.wantError { + if result == nil { + t.Errorf("expected non-nil result") + } else if len(result) != ed25519.PublicKeySize { + t.Errorf("result length = %d, want %d", len(result), ed25519.PublicKeySize) + } + } else if result != nil { + t.Errorf("expected nil result on error, got %v", result) + } + }) + } +} + +func TestValidateEd25519PublicKey(t *testing.T) { + tests := []struct { + name string + keyBytes []byte + wantError bool + errorMsg string + }{ + {"valid public key", testValidKey, false, ""}, + {"nil key bytes", nil, true, "invalid Ed25519 public key length: expected 32 bytes, got 0"}, + {"empty key bytes", []byte{}, true, "invalid Ed25519 public key length: expected 32 bytes, got 0"}, + {"too short key", make([]byte, 16), true, "invalid Ed25519 public key length: expected 32 bytes, got 16"}, + {"too long key", make([]byte, 64), true, "invalid Ed25519 public key length: expected 32 bytes, got 64"}, + {"all zeros", testAllZeros, false, ""}, + {"all max bytes", testAllMax, false, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateEd25519PublicKey(tt.keyBytes) + checkError(t, err, tt.wantError, tt.errorMsg) + }) + } +} + +func TestParseAndValidateIntegration(t *testing.T) { + testKeys := []ed25519.PublicKey{testValidKey} + + // Generate a few more keys for testing + for i := 0; i < 3; i++ { + pubKey, _, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("Failed to generate test key %d: %v", i, err) + } + testKeys = append(testKeys, pubKey) + } + + for i, validPubKey := range testKeys { + validHex := hex.EncodeToString(validPubKey) + + parsedKey, err := ParseEd25519PublicKeyFromHex(validHex) + if err != nil { + t.Errorf("ParseEd25519PublicKeyFromHex() failed for key %d: %v", i, err) + continue + } + + if err := ValidateEd25519PublicKey(parsedKey); err != nil { + t.Errorf("ValidateEd25519PublicKey() failed for key %d: %v", i, err) + } + + if !compareBytes(validPubKey, parsedKey) { + t.Errorf("Key %d: parsed key differs from original", i) + } + } +} + +// Helper function to compare byte slices +func compareBytes(a, b []byte) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} + +func BenchmarkParseEd25519PublicKeyFromHex(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = ParseEd25519PublicKeyFromHex(testValidHex) + } +} + +func BenchmarkValidateEd25519PublicKey(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = ValidateEd25519PublicKey(testValidKey) + } +} diff --git a/pkg/encryption/p256.go b/pkg/encryption/p256.go new file mode 100644 index 0000000..169ed32 --- /dev/null +++ b/pkg/encryption/p256.go @@ -0,0 +1,185 @@ +package encryption + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/hex" + "fmt" + "strings" +) + +type KeyData struct { + PublicKeyHex string + PrivateKeyHex string +} + +// ParseP256PrivateKey parses a P256 private key from either DER or hex format +func ParseP256PrivateKey(keyData []byte) (*ecdsa.PrivateKey, error) { + // Try to parse as DER first + if key, err := x509.ParsePKCS8PrivateKey(keyData); err == nil { + if ecdsaKey, ok := key.(*ecdsa.PrivateKey); ok { + return ecdsaKey, nil + } + } + + // Try to parse as EC private key + if key, err := x509.ParseECPrivateKey(keyData); err == nil { + return key, nil + } + + // Try to parse as hex string + keyStr := strings.TrimSpace(string(keyData)) + keyStr = strings.TrimPrefix(keyStr, "0x") + + keyBytes, err := hex.DecodeString(keyStr) + if err != nil { + return nil, fmt.Errorf("failed to decode hex string: %w", err) + } + + // Try to parse as DER again with decoded hex + if key, err := x509.ParsePKCS8PrivateKey(keyBytes); err == nil { + if ecdsaKey, ok := key.(*ecdsa.PrivateKey); ok { + return ecdsaKey, nil + } + } + + if key, err := x509.ParseECPrivateKey(keyBytes); err == nil { + return key, nil + } + + return nil, fmt.Errorf("failed to parse P256 private key from DER or hex format") +} + +// SignWithP256 signs data using a P256 private key +func SignWithP256(privateKey *ecdsa.PrivateKey, data []byte) ([]byte, error) { + if privateKey == nil { + return nil, fmt.Errorf("invalid private key: private key is nil") + } + + if privateKey.Curve == nil { + return nil, fmt.Errorf("invalid private key: curve is nil") + } + + hash := sha256.Sum256(data) + + signature, err := ecdsa.SignASN1(rand.Reader, privateKey, hash[:]) + if err != nil { + return nil, fmt.Errorf("failed to sign data: %w", err) + } + + return signature, nil +} + +func GenerateP256Keys() (KeyData, error) { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return KeyData{}, err + } + + // Convert private key to PEM format + privateKeyBytes, err := x509.MarshalECPrivateKey(privateKey) + if err != nil { + return KeyData{}, err + } + + // Convert public key to DER format + publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) + if err != nil { + return KeyData{}, err + } + + return KeyData{ + PublicKeyHex: hex.EncodeToString(publicKeyBytes), + PrivateKeyHex: hex.EncodeToString(privateKeyBytes), + }, nil +} + +// VerifyP256Signature verifies a P256 signature +func VerifyP256Signature(publicKey *ecdsa.PublicKey, data []byte, signature []byte) error { + if publicKey == nil { + return fmt.Errorf("public key is nil") + } + if publicKey.Curve != elliptic.P256() { + return fmt.Errorf("public key is not on P-256 curve") + } + if len(signature) == 0 { + return fmt.Errorf("signature is empty") + } + + hash := sha256.Sum256(data) + + if !ecdsa.VerifyASN1(publicKey, hash[:], signature) { + return fmt.Errorf("invalid signature") + } + + return nil +} +func ParseP256PublicKeyFromBytes(keyBytes []byte) (*ecdsa.PublicKey, error) { + // Try to parse as DER first + if key, err := x509.ParsePKIXPublicKey(keyBytes); err == nil { + if ecdsaKey, ok := key.(*ecdsa.PublicKey); ok { + if ecdsaKey.Curve == elliptic.P256() { + return ecdsaKey, nil + } + } + } + + // Try to parse as EC public key + if key, err := x509.ParsePKIXPublicKey(keyBytes); err == nil { + if ecdsaKey, ok := key.(*ecdsa.PublicKey); ok { + if ecdsaKey.Curve == elliptic.P256() { + return ecdsaKey, nil + } + } + } + + return nil, fmt.Errorf("failed to parse P-256 public key from bytes") +} + +// ParseP256PublicKeyFromHex parses a P-256 public key from hex string +func ParseP256PublicKeyFromHex(hexString string) (*ecdsa.PublicKey, error) { + hexString = strings.TrimPrefix(hexString, "0x") + keyBytes, err := hex.DecodeString(hexString) + if err != nil { + return nil, fmt.Errorf("failed to decode hex string: %w", err) + } + + return ParseP256PublicKeyFromBytes(keyBytes) +} + +// ParseP256PublicKeyFromBase64 parses a P-256 public key from base64 string +func ParseP256PublicKeyFromBase64(base64String string) (*ecdsa.PublicKey, error) { + keyBytes, err := base64.StdEncoding.DecodeString(base64String) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 string: %w", err) + } + + return ParseP256PublicKeyFromBytes(keyBytes) +} + +// ValidateP256PublicKey validates that a public key is P-256 +func ValidateP256PublicKey(publicKey *ecdsa.PublicKey) error { + if publicKey == nil { + return fmt.Errorf("public key is nil") + } + if publicKey.Curve == nil { + return fmt.Errorf("public key curve is nil") + } + if publicKey.Curve != elliptic.P256() { + return fmt.Errorf("public key is not P-256 curve (got: %s)", publicKey.Curve.Params().Name) + } + return nil +} + +// MarshalP256PublicKey marshals a P256 public key to DER format +func MarshalP256PublicKey(publicKey *ecdsa.PublicKey) ([]byte, error) { + if err := ValidateP256PublicKey(publicKey); err != nil { + return nil, fmt.Errorf("invalid P256 public key: %w", err) + } + + return x509.MarshalPKIXPublicKey(publicKey) +} diff --git a/pkg/encryption/p256_test.go b/pkg/encryption/p256_test.go new file mode 100644 index 0000000..cf99f01 --- /dev/null +++ b/pkg/encryption/p256_test.go @@ -0,0 +1,303 @@ +package encryption + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/hex" + "testing" +) + +// ---------------------- +// Helper functions +// ---------------------- + +func mustGenerateP256Key(t *testing.T) *ecdsa.PrivateKey { + t.Helper() + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("Failed to generate P256 key: %v", err) + } + return key +} + +func mustMarshalToDER(t *testing.T, key *ecdsa.PrivateKey) []byte { + t.Helper() + der, err := x509.MarshalECPrivateKey(key) + if err != nil { + t.Fatalf("Failed to marshal private key: %v", err) + } + return der +} + +func mustParsePrivateKey(t *testing.T, data []byte) *ecdsa.PrivateKey { + t.Helper() + key, err := ParseP256PrivateKey(data) + if err != nil { + t.Fatalf("Failed to parse private key: %v", err) + } + return key +} + +func mustParseHexPrivateKey(t *testing.T, hexKey string) *ecdsa.PrivateKey { + t.Helper() + keyBytes, err := hex.DecodeString(hexKey) + if err != nil { + t.Fatalf("Failed to decode hex: %v", err) + } + return mustParsePrivateKey(t, keyBytes) +} + +func mustSign(t *testing.T, key *ecdsa.PrivateKey, data []byte) []byte { + t.Helper() + sig, err := SignWithP256(key, data) + if err != nil { + t.Fatalf("Failed to sign: %v", err) + } + if len(sig) == 0 { + t.Fatal("Signature is empty") + } + return sig +} + +func mustVerify(t *testing.T, pub *ecdsa.PublicKey, data, sig []byte) { + t.Helper() + if err := VerifyP256Signature(pub, data, sig); err != nil { + t.Fatalf("Failed to verify signature: %v", err) + } +} + +// ---------------------- +// Actual tests +// ---------------------- + +func TestGenerateP256Keys(t *testing.T) { + keyData, err := GenerateP256Keys() + if err != nil { + t.Fatalf("Failed to generate P256 keys: %v", err) + } + + if _, err := hex.DecodeString(keyData.PublicKeyHex); err != nil { + t.Errorf("Public key is not valid hex: %v", err) + } + if _, err := hex.DecodeString(keyData.PrivateKeyHex); err != nil { + t.Errorf("Private key is not valid hex: %v", err) + } + + privateKey := mustParsePrivateKey(t, []byte(keyData.PrivateKeyHex)) + if privateKey.Curve != elliptic.P256() { + t.Error("Generated key is not P256 curve") + } +} + +func TestParseP256PrivateKey_DER(t *testing.T) { + original := mustGenerateP256Key(t) + der := mustMarshalToDER(t, original) + parsed := mustParsePrivateKey(t, der) + + if !original.Equal(parsed) { + t.Error("Parsed key is not equal to original key") + } +} + +func TestParseP256PrivateKey_Hex(t *testing.T) { + original := mustGenerateP256Key(t) + der := mustMarshalToDER(t, original) + hexStr := hex.EncodeToString(der) + + parsed := mustParsePrivateKey(t, []byte(hexStr)) + if !original.Equal(parsed) { + t.Error("Parsed key is not equal to original key") + } +} + +func TestParseP256PrivateKey_InvalidInput(t *testing.T) { + cases := [][]byte{ + []byte("invalid-hex"), + {}, + func() []byte { + b := make([]byte, 32) + _, err := rand.Read(b) + if err != nil { + t.Fatalf("Failed to generate random bytes: %v", err) + } + return b + }(), + } + + for _, in := range cases { + if _, err := ParseP256PrivateKey(in); err == nil { + t.Errorf("Expected error for input %q", in) + } + } +} + +func TestSignWithP256(t *testing.T) { + key := mustGenerateP256Key(t) + data := []byte("Hello, P256 signing!") + sig := mustSign(t, key, data) + + hash := sha256.Sum256(data) + if !ecdsa.VerifyASN1(&key.PublicKey, hash[:], sig) { + t.Error("Generated signature is invalid") + } +} + +func TestSignWithP256_InvalidKey(t *testing.T) { + if _, err := SignWithP256(nil, []byte("test")); err == nil { + t.Error("Expected error for nil private key") + } + + invalidKey := &ecdsa.PrivateKey{} + _, err := SignWithP256(invalidKey, []byte("test")) + if err == nil || err.Error() != "invalid private key: curve is nil" { + t.Errorf("Expected specific error, got: %v", err) + } +} + +func TestSignWithP256_EmptyData(t *testing.T) { + key := mustGenerateP256Key(t) + mustSign(t, key, []byte{}) +} + +func TestParseP256PrivateKey_With0xPrefix(t *testing.T) { + original := mustGenerateP256Key(t) + der := mustMarshalToDER(t, original) + hexStr := "0x" + hex.EncodeToString(der) + + parsed := mustParsePrivateKey(t, []byte(hexStr)) + if !original.Equal(parsed) { + t.Error("Parsed key is not equal to original key") + } +} + +func TestSignAndVerifyWithSpecificKey(t *testing.T) { + const privHex = "307702010104205dbfd209d750b8c501818d0075ce0c23d1c59dabc33f0a8d4d3e52b30cbdbb20a00a06082a8648ce3d030107a14403420004cd9f1b35c241103eb25dbdcf0c93d8cbb444150fde72acecea2eafcee97e3c03aad1c8a8170960dcc2b921822cc6ac1795f4692c22b3ed71dab1deb9aee53018" + key := mustParseHexPrivateKey(t, privHex) + + data := []byte("test-wallet-p256") + sig := mustSign(t, key, data) + mustVerify(t, &key.PublicKey, data, sig) +} + +func TestWalletIDSigningFlow(t *testing.T) { + const privHex = "307702010104205dbfd209d750b8c501818d0075ce0c23d1c59dabc33f0a8d4d3e52b30cbdbb20a00a06082a8648ce3d030107a14403420004cd9f1b35c241103eb25dbdcf0c93d8cbb444150fde72acecea2eafcee97e3c03aad1c8a8170960dcc2b921822cc6ac1795f4692c22b3ed71dab1deb9aee53018" + key := mustParseHexPrivateKey(t, privHex) + + for _, walletID := range []string{"test-wallet-p256", "aa7a8764-0899-45ad-9017-ec5a0ec5bfff", "another-test-wallet", "wallet-123"} { + t.Run(walletID, func(t *testing.T) { + data := []byte(walletID) + sig := mustSign(t, key, data) + mustVerify(t, &key.PublicKey, data, sig) + }) + } +} + +func TestParseP256PublicKeyFromHexAndBase64(t *testing.T) { + keyData, err := GenerateP256Keys() + if err != nil { + t.Fatalf("Failed to generate keys: %v", err) + } + + // Hex case + pubKey, err := ParseP256PublicKeyFromHex(keyData.PublicKeyHex) + if err != nil { + t.Fatalf("Failed to parse public key from hex: %v", err) + } + if pubKey.Curve != elliptic.P256() { + t.Errorf("Expected P-256 curve, got %s", pubKey.Curve.Params().Name) + } + + // With "0x" prefix + pubKey2, err := ParseP256PublicKeyFromHex("0x" + keyData.PublicKeyHex) + if err != nil { + t.Fatalf("Failed to parse public key with 0x prefix: %v", err) + } + if !pubKey.Equal(pubKey2) { + t.Error("Public key mismatch with 0x prefix") + } + + // Base64 case + pubBytes, _ := hex.DecodeString(keyData.PublicKeyHex) + pubB64 := base64.StdEncoding.EncodeToString(pubBytes) + + pubKey3, err := ParseP256PublicKeyFromBase64(pubB64) + if err != nil { + t.Fatalf("Failed to parse public key from base64: %v", err) + } + if !pubKey.Equal(pubKey3) { + t.Error("Public key mismatch with base64 parsing") + } +} + +func TestValidateP256PublicKey(t *testing.T) { + privKey := mustGenerateP256Key(t) + pubKey := &privKey.PublicKey + + // Valid case + if err := ValidateP256PublicKey(pubKey); err != nil { + t.Errorf("Unexpected error for valid key: %v", err) + } + + // Nil key + if err := ValidateP256PublicKey(nil); err == nil { + t.Error("Expected error for nil public key") + } + + // Nil curve + badKey := &ecdsa.PublicKey{} + if err := ValidateP256PublicKey(badKey); err == nil { + t.Error("Expected error for nil curve") + } + + // Wrong curve + otherPriv, _ := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err := ValidateP256PublicKey(&otherPriv.PublicKey); err == nil { + t.Error("Expected error for non-P256 curve") + } +} + +func TestVerifyP256Signature_InvalidCases(t *testing.T) { + privKey := mustGenerateP256Key(t) + data := []byte("verify test") + sig := mustSign(t, privKey, data) + + // Nil public key + if err := VerifyP256Signature(nil, data, sig); err == nil { + t.Error("Expected error for nil public key") + } + + // Wrong curve + otherPriv, _ := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err := VerifyP256Signature(&otherPriv.PublicKey, data, sig); err == nil { + t.Error("Expected error for wrong curve") + } + + // Empty signature + if err := VerifyP256Signature(&privKey.PublicKey, data, []byte{}); err == nil { + t.Error("Expected error for empty signature") + } + + // Tampered signature + tampered := append([]byte{}, sig...) + tampered[len(tampered)-1] ^= 0xFF + if err := VerifyP256Signature(&privKey.PublicKey, data, tampered); err == nil { + t.Error("Expected error for tampered signature") + } +} + +func TestParseP256PublicKeyFromBytes_Invalid(t *testing.T) { + // Random bytes that are not a public key + randomBytes := make([]byte, 64) + _, err := rand.Read(randomBytes) + if err != nil { + t.Fatalf("Failed to generate random bytes: %v", err) + } + + if _, err := ParseP256PublicKeyFromBytes(randomBytes); err == nil { + t.Error("Expected error for invalid public key bytes") + } +} diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index c84e684..691c712 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -401,7 +401,8 @@ func (ec *eventConsumer) handleSigningEvent(natMsg *nats.Msg) { ec.signingResultQueue, idempotentKey, ) - + default: + sessionErr = fmt.Errorf("unsupported key type: %v", msg.KeyType) } if sessionErr != nil { if errors.Is(sessionErr, mpc.ErrNotEnoughParticipants) { diff --git a/pkg/identity/identity.go b/pkg/identity/identity.go index 0d2329a..3dbccf6 100644 --- a/pkg/identity/identity.go +++ b/pkg/identity/identity.go @@ -1,6 +1,7 @@ package identity import ( + "crypto/ecdsa" "crypto/ed25519" "encoding/hex" "encoding/json" @@ -8,6 +9,7 @@ import ( "fmt" "io" "os" + "slices" "strings" "sync" "syscall" @@ -52,6 +54,12 @@ type Store interface { DecryptMessage(cipher []byte, peerID string) ([]byte, error) } +type InitiatorKey struct { + Algorithm types.EventInitiatorKeyType + Ed25519 []byte + P256 *ecdsa.PublicKey +} + // fileStore implements the Store interface using the filesystem type fileStore struct { identityDir string @@ -61,9 +69,9 @@ type fileStore struct { publicKeys map[string][]byte mu sync.RWMutex - privateKey []byte - initiatorPubKey []byte - symmetricKeys map[string][]byte + privateKey []byte + initiatorKey *InitiatorKey + symmetricKeys map[string][]byte } // NewFileStore creates a new identity store @@ -82,17 +90,11 @@ func NewFileStore(identityDir, nodeName string, decrypt bool) (*fileStore, error return nil, fmt.Errorf("invalid private key format: %w", err) } - pubKeyHex := viper.GetString("event_initiator_pubkey") - if pubKeyHex == "" { - return nil, fmt.Errorf("event_initiator_pubkey not found in quax config") - } - initiatorPubKey, err := hex.DecodeString(pubKeyHex) + initiatorKey, err := loadInitiatorKeys() if err != nil { - return nil, fmt.Errorf("invalid initiator public key format: %w", err) + return nil, err } - logger.Infof("Loaded initiator public key for node %s", pubKeyHex) - // Load peers.json to validate all nodes have identity files peersData, err := os.ReadFile("peers.json") if err != nil { @@ -109,7 +111,7 @@ func NewFileStore(identityDir, nodeName string, decrypt bool) (*fileStore, error currentNodeName: nodeName, publicKeys: make(map[string][]byte), privateKey: privateKey, - initiatorPubKey: initiatorPubKey, + initiatorKey: initiatorKey, symmetricKeys: make(map[string][]byte), } @@ -123,7 +125,12 @@ func NewFileStore(identityDir, nodeName string, decrypt bool) (*fileStore, error data, err := os.ReadFile(identityFilePath) if err != nil { - return nil, fmt.Errorf("missing identity file for node %s (%s): %w", nodeName, nodeID, err) + return nil, fmt.Errorf( + "missing identity file for node %s (%s): %w", + nodeName, + nodeID, + err, + ) } var identity NodeIdentity @@ -133,8 +140,12 @@ func NewFileStore(identityDir, nodeName string, decrypt bool) (*fileStore, error // Verify that the nodeID in peers.json matches the one in the identity file if identity.NodeID != nodeID { - return nil, fmt.Errorf("node ID mismatch for %s: %s in peers.json vs %s in identity file", - nodeName, nodeID, identity.NodeID) + return nil, fmt.Errorf( + "node ID mismatch for %s: %s in peers.json vs %s in identity file", + nodeName, + nodeID, + identity.NodeID, + ) } key, err := hex.DecodeString(identity.PublicKey) @@ -148,6 +159,95 @@ func NewFileStore(identityDir, nodeName string, decrypt bool) (*fileStore, error return store, nil } +func loadInitiatorKeys() (*InitiatorKey, error) { + // Get algorithm configuration with default + algorithm := viper.GetString("event_initiator_algorithm") + if algorithm == "" { + algorithm = string(types.KeyTypeEd25519) + } + + // Validate algorithm + if !slices.Contains( + []string{string(types.EventInitiatorKeyTypeEd25519), string(types.EventInitiatorKeyTypeP256)}, + algorithm, + ) { + return nil, fmt.Errorf("invalid algorithm: %s. Must be %s or %s", + algorithm, + types.EventInitiatorKeyTypeEd25519, + types.EventInitiatorKeyTypeP256, + ) + } + + var initiatorKey *InitiatorKey + + switch algorithm { + case string(types.EventInitiatorKeyTypeEd25519): + key, err := loadEd25519InitiatorKey() + if err != nil { + return nil, fmt.Errorf("failed to load Ed25519 initiator key: %w", err) + } + initiatorKey = &InitiatorKey{ + Algorithm: types.EventInitiatorKeyTypeEd25519, + Ed25519: key, + } + logger.Info("Loaded Ed25519 initiator public key") + + case string(types.EventInitiatorKeyTypeP256): + key, err := loadP256InitiatorKey() + if err != nil { + return nil, fmt.Errorf("failed to load P-256 initiator key: %w", err) + } + initiatorKey = &InitiatorKey{ + Algorithm: types.EventInitiatorKeyTypeP256, + P256: key, + } + logger.Info("Loaded P-256 initiator public key") + } + + return initiatorKey, nil +} + +// loadEd25519InitiatorKey loads Ed25519 initiator public key +func loadEd25519InitiatorKey() ([]byte, error) { + pubKeyHex := viper.GetString("event_initiator_pubkey") + if pubKeyHex == "" { + return nil, fmt.Errorf("event_initiator_pubkey not found in config") + } + + key, err := encryption.ParseEd25519PublicKeyFromHex(pubKeyHex) + + if err != nil { + return nil, fmt.Errorf("failed to decode event_initiator_pubkey as hex: %w", err) + } + + return key, nil + +} + +func loadP256InitiatorKey() (*ecdsa.PublicKey, error) { + pubKeyHex := viper.GetString("event_initiator_pubkey") + if pubKeyHex == "" { + return nil, fmt.Errorf("event_initiator_pubkey not found in config") + } + + // Use the new P256 functions from p256.go + publicKey, err := encryption.ParseP256PublicKeyFromHex(pubKeyHex) + if err == nil { + return publicKey, nil + } + + // If hex parsing fails, try base64 + publicKey, err = encryption.ParseP256PublicKeyFromBase64(pubKeyHex) + if err == nil { + return publicKey, nil + } + + return nil, fmt.Errorf( + "failed to decode event_initiator_pubkey as hex or base64: %w", + err, + ) +} + // loadPrivateKey loads the private key from file, decrypting if necessary func loadPrivateKey(identityDir, nodeName string, decrypt bool) (string, error) { // Check for encrypted or unencrypted private key @@ -374,28 +474,48 @@ func (s *fileStore) VerifySignature(msg *types.ECDHMessage) error { return nil } -// VerifyInitiatorMessage verifies that a message was signed by the known initiator func (s *fileStore) VerifyInitiatorMessage(msg types.InitiatorMessage) error { - // Get the raw message that was signed + algo := s.initiatorKey.Algorithm + + switch algo { + case types.EventInitiatorKeyTypeEd25519: + return s.verifyEd25519(msg) + case types.EventInitiatorKeyTypeP256: + return s.verifyP256(msg) + } + return fmt.Errorf("unsupported algorithm: %s", algo) +} + +func (s *fileStore) verifyEd25519(msg types.InitiatorMessage) error { msgBytes, err := msg.Raw() if err != nil { return fmt.Errorf("failed to get raw message data: %w", err) } - - // Get the signature signature := msg.Sig() if len(signature) == 0 { return errors.New("signature is empty") } - // Verify the signature using the initiator's public key - if !ed25519.Verify(s.initiatorPubKey, msgBytes, signature) { + if !ed25519.Verify(s.initiatorKey.Ed25519, msgBytes, signature) { return fmt.Errorf("invalid signature from initiator") } - return nil } +func (s *fileStore) verifyP256(msg types.InitiatorMessage) error { + msgBytes, err := msg.Raw() + if err != nil { + return fmt.Errorf("failed to get raw message data: %w", err) + } + signature := msg.Sig() + + if s.initiatorKey.P256 == nil { + return fmt.Errorf("initiator public key for secp256r1 is not set") + } + + return encryption.VerifyP256Signature(s.initiatorKey.P256, msgBytes, signature) +} + func partyIDToNodeID(partyID *tss.PartyID) string { return strings.Split(string(partyID.KeyInt().Bytes()), ":")[0] } diff --git a/pkg/types/initiator_msg.go b/pkg/types/initiator_msg.go index d56b554..d770e79 100644 --- a/pkg/types/initiator_msg.go +++ b/pkg/types/initiator_msg.go @@ -9,6 +9,13 @@ const ( KeyTypeEd25519 KeyType = "ed25519" ) +type EventInitiatorKeyType string + +const ( + EventInitiatorKeyTypeEd25519 EventInitiatorKeyType = "ed25519" + EventInitiatorKeyTypeP256 EventInitiatorKeyType = "p256" +) + // InitiatorMessage is anything that carries a payload to verify and its signature. type InitiatorMessage interface { // Raw returns the canonical byte‐slice that was signed.