diff --git a/cmd/start.go b/cmd/provider.go similarity index 62% rename from cmd/start.go rename to cmd/provider.go index d614a70e8..c6ecab0de 100644 --- a/cmd/start.go +++ b/cmd/provider.go @@ -7,6 +7,7 @@ import ( "github.com/open-feature/flagd/pkg/logger" "github.com/open-feature/flagd/pkg/runtime" "github.com/spf13/cobra" + "github.com/spf13/pflag" "github.com/spf13/viper" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -27,9 +28,21 @@ const ( uriFlagName = "uri" ) -func init() { - flags := startCmd.Flags() +// NewProviderCmd is the command to start flagd as a provider +func NewProviderCmd() *cobra.Command { + flagd := &cobra.Command{ + Use: "start", + Short: "Start flagd", + Long: ``, + Run: runProvider, + } + setupProvider(flagd.Flags()) + return flagd +} + +// setupProvider setup flags of the command +func setupProvider(flags *pflag.FlagSet) { // allows environment variables to use _ instead of - viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) // sync-provider-args becomes SYNC_PROVIDER_ARGS viper.SetEnvPrefix("FLAGD") // port becomes FLAGD_PORT @@ -72,54 +85,49 @@ func init() { _ = viper.BindPFlag(uriFlagName, flags.Lookup(uriFlagName)) } -// startCmd represents the start command -var startCmd = &cobra.Command{ - Use: "start", - Short: "Start flagd", - Long: ``, - Run: func(cmd *cobra.Command, args []string) { - // Configure loggers ------------------------------------------------------- - var level zapcore.Level - var err error - if Debug { - level = zapcore.DebugLevel - } else { - level = zapcore.InfoLevel - } - l, err := logger.NewZapLogger(level, viper.GetString(logFormatFlagName)) - if err != nil { - log.Fatalf("can't initialize zap logger: %v", err) - } - logger := logger.NewLogger(l, Debug) - rtLogger := logger.WithFields(zap.String("component", "start")) +// runProvider starts the provider implementation +func runProvider(cmd *cobra.Command, args []string) { + // Configure loggers ------------------------------------------------------- + var level zapcore.Level + var err error + if Debug { + level = zapcore.DebugLevel + } else { + level = zapcore.InfoLevel + } + l, err := logger.NewZapLogger(level, viper.GetString(logFormatFlagName)) + if err != nil { + log.Fatalf("can't initialize zap logger: %v", err) + } + logger := logger.NewLogger(l, Debug) + rtLogger := logger.WithFields(zap.String("component", "start")) - if viper.GetString(syncProviderFlagName) != "" { - rtLogger.Warn("DEPRECATED: The --sync-provider flag has been deprecated. " + - "Docs: https://github.com/open-feature/flagd/blob/main/docs/configuration/configuration.md") - } + if viper.GetString(syncProviderFlagName) != "" { + rtLogger.Warn("DEPRECATED: The --sync-provider flag has been deprecated. " + + "Docs: https://github.com/open-feature/flagd/blob/main/docs/configuration/configuration.md") + } - if viper.GetString(evaluatorFlagName) != "json" { - rtLogger.Warn("DEPRECATED: The --evaluator flag has been deprecated. " + - "Docs: https://github.com/open-feature/flagd/blob/main/docs/configuration/configuration.md") - } - // Build Runtime ----------------------------------------------------------- - rt, err := runtime.FromConfig(logger, runtime.Config{ - CORS: viper.GetStringSlice(corsFlagName), - MetricsPort: viper.GetInt32(metricsPortFlagName), - ProviderArgs: viper.GetStringMapString(providerArgsFlagName), - ServiceCertPath: viper.GetString(serverCertPathFlagName), - ServiceKeyPath: viper.GetString(serverKeyPathFlagName), - ServicePort: viper.GetInt32(portFlagName), - ServiceSocketPath: viper.GetString(socketPathFlagName), - SyncBearerToken: viper.GetString(bearerTokenFlagName), - SyncURI: viper.GetStringSlice(uriFlagName), - }) - if err != nil { - rtLogger.Fatal(err.Error()) - } + if viper.GetString(evaluatorFlagName) != "json" { + rtLogger.Warn("DEPRECATED: The --evaluator flag has been deprecated. " + + "Docs: https://github.com/open-feature/flagd/blob/main/docs/configuration/configuration.md") + } + // Build Runtime ----------------------------------------------------------- + rt, err := runtime.FromConfig(logger, runtime.Config{ + CORS: viper.GetStringSlice(corsFlagName), + MetricsPort: viper.GetInt32(metricsPortFlagName), + ProviderArgs: viper.GetStringMapString(providerArgsFlagName), + ServiceCertPath: viper.GetString(serverCertPathFlagName), + ServiceKeyPath: viper.GetString(serverKeyPathFlagName), + ServicePort: viper.GetInt32(portFlagName), + ServiceSocketPath: viper.GetString(socketPathFlagName), + SyncBearerToken: viper.GetString(bearerTokenFlagName), + SyncURI: viper.GetStringSlice(uriFlagName), + }) + if err != nil { + rtLogger.Fatal(err.Error()) + } - if err := rt.Start(); err != nil { - rtLogger.Fatal(err.Error()) - } - }, + if err := rt.Start(); err != nil { + rtLogger.Fatal(err.Error()) + } } diff --git a/cmd/root.go b/cmd/root.go index c70638f9b..dd15746d8 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -48,7 +48,8 @@ func init() { // will be global for your application. rootCmd.PersistentFlags().BoolVarP(&Debug, "debug", "x", false, "verbose logging") rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.agent.yaml)") - rootCmd.AddCommand(startCmd) + rootCmd.AddCommand(NewProviderCmd()) + rootCmd.AddCommand(NewServerCmd()) rootCmd.AddCommand(versionCmd) } diff --git a/cmd/server.go b/cmd/server.go new file mode 100644 index 000000000..35efac5b2 --- /dev/null +++ b/cmd/server.go @@ -0,0 +1,86 @@ +package cmd + +import ( + "fmt" + "log" + + "github.com/open-feature/flagd/pkg/logger" + "github.com/open-feature/flagd/pkg/runtime" + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "github.com/spf13/viper" + "go.uber.org/zap/zapcore" +) + +const ( + address = "address" + secure = "secure" + certPath = "cert-path" + keyPath = "key-path" + source = "source" +) + +// NewServerCmd is the command to start flagd in server mode +func NewServerCmd() *cobra.Command { + flagdCmd := &cobra.Command{ + Use: "server", + Short: "Start flagd as a server", + Run: runServer, + } + + setupServer(flagdCmd) + return flagdCmd +} + +// setupServer setup flags of the command +func setupServer(cmd *cobra.Command) { + flags := cmd.Flags() + + flags.StringP(address, "p", "localhost:9090", "Path this server binds to") + + flags.BoolP(secure, "s", false, "Start secure server") + flags.StringP(certPath, "c", "", "TLS certificate path") + flags.StringP(keyPath, "k", "", "TLS key path of the certificate") + cmd.MarkFlagsRequiredTogether(secure, certPath, keyPath) + + flags.StringP(source, "f", "", "CRD with feature flag configurations") + + _ = viper.BindPFlag(address, flags.Lookup(address)) + _ = viper.BindPFlag(secure, flags.Lookup(secure)) + _ = viper.BindPFlag(certPath, flags.Lookup(certPath)) + _ = viper.BindPFlag(keyPath, flags.Lookup(keyPath)) + _ = viper.BindPFlag(source, flags.Lookup(source)) +} + +func runServer(cmd *cobra.Command, args []string) { + // todo align log format with provider runtime + zapLogger, err := logger.NewZapLogger(zapcore.DebugLevel, "console") + if err != nil { + log.Fatalf("error setting up the logger: %s", err) + } + + logWrapper := logger.NewLogger(zapLogger, true) + + err = viper.BindPFlags(pflag.CommandLine) + if err != nil { + logWrapper.Fatal(fmt.Sprintf("error parsing flags: %s", err.Error())) + } + + serverConfig := runtime.ServerConfig{ + Address: viper.GetString(address), + Secure: viper.GetBool(secure), + CertPath: viper.GetString(certPath), + KeyPath: viper.GetString(keyPath), + SyncSources: viper.GetString(source), + } + + serverRuntime, err := runtime.NewServerRuntime(serverConfig, logWrapper) + if err != nil { + logWrapper.Fatal(fmt.Sprintf("error creating the server runtime: %s", err.Error())) + } + + err = serverRuntime.Start() + if err != nil { + logWrapper.Fatal(fmt.Sprintf("error from server runtime: %s", err.Error())) + } +} diff --git a/docs/configuration/flagd.md b/docs/configuration/flagd.md index 90cd032fd..0e49ac803 100644 --- a/docs/configuration/flagd.md +++ b/docs/configuration/flagd.md @@ -12,6 +12,7 @@ Flagd is a simple command line tool for fetching and presenting feature flags to ### SEE ALSO +* [flagd server](flagd_server.md) - Start flagd as a server * [flagd start](flagd_start.md) - Start flagd * [flagd version](flagd_version.md) - Print the version number of FlagD diff --git a/docs/configuration/flagd_server.md b/docs/configuration/flagd_server.md new file mode 100644 index 000000000..712d41262 --- /dev/null +++ b/docs/configuration/flagd_server.md @@ -0,0 +1,30 @@ +## flagd server + +Start flagd as a server + +``` +flagd server [flags] +``` + +### Options + +``` + -p, --address string Path this server binds to (default "localhost:9090") + -c, --cert-path string TLS certificate path + -h, --help help for server + -k, --key-path string TLS key path of the certificate + -s, --secure Start secure server + -f, --source string CRD with feature flag configurations +``` + +### Options inherited from parent commands + +``` + --config string config file (default is $HOME/.agent.yaml) + -x, --debug verbose logging +``` + +### SEE ALSO + +* [flagd](flagd.md) - Flagd is a simple command line tool for fetching and presenting feature flags to services. It is designed to conform to Open Feature schema for flag definitions. + diff --git a/pkg/runtime/runtime.go b/pkg/runtime/providerRuntime.go similarity index 100% rename from pkg/runtime/runtime.go rename to pkg/runtime/providerRuntime.go diff --git a/pkg/runtime/serverRuntime.go b/pkg/runtime/serverRuntime.go new file mode 100644 index 000000000..6e546a373 --- /dev/null +++ b/pkg/runtime/serverRuntime.go @@ -0,0 +1,99 @@ +package runtime + +import ( + "context" + "errors" + "fmt" + "os" + "os/signal" + "syscall" + + "github.com/open-feature/flagd/pkg/server" + "github.com/open-feature/flagd/pkg/sync/kubernetes" + "go.uber.org/zap" + + "golang.org/x/sync/errgroup" + + "github.com/open-feature/flagd/pkg/logger" + "github.com/open-feature/flagd/pkg/sync" +) + +type ServerConfig struct { + Address string + Secure bool + CertPath string + KeyPath string + SyncSources string +} + +type ServerRuntime struct { + syncProvider sync.ISync + logger *logger.Logger + config ServerConfig +} + +func NewServerRuntime(config ServerConfig, rootLogger *logger.Logger) (*ServerRuntime, error) { + syncImpl, err := buildSyncImpl(config.SyncSources, rootLogger) + if err != nil { + return nil, err + } + + return &ServerRuntime{ + syncProvider: syncImpl, + logger: rootLogger.WithFields(zap.String("component", "Server Runtime")), + config: config, + }, nil +} + +func (sr *ServerRuntime) Start() error { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancel() + + // Build server + s := server.Server{ + Logger: sr.logger.WithFields(zap.String("component", "Server")), + Secure: sr.config.Secure, + CertPath: sr.config.CertPath, + KeyPath: sr.config.KeyPath, + Address: sr.config.Address, + } + + g, gCtx := errgroup.WithContext(ctx) + dataSync := make(chan sync.DataSync) + + // Start server + g.Go(func() error { + return s.Listen(gCtx, dataSync) + }) + + // Start sync provider + g.Go(func() error { + return sr.syncProvider.Sync(gCtx, dataSync) + }) + + if err := g.Wait(); err != nil { + return err + } + + return nil +} + +func buildSyncImpl(source string, rootLogger *logger.Logger) (sync.ISync, error) { + if len(source) == 0 { + return nil, errors.New("no sync provider sources provided") + } + + switch sourceBytes := []byte(source); { + case regCrd.Match(sourceBytes): + rootLogger.Debug(fmt.Sprintf("using kubernetes sync-provider for: %s", source)) + return &kubernetes.Sync{ + Logger: rootLogger.WithFields( + zap.String("component", "sync"), + zap.String("sync", "kubernetes"), + ), + URI: regCrd.ReplaceAllString(source, ""), + }, nil + default: + return nil, fmt.Errorf("server supports only crd sync provider, but received : %s", source) + } +} diff --git a/pkg/server/server.go b/pkg/server/server.go new file mode 100644 index 000000000..6c2df35b0 --- /dev/null +++ b/pkg/server/server.go @@ -0,0 +1,154 @@ +package server + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "time" + + "buf.build/gen/go/open-feature/flagd/grpc/go/sync/v1/syncv1grpc" + v1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/sync/v1" + + "github.com/open-feature/flagd/pkg/logger" + "github.com/open-feature/flagd/pkg/sync" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +const ( + // type of the listener + serverListenType = "tcp" + + // Time between server to client pings + pingDelay time.Duration = 20 * time.Second +) + +type Server struct { + Logger *logger.Logger + + Secure bool + CertPath string + KeyPath string + Address string +} + +func (s *Server) Listen(ctx context.Context, sync <-chan sync.DataSync) error { + options, err := s.buildOptions() + if err != nil { + s.Logger.Error(fmt.Sprintf("error building dial options : %s\n", err.Error())) + return err + } + + server := grpc.NewServer(options...) + + store := NewDataStore() + syncv1grpc.RegisterFlagSyncServiceServer(server, &StreamHandler{ + Logger: s.Logger, + DS: store, + }) + + group, lcCtxt := errgroup.WithContext(ctx) + + group.Go(func() error { + for { + select { + case data := <-sync: + store.cache(dataType(data.FlagData)) + case <-lcCtxt.Done(): + s.Logger.Debug("exiting server with context done") + server.Stop() + return nil + } + } + }) + + group.Go(func() error { + listen, err := net.Listen(serverListenType, s.Address) + if err != nil { + s.Logger.Error(fmt.Sprintf("error when listening to address : %s\n", err.Error())) + return err + } + + err = server.Serve(listen) + if err != nil { + s.Logger.Error(fmt.Sprintf("error when starting the server : %s\n", err.Error())) + return err + } + + return nil + }) + + err = group.Wait() + if err != nil { + return err + } + + return nil +} + +func (s *Server) buildOptions() ([]grpc.ServerOption, error) { + var options []grpc.ServerOption + + if !s.Secure { + return options, nil + } + + keyPair, err := tls.LoadX509KeyPair(s.CertPath, s.KeyPath) + if err != nil { + return nil, err + } + + options = append(options, grpc.Creds(credentials.NewServerTLSFromCert(&keyPair))) + + return options, nil +} + +type StreamHandler struct { + Logger *logger.Logger + DS *DataStore +} + +func (sh *StreamHandler) SyncFlags(req *v1.SyncFlagsRequest, stream syncv1grpc.FlagSyncService_SyncFlagsServer) error { + sh.Logger.Debug(fmt.Sprintf("stream registering for provider identifier: %s", req.ProviderId)) + + subID := StorageID() + syncChan := make(chan dataType) + + sh.DS.subscribe(subID, syncChan) + defer sh.DS.unsubscribe(subID) + + // Initially send the current state + err := stream.Send(&v1.SyncFlagsResponse{ + FlagConfiguration: sh.DS.currentState().string(), + State: v1.SyncState_SYNC_STATE_ALL, + }) + if err != nil { + sh.Logger.Warn(fmt.Sprintf("error writing to stream: %s", err.Error())) + return err + } + + // Then wait for updates + for { + select { + case data := <-syncChan: + err := stream.Send(&v1.SyncFlagsResponse{ + FlagConfiguration: data.string(), + State: v1.SyncState_SYNC_STATE_ALL, + }) + if err != nil { + sh.Logger.Warn(fmt.Sprintf("exiting stream listener, stream send failed: %s", err.Error())) + return err + } + case <-time.After(pingDelay): + err := stream.Send(&v1.SyncFlagsResponse{ + State: v1.SyncState_SYNC_STATE_PING, + }) + if err != nil { + sh.Logger.Warn(fmt.Sprintf("exiting stream listener, server ping failed: %s", err.Error())) + return err + } + } + } +} diff --git a/pkg/server/store.go b/pkg/server/store.go new file mode 100644 index 000000000..99f23a68f --- /dev/null +++ b/pkg/server/store.go @@ -0,0 +1,66 @@ +package server + +import ( + "sync" + + "github.com/google/uuid" +) + +// dataType is an abstraction of the internal type of the storage +type dataType string + +func (t dataType) string() string { + return string(t) +} + +// DataStore is and intermediate storage layer between sync providers and server stream handler +type DataStore struct { + data dataType + subscriptions map[string]chan dataType + + mu sync.RWMutex +} + +func NewDataStore() *DataStore { + return &DataStore{ + data: "", + subscriptions: make(map[string]chan dataType), + } +} + +func (store *DataStore) subscribe(id string, c chan dataType) { + store.mu.Lock() + defer store.mu.Unlock() + + store.subscriptions[id] = c +} + +func (store *DataStore) unsubscribe(id string) { + store.mu.Lock() + defer store.mu.Unlock() + + delete(store.subscriptions, id) +} + +func (store *DataStore) cache(data dataType) { + store.mu.Lock() + defer store.mu.Unlock() + + store.data = data + + for _, sub := range store.subscriptions { + sub <- data + } +} + +func (store *DataStore) currentState() dataType { + store.mu.RLock() + defer store.mu.RUnlock() + + return store.data +} + +// StorageID is an abstraction to generate unique storage subscription identifiers +func StorageID() string { + return uuid.New().String() +}