From 51be4e05bba628717aea4b28208d8eed931dec2f Mon Sep 17 00:00:00 2001 From: Kavindu Dodanduwa Date: Thu, 16 Feb 2023 15:29:26 -0800 Subject: [PATCH 1/2] initialize server poc Signed-off-by: Kavindu Dodanduwa --- cmd/{start.go => provider.go} | 111 +++++++------ cmd/root.go | 3 +- cmd/server.go | 79 +++++++++ .../{runtime.go => providerRuntime.go} | 0 pkg/runtime/serverRuntime.go | 107 ++++++++++++ pkg/server/server.go | 156 ++++++++++++++++++ 6 files changed, 403 insertions(+), 53 deletions(-) rename cmd/{start.go => provider.go} (62%) create mode 100644 cmd/server.go rename pkg/runtime/{runtime.go => providerRuntime.go} (100%) create mode 100644 pkg/runtime/serverRuntime.go create mode 100644 pkg/server/server.go diff --git a/cmd/start.go b/cmd/provider.go similarity index 62% rename from cmd/start.go rename to cmd/provider.go index d614a70e8..a13fb804b 100644 --- a/cmd/start.go +++ b/cmd/provider.go @@ -1,15 +1,15 @@ package cmd import ( - "log" - "strings" - "github.com/open-feature/flagd/pkg/logger" "github.com/open-feature/flagd/pkg/runtime" "github.com/spf13/cobra" + "github.com/spf13/pflag" "github.com/spf13/viper" "go.uber.org/zap" "go.uber.org/zap/zapcore" + "log" + "strings" ) const ( @@ -27,9 +27,21 @@ const ( uriFlagName = "uri" ) -func init() { - flags := startCmd.Flags() +// NewProviderCmd is the command to start flagd as a provider +func NewProviderCmd() *cobra.Command { + flagd := &cobra.Command{ + Use: "start", + Short: "Start flagd", + Long: ``, + Run: runProvider, + } + + setupProvider(flagd.Flags()) + return flagd +} +// setupProvider setup flags of the command +func setupProvider(flags *pflag.FlagSet) { // allows environment variables to use _ instead of - viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) // sync-provider-args becomes SYNC_PROVIDER_ARGS viper.SetEnvPrefix("FLAGD") // port becomes FLAGD_PORT @@ -72,54 +84,49 @@ func init() { _ = viper.BindPFlag(uriFlagName, flags.Lookup(uriFlagName)) } -// startCmd represents the start command -var startCmd = &cobra.Command{ - Use: "start", - Short: "Start flagd", - Long: ``, - Run: func(cmd *cobra.Command, args []string) { - // Configure loggers ------------------------------------------------------- - var level zapcore.Level - var err error - if Debug { - level = zapcore.DebugLevel - } else { - level = zapcore.InfoLevel - } - l, err := logger.NewZapLogger(level, viper.GetString(logFormatFlagName)) - if err != nil { - log.Fatalf("can't initialize zap logger: %v", err) - } - logger := logger.NewLogger(l, Debug) - rtLogger := logger.WithFields(zap.String("component", "start")) +// runProvider starts the provider implementation +func runProvider(cmd *cobra.Command, args []string) { + // Configure loggers ------------------------------------------------------- + var level zapcore.Level + var err error + if Debug { + level = zapcore.DebugLevel + } else { + level = zapcore.InfoLevel + } + l, err := logger.NewZapLogger(level, viper.GetString(logFormatFlagName)) + if err != nil { + log.Fatalf("can't initialize zap logger: %v", err) + } + logger := logger.NewLogger(l, Debug) + rtLogger := logger.WithFields(zap.String("component", "start")) - if viper.GetString(syncProviderFlagName) != "" { - rtLogger.Warn("DEPRECATED: The --sync-provider flag has been deprecated. " + - "Docs: https://github.com/open-feature/flagd/blob/main/docs/configuration/configuration.md") - } + if viper.GetString(syncProviderFlagName) != "" { + rtLogger.Warn("DEPRECATED: The --sync-provider flag has been deprecated. " + + "Docs: https://github.com/open-feature/flagd/blob/main/docs/configuration/configuration.md") + } - if viper.GetString(evaluatorFlagName) != "json" { - rtLogger.Warn("DEPRECATED: The --evaluator flag has been deprecated. " + - "Docs: https://github.com/open-feature/flagd/blob/main/docs/configuration/configuration.md") - } - // Build Runtime ----------------------------------------------------------- - rt, err := runtime.FromConfig(logger, runtime.Config{ - CORS: viper.GetStringSlice(corsFlagName), - MetricsPort: viper.GetInt32(metricsPortFlagName), - ProviderArgs: viper.GetStringMapString(providerArgsFlagName), - ServiceCertPath: viper.GetString(serverCertPathFlagName), - ServiceKeyPath: viper.GetString(serverKeyPathFlagName), - ServicePort: viper.GetInt32(portFlagName), - ServiceSocketPath: viper.GetString(socketPathFlagName), - SyncBearerToken: viper.GetString(bearerTokenFlagName), - SyncURI: viper.GetStringSlice(uriFlagName), - }) - if err != nil { - rtLogger.Fatal(err.Error()) - } + if viper.GetString(evaluatorFlagName) != "json" { + rtLogger.Warn("DEPRECATED: The --evaluator flag has been deprecated. " + + "Docs: https://github.com/open-feature/flagd/blob/main/docs/configuration/configuration.md") + } + // Build Runtime ----------------------------------------------------------- + rt, err := runtime.FromConfig(logger, runtime.Config{ + CORS: viper.GetStringSlice(corsFlagName), + MetricsPort: viper.GetInt32(metricsPortFlagName), + ProviderArgs: viper.GetStringMapString(providerArgsFlagName), + ServiceCertPath: viper.GetString(serverCertPathFlagName), + ServiceKeyPath: viper.GetString(serverKeyPathFlagName), + ServicePort: viper.GetInt32(portFlagName), + ServiceSocketPath: viper.GetString(socketPathFlagName), + SyncBearerToken: viper.GetString(bearerTokenFlagName), + SyncURI: viper.GetStringSlice(uriFlagName), + }) + if err != nil { + rtLogger.Fatal(err.Error()) + } - if err := rt.Start(); err != nil { - rtLogger.Fatal(err.Error()) - } - }, + if err := rt.Start(); err != nil { + rtLogger.Fatal(err.Error()) + } } diff --git a/cmd/root.go b/cmd/root.go index c70638f9b..dd15746d8 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -48,7 +48,8 @@ func init() { // will be global for your application. rootCmd.PersistentFlags().BoolVarP(&Debug, "debug", "x", false, "verbose logging") rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.agent.yaml)") - rootCmd.AddCommand(startCmd) + rootCmd.AddCommand(NewProviderCmd()) + rootCmd.AddCommand(NewServerCmd()) rootCmd.AddCommand(versionCmd) } diff --git a/cmd/server.go b/cmd/server.go new file mode 100644 index 000000000..66ac7bdcb --- /dev/null +++ b/cmd/server.go @@ -0,0 +1,79 @@ +package cmd + +import ( + "fmt" + "github.com/open-feature/flagd/pkg/logger" + "github.com/open-feature/flagd/pkg/runtime" + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "github.com/spf13/viper" + "go.uber.org/zap/zapcore" + "log" +) + +const ( + address = "address" + secure = "secure" + certPath = "cert-path" + keyPath = "key-path" + sources = "source" +) + +// NewServerCmd is the command to start flagd in server mode +func NewServerCmd() *cobra.Command { + flagdCmd := &cobra.Command{ + Use: "server", + Short: "Start flagd as a server", + Run: runServer, + } + + setupServer(flagdCmd.Flags()) + return flagdCmd +} + +// setupServer setup flags of the command +func setupServer(flags *pflag.FlagSet) { + flags.StringP(address, "p", "localhost:9090", "Path this server binds to") + flags.BoolP(secure, "s", false, "Start secure server") + flags.StringP(certPath, "c", "", "TLS certificate path") + flags.StringP(keyPath, "k", "", "TLS key path of the certificate") + flags.StringP(sources, "f", "", "CRD with feature flag configurations") + + _ = viper.BindPFlag(address, flags.Lookup(address)) + _ = viper.BindPFlag(secure, flags.Lookup(secure)) + _ = viper.BindPFlag(certPath, flags.Lookup(certPath)) + _ = viper.BindPFlag(keyPath, flags.Lookup(keyPath)) + _ = viper.BindPFlag(sources, flags.Lookup(sources)) +} + +func runServer(cmd *cobra.Command, args []string) { + zapLogger, err := logger.NewZapLogger(zapcore.DebugLevel, "console") + if err != nil { + log.Fatalf("error setting up the logger: %s", err) + } + + logWrapper := logger.NewLogger(zapLogger, true) + + err = viper.BindPFlags(pflag.CommandLine) + if err != nil { + logWrapper.Fatal(fmt.Sprintf("error parsing flags: %s", err.Error())) + } + + serverConfig := runtime.ServerConfig{ + Address: viper.GetString(address), + Secure: viper.GetBool(secure), + CertPath: viper.GetString(certPath), + KeyPath: viper.GetString(keyPath), + SyncSources: viper.GetStringSlice(sources), + } + + serverRuntime, err := runtime.NewServerRuntime(serverConfig, logWrapper) + if err != nil { + logWrapper.Fatal(fmt.Sprintf("error creating the server runtime: %s", err.Error())) + } + + err = serverRuntime.Start() + if err != nil { + logWrapper.Fatal(fmt.Sprintf("error from server runtime: %s", err.Error())) + } +} diff --git a/pkg/runtime/runtime.go b/pkg/runtime/providerRuntime.go similarity index 100% rename from pkg/runtime/runtime.go rename to pkg/runtime/providerRuntime.go diff --git a/pkg/runtime/serverRuntime.go b/pkg/runtime/serverRuntime.go new file mode 100644 index 000000000..f2053af65 --- /dev/null +++ b/pkg/runtime/serverRuntime.go @@ -0,0 +1,107 @@ +package runtime + +import ( + "context" + "errors" + "fmt" + "github.com/open-feature/flagd/pkg/server" + "github.com/open-feature/flagd/pkg/sync/kubernetes" + "go.uber.org/zap" + "os" + "os/signal" + "syscall" + + "golang.org/x/sync/errgroup" + + "github.com/open-feature/flagd/pkg/logger" + "github.com/open-feature/flagd/pkg/sync" +) + +type ServerConfig struct { + Address string + Secure bool + CertPath string + KeyPath string + SyncSources []string +} + +type ServerRuntime struct { + syncImpl []sync.ISync + logger *logger.Logger + config ServerConfig +} + +func NewServerRuntime(config ServerConfig, rootLogger *logger.Logger) (*ServerRuntime, error) { + impls, err := buildSyncImpls(config.SyncSources, rootLogger) + if err != nil { + return nil, err + } + + return &ServerRuntime{ + syncImpl: impls, + logger: rootLogger.WithFields(zap.String("component", "Server Runtime")), + config: config, + }, nil +} + +func (sr *ServerRuntime) Start() error { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancel() + + // Build server + s := server.Server{ + Logger: sr.logger.WithFields(zap.String("component", "Server")), + Secure: sr.config.Secure, + CertPath: sr.config.CertPath, + KeyPath: sr.config.KeyPath, + Address: sr.config.Address, + } + + g, gCtx := errgroup.WithContext(ctx) + dataSync := make(chan sync.DataSync, len(sr.syncImpl)) + + // Start server + g.Go(func() error { + return s.Listen(gCtx, dataSync) + }) + + // Start sync providers + for _, s := range sr.syncImpl { + p := s + g.Go(func() error { + return p.Sync(gCtx, dataSync) + }) + } + + <-gCtx.Done() + if err := g.Wait(); err != nil { + return err + } + + return nil +} + +func buildSyncImpls(sources []string, rootLogger *logger.Logger) ([]sync.ISync, error) { + if len(sources) == 0 { + return nil, errors.New("no sync provider sources provided") + } + + var syncs []sync.ISync + for _, source := range sources { + switch sourceBytes := []byte(source); { + case regCrd.Match(sourceBytes): + syncs = append(syncs, &kubernetes.Sync{ + Logger: rootLogger.WithFields( + zap.String("component", "sync"), + zap.String("sync", "kubernetes"), + ), + URI: regCrd.ReplaceAllString(source, ""), + }) + rootLogger.Debug(fmt.Sprintf("using kubernetes sync-provider for: %s", source)) + default: + return nil, fmt.Errorf("server supports only crd sync providers. recieved : %s", source) + } + } + + return syncs, nil +} diff --git a/pkg/server/server.go b/pkg/server/server.go new file mode 100644 index 000000000..d1a8a76b2 --- /dev/null +++ b/pkg/server/server.go @@ -0,0 +1,156 @@ +package server + +import ( + "buf.build/gen/go/open-feature/flagd/grpc/go/sync/v1/syncv1grpc" + v1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/sync/v1" + "context" + "crypto/tls" + "fmt" + "github.com/open-feature/flagd/pkg/logger" + "github.com/open-feature/flagd/pkg/sync" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "net" +) + +type Server struct { + Logger *logger.Logger + + Secure bool + CertPath string + KeyPath string + Address string +} + +func (s *Server) Listen(ctx context.Context, sync <-chan sync.DataSync) error { + listen, err := net.Listen("tcp", s.Address) + if err != nil { + s.Logger.Error(fmt.Sprintf("Error when listening to address : %s\n", err.Error())) + return err + } + + options, err := s.buildOptions() + if err != nil { + s.Logger.Error(fmt.Sprintf("Error building dial options : %s\n", err.Error())) + return err + } + + server := grpc.NewServer(options...) + defer server.Stop() + + listener := NewListener() + + group, lcCtxt := errgroup.WithContext(ctx) + + group.Go(func() error { + for { + select { + case data := <-sync: + fmt.Printf("New data :%s", data) + listener.persist(data.FlagData) + case <-ctx.Done(): + return nil + } + } + }) + + syncv1grpc.RegisterFlagSyncServiceServer(server, &internal{ + Logger: s.Logger, + Ls: &listener, + }) + + group.Go(func() error { + err = server.Serve(listen) + if err != nil { + s.Logger.Error(fmt.Sprintf("Error when starting the server : %s\n", err.Error())) + return err + } + + return nil + }) + + <-lcCtxt.Done() + err = group.Wait() + if err != nil { + return err + } + + return nil +} + +func (s *Server) buildOptions() ([]grpc.ServerOption, error) { + var options []grpc.ServerOption + + if !s.Secure { + return options, nil + } + + keyPair, err := tls.LoadX509KeyPair(s.CertPath, s.KeyPath) + if err != nil { + return nil, err + } + + options = append(options, grpc.Creds(credentials.NewServerTLSFromCert(&keyPair))) + + return options, nil +} + +type internal struct { + Logger *logger.Logger + Ls *Listener +} + +func (i *internal) SyncFlags(req *v1.SyncFlagsRequest, stream syncv1grpc.FlagSyncService_SyncFlagsServer) error { + i.Logger.Info(fmt.Sprintf("Request with ID: %s", req.ProviderId)) + + // Initially send the current state + err := stream.Send(&v1.SyncFlagsResponse{ + FlagConfiguration: i.Ls.currentState(), + State: v1.SyncState_SYNC_STATE_ALL, + }) + if err != nil { + return err + } + + emit := i.Ls.getEmit() + + // Then wait for updates + for { + select { + case _ = <-emit: + stream.Send(&v1.SyncFlagsResponse{ + FlagConfiguration: i.Ls.currentState(), + State: v1.SyncState_SYNC_STATE_ALL, + }) + } + } + +} + +// todo we need a sync mechanism better than listener + +type Listener struct { + emit chan string + data string +} + +func NewListener() Listener { + return Listener{ + emit: make(chan string), + data: "", + } +} + +func (s *Listener) persist(input string) { + s.data = input + s.emit <- s.data +} + +func (s *Listener) getEmit() <-chan string { + return s.emit +} + +func (s *Listener) currentState() string { + return s.data +} From 62dd7bba1dcdf38b56f48a284703590b6dc4e21f Mon Sep 17 00:00:00 2001 From: Kavindu Dodanduwa Date: Tue, 21 Feb 2023 14:22:02 -0800 Subject: [PATCH 2/2] finalize POC Signed-off-by: Kavindu Dodanduwa --- cmd/provider.go | 5 +- cmd/server.go | 21 +++-- docs/configuration/flagd.md | 1 + docs/configuration/flagd_server.md | 30 ++++++++ pkg/runtime/serverRuntime.go | 70 ++++++++--------- pkg/server/server.go | 118 ++++++++++++++--------------- pkg/server/store.go | 66 ++++++++++++++++ 7 files changed, 203 insertions(+), 108 deletions(-) create mode 100644 docs/configuration/flagd_server.md create mode 100644 pkg/server/store.go diff --git a/cmd/provider.go b/cmd/provider.go index a13fb804b..c6ecab0de 100644 --- a/cmd/provider.go +++ b/cmd/provider.go @@ -1,6 +1,9 @@ package cmd import ( + "log" + "strings" + "github.com/open-feature/flagd/pkg/logger" "github.com/open-feature/flagd/pkg/runtime" "github.com/spf13/cobra" @@ -8,8 +11,6 @@ import ( "github.com/spf13/viper" "go.uber.org/zap" "go.uber.org/zap/zapcore" - "log" - "strings" ) const ( diff --git a/cmd/server.go b/cmd/server.go index 66ac7bdcb..35efac5b2 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -2,13 +2,14 @@ package cmd import ( "fmt" + "log" + "github.com/open-feature/flagd/pkg/logger" "github.com/open-feature/flagd/pkg/runtime" "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/spf13/viper" "go.uber.org/zap/zapcore" - "log" ) const ( @@ -16,7 +17,7 @@ const ( secure = "secure" certPath = "cert-path" keyPath = "key-path" - sources = "source" + source = "source" ) // NewServerCmd is the command to start flagd in server mode @@ -27,26 +28,32 @@ func NewServerCmd() *cobra.Command { Run: runServer, } - setupServer(flagdCmd.Flags()) + setupServer(flagdCmd) return flagdCmd } // setupServer setup flags of the command -func setupServer(flags *pflag.FlagSet) { +func setupServer(cmd *cobra.Command) { + flags := cmd.Flags() + flags.StringP(address, "p", "localhost:9090", "Path this server binds to") + flags.BoolP(secure, "s", false, "Start secure server") flags.StringP(certPath, "c", "", "TLS certificate path") flags.StringP(keyPath, "k", "", "TLS key path of the certificate") - flags.StringP(sources, "f", "", "CRD with feature flag configurations") + cmd.MarkFlagsRequiredTogether(secure, certPath, keyPath) + + flags.StringP(source, "f", "", "CRD with feature flag configurations") _ = viper.BindPFlag(address, flags.Lookup(address)) _ = viper.BindPFlag(secure, flags.Lookup(secure)) _ = viper.BindPFlag(certPath, flags.Lookup(certPath)) _ = viper.BindPFlag(keyPath, flags.Lookup(keyPath)) - _ = viper.BindPFlag(sources, flags.Lookup(sources)) + _ = viper.BindPFlag(source, flags.Lookup(source)) } func runServer(cmd *cobra.Command, args []string) { + // todo align log format with provider runtime zapLogger, err := logger.NewZapLogger(zapcore.DebugLevel, "console") if err != nil { log.Fatalf("error setting up the logger: %s", err) @@ -64,7 +71,7 @@ func runServer(cmd *cobra.Command, args []string) { Secure: viper.GetBool(secure), CertPath: viper.GetString(certPath), KeyPath: viper.GetString(keyPath), - SyncSources: viper.GetStringSlice(sources), + SyncSources: viper.GetString(source), } serverRuntime, err := runtime.NewServerRuntime(serverConfig, logWrapper) diff --git a/docs/configuration/flagd.md b/docs/configuration/flagd.md index 90cd032fd..0e49ac803 100644 --- a/docs/configuration/flagd.md +++ b/docs/configuration/flagd.md @@ -12,6 +12,7 @@ Flagd is a simple command line tool for fetching and presenting feature flags to ### SEE ALSO +* [flagd server](flagd_server.md) - Start flagd as a server * [flagd start](flagd_start.md) - Start flagd * [flagd version](flagd_version.md) - Print the version number of FlagD diff --git a/docs/configuration/flagd_server.md b/docs/configuration/flagd_server.md new file mode 100644 index 000000000..712d41262 --- /dev/null +++ b/docs/configuration/flagd_server.md @@ -0,0 +1,30 @@ +## flagd server + +Start flagd as a server + +``` +flagd server [flags] +``` + +### Options + +``` + -p, --address string Path this server binds to (default "localhost:9090") + -c, --cert-path string TLS certificate path + -h, --help help for server + -k, --key-path string TLS key path of the certificate + -s, --secure Start secure server + -f, --source string CRD with feature flag configurations +``` + +### Options inherited from parent commands + +``` + --config string config file (default is $HOME/.agent.yaml) + -x, --debug verbose logging +``` + +### SEE ALSO + +* [flagd](flagd.md) - Flagd is a simple command line tool for fetching and presenting feature flags to services. It is designed to conform to Open Feature schema for flag definitions. + diff --git a/pkg/runtime/serverRuntime.go b/pkg/runtime/serverRuntime.go index f2053af65..6e546a373 100644 --- a/pkg/runtime/serverRuntime.go +++ b/pkg/runtime/serverRuntime.go @@ -4,13 +4,14 @@ import ( "context" "errors" "fmt" - "github.com/open-feature/flagd/pkg/server" - "github.com/open-feature/flagd/pkg/sync/kubernetes" - "go.uber.org/zap" "os" "os/signal" "syscall" + "github.com/open-feature/flagd/pkg/server" + "github.com/open-feature/flagd/pkg/sync/kubernetes" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" "github.com/open-feature/flagd/pkg/logger" @@ -22,25 +23,25 @@ type ServerConfig struct { Secure bool CertPath string KeyPath string - SyncSources []string + SyncSources string } type ServerRuntime struct { - syncImpl []sync.ISync - logger *logger.Logger - config ServerConfig + syncProvider sync.ISync + logger *logger.Logger + config ServerConfig } func NewServerRuntime(config ServerConfig, rootLogger *logger.Logger) (*ServerRuntime, error) { - impls, err := buildSyncImpls(config.SyncSources, rootLogger) + syncImpl, err := buildSyncImpl(config.SyncSources, rootLogger) if err != nil { return nil, err } return &ServerRuntime{ - syncImpl: impls, - logger: rootLogger.WithFields(zap.String("component", "Server Runtime")), - config: config, + syncProvider: syncImpl, + logger: rootLogger.WithFields(zap.String("component", "Server Runtime")), + config: config, }, nil } @@ -58,22 +59,18 @@ func (sr *ServerRuntime) Start() error { } g, gCtx := errgroup.WithContext(ctx) - dataSync := make(chan sync.DataSync, len(sr.syncImpl)) + dataSync := make(chan sync.DataSync) // Start server g.Go(func() error { return s.Listen(gCtx, dataSync) }) - // Start sync providers - for _, s := range sr.syncImpl { - p := s - g.Go(func() error { - return p.Sync(gCtx, dataSync) - }) - } + // Start sync provider + g.Go(func() error { + return sr.syncProvider.Sync(gCtx, dataSync) + }) - <-gCtx.Done() if err := g.Wait(); err != nil { return err } @@ -81,27 +78,22 @@ func (sr *ServerRuntime) Start() error { return nil } -func buildSyncImpls(sources []string, rootLogger *logger.Logger) ([]sync.ISync, error) { - if len(sources) == 0 { +func buildSyncImpl(source string, rootLogger *logger.Logger) (sync.ISync, error) { + if len(source) == 0 { return nil, errors.New("no sync provider sources provided") } - var syncs []sync.ISync - for _, source := range sources { - switch sourceBytes := []byte(source); { - case regCrd.Match(sourceBytes): - syncs = append(syncs, &kubernetes.Sync{ - Logger: rootLogger.WithFields( - zap.String("component", "sync"), - zap.String("sync", "kubernetes"), - ), - URI: regCrd.ReplaceAllString(source, ""), - }) - rootLogger.Debug(fmt.Sprintf("using kubernetes sync-provider for: %s", source)) - default: - return nil, fmt.Errorf("server supports only crd sync providers. recieved : %s", source) - } + switch sourceBytes := []byte(source); { + case regCrd.Match(sourceBytes): + rootLogger.Debug(fmt.Sprintf("using kubernetes sync-provider for: %s", source)) + return &kubernetes.Sync{ + Logger: rootLogger.WithFields( + zap.String("component", "sync"), + zap.String("sync", "kubernetes"), + ), + URI: regCrd.ReplaceAllString(source, ""), + }, nil + default: + return nil, fmt.Errorf("server supports only crd sync provider, but received : %s", source) } - - return syncs, nil } diff --git a/pkg/server/server.go b/pkg/server/server.go index d1a8a76b2..6c2df35b0 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1,17 +1,28 @@ package server import ( - "buf.build/gen/go/open-feature/flagd/grpc/go/sync/v1/syncv1grpc" - v1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/sync/v1" "context" "crypto/tls" "fmt" + "net" + "time" + + "buf.build/gen/go/open-feature/flagd/grpc/go/sync/v1/syncv1grpc" + v1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/sync/v1" + "github.com/open-feature/flagd/pkg/logger" "github.com/open-feature/flagd/pkg/sync" "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/credentials" - "net" +) + +const ( + // type of the listener + serverListenType = "tcp" + + // Time between server to client pings + pingDelay time.Duration = 20 * time.Second ) type Server struct { @@ -24,22 +35,19 @@ type Server struct { } func (s *Server) Listen(ctx context.Context, sync <-chan sync.DataSync) error { - listen, err := net.Listen("tcp", s.Address) - if err != nil { - s.Logger.Error(fmt.Sprintf("Error when listening to address : %s\n", err.Error())) - return err - } - options, err := s.buildOptions() if err != nil { - s.Logger.Error(fmt.Sprintf("Error building dial options : %s\n", err.Error())) + s.Logger.Error(fmt.Sprintf("error building dial options : %s\n", err.Error())) return err } server := grpc.NewServer(options...) - defer server.Stop() - listener := NewListener() + store := NewDataStore() + syncv1grpc.RegisterFlagSyncServiceServer(server, &StreamHandler{ + Logger: s.Logger, + DS: store, + }) group, lcCtxt := errgroup.WithContext(ctx) @@ -47,30 +55,31 @@ func (s *Server) Listen(ctx context.Context, sync <-chan sync.DataSync) error { for { select { case data := <-sync: - fmt.Printf("New data :%s", data) - listener.persist(data.FlagData) - case <-ctx.Done(): + store.cache(dataType(data.FlagData)) + case <-lcCtxt.Done(): + s.Logger.Debug("exiting server with context done") + server.Stop() return nil } } }) - syncv1grpc.RegisterFlagSyncServiceServer(server, &internal{ - Logger: s.Logger, - Ls: &listener, - }) - group.Go(func() error { + listen, err := net.Listen(serverListenType, s.Address) + if err != nil { + s.Logger.Error(fmt.Sprintf("error when listening to address : %s\n", err.Error())) + return err + } + err = server.Serve(listen) if err != nil { - s.Logger.Error(fmt.Sprintf("Error when starting the server : %s\n", err.Error())) + s.Logger.Error(fmt.Sprintf("error when starting the server : %s\n", err.Error())) return err } return nil }) - <-lcCtxt.Done() err = group.Wait() if err != nil { return err @@ -96,61 +105,50 @@ func (s *Server) buildOptions() ([]grpc.ServerOption, error) { return options, nil } -type internal struct { +type StreamHandler struct { Logger *logger.Logger - Ls *Listener + DS *DataStore } -func (i *internal) SyncFlags(req *v1.SyncFlagsRequest, stream syncv1grpc.FlagSyncService_SyncFlagsServer) error { - i.Logger.Info(fmt.Sprintf("Request with ID: %s", req.ProviderId)) +func (sh *StreamHandler) SyncFlags(req *v1.SyncFlagsRequest, stream syncv1grpc.FlagSyncService_SyncFlagsServer) error { + sh.Logger.Debug(fmt.Sprintf("stream registering for provider identifier: %s", req.ProviderId)) + + subID := StorageID() + syncChan := make(chan dataType) + + sh.DS.subscribe(subID, syncChan) + defer sh.DS.unsubscribe(subID) // Initially send the current state err := stream.Send(&v1.SyncFlagsResponse{ - FlagConfiguration: i.Ls.currentState(), + FlagConfiguration: sh.DS.currentState().string(), State: v1.SyncState_SYNC_STATE_ALL, }) if err != nil { + sh.Logger.Warn(fmt.Sprintf("error writing to stream: %s", err.Error())) return err } - emit := i.Ls.getEmit() - // Then wait for updates for { select { - case _ = <-emit: - stream.Send(&v1.SyncFlagsResponse{ - FlagConfiguration: i.Ls.currentState(), + case data := <-syncChan: + err := stream.Send(&v1.SyncFlagsResponse{ + FlagConfiguration: data.string(), State: v1.SyncState_SYNC_STATE_ALL, }) + if err != nil { + sh.Logger.Warn(fmt.Sprintf("exiting stream listener, stream send failed: %s", err.Error())) + return err + } + case <-time.After(pingDelay): + err := stream.Send(&v1.SyncFlagsResponse{ + State: v1.SyncState_SYNC_STATE_PING, + }) + if err != nil { + sh.Logger.Warn(fmt.Sprintf("exiting stream listener, server ping failed: %s", err.Error())) + return err + } } } - -} - -// todo we need a sync mechanism better than listener - -type Listener struct { - emit chan string - data string -} - -func NewListener() Listener { - return Listener{ - emit: make(chan string), - data: "", - } -} - -func (s *Listener) persist(input string) { - s.data = input - s.emit <- s.data -} - -func (s *Listener) getEmit() <-chan string { - return s.emit -} - -func (s *Listener) currentState() string { - return s.data } diff --git a/pkg/server/store.go b/pkg/server/store.go new file mode 100644 index 000000000..99f23a68f --- /dev/null +++ b/pkg/server/store.go @@ -0,0 +1,66 @@ +package server + +import ( + "sync" + + "github.com/google/uuid" +) + +// dataType is an abstraction of the internal type of the storage +type dataType string + +func (t dataType) string() string { + return string(t) +} + +// DataStore is and intermediate storage layer between sync providers and server stream handler +type DataStore struct { + data dataType + subscriptions map[string]chan dataType + + mu sync.RWMutex +} + +func NewDataStore() *DataStore { + return &DataStore{ + data: "", + subscriptions: make(map[string]chan dataType), + } +} + +func (store *DataStore) subscribe(id string, c chan dataType) { + store.mu.Lock() + defer store.mu.Unlock() + + store.subscriptions[id] = c +} + +func (store *DataStore) unsubscribe(id string) { + store.mu.Lock() + defer store.mu.Unlock() + + delete(store.subscriptions, id) +} + +func (store *DataStore) cache(data dataType) { + store.mu.Lock() + defer store.mu.Unlock() + + store.data = data + + for _, sub := range store.subscriptions { + sub <- data + } +} + +func (store *DataStore) currentState() dataType { + store.mu.RLock() + defer store.mu.RUnlock() + + return store.data +} + +// StorageID is an abstraction to generate unique storage subscription identifiers +func StorageID() string { + return uuid.New().String() +}