diff --git a/cmd/createtree/main.go b/cmd/createtree/main.go index efc90961ab..3521168af3 100644 --- a/cmd/createtree/main.go +++ b/cmd/createtree/main.go @@ -35,8 +35,6 @@ import ( "flag" "fmt" "io/ioutil" - "os" - "time" "github.com/golang/glog" "github.com/golang/protobuf/ptypes" @@ -57,7 +55,7 @@ var ( treeType = flag.String("tree_type", trillian.TreeType_LOG.String(), "Type of the new tree") hashStrategy = flag.String("hash_strategy", trillian.HashStrategy_RFC6962_SHA256.String(), "Hash strategy (aka preimage protection) of the new tree") hashAlgorithm = flag.String("hash_algorithm", sigpb.DigitallySigned_SHA256.String(), "Hash algorithm of the new tree") - signatureAlgorithm = flag.String("signature_algorithm", sigpb.DigitallySigned_RSA.String(), "Signature algorithm of the new tree") + signatureAlgorithm = flag.String("signature_algorithm", sigpb.DigitallySigned_ECDSA.String(), "Signature algorithm of the new tree") displayName = flag.String("display_name", "", "Display name of the new tree") description = flag.String("description", "", "Description of the new tree") maxRootDuration = flag.Duration("max_root_duration", 0, "Interval after which a new signed root is produced despite no submissions; zero means never") @@ -70,26 +68,17 @@ var ( configFile = flag.String("config", "", "Config file containing flags, file contents can be overridden by command line flags") ) -// createOpts contains all user-supplied options required to run the program. -// It's meant to facilitate tests and focus flag reads to a single point. -type createOpts struct { - addr string - treeState, treeType, hashStrategy, hashAlgorithm, sigAlgorithm, displayName, description string - maxRootDuration time.Duration - privateKeyType, pemKeyPath, pemKeyPass, pkcs11ConfigPath string -} - -func createTree(ctx context.Context, opts *createOpts) (*trillian.Tree, error) { - if opts.addr == "" { +func createTree(ctx context.Context) (*trillian.Tree, error) { + if *adminServerAddr == "" { return nil, errors.New("empty --admin_server, please provide the Admin server host:port") } - req, err := newRequest(opts) + req, err := newRequest() if err != nil { return nil, err } - conn, err := grpc.Dial(opts.addr, grpc.WithInsecure()) + conn, err := grpc.Dial(*adminServerAddr, grpc.WithInsecure()) if err != nil { return nil, err } @@ -102,30 +91,30 @@ func createTree(ctx context.Context, opts *createOpts) (*trillian.Tree, error) { return tree, nil } -func newRequest(opts *createOpts) (*trillian.CreateTreeRequest, error) { - ts, ok := trillian.TreeState_value[opts.treeState] +func newRequest() (*trillian.CreateTreeRequest, error) { + ts, ok := trillian.TreeState_value[*treeState] if !ok { - return nil, fmt.Errorf("unknown TreeState: %v", opts.treeState) + return nil, fmt.Errorf("unknown TreeState: %v", *treeState) } - tt, ok := trillian.TreeType_value[opts.treeType] + tt, ok := trillian.TreeType_value[*treeType] if !ok { - return nil, fmt.Errorf("unknown TreeType: %v", opts.treeType) + return nil, fmt.Errorf("unknown TreeType: %v", *treeType) } - hs, ok := trillian.HashStrategy_value[opts.hashStrategy] + hs, ok := trillian.HashStrategy_value[*hashStrategy] if !ok { - return nil, fmt.Errorf("unknown HashStrategy: %v", opts.hashStrategy) + return nil, fmt.Errorf("unknown HashStrategy: %v", *hashStrategy) } - ha, ok := sigpb.DigitallySigned_HashAlgorithm_value[opts.hashAlgorithm] + ha, ok := sigpb.DigitallySigned_HashAlgorithm_value[*hashAlgorithm] if !ok { - return nil, fmt.Errorf("unknown HashAlgorithm: %v", opts.hashAlgorithm) + return nil, fmt.Errorf("unknown HashAlgorithm: %v", *hashAlgorithm) } - sa, ok := sigpb.DigitallySigned_SignatureAlgorithm_value[opts.sigAlgorithm] + sa, ok := sigpb.DigitallySigned_SignatureAlgorithm_value[*signatureAlgorithm] if !ok { - return nil, fmt.Errorf("unknown SignatureAlgorithm: %v", opts.sigAlgorithm) + return nil, fmt.Errorf("unknown SignatureAlgorithm: %v", *signatureAlgorithm) } ctr := &trillian.CreateTreeRequest{Tree: &trillian.Tree{ @@ -134,13 +123,13 @@ func newRequest(opts *createOpts) (*trillian.CreateTreeRequest, error) { HashStrategy: trillian.HashStrategy(hs), HashAlgorithm: sigpb.DigitallySigned_HashAlgorithm(ha), SignatureAlgorithm: sigpb.DigitallySigned_SignatureAlgorithm(sa), - DisplayName: opts.displayName, - Description: opts.description, - MaxRootDuration: ptypes.DurationProto(opts.maxRootDuration), + DisplayName: *displayName, + Description: *description, + MaxRootDuration: ptypes.DurationProto(*maxRootDuration), }} - if opts.privateKeyType != "" { - pk, err := newPK(opts) + if *privateKeyFormat != "" { + pk, err := newPK(*privateKeyFormat) if err != nil { return nil, err } @@ -149,7 +138,7 @@ func newRequest(opts *createOpts) (*trillian.CreateTreeRequest, error) { // Cannot continue if options specifying a key were provided but // privateKeyType is not set, as there's no way to know what protobuf // message type was intended. - if opts.pemKeyPath != "" || opts.pemKeyPass != "" || opts.pkcs11ConfigPath != "" { + if *pemKeyPath != "" || *pemKeyPassword != "" || *pkcs11ConfigPath != "" { return nil, errors.New("must specify private key format") } @@ -173,26 +162,26 @@ func newRequest(opts *createOpts) (*trillian.CreateTreeRequest, error) { return ctr, nil } -func newPK(opts *createOpts) (*any.Any, error) { - switch opts.privateKeyType { +func newPK(keyFormat string) (*any.Any, error) { + switch keyFormat { case "PEMKeyFile": - if opts.pemKeyPath == "" { + if *pemKeyPath == "" { return nil, errors.New("empty pem_key_path") } - if opts.pemKeyPass == "" { - return nil, fmt.Errorf("empty password for PEM key file %q", opts.pemKeyPath) + if *pemKeyPassword == "" { + return nil, fmt.Errorf("empty password for PEM key file %q", *pemKeyPath) } pemKey := &keyspb.PEMKeyFile{ - Path: opts.pemKeyPath, - Password: opts.pemKeyPass, + Path: *pemKeyPath, + Password: *pemKeyPassword, } return ptypes.MarshalAny(pemKey) case "PrivateKey": - if opts.pemKeyPath == "" { + if *pemKeyPath == "" { return nil, errors.New("empty pem_key_path") } pemSigner, err := keys.NewFromPrivatePEMFile( - opts.pemKeyPath, opts.pemKeyPass) + *pemKeyPath, *pemKeyPassword) if err != nil { return nil, err } @@ -202,10 +191,10 @@ func newPK(opts *createOpts) (*any.Any, error) { } return ptypes.MarshalAny(&keyspb.PrivateKey{Der: der}) case "PKCS11ConfigFile": - if opts.pkcs11ConfigPath == "" { + if *pkcs11ConfigPath == "" { return nil, errors.New("empty PKCS11 config file path") } - configBytes, err := ioutil.ReadFile(opts.pkcs11ConfigPath) + configBytes, err := ioutil.ReadFile(*pkcs11ConfigPath) if err != nil { return nil, err } @@ -223,25 +212,7 @@ func newPK(opts *createOpts) (*any.Any, error) { PublicKey: string(pubKeyBytes), }) default: - return nil, fmt.Errorf("unknown private key type: %v", opts.privateKeyType) - } -} - -func newOptsFromFlags() *createOpts { - return &createOpts{ - addr: *adminServerAddr, - treeState: *treeState, - treeType: *treeType, - hashStrategy: *hashStrategy, - hashAlgorithm: *hashAlgorithm, - sigAlgorithm: *signatureAlgorithm, - displayName: *displayName, - description: *description, - maxRootDuration: *maxRootDuration, - privateKeyType: *privateKeyFormat, - pemKeyPath: *pemKeyPath, - pemKeyPass: *pemKeyPassword, - pkcs11ConfigPath: *pkcs11ConfigPath, + return nil, fmt.Errorf("unknown private key type: %v", keyFormat) } } @@ -255,10 +226,9 @@ func main() { } ctx := context.Background() - tree, err := createTree(ctx, newOptsFromFlags()) + tree, err := createTree(ctx) if err != nil { - fmt.Fprintf(os.Stderr, "Failed to create tree: %v\n", err) - os.Exit(1) + glog.Exitf("Failed to create tree: %v", err) } // DO NOT change the output format, scripts are meant to depend on it. diff --git a/cmd/createtree/main_test.go b/cmd/createtree/main_test.go index df0cedfc08..4c895e62b6 100644 --- a/cmd/createtree/main_test.go +++ b/cmd/createtree/main_test.go @@ -15,10 +15,9 @@ package main import ( + "context" "errors" - "fmt" - "net" - "os" + "flag" "testing" "time" @@ -27,15 +26,32 @@ import ( "github.com/golang/protobuf/ptypes/any" "github.com/golang/protobuf/ptypes/empty" "github.com/google/trillian" - "github.com/google/trillian/crypto/keys" - "github.com/google/trillian/crypto/keyspb" + "github.com/google/trillian/cmd/createtree/testonly" "github.com/google/trillian/crypto/sigpb" + "github.com/google/trillian/util/flagsaver" "github.com/kylelemons/godebug/pretty" - "golang.org/x/net/context" - "google.golang.org/grpc" ) -func marshalAny(p proto.Message) *any.Any { +// defaultTree reflects all flag defaults with the addition of a valid private key. +var defaultTree = &trillian.Tree{ + TreeState: trillian.TreeState_ACTIVE, + TreeType: trillian.TreeType_LOG, + HashStrategy: trillian.HashStrategy_RFC6962_SHA256, + HashAlgorithm: sigpb.DigitallySigned_SHA256, + SignatureAlgorithm: sigpb.DigitallySigned_ECDSA, + PrivateKey: mustMarshalAny(&empty.Empty{}), + MaxRootDuration: ptypes.DurationProto(0 * time.Millisecond), +} + +type testCase struct { + desc string + setFlags func() + createErr error + wantErr bool + wantTree *trillian.Tree +} + +func mustMarshalAny(p proto.Message) *any.Any { anyKey, err := ptypes.MarshalAny(p) if err != nil { panic(err) @@ -43,233 +59,103 @@ func marshalAny(p proto.Message) *any.Any { return anyKey } -func TestRun(t *testing.T) { - err := os.Chdir("../..") - if err != nil { - t.Fatalf("Unable to change working directory to ../..: %s", err) - } - - pemPath, pemPassword := "testdata/log-rpc-server.privkey.pem", "towel" - pemSigner, err := keys.NewFromPrivatePEMFile(pemPath, pemPassword) - if err != nil { - t.Fatalf("NewFromPrivatePEM(): %v", err) - } - pemDer, err := keys.MarshalPrivateKey(pemSigner) - if err != nil { - t.Fatalf("MarshalPrivateKey(): %v", err) - } - anyPrivKey, err := ptypes.MarshalAny(&keyspb.PrivateKey{Der: pemDer}) - if err != nil { - t.Fatalf("MarshalAny(%v): %v", pemDer, err) - } - - // defaultTree reflects all flag defaults with the addition of a valid pk - defaultTree := &trillian.Tree{ - TreeState: trillian.TreeState_ACTIVE, - TreeType: trillian.TreeType_LOG, - HashStrategy: trillian.HashStrategy_RFC6962_SHA256, - HashAlgorithm: sigpb.DigitallySigned_SHA256, - SignatureAlgorithm: sigpb.DigitallySigned_RSA, - PrivateKey: anyPrivKey, - MaxRootDuration: ptypes.DurationProto(0 * time.Millisecond), - } - - server, lis, stopFn, err := startFakeServer() - if err != nil { - t.Fatalf("Error starting fake server: %v", err) - } - defer stopFn() - server.generatedKey = anyPrivKey - - validOpts := newOptsFromFlags() - validOpts.addr = lis.Addr().String() - +func TestCreateTree(t *testing.T) { nonDefaultTree := *defaultTree nonDefaultTree.TreeType = trillian.TreeType_MAP - nonDefaultTree.SignatureAlgorithm = sigpb.DigitallySigned_ECDSA + nonDefaultTree.SignatureAlgorithm = sigpb.DigitallySigned_RSA nonDefaultTree.DisplayName = "Llamas Map" nonDefaultTree.Description = "For all your digital llama needs!" - nonDefaultOpts := *validOpts - nonDefaultOpts.treeType = nonDefaultTree.TreeType.String() - nonDefaultOpts.sigAlgorithm = nonDefaultTree.SignatureAlgorithm.String() - nonDefaultOpts.displayName = nonDefaultTree.DisplayName - nonDefaultOpts.description = nonDefaultTree.Description - - emptyAddr := *validOpts - emptyAddr.addr = "" - - invalidEnumOpts := *validOpts - invalidEnumOpts.treeType = "LLAMA!" - - privateKeyOpts := *validOpts - privateKeyOpts.privateKeyType = "PrivateKey" - privateKeyOpts.pemKeyPath = pemPath - privateKeyOpts.pemKeyPass = pemPassword - - emptyKeyTypeOpts := privateKeyOpts - emptyKeyTypeOpts.privateKeyType = "" - - invalidKeyTypeOpts := privateKeyOpts - invalidKeyTypeOpts.privateKeyType = "LLAMA!!" - - emptyPEMPath := privateKeyOpts - emptyPEMPath.pemKeyPath = "" - - emptyPEMPass := privateKeyOpts - emptyPEMPass.pemKeyPass = "" - - pemKeyOpts := privateKeyOpts - pemKeyOpts.privateKeyType = "PEMKeyFile" - pemKeyTree := *defaultTree - pemKeyTree.PrivateKey, err = ptypes.MarshalAny(&keyspb.PEMKeyFile{ - Path: pemPath, - Password: pemPassword, + runTest(t, []*testCase{ + { + desc: "validOpts", + wantTree: defaultTree, + }, + { + desc: "nonDefaultOpts", + setFlags: func() { + *treeType = nonDefaultTree.TreeType.String() + *signatureAlgorithm = nonDefaultTree.SignatureAlgorithm.String() + *displayName = nonDefaultTree.DisplayName + *description = nonDefaultTree.Description + }, + wantTree: &nonDefaultTree, + }, + { + desc: "defaultOptsOnly", + setFlags: resetFlags, + wantErr: true, + }, + { + desc: "emptyAddr", + setFlags: func() { *adminServerAddr = "" }, + wantErr: true, + }, + { + desc: "invalidEnumOpts", + setFlags: func() { *treeType = "LLAMA!" }, + wantErr: true, + }, + { + desc: "invalidKeyTypeOpts", + setFlags: func() { *privateKeyFormat = "LLAMA!!" }, + wantErr: true, + }, + { + desc: "createErr", + createErr: errors.New("create tree failed"), + wantErr: true, + }, }) - if err != nil { - t.Fatalf("MarshalAny(PEMKeyFile): %v", err) - } +} - pkcs11Opts := *validOpts - pkcs11Opts.privateKeyType = "PKCS11ConfigFile" - pkcs11Opts.pkcs11ConfigPath = "testdata/pkcs11-conf.json" - pkcs11Tree := *defaultTree - pkcs11Tree.PrivateKey, err = ptypes.MarshalAny(&keyspb.PKCS11Config{ - TokenLabel: "log", - Pin: "1234", - PublicKey: `-----BEGIN PUBLIC KEY----- -MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC7/tWwqUXZJaNfnpvnqiaeNMkn -hKusCsyAidrHxvuL+t54XFCHJwsB3wIlQZ4mMwb8mC/KRYhCqECBEoCAf/b0m3j/ -ASuEPLyYOrz/aEs3wP02IZQLGmihmjMk7T/ouNCuX7y1fTjX3GeVQ06U/EePwZFC -xToc6NWBri0N3VVsswIDAQAB ------END PUBLIC KEY----- -`, - }) - if err != nil { - t.Fatalf("MarshalAny(PKCS11Config): %v", err) +// runTest executes the createtree command against a fake TrillianAdminServer +// for each of the provided tests, and checks that the tree in the request is +// as expected, or an expected error occurs. +// Prior to each test case, it: +// 1. Resets all flags to their original values. +// 2. Sets the adminServerAddr flag to point to the fake server. +// 3. Calls the test's setFlags func (if provided) to allow it to change flags specific to the test. +func runTest(t *testing.T, tests []*testCase) { + server := &testonly.FakeAdminServer{ + GeneratedKey: defaultTree.PrivateKey, } - emptyPKCS11Path := pkcs11Opts - emptyPKCS11Path.pkcs11ConfigPath = "" - - tests := []struct { - desc string - opts *createOpts - createErr error - wantErr bool - wantTree *trillian.Tree - }{ - {desc: "validOpts", opts: validOpts, wantTree: defaultTree}, - {desc: "nonDefaultOpts", opts: &nonDefaultOpts, wantTree: &nonDefaultTree}, - {desc: "defaultOptsOnly", opts: newOptsFromFlags(), wantErr: true}, // No mandatory opts provided - {desc: "emptyAddr", opts: &emptyAddr, wantErr: true}, - {desc: "invalidEnumOpts", opts: &invalidEnumOpts, wantErr: true}, - {desc: "emptyKeyTypeOpts", opts: &emptyKeyTypeOpts, wantErr: true}, - {desc: "invalidKeyTypeOpts", opts: &invalidKeyTypeOpts, wantErr: true}, - {desc: "emptyPEMPath", opts: &emptyPEMPath, wantErr: true}, - {desc: "emptyPEMPass", opts: &emptyPEMPass, wantErr: true}, - {desc: "PrivateKey", opts: &privateKeyOpts, wantTree: defaultTree}, - {desc: "PEMKeyFile", opts: &pemKeyOpts, wantTree: &pemKeyTree}, - {desc: "createErr", opts: validOpts, createErr: errors.New("create tree failed"), wantErr: true}, - {desc: "PKCS11Config", opts: &pkcs11Opts, wantTree: &pkcs11Tree}, - {desc: "emptyPKCS11Path", opts: &emptyPKCS11Path, wantErr: true}, + lis, stopFakeServer, err := testonly.StartFakeAdminServer(server) + if err != nil { + t.Fatalf("Error starting fake server: %v", err) } + defer stopFakeServer() ctx := context.Background() for _, test := range tests { - server.err = test.createErr - - tree, err := createTree(ctx, test.opts) - switch hasErr := err != nil; { - case hasErr != test.wantErr: - t.Errorf("%v: createTree() returned err = '%v', wantErr = %v", test.desc, err, test.wantErr) - continue - case hasErr: - continue - } - - if !proto.Equal(tree, test.wantTree) { - t.Errorf("%v: post-createTree diff -got +want:\n%v", test.desc, pretty.Compare(tree, test.wantTree)) - } + t.Run(test.desc, func(t *testing.T) { + defer flagsaver.Save().Restore() + *adminServerAddr = lis.Addr().String() + if test.setFlags != nil { + test.setFlags() + } + + server.Err = test.createErr + + tree, err := createTree(ctx) + switch hasErr := err != nil; { + case hasErr != test.wantErr: + t.Errorf("createTree() returned err = '%v', wantErr = %v", err, test.wantErr) + return + case hasErr: + return + } + + if !proto.Equal(tree, test.wantTree) { + t.Errorf("post-createTree diff -got +want:\n%v", pretty.Compare(tree, test.wantTree)) + } + }) } } -// fakeAdminServer that implements CreateTree. -// If err is not nil, it will be returned in response to CreateTree requests. -// If generatedKey is not nil, and a request has a KeySpec set, the response -// will contain generatedKey. -// The response to a CreateTree request will otherwise contain an identical copy -// of the tree sent in the request. -// The remaining methods are not implemented. -type fakeAdminServer struct { - err error - generatedKey *any.Any -} - -// startFakeServer starts a fakeAdminServer on a random port. -// Returns the started server, the listener it's using for connection and a -// close function that must be defer-called on the scope the server is meant to -// stop. -func startFakeServer() (*fakeAdminServer, net.Listener, func(), error) { - grpcServer := grpc.NewServer() - fakeServer := &fakeAdminServer{} - trillian.RegisterTrillianAdminServer(grpcServer, fakeServer) - - lis, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return nil, nil, nil, err - } - go grpcServer.Serve(lis) - - stopFn := func() { - grpcServer.Stop() - lis.Close() - } - return fakeServer, lis, stopFn, nil -} - -func (s *fakeAdminServer) CreateTree(ctx context.Context, req *trillian.CreateTreeRequest) (*trillian.Tree, error) { - if s.err != nil { - return nil, s.err - } - resp := *req.Tree - if req.KeySpec != nil { - if s.generatedKey == nil { - panic("fakeAdminServer.generatedKey == nil but CreateTreeRequest requests generated key") - } - - var keySigAlgo sigpb.DigitallySigned_SignatureAlgorithm - switch req.KeySpec.Params.(type) { - case *keyspb.Specification_EcdsaParams: - keySigAlgo = sigpb.DigitallySigned_ECDSA - case *keyspb.Specification_RsaParams: - keySigAlgo = sigpb.DigitallySigned_RSA - default: - return nil, fmt.Errorf("got unsupported type of key_spec.params: %T", req.KeySpec.Params) - } - if treeSigAlgo := req.Tree.GetSignatureAlgorithm(); treeSigAlgo != keySigAlgo { - return nil, fmt.Errorf("got tree.SignatureAlgorithm = %v but key_spec.Params of type %T", treeSigAlgo, req.KeySpec.Params) - } - - resp.PrivateKey = s.generatedKey - } - return &resp, nil -} - -var errUnimplemented = errors.New("unimplemented") - -func (s *fakeAdminServer) ListTrees(context.Context, *trillian.ListTreesRequest) (*trillian.ListTreesResponse, error) { - return nil, errUnimplemented -} - -func (s *fakeAdminServer) GetTree(context.Context, *trillian.GetTreeRequest) (*trillian.Tree, error) { - return nil, errUnimplemented -} - -func (s *fakeAdminServer) UpdateTree(context.Context, *trillian.UpdateTreeRequest) (*trillian.Tree, error) { - return nil, errUnimplemented -} - -func (s *fakeAdminServer) DeleteTree(context.Context, *trillian.DeleteTreeRequest) (*empty.Empty, error) { - return nil, errUnimplemented +func resetFlags() { + flag.Visit(func(f *flag.Flag) { + f.Value.Set(f.DefValue) + }) } diff --git a/cmd/createtree/pem_test.go b/cmd/createtree/pem_test.go new file mode 100644 index 0000000000..606b0ce2b5 --- /dev/null +++ b/cmd/createtree/pem_test.go @@ -0,0 +1,111 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "testing" + + "github.com/google/trillian/crypto/keys" + "github.com/google/trillian/crypto/keyspb" +) + +func TestWithPEMKeyFile(t *testing.T) { + pemPath, pemPassword := "../../testdata/log-rpc-server.privkey.pem", "towel" + + wantTree := *defaultTree + wantTree.PrivateKey = mustMarshalAny(&keyspb.PEMKeyFile{ + Path: pemPath, + Password: pemPassword, + }) + + runTest(t, []*testCase{ + { + desc: "empty pemKeyPath", + setFlags: func() { + *privateKeyFormat = "PEMKeyFile" + *pemKeyPath = "" + *pemKeyPassword = pemPassword + }, + wantErr: true, + }, + { + desc: "empty pemKeyPass", + setFlags: func() { + *privateKeyFormat = "PEMKeyFile" + *pemKeyPath = pemPath + *pemKeyPassword = "" + }, + wantErr: true, + }, + { + desc: "valid pemKeyPath and pemKeyPass", + setFlags: func() { + *privateKeyFormat = "PEMKeyFile" + *pemKeyPath = pemPath + *pemKeyPassword = pemPassword + }, + wantTree: &wantTree, + }, + }) +} + +func TestWithPrivateKey(t *testing.T) { + pemPath, pemPassword := "../../testdata/log-rpc-server.privkey.pem", "towel" + + key, err := keys.NewFromPrivatePEMFile(pemPath, pemPassword) + if err != nil { + t.Fatalf("Error reading test private key file: %v", err) + } + + keyDER, err := keys.MarshalPrivateKey(key) + if err != nil { + t.Fatalf("Error marshaling test private key to DER: %v", err) + } + + wantTree := *defaultTree + wantTree.PrivateKey = mustMarshalAny(&keyspb.PrivateKey{ + Der: keyDER, + }) + + runTest(t, []*testCase{ + { + desc: "empty pemKeyPath", + setFlags: func() { + *privateKeyFormat = "PrivateKey" + *pemKeyPath = "" + *pemKeyPassword = pemPassword + }, + wantErr: true, + }, + { + desc: "empty pemKeyPass", + setFlags: func() { + *privateKeyFormat = "PrivateKey" + *pemKeyPath = pemPath + *pemKeyPassword = "" + }, + wantErr: true, + }, + { + desc: "valid pemKeyPath and pemKeyPass", + setFlags: func() { + *privateKeyFormat = "PrivateKey" + *pemKeyPath = pemPath + *pemKeyPassword = pemPassword + }, + wantTree: &wantTree, + }, + }) +} diff --git a/cmd/createtree/pkcs11_test.go b/cmd/createtree/pkcs11_test.go new file mode 100644 index 0000000000..081b09d887 --- /dev/null +++ b/cmd/createtree/pkcs11_test.go @@ -0,0 +1,63 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "os" + "testing" + + "github.com/google/trillian/crypto/keyspb" +) + +func TestRunPkcs11(t *testing.T) { + err := os.Chdir("../..") + if err != nil { + t.Fatalf("Unable to change working directory to ../..: %s", err) + } + defer os.Chdir("cmd/createtree") + + pkcs11Tree := *defaultTree + pkcs11Tree.PrivateKey = mustMarshalAny(&keyspb.PKCS11Config{ + TokenLabel: "log", + Pin: "1234", + PublicKey: `-----BEGIN PUBLIC KEY----- +MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC7/tWwqUXZJaNfnpvnqiaeNMkn +hKusCsyAidrHxvuL+t54XFCHJwsB3wIlQZ4mMwb8mC/KRYhCqECBEoCAf/b0m3j/ +ASuEPLyYOrz/aEs3wP02IZQLGmihmjMk7T/ouNCuX7y1fTjX3GeVQ06U/EePwZFC +xToc6NWBri0N3VVsswIDAQAB +-----END PUBLIC KEY----- +`, + }) + + runTest(t, []*testCase{ + { + desc: "PKCS11Config", + setFlags: func() { + *privateKeyFormat = "PKCS11ConfigFile" + *pkcs11ConfigPath = "testdata/pkcs11-conf.json" + }, + wantErr: false, + wantTree: &pkcs11Tree, + }, + { + desc: "emptyPKCS11Path", + setFlags: func() { + *privateKeyFormat = "PKCS11ConfigFile" + *pkcs11ConfigPath = "" + }, + wantErr: true, + }, + }) +} diff --git a/cmd/createtree/testonly/fake_server.go b/cmd/createtree/testonly/fake_server.go new file mode 100644 index 0000000000..15634e11c6 --- /dev/null +++ b/cmd/createtree/testonly/fake_server.go @@ -0,0 +1,94 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testonly + +import ( + "fmt" + "net" + + "github.com/golang/protobuf/ptypes/any" + "github.com/google/trillian" + "github.com/google/trillian/crypto/keyspb" + "github.com/google/trillian/crypto/sigpb" + "golang.org/x/net/context" + "google.golang.org/grpc" +) + +// FakeAdminServer implements the TrillianAdminServer CreateTree RPC. +// The remaining RPCs are not implemented. +type FakeAdminServer struct { + trillian.TrillianAdminServer + + // Err will be returned by CreateTree if not nil. + Err error + // GeneratedKey will be used to set a tree's PrivateKey if a CreateTree request has a KeySpec. + // This is for simulating key generation. + GeneratedKey *any.Any +} + +// StartFakeAdminServer starts a FakeAdminServer on a random port. +// Returns the started server, the listener it's using for connection and a +// close function that must be defer-called on the scope the server is meant to +// stop. +func StartFakeAdminServer(server *FakeAdminServer) (net.Listener, func(), error) { + grpcServer := grpc.NewServer() + trillian.RegisterTrillianAdminServer(grpcServer, server) + + lis, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, nil, err + } + go grpcServer.Serve(lis) + + stopFn := func() { + grpcServer.Stop() + lis.Close() + } + + return lis, stopFn, nil +} + +// CreateTree returns req.Tree, unless s.Err is not nil, in which case it +// returns s.Err. This allows tests to examine the requested tree and check +// behavior under error conditions. +// If s.GeneratedKey and req.KeySpec are not nil, the returned tree will have +// its PrivateKey field set to s.GeneratedKey. +func (s *FakeAdminServer) CreateTree(ctx context.Context, req *trillian.CreateTreeRequest) (*trillian.Tree, error) { + if s.Err != nil { + return nil, s.Err + } + resp := *req.Tree + if req.KeySpec != nil { + if s.GeneratedKey == nil { + panic("fakeAdminServer.GeneratedKey == nil but CreateTreeRequest requests generated key") + } + + var keySigAlgo sigpb.DigitallySigned_SignatureAlgorithm + switch req.KeySpec.Params.(type) { + case *keyspb.Specification_EcdsaParams: + keySigAlgo = sigpb.DigitallySigned_ECDSA + case *keyspb.Specification_RsaParams: + keySigAlgo = sigpb.DigitallySigned_RSA + default: + return nil, fmt.Errorf("got unsupported type of key_spec.params: %T", req.KeySpec.Params) + } + if treeSigAlgo := req.Tree.GetSignatureAlgorithm(); treeSigAlgo != keySigAlgo { + return nil, fmt.Errorf("got tree.SignatureAlgorithm = %v but key_spec.Params of type %T", treeSigAlgo, req.KeySpec.Params) + } + + resp.PrivateKey = s.GeneratedKey + } + return &resp, nil +} diff --git a/util/flagsaver/flagsaver.go b/util/flagsaver/flagsaver.go new file mode 100644 index 0000000000..6cb869cbe1 --- /dev/null +++ b/util/flagsaver/flagsaver.go @@ -0,0 +1,50 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package flagsaver provides a simple way to save and restore flag values. +// TODO(RJPercival): Move this to its own GitHub project. +// +// Example: +// func TestFoo(t *testing.T) { +// defer flagsaver.Save().Restore() +// // Test code that changes flags +// } // flags are reset to their original values here. +package flagsaver + +import "flag" + +// Stash holds flag values so that they can be restored at the end of a test. +type Stash struct { + flags map[string]string +} + +// Restore sets all non-hidden flags to the values they had when the Stash was created. +func (s *Stash) Restore() { + for name, value := range s.flags { + flag.Set(name, value) + } +} + +// Save returns a Stash that captures the current value of all non-hidden flags. +func Save() *Stash { + s := Stash{ + flags: make(map[string]string, flag.NFlag()), + } + + flag.VisitAll(func(f *flag.Flag) { + s.flags[f.Name] = f.Value.String() + }) + + return &s +} diff --git a/util/flagsaver/flagsaver_test.go b/util/flagsaver/flagsaver_test.go new file mode 100644 index 0000000000..392f5eb62f --- /dev/null +++ b/util/flagsaver/flagsaver_test.go @@ -0,0 +1,106 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package flagsaver + +import ( + "flag" + "testing" + "time" +) + +var ( + intFlag = flag.Int("int_flag", 123, "test integer flag") + strFlag = flag.String("str_flag", "foo", "test string flag") + durationFlag = flag.Duration("duration_flag", 5*time.Second, "test duration flag") +) + +// TestRestore checks that flags are saved and restore correctly. +// Checks are performed on flags with both their default values and with explicit values set. +// Only a subset of the possible flag types are currently tested. +func TestRestore(t *testing.T) { + tests := []struct { + desc string + // flag is the name of the flag to save and restore. + flag string + // oldValue is the value the flag should have when saved. If empty, this indicates the flag should have its default value. + oldValue string + // newValue ist he value the flag should have just before being restored to oldValue. + newValue string + }{ + { + desc: "RestoreDefaultIntValue", + flag: "int_flag", + newValue: "666", + }, + { + desc: "RestoreDefaultStrValue", + flag: "str_flag", + newValue: "baz", + }, + { + desc: "RestoreDefaultDurationValue", + flag: "duration_flag", + newValue: "1m0s", + }, + { + desc: "RestoreSetIntValue", + flag: "int_flag", + oldValue: "555", + newValue: "666", + }, + { + desc: "RestoreSetStrValue", + flag: "str_flag", + oldValue: "bar", + newValue: "baz", + }, + { + desc: "RestoreSetDurationValue", + flag: "duration_flag", + oldValue: "10s", + newValue: "1m0s", + }, + } + + for _, test := range tests { + f := flag.Lookup(test.flag) + if f == nil { + t.Errorf("%v: flag.Lookup(%q) = nil, want not nil", test.desc, test.flag) + continue + } + + if test.oldValue != "" { + if err := flag.Set(test.flag, test.oldValue); err != nil { + t.Errorf("%v: flag.Set(%q, %q) = %q, want nil", test.desc, test.flag, test.oldValue, err) + continue + } + } else { + // Use the default value of the flag as the oldValue if none was set. + test.oldValue = f.DefValue + } + + func() { + defer Save().Restore() + flag.Set(test.flag, test.newValue) + if gotValue := f.Value.String(); gotValue != test.newValue { + t.Errorf("%v: flag.Lookup(%q).Value.String() = %q, want %q", test.desc, test.flag, gotValue, test.newValue) + } + }() + + if gotValue := f.Value.String(); gotValue != test.oldValue { + t.Errorf("%v: flag.Lookup(%q).Value.String() = %q, want %q", test.desc, test.flag, gotValue, test.oldValue) + } + } +}