diff --git a/cmd/start.go b/cmd/start.go index d614a70e8..ba815c62d 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -16,6 +16,7 @@ const ( bearerTokenFlagName = "bearer-token" corsFlagName = "cors-origin" evaluatorFlagName = "evaluator" + grpcCertPathFlagName = "grpc-sync-cert-path" logFormatFlagName = "log-format" metricsPortFlagName = "metrics-port" portFlagName = "port" @@ -57,10 +58,13 @@ func init() { syncProviderFlagName, "y", "", "DEPRECATED: Set a sync provider e.g. filepath or remote", ) flags.StringP(logFormatFlagName, "z", "console", "Set the logging format, e.g. console or json ") + flags.StringP(grpcCertPathFlagName, "g", "", "Path to root certificate to be used by TLS enabled grpc"+ + " sync (grpcs://). If TLS is used and this configuration is ignored, TLS uses the host's root CA set.") _ = viper.BindPFlag(bearerTokenFlagName, flags.Lookup(bearerTokenFlagName)) _ = viper.BindPFlag(corsFlagName, flags.Lookup(corsFlagName)) _ = viper.BindPFlag(evaluatorFlagName, flags.Lookup(evaluatorFlagName)) + _ = viper.BindPFlag(grpcCertPathFlagName, flags.Lookup(grpcCertPathFlagName)) _ = viper.BindPFlag(logFormatFlagName, flags.Lookup(logFormatFlagName)) _ = viper.BindPFlag(metricsPortFlagName, flags.Lookup(metricsPortFlagName)) _ = viper.BindPFlag(portFlagName, flags.Lookup(portFlagName)) @@ -105,6 +109,7 @@ var startCmd = &cobra.Command{ // Build Runtime ----------------------------------------------------------- rt, err := runtime.FromConfig(logger, runtime.Config{ CORS: viper.GetStringSlice(corsFlagName), + GrpcCertPath: viper.GetString(grpcCertPathFlagName), MetricsPort: viper.GetInt32(metricsPortFlagName), ProviderArgs: viper.GetStringMapString(providerArgsFlagName), ServiceCertPath: viper.GetString(serverCertPathFlagName), diff --git a/docs/configuration/configuration.md b/docs/configuration/configuration.md index 2a209bdef..a580d11aa 100644 --- a/docs/configuration/configuration.md +++ b/docs/configuration/configuration.md @@ -14,12 +14,12 @@ Config file expects the keys to have the exact naming as the flags. Any URI passed to flagd via the `--uri` flag must follow one of the 4 following patterns to ensure that it is passed to the correct implementation: -| Sync | Pattern | Example | -|------------|------------------------------------|---------------------------------------| +| Sync | Pattern | Example | +|------------|---------------------------------------|---------------------------------------| | Kubernetes | `core.openfeature.dev/namespace/name` | `core.openfeature.dev/default/my-crd` | -| Filepath | `file:path/to/my/flag` | `file:etc/flagd/my-flags.json` | -| Remote | `http(s)://flag-source-url` | `https://my-flags.com/flags` | -| Grpc | `grpc://flag-source-url` | `grpc://my-flags-server` | +| Filepath | `file:path/to/my/flag` | `file:etc/flagd/my-flags.json` | +| Remote | `http(s)://flag-source-url` | `https://my-flags.com/flags` | +| Grpc | `grpc(s)://flag-source-url` | `grpc://my-flags-server` | diff --git a/docs/configuration/flagd_start.md b/docs/configuration/flagd_start.md index 0cc8945f3..ea1ddd109 100644 --- a/docs/configuration/flagd_start.md +++ b/docs/configuration/flagd_start.md @@ -12,6 +12,7 @@ flagd start [flags] -b, --bearer-token string Set a bearer token to use for remote sync -C, --cors-origin strings CORS allowed origins, * will allow all origins -e, --evaluator string DEPRECATED: Set an evaluator e.g. json, yaml/yml.Please note that yaml/yml and json evaluations work the same (yaml/yml files are converted to json internally) (default "json") + -g, --grpc-sync-cert-path string Path to root certificate to be used by TLS enabled grpc sync (grpcs://). If TLS is used and this configuration is ignored, TLS uses the host's root CA set. -h, --help help for start -z, --log-format string Set the logging format, e.g. console or json (default "console") -m, --metrics-port int32 Port to serve metrics on (default 8014) diff --git a/pkg/runtime/from_config.go b/pkg/runtime/from_config.go index 4ff8c4e12..9ef00e590 100644 --- a/pkg/runtime/from_config.go +++ b/pkg/runtime/from_config.go @@ -19,16 +19,18 @@ import ( ) var ( - regCrd *regexp.Regexp - regURL *regexp.Regexp - regGRPC *regexp.Regexp - regFile *regexp.Regexp + regCrd *regexp.Regexp + regGRPC *regexp.Regexp + regGRPCS *regexp.Regexp + regFile *regexp.Regexp + regURL *regexp.Regexp ) func init() { regCrd = regexp.MustCompile("^core.openfeature.dev/") regURL = regexp.MustCompile("^https?://") regGRPC = regexp.MustCompile("^" + grpc.Prefix) + regGRPCS = regexp.MustCompile("^" + grpc.PrefixSecure) regFile = regexp.MustCompile("^file:") } @@ -101,17 +103,18 @@ func (r *Runtime) setSyncImplFromConfig(logger *logger.Logger) error { Cron: cron.New(), }) rtLogger.Debug(fmt.Sprintf("using remote sync-provider for: %q", uri)) - case regGRPC.Match(uriB): + case regGRPC.Match(uriB), regGRPCS.Match(uriB): r.SyncImpl = append(r.SyncImpl, &grpc.Sync{ - Target: grpc.URLToGRPCTarget(uri), + CertPath: r.config.GrpcCertPath, + Source: uri, Logger: logger.WithFields( zap.String("component", "sync"), zap.String("sync", "grpc"), ), }) default: - return fmt.Errorf("invalid sync uri argument: %s, must start with 'file:', 'http(s)://', 'grpc://',"+ - " or 'core.openfeature.dev'", uri) + return fmt.Errorf("invalid sync uri argument: %s, must start with 'file:', 'http(s)://', 'grpc(s)://'"+ + ", or 'core.openfeature.dev'", uri) } } return nil diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 0b9ffefbb..63f0b3fb1 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -38,6 +38,8 @@ type Config struct { RemoteSyncType string SyncBearerToken string + GrpcCertPath string + CORS []string } diff --git a/pkg/sync/grpc/grpc_sync.go b/pkg/sync/grpc/grpc_sync.go index 20d96f485..24c8c82f9 100644 --- a/pkg/sync/grpc/grpc_sync.go +++ b/pkg/sync/grpc/grpc_sync.go @@ -2,11 +2,16 @@ package grpc import ( "context" + "crypto/tls" + "crypto/x509" "fmt" "math" + "os" "strings" "time" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" "buf.build/gen/go/open-feature/flagd/grpc/go/sync/v1/syncv1grpc" @@ -18,9 +23,10 @@ import ( ) const ( - // Prefix for GRPC URL inputs. GRPC does not define a prefix through standard. This prefix helps to differentiate - // remote URLs for REST APIs (i.e - HTTP) from GRPC endpoints. - Prefix = "grpc://" + // Prefix for GRPC URL inputs. GRPC does not define a standard prefix. This prefix helps to differentiate remote + // URLs for REST APIs (i.e - HTTP) from GRPC endpoints. + Prefix = "grpc://" + PrefixSecure = "grpcs://" // Connection retry constants // Back off period is calculated with backOffBase ^ #retry-iteration. However, when #retry-iteration count reach @@ -28,40 +34,66 @@ const ( backOffLimit = 3 backOffBase = 4 constantBackOffDelay = 60 + + tlsVersion = tls.VersionTLS12 ) type Sync struct { - Target string - ProviderID string + CertPath string Logger *logger.Logger + ProviderID string + Source string + + // rpcCon is a reusable grpc client connection. Lazy initialization by waiting for runtime to call Sync + rpcCon *grpc.ClientConn } +// Sync initialize internals and start internal sync implementation func (g *Sync) Sync(ctx context.Context, dataSync chan<- sync.DataSync) error { - options := []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), + tCredentials, err := buildTransportCredentials(g.Source, g.CertPath) + if err != nil { + g.Logger.Error(fmt.Sprintf("error building transport credentials: %s", err.Error())) + return err + } + + target, ok := sourceToGRPCTarget(g.Source) + if !ok { + return fmt.Errorf("invalid grpc source: %s", g.Source) } - // initial dial and connection. Failure here must result in a startup failure - dial, err := grpc.DialContext(ctx, g.Target, options...) + // Derive reusable client connection + g.rpcCon, err = grpc.DialContext(ctx, target, grpc.WithTransportCredentials(tCredentials)) if err != nil { - g.Logger.Error(fmt.Sprintf("error establishing grpc connection: %s", err.Error())) + g.Logger.Error(fmt.Sprintf("error initiating grpc client connection: %s", err.Error())) return err } - serviceClient := syncv1grpc.NewFlagSyncServiceClient(dial) + // Cleanup when exiting the sync + defer func(rpcCon *grpc.ClientConn) { + err := rpcCon.Close() + if err != nil { + g.Logger.Warn(fmt.Sprintf("error while closing the client connection: %s", err.Error())) + } + }(g.rpcCon) + return g.syncInternal(ctx, dataSync) +} + +// syncInternal connects to grpc stream and push updates through sync channel. It attempts to reconnect if connection +// fails. However, initial connection must be successful. This makes sure provided configurations are valid. +func (g *Sync) syncInternal(ctx context.Context, dataSync chan<- sync.DataSync) error { + serviceClient := syncv1grpc.NewFlagSyncServiceClient(g.rpcCon) syncClient, err := serviceClient.SyncFlags(ctx, &v1.SyncFlagsRequest{ProviderId: g.ProviderID}) if err != nil { - g.Logger.Error(fmt.Sprintf("error calling streaming operation: %s", err.Error())) + g.Logger.Error(fmt.Sprintf("error initializing the client: %s", err.Error())) return err } - // initial stream listening err = g.handleFlagSync(syncClient, dataSync) g.Logger.Warn(fmt.Sprintf("error with stream listener: %s", err.Error())) // retry connection establishment for { - syncClient, ok := g.connectWithRetry(ctx, options...) + syncClient, ok := g.connectWithRetry(ctx) if !ok { // We shall exit return nil @@ -79,9 +111,7 @@ func (g *Sync) Sync(ctx context.Context, dataSync chan<- sync.DataSync) error { // a successful connection is established. Caller must not expect an error. Hence, errors are handled, logged // internally. However, if the provided context is done, method exit with a non-ok state which must be verified by the // caller -func (g *Sync) connectWithRetry( - ctx context.Context, options ...grpc.DialOption, -) (syncv1grpc.FlagSyncService_SyncFlagsClient, bool) { +func (g *Sync) connectWithRetry(ctx context.Context) (syncv1grpc.FlagSyncService_SyncFlagsClient, bool) { var iteration int for { @@ -102,22 +132,16 @@ func (g *Sync) connectWithRetry( return nil, false } - g.Logger.Warn(fmt.Sprintf("connection re-establishment attempt in-progress for grpc target: %s", g.Target)) - - dial, err := grpc.DialContext(ctx, g.Target, options...) - if err != nil { - g.Logger.Debug(fmt.Sprintf("error dialing target: %s", err.Error())) - continue - } + g.Logger.Warn(fmt.Sprintf("connection re-establishment attempt in-progress for grpc source: %s", g.Source)) - serviceClient := syncv1grpc.NewFlagSyncServiceClient(dial) + serviceClient := syncv1grpc.NewFlagSyncServiceClient(g.rpcCon) syncClient, err := serviceClient.SyncFlags(ctx, &v1.SyncFlagsRequest{ProviderId: g.ProviderID}) if err != nil { g.Logger.Debug(fmt.Sprintf("error opening service client: %s", err.Error())) continue } - g.Logger.Info(fmt.Sprintf("connection re-established with grpc target: %s", g.Target)) + g.Logger.Info(fmt.Sprintf("connection re-established with grpc source: %s", g.Source)) return syncClient, true } } @@ -134,7 +158,7 @@ func (g *Sync) handleFlagSync(stream syncv1grpc.FlagSyncService_SyncFlagsClient, case v1.SyncState_SYNC_STATE_ALL: dataSync <- sync.DataSync{ FlagData: data.FlagConfiguration, - Source: g.Target, + Source: g.Source, Type: sync.ALL, } @@ -142,7 +166,7 @@ func (g *Sync) handleFlagSync(stream syncv1grpc.FlagSyncService_SyncFlagsClient, case v1.SyncState_SYNC_STATE_ADD: dataSync <- sync.DataSync{ FlagData: data.FlagConfiguration, - Source: g.Target, + Source: g.Source, Type: sync.ADD, } @@ -150,7 +174,7 @@ func (g *Sync) handleFlagSync(stream syncv1grpc.FlagSyncService_SyncFlagsClient, case v1.SyncState_SYNC_STATE_UPDATE: dataSync <- sync.DataSync{ FlagData: data.FlagConfiguration, - Source: g.Target, + Source: g.Source, Type: sync.UPDATE, } @@ -158,7 +182,7 @@ func (g *Sync) handleFlagSync(stream syncv1grpc.FlagSyncService_SyncFlagsClient, case v1.SyncState_SYNC_STATE_DELETE: dataSync <- sync.DataSync{ FlagData: data.FlagConfiguration, - Source: g.Target, + Source: g.Source, Type: sync.DELETE, } @@ -171,14 +195,57 @@ func (g *Sync) handleFlagSync(stream syncv1grpc.FlagSyncService_SyncFlagsClient, } } -// URLToGRPCTarget is a helper to derive GRPC target from a provided URL +// buildTransportCredentials is a helper to build grpc credentials.TransportCredentials based on source and cert path +func buildTransportCredentials(source string, certPath string) (credentials.TransportCredentials, error) { + if strings.Contains(source, Prefix) { + return insecure.NewCredentials(), nil + } + + if !strings.Contains(source, PrefixSecure) { + return nil, fmt.Errorf("invalid source. grpc source must contain prefix %s or %s", Prefix, PrefixSecure) + } + + if certPath == "" { + // Rely on CA certs provided from system + return credentials.NewTLS(&tls.Config{MinVersion: tlsVersion}), nil + } + + // Rely on provided certificate + certBytes, err := os.ReadFile(certPath) + if err != nil { + return nil, err + } + + cp := x509.NewCertPool() + if !cp.AppendCertsFromPEM(certBytes) { + return nil, fmt.Errorf("invalid certificate provided at path: %s", certPath) + } + + return credentials.NewTLS(&tls.Config{ + MinVersion: tlsVersion, + RootCAs: cp, + }), nil +} + +// sourceToGRPCTarget is a helper to derive GRPC target from a provided URL // For example, function returns the target localhost:9090 for the input grpc://localhost:9090 -func URLToGRPCTarget(url string) string { - index := strings.Split(url, Prefix) +func sourceToGRPCTarget(url string) (string, bool) { + var separator string + + switch { + case strings.Contains(url, Prefix): + separator = Prefix + case strings.Contains(url, PrefixSecure): + separator = PrefixSecure + default: + return "", false + } + + index := strings.Split(url, separator) - if len(index) == 2 { - return index[1] + if len(index) == 2 && len(index[1]) != 0 { + return index[1], true } - return index[0] + return "", false } diff --git a/pkg/sync/grpc/grpc_sync_test.go b/pkg/sync/grpc/grpc_sync_test.go index d521372c8..44899740a 100644 --- a/pkg/sync/grpc/grpc_sync_test.go +++ b/pkg/sync/grpc/grpc_sync_test.go @@ -6,7 +6,13 @@ import ( "io" "log" "net" + "os" "testing" + "time" + + "golang.org/x/sync/errgroup" + + "google.golang.org/grpc/credentials/insecure" "buf.build/gen/go/open-feature/flagd/grpc/go/sync/v1/syncv1grpc" v1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/sync/v1" @@ -14,36 +20,91 @@ import ( "github.com/open-feature/flagd/pkg/logger" "github.com/open-feature/flagd/pkg/sync" "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/test/bufconn" ) -func TestUrlToGRPCTarget(t *testing.T) { +const sampleCert = `-----BEGIN CERTIFICATE----- +MIIEnDCCAoQCCQCHcl3hGXwRQzANBgkqhkiG9w0BAQsFADAQMQ4wDAYDVQQDDAVm +bGFnZDAeFw0yMzAyMTAxODM1NDVaFw0zMzAyMDcxODM1NDVaMBAxDjAMBgNVBAMM +BWZsYWdkMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAwDLEAUti/kG9 +MhJLtO7oAy7diHxWKDFmsIHrE+z2IzTxjXxVHQLv1HiYB/UN75y7qlb3MwvzSc+C +BoLuoiM0PDiMio9/o9X5j0U+v3H1JpUU5LardkvsprFqJWmHF+D7aRdM0LBLn2X6 +HQOhSnPyH9Qjl2l2tyPiPTZ6g0i2+rXZsNUoTs4fm6ThhZ0LeXR8KDmCTun3ze1d +hXA7ydxwILH2OVc+Wnzl30+BRvOiLQbc9nYnwSREFeIy8sFbhrTHqSNn3eY79ssZ +T6f4tN3jEV1d7NqoFk9KFLJKJhMt7smMB9NLwVWi581Zj1krYirNlP6mtmPrn3kJ +lsgT15kFftShMVcYFSHqOSLiy4SspHGK8KJaFoEVx0wp/weRwrWXi6vWg7tuHATH +fw7gW/9CyV+ylc0pJ002wtPAgzJYUaOrna0R2r3yQsSzRcDnqsm4FLkPHLoyjrwQ +vshKcEqjhGml1M+lTDEo3RO5ZoQ3ZN2AZKPDrK2zGG4wFJjHRu9FtutOEZkYYOzA +emTQWW8US3q8WVQqGl/EwQqzXk9Lco7uhLdXmqVOvAi6z01gehQJPnjhH7iqAPVp +1tlOBHit1F3sTAQIO/2zff3LCKiD2d27KINh4aFEyDbDmglPA8VPO3BMQVSjFlxj +K1s2G1IDBixXK76VmBP+ZpvxOaQtYIUCAwEAATANBgkqhkiG9w0BAQsFAAOCAgEA +K9+wnl5gpkfNBa+OSxlhOn3CKhcaW/SWZ4aLw2yK1NZNnNjpwUcLQScUDBKDoJJR +5roc3PIImX7hdnobZWqFhD23laaAlu5XLk9P7n51uMEiNjQQc2WaaBZDTRJfki1C +MvPskXqptgPsVyuPJc0DxfaCz7pDYjq/CtJ+osaj404P5mlO1QJ8W91QSx+aq2x4 +uUTUWuyr/8flIcxiX0o8VTb2LcUvWZBMGa3CdeLnPHrOjovfjJFy0Ysk3SGEACLL +9mpbNbv23v9UXVfyFffHpyzvyUJIOsNXG0O1AYf5t9bukqHolGR/RQUN4yGd3M62 +mFR5bOST36DjNSzTrx1eyCLv22+h9VVlWFPrebFnq1W5SSi8PtsGSMjhvX7dB1kS +t0yJtlj2HwBAvI1zVKG76q6neSU51UXFQUbO0OA0sxjicEOlNfXnShM/kY2lobpX +hrCysWpqoSS0S3UBvmuRiraLWkP1KueC0XHoAi8yuwMAdM6Y+h2OJpnO0PdpUmrp +lAqdxbyICnB1Nsm5QGGm6Pxd8lEbQ9ZSwFjgqApjT2zVhuaaUC7jdlEP1H5snt9n +8FQR06lrzGyW04ud9pd6MXJup1oghAlvnzXioAH2Az0IXcHvqUGZQattFv27OXqj +QZ6ayNO119SNscvC6Qe9GLlbBEHDQWKPiftnS2Mh6Do= +-----END CERTIFICATE-----` + +func TestSourceToGRPCTarget(t *testing.T) { tests := []struct { name string url string want string + ok bool }{ { name: "With Prefix", url: "grpc://test.com/endpoint", want: "test.com/endpoint", + ok: true, }, { - name: "Without Prefix", - url: "test.com/endpoint", + name: "With secure Prefix", + url: "grpcs://test.com/endpoint", want: "test.com/endpoint", + ok: true, }, { - name: "Empty is empty", + name: "Empty is error", url: "", want: "", + ok: false, + }, + { + name: "Invalid is error", + url: "https://test.com/endpoint", + want: "", + ok: false, + }, + { + name: "Prefix is not enough I", + url: Prefix, + want: "", + ok: false, + }, + { + name: "Prefix is not enough II", + url: PrefixSecure, + want: "", + ok: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := URLToGRPCTarget(tt.url); got != tt.want { - t.Errorf("URLToGRPCTarget() = %v, want %v", got, tt.want) + got, ok := sourceToGRPCTarget(tt.url) + + if tt.ok != ok { + t.Errorf("URLToGRPCTarget() returned = %v, want %v", ok, tt.ok) + } + + if got != tt.want { + t.Errorf("URLToGRPCTarget() returned = %v, want %v", got, tt.want) } }) } @@ -51,7 +112,7 @@ func TestUrlToGRPCTarget(t *testing.T) { func TestSync_BasicFlagSyncStates(t *testing.T) { grpcSyncImpl := Sync{ - Target: "grpc://test", + Source: "grpc://test", ProviderID: "", Logger: logger.NewLogger(nil, false), } @@ -123,7 +184,7 @@ func TestSync_BasicFlagSyncStates(t *testing.T) { } func Test_StreamListener(t *testing.T) { - const target = "localBufCon" + const target = "grpc://test" tests := []struct { name string @@ -228,7 +289,7 @@ func Test_StreamListener(t *testing.T) { go serve(&bufServer) grpcSync := Sync{ - Target: target, + Source: target, ProviderID: "", Logger: logger.NewLogger(nil, false), } @@ -281,6 +342,251 @@ func Test_StreamListener(t *testing.T) { } } +func Test_BuildTCredentials(t *testing.T) { + // "insecure" is a hardcoded term at insecure.NewCredentials + const insecure = "insecure" + // "tls" is a hardcoded term at tlsCreds.Info + const tls = "tls" + // local test file with valid certificate + const validCertFile = "valid.cert" + // local test file with invalid certificate + const invalidCertFile = "invalid.cert" + + // init cert files for tests & cleanup with a deffer + err := os.WriteFile(validCertFile, []byte(sampleCert), 0o600) + if err != nil { + t.Errorf("error creating valid certificate file: %s", err) + } + + err = os.WriteFile(invalidCertFile, []byte("--certificate--"), 0o600) + if err != nil { + t.Errorf("error creating invalid certificate file: %s", err) + } + + defer func() { + errV := os.Remove(validCertFile) + errI := os.Remove(invalidCertFile) + if errV != nil || errI != nil { + t.Errorf("error removing cerificate files: %v, %v", errV, errI) + } + }() + + tests := []struct { + name string + source string + certPath string + expectSecProto string + error bool + }{ + { + name: "Insecure source results in insecure connection", + source: Prefix + "some.domain", + certPath: "", + expectSecProto: insecure, + }, + { + name: "Secure source results in secure connection", + source: PrefixSecure + "some.domain", + certPath: validCertFile, + expectSecProto: tls, + }, + { + name: "Secure source with no certificate results in secure connection", + source: PrefixSecure + "some.domain", + expectSecProto: tls, + }, + { + name: "Invalid cert path results in error", + source: PrefixSecure + "some.domain", + certPath: "invalid/path", + error: true, + }, + { + name: "Invalid certificate results in error", + source: PrefixSecure + "some.domain", + certPath: invalidCertFile, + error: true, + }, + { + name: "Invalid prefix results in error", + source: "http://some.domain", + error: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tCred, err := buildTransportCredentials(test.source, test.certPath) + + if test.error { + if err == nil { + t.Errorf("test expected non error execution. But resulted in an error: %s", err.Error()) + } + + // Test expected an error. Nothing to validate further + return + } + + // check for errors to be certain + if err != nil { + t.Errorf("unexpected error: %s", err.Error()) + } + + protoc := tCred.Info().SecurityProtocol + if protoc != test.expectSecProto { + t.Errorf("buildTransportCredentials() returned protocol= %v, want %v", protoc, test.expectSecProto) + } + }) + } +} + +func Test_SyncInternal(t *testing.T) { + responseState := sync.ALL + target := "local" + + bufListener := bufconn.Listen(1) + // buffer based server. response ignored purposefully + bServer := bufferedServer{listener: bufListener, mockResponses: []serverPayload{ + { + flags: "{}", + state: v1.SyncState_SYNC_STATE_ALL, + }, + }} + + // generate a client connection backed with bufconn + clientConn, err := grpc.Dial(target, + grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { + return bufListener.DialContext(ctx) + }), + grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Errorf("error initiating the connection: %s", err.Error()) + } + + // minimal sync provider + grpcSync := Sync{ + Logger: logger.NewLogger(nil, false), + rpcCon: clientConn, + } + + // error group with a timeout to disconnect the server + tCtx, cancelFunc := context.WithTimeout(context.Background(), 2*time.Second) + defer cancelFunc() + + group, _ := errgroup.WithContext(tCtx) + + group.Go(func() error { + serve(&bServer) + return nil + }) + + syncChan := make(chan sync.DataSync) + + // Start the grpc sync + go func() { + err := grpcSync.syncInternal(context.Background(), syncChan) + if err != nil { + t.Errorf("Error: %s", err.Error()) + } + }() + + select { + case <-tCtx.Done(): + cancelFunc() + t.Error("Context timeout") + case rsp := <-syncChan: + if rsp.Type != responseState { + t.Errorf("expected response: %s, but got: %s", responseState, rsp.Type) + } + } + + // test must complete within an acceptable timeframe + tCtx, cancelFunc = context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFunc() + + // Restart the server + go serve(&bServer) + + // validate connection re-establishment + select { + case <-tCtx.Done(): + cancelFunc() + t.Error("Context timeout") + case rsp := <-syncChan: + if rsp.Type != responseState { + t.Errorf("expected response: %s, but got: %s", responseState, rsp.Type) + } + } +} + +func Test_ConnectWithRetry(t *testing.T) { + target := "local" + bufListener := bufconn.Listen(1) + // buffer based server. response ignored purposefully + bServer := bufferedServer{listener: bufListener} + + // generate a client connection backed with bufconn + clientConn, err := grpc.Dial(target, + grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { + return bufListener.DialContext(ctx) + }), + grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Errorf("error initiating the connection: %s", err.Error()) + } + + // minimal sync provider + grpcSync := Sync{ + Logger: logger.NewLogger(nil, false), + rpcCon: clientConn, + } + + // test must complete within an acceptable timeframe + tCtx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Second) + defer cancelFunc() + + // channel for connection + clientChan := make(chan syncv1grpc.FlagSyncService_SyncFlagsClient) + + // start connection retry attempts + go func() { + client, ok := grpcSync.connectWithRetry(tCtx) + if !ok { + clientChan <- nil + } + + clientChan <- client + }() + + // Wait for retries in the background + select { + case <-time.After(1 * time.Second): + break + case <-tCtx.Done(): + // We should not reach this with correct test setup, but in case we do + cancelFunc() + t.Errorf("timeout occurred while waiting for conditions to fulfil") + } + + // start the server - fulfill connection after the wait + go serve(&bServer) + + // Wait for client connection + var con syncv1grpc.FlagSyncService_SyncFlagsClient + + select { + case con = <-clientChan: + break + case <-tCtx.Done(): + cancelFunc() + t.Errorf("timeout occurred while waiting for conditions to fulfil") + } + + if con == nil { + t.Errorf("received a nil value when expecting a non-nil return") + } +} + // Mock implementations type SimpleRecvMock struct { @@ -292,7 +598,7 @@ func (s *SimpleRecvMock) Recv() (*v1.SyncFlagsResponse, error) { return &s.mockResponse, nil } -// serve serves a bufferedServer +// serve serves a bufferedServer. This is a blocking call func serve(bServer *bufferedServer) { server := grpc.NewServer()