diff --git a/core/pkg/service/flag-evaluation/connect_service.go b/core/pkg/service/flag-evaluation/connect_service.go index 3bdd642bb..461623010 100644 --- a/core/pkg/service/flag-evaluation/connect_service.go +++ b/core/pkg/service/flag-evaluation/connect_service.go @@ -10,6 +10,8 @@ import ( "sync" "time" + "golang.org/x/sync/errgroup" + "github.com/open-feature/flagd/core/pkg/service/middleware" schemaConnectV1 "buf.build/gen/go/open-feature/flagd/bufbuild/connect-go/schema/v1/schemav1connect" @@ -31,7 +33,11 @@ type ConnectService struct { Eval eval.IEvaluator Metrics *otel.MetricsRecorder eventingConfiguration *eventingConfiguration - server http.Server + server *http.Server + metricsServer *http.Server + + serverMtx sync.RWMutex + metricsServerMtx sync.RWMutex } func (s *ConnectService) Serve(ctx context.Context, eval eval.IEvaluator, svcConf service.Configuration) error { @@ -40,40 +46,38 @@ func (s *ConnectService) Serve(ctx context.Context, eval eval.IEvaluator, svcCon subs: make(map[interface{}]chan service.Notification), mu: &sync.RWMutex{}, } - lis, err := s.setupServer(svcConf) - if err != nil { - return err - } - errChan := make(chan error, 1) - go func() { - s.Logger.Info(fmt.Sprintf("Flag Evaluation listening at %s", lis.Addr())) - if svcConf.CertPath != "" && svcConf.KeyPath != "" { - if err := s.server.ServeTLS( - lis, - svcConf.CertPath, - svcConf.KeyPath, - ); err != nil && !errors.Is(err, http.ErrServerClosed) { - errChan <- err + g, gCtx := errgroup.WithContext(ctx) + + g.Go(func() error { + return s.startServer(svcConf) + }) + g.Go(func() error { + return s.startMetricsServer(svcConf) + }) + g.Go(func() error { + <-gCtx.Done() + s.serverMtx.RLock() + defer s.serverMtx.RUnlock() + if s.server != nil { + if err := s.server.Shutdown(gCtx); err != nil { + return err } - } else { - if err := s.server.Serve( - lis, - ); err != nil && !errors.Is(err, http.ErrServerClosed) { - errChan <- err + } + return nil + }) + g.Go(func() error { + <-gCtx.Done() + s.metricsServerMtx.RLock() + defer s.metricsServerMtx.RUnlock() + if s.metricsServer != nil { + if err := s.metricsServer.Shutdown(gCtx); err != nil { + return err } } - close(errChan) - }() - - go s.startMetricsServer(svcConf) - - select { - case err := <-errChan: - return err - case <-ctx.Done(): - return s.server.Shutdown(ctx) - } + return nil + }) + return g.Wait() } func (s *ConnectService) setupServer(svcConf service.Configuration) (net.Listener, error) { @@ -97,10 +101,12 @@ func (s *ConnectService) setupServer(svcConf service.Configuration) (net.Listene path, handler := schemaConnectV1.NewServiceHandler(fes) mux.Handle(path, handler) - s.server = http.Server{ + s.serverMtx.Lock() + s.server = &http.Server{ ReadHeaderTimeout: time.Second, Handler: handler, } + s.serverMtx.Unlock() // Add middlewares @@ -136,13 +142,39 @@ func (s *ConnectService) Notify(n service.Notification) { } } -func (s *ConnectService) startMetricsServer(svcConf service.Configuration) { +func (s *ConnectService) startServer(svcConf service.Configuration) error { + lis, err := s.setupServer(svcConf) + if err != nil { + return err + } + s.Logger.Info(fmt.Sprintf("Flag Evaluation listening at %s", lis.Addr())) + if svcConf.CertPath != "" && svcConf.KeyPath != "" { + if err := s.server.ServeTLS( + lis, + svcConf.CertPath, + svcConf.KeyPath, + ); err != nil && !errors.Is(err, http.ErrServerClosed) { + return err + } + } else { + if err := s.server.Serve( + lis, + ); err != nil && !errors.Is(err, http.ErrServerClosed) { + return err + } + } + return nil +} + +func (s *ConnectService) startMetricsServer(svcConf service.Configuration) error { s.Logger.Info(fmt.Sprintf("metrics and probes listening at %d", svcConf.MetricsPort)) - server := &http.Server{ + s.metricsServerMtx.Lock() + s.metricsServer = &http.Server{ Addr: fmt.Sprintf(":%d", svcConf.MetricsPort), ReadHeaderTimeout: 3 * time.Second, } - server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.metricsServerMtx.Unlock() + s.metricsServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/healthz": w.WriteHeader(http.StatusOK) @@ -158,8 +190,8 @@ func (s *ConnectService) startMetricsServer(svcConf service.Configuration) { w.WriteHeader(http.StatusNotFound) } }) - err := server.ListenAndServe() - if err != nil { - panic(err) + if err := s.metricsServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return err } + return nil }