diff --git a/Makefile b/Makefile index b17f4bba4..94e8e1274 100644 --- a/Makefile +++ b/Makefile @@ -60,6 +60,7 @@ mockgen: install-mockgen cd core; mockgen -source=pkg/sync/grpc/grpc_sync.go -destination=pkg/sync/grpc/mock/grpc.go -package=grpcmock cd core; mockgen -source=pkg/sync/grpc/credentials/builder.go -destination=pkg/sync/grpc/credentials/mock/builder.go -package=credendialsmock cd core; mockgen -source=pkg/eval/ievaluator.go -destination=pkg/eval/mock/ievaluator.go -package=evalmock + cd core; mockgen -source=pkg/service/middleware/interface.go -destination=pkg/service/middleware/mock/interface.go -package=middlewaremock generate-docs: cd flagd; go run ./cmd/doc/main.go diff --git a/core/pkg/service/flag-evaluation/connect_service.go b/core/pkg/service/flag-evaluation/connect_service.go index 9e1d3bd53..3bdd642bb 100644 --- a/core/pkg/service/flag-evaluation/connect_service.go +++ b/core/pkg/service/flag-evaluation/connect_service.go @@ -10,17 +10,18 @@ import ( "sync" "time" + "github.com/open-feature/flagd/core/pkg/service/middleware" + schemaConnectV1 "buf.build/gen/go/open-feature/flagd/bufbuild/connect-go/schema/v1/schemav1connect" "github.com/open-feature/flagd/core/pkg/eval" "github.com/open-feature/flagd/core/pkg/logger" "github.com/open-feature/flagd/core/pkg/otel" "github.com/open-feature/flagd/core/pkg/service" - "github.com/open-feature/flagd/core/pkg/service/middleware" + corsmw "github.com/open-feature/flagd/core/pkg/service/middleware/cors" + h2cmw "github.com/open-feature/flagd/core/pkg/service/middleware/h2c" + metricsmw "github.com/open-feature/flagd/core/pkg/service/middleware/metrics" "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/rs/cors" "go.uber.org/zap" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" ) const ErrorPrefix = "FlagdError:" @@ -65,6 +66,8 @@ func (s *ConnectService) Serve(ctx context.Context, eval eval.IEvaluator, svcCon close(errChan) }() + go s.startMetricsServer(svcConf) + select { case err := <-errChan: return err @@ -94,30 +97,37 @@ func (s *ConnectService) setupServer(svcConf service.Configuration) (net.Listene path, handler := schemaConnectV1.NewServiceHandler(fes) mux.Handle(path, handler) - mdlw := middleware.NewHTTPMetric(middleware.Config{ - Service: "openfeature/flagd", + s.server = http.Server{ + ReadHeaderTimeout: time.Second, + Handler: handler, + } + + // Add middlewares + + metricsMiddleware := metricsmw.NewHTTPMetric(metricsmw.Config{ + Service: svcConf.ServiceName, MetricRecorder: s.Metrics, Logger: s.Logger, + HandlerID: "", }) - h := middleware.Handler("", mdlw, mux) - go bindMetrics(s, svcConf) + s.AddMiddleware(metricsMiddleware) - if svcConf.CertPath != "" && svcConf.KeyPath != "" { - handler = s.newCORS(svcConf).Handler(h) - } else { - handler = h2c.NewHandler( - s.newCORS(svcConf).Handler(h), - &http2.Server{}, - ) - } - s.server = http.Server{ - ReadHeaderTimeout: time.Second, - Handler: handler, + corsMiddleware := corsmw.New(svcConf.CORS) + s.AddMiddleware(corsMiddleware) + + if svcConf.CertPath == "" || svcConf.KeyPath == "" { + h2cMiddleware := h2cmw.New() + s.AddMiddleware(h2cMiddleware) } + return lis, nil } +func (s *ConnectService) AddMiddleware(mw middleware.IMiddleware) { + s.server.Handler = mw.Handler(s.server.Handler) +} + func (s *ConnectService) Notify(n service.Notification) { s.eventingConfiguration.mu.RLock() defer s.eventingConfiguration.mu.RUnlock() @@ -126,36 +136,7 @@ func (s *ConnectService) Notify(n service.Notification) { } } -func (s *ConnectService) newCORS(svcConf service.Configuration) *cors.Cors { - return cors.New(cors.Options{ - AllowedMethods: []string{ - http.MethodHead, - http.MethodGet, - http.MethodPost, - http.MethodPut, - http.MethodPatch, - http.MethodDelete, - }, - AllowedOrigins: svcConf.CORS, - AllowedHeaders: []string{"*"}, - ExposedHeaders: []string{ - // Content-Type is in the default safelist. - "Accept", - "Accept-Encoding", - "Accept-Post", - "Connect-Accept-Encoding", - "Connect-Content-Encoding", - "Content-Encoding", - "Grpc-Accept-Encoding", - "Grpc-Encoding", - "Grpc-Message", - "Grpc-Status", - "Grpc-Status-Details-Bin", - }, - }) -} - -func bindMetrics(s *ConnectService, svcConf service.Configuration) { +func (s *ConnectService) startMetricsServer(svcConf service.Configuration) { s.Logger.Info(fmt.Sprintf("metrics and probes listening at %d", svcConf.MetricsPort)) server := &http.Server{ Addr: fmt.Sprintf(":%d", svcConf.MetricsPort), diff --git a/core/pkg/service/flag-evaluation/connect_service_test.go b/core/pkg/service/flag-evaluation/connect_service_test.go index 84668c2b7..098673e08 100644 --- a/core/pkg/service/flag-evaluation/connect_service_test.go +++ b/core/pkg/service/flag-evaluation/connect_service_test.go @@ -4,10 +4,13 @@ import ( "context" "errors" "fmt" + "net/http" "os" "testing" "time" + middlewaremock "github.com/open-feature/flagd/core/pkg/service/middleware/mock" + schemaGrpcV1 "buf.build/gen/go/open-feature/flagd/grpc/go/schema/v1/schemav1grpc" schemaV1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/schema/v1" "github.com/golang/mock/gomock" @@ -117,3 +120,53 @@ func TestConnectService_UnixConnection(t *testing.T) { }) } } + +func TestAddMiddleware(t *testing.T) { + const port = 12345 + ctrl := gomock.NewController(t) + + mwMock := middlewaremock.NewMockIMiddleware(ctrl) + + mwMock.EXPECT().Handler(gomock.Any()).Return( + http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + writer.WriteHeader(http.StatusOK) + })) + + exp := metric.NewManualReader() + metricRecorder := otel.NewOTelRecorder(exp, "my-exporter") + + svc := ConnectService{ + Logger: logger.NewLogger(nil, false), + Metrics: metricRecorder, + } + + serveConf := iservice.Configuration{ + ReadinessProbe: func() bool { + return true + }, + Port: port, + } + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + err := svc.Serve(ctx, nil, serveConf) + fmt.Println(err) + }() + + require.Eventually(t, func() bool { + resp, err := http.Get(fmt.Sprintf("http://localhost:%d/schema.v1.Service/ResolveAll", port)) + // with the default http handler we should get a method not allowed (405) when attempting a GET request + return err == nil && resp.StatusCode == http.StatusMethodNotAllowed + }, 3*time.Second, 100*time.Millisecond) + + svc.AddMiddleware(mwMock) + + // with the injected middleware, the GET method should work + resp, err := http.Get(fmt.Sprintf("http://localhost:%d/schema.v1.Service/ResolveAll", port)) + + require.Nil(t, err) + // verify that the status we return in the mocked middleware + require.Equal(t, http.StatusOK, resp.StatusCode) +} diff --git a/core/pkg/service/middleware/cors/cors.go b/core/pkg/service/middleware/cors/cors.go new file mode 100644 index 000000000..303be268d --- /dev/null +++ b/core/pkg/service/middleware/cors/cors.go @@ -0,0 +1,46 @@ +package cors + +import ( + "net/http" + + "github.com/rs/cors" +) + +type Middleware struct { + cors *cors.Cors +} + +func New(allowedOrigins []string) *Middleware { + return &Middleware{ + cors: cors.New(cors.Options{ + AllowedMethods: []string{ + http.MethodHead, + http.MethodGet, + http.MethodPost, + http.MethodPut, + http.MethodPatch, + http.MethodDelete, + }, + AllowedOrigins: allowedOrigins, + AllowedHeaders: []string{"*"}, + ExposedHeaders: []string{ + // Content-Type is in the default safelist. + "Accept", + "Accept-Encoding", + "Accept-Post", + "Connect-Accept-Encoding", + "Connect-Content-Encoding", + "Content-Encoding", + "Grpc-Accept-Encoding", + "Grpc-Encoding", + "Grpc-Message", + "Grpc-Status", + "Grpc-Status-Details-Bin", + }, + }), + } +} + +func (c Middleware) Handler(handler http.Handler) http.Handler { + return c.cors.Handler(handler) +} diff --git a/core/pkg/service/middleware/cors/cors_test.go b/core/pkg/service/middleware/cors/cors_test.go new file mode 100644 index 000000000..f5f72cb52 --- /dev/null +++ b/core/pkg/service/middleware/cors/cors_test.go @@ -0,0 +1,44 @@ +package cors + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/golang/mock/gomock" + middlewaremock "github.com/open-feature/flagd/core/pkg/service/middleware/mock" + "github.com/stretchr/testify/require" +) + +func TestMiddleware(t *testing.T) { + ctrl := gomock.NewController(t) + mockMw := middlewaremock.NewMockIMiddleware(ctrl) + + handlerFunc := http.HandlerFunc( + func(writer http.ResponseWriter, request *http.Request) { + writer.WriteHeader(http.StatusOK) + }, + ) + + mockMw.EXPECT().Handler(gomock.Any()).Return(handlerFunc) + + ts := httptest.NewServer(handlerFunc) + + defer ts.Close() + + mw := New([]string{"*"}) + require.NotNil(t, mw) + + // wrap the cors middleware around the mock to make sure the wrapped handler is called by the cors middleware + ts.Config.Handler = mw.Handler(mockMw.Handler(handlerFunc)) + + req, err := http.NewRequest(http.MethodGet, ts.URL, nil) + + require.Nil(t, err) + + client := http.DefaultClient + resp, err := client.Do(req) + + require.Nil(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) +} diff --git a/core/pkg/service/middleware/h2c/h2c.go b/core/pkg/service/middleware/h2c/h2c.go new file mode 100644 index 000000000..73e573ad9 --- /dev/null +++ b/core/pkg/service/middleware/h2c/h2c.go @@ -0,0 +1,18 @@ +package h2c + +import ( + "net/http" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" +) + +type Middleware struct{} + +func New() *Middleware { + return &Middleware{} +} + +func (m Middleware) Handler(handler http.Handler) http.Handler { + return h2c.NewHandler(handler, &http2.Server{}) +} diff --git a/core/pkg/service/middleware/h2c/h2c_test.go b/core/pkg/service/middleware/h2c/h2c_test.go new file mode 100644 index 000000000..c292dfa76 --- /dev/null +++ b/core/pkg/service/middleware/h2c/h2c_test.go @@ -0,0 +1,39 @@ +package h2c + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/golang/mock/gomock" + middlewaremock "github.com/open-feature/flagd/core/pkg/service/middleware/mock" + "github.com/stretchr/testify/require" +) + +func TestMiddleware(t *testing.T) { + ctrl := gomock.NewController(t) + mockMw := middlewaremock.NewMockIMiddleware(ctrl) + + handlerFunc := http.HandlerFunc( + func(writer http.ResponseWriter, request *http.Request) { + writer.WriteHeader(http.StatusOK) + }, + ) + + mockMw.EXPECT().Handler(gomock.Any()).Return(handlerFunc) + + ts := httptest.NewServer(handlerFunc) + + defer ts.Close() + + mw := New() + require.NotNil(t, mw) + + // wrap the h2c middleware around the mock to make sure the wrapped handler is called by the h2c middleware + ts.Config.Handler = mw.Handler(mockMw.Handler(handlerFunc)) + + resp, err := http.Get(ts.URL) + + require.Nil(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) +} diff --git a/core/pkg/service/middleware/interface.go b/core/pkg/service/middleware/interface.go new file mode 100644 index 000000000..47aaf307a --- /dev/null +++ b/core/pkg/service/middleware/interface.go @@ -0,0 +1,9 @@ +package middleware + +import ( + "net/http" +) + +type IMiddleware interface { + Handler(handler http.Handler) http.Handler +} diff --git a/core/pkg/service/middleware/http_metrics.go b/core/pkg/service/middleware/metrics/http_metrics.go similarity index 95% rename from core/pkg/service/middleware/http_metrics.go rename to core/pkg/service/middleware/metrics/http_metrics.go index daa295c2f..cc21390b2 100644 --- a/core/pkg/service/middleware/http_metrics.go +++ b/core/pkg/service/middleware/metrics/http_metrics.go @@ -1,4 +1,4 @@ -package middleware +package metrics import ( "bufio" @@ -21,6 +21,7 @@ type Config struct { Service string GroupedStatus bool DisableMeasureSize bool + HandlerID string } type Middleware struct { @@ -90,7 +91,7 @@ func (m Middleware) Measure(ctx context.Context, handlerID string, reporter Repo } // Handler returns an measuring standard http.Handler. -func Handler(handlerID string, m Middleware, h http.Handler) http.Handler { +func (m Middleware) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { wi := &responseWriterInterceptor{ statusCode: http.StatusOK, @@ -100,7 +101,7 @@ func Handler(handlerID string, m Middleware, h http.Handler) http.Handler { w: wi, r: r, } - m.Measure(r.Context(), handlerID, reporter, func() { + m.Measure(r.Context(), m.cfg.HandlerID, reporter, func() { h.ServeHTTP(wi, r) }) }) diff --git a/core/pkg/service/middleware/http_metrics_test.go b/core/pkg/service/middleware/metrics/http_metrics_test.go similarity index 98% rename from core/pkg/service/middleware/http_metrics_test.go rename to core/pkg/service/middleware/metrics/http_metrics_test.go index 4d03cf737..053d788f2 100644 --- a/core/pkg/service/middleware/http_metrics_test.go +++ b/core/pkg/service/middleware/metrics/http_metrics_test.go @@ -1,4 +1,4 @@ -package middleware +package metrics import ( "context" @@ -22,11 +22,12 @@ func TestMiddlewareExposesMetrics(t *testing.T) { MetricRecorder: otel.NewOTelRecorder(exp, svcName), Service: svcName, Logger: logger.NewLogger(l, true), + HandlerID: "id", }) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("answer")) }) - svr := httptest.NewServer(Handler("id", m, handler)) + svr := httptest.NewServer(m.Handler(handler)) defer svr.Close() resp, err := http.Get(svr.URL) if err != nil { diff --git a/core/pkg/service/middleware/mock/interface.go b/core/pkg/service/middleware/mock/interface.go new file mode 100644 index 000000000..f4b098472 --- /dev/null +++ b/core/pkg/service/middleware/mock/interface.go @@ -0,0 +1,49 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: pkg/service/middleware/interface.go + +// Package middlewaremock is a generated GoMock package. +package middlewaremock + +import ( + http "net/http" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockIMiddleware is a mock of IMiddleware interface. +type MockIMiddleware struct { + ctrl *gomock.Controller + recorder *MockIMiddlewareMockRecorder +} + +// MockIMiddlewareMockRecorder is the mock recorder for MockIMiddleware. +type MockIMiddlewareMockRecorder struct { + mock *MockIMiddleware +} + +// NewMockIMiddleware creates a new mock instance. +func NewMockIMiddleware(ctrl *gomock.Controller) *MockIMiddleware { + mock := &MockIMiddleware{ctrl: ctrl} + mock.recorder = &MockIMiddlewareMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockIMiddleware) EXPECT() *MockIMiddlewareMockRecorder { + return m.recorder +} + +// Handler mocks base method. +func (m *MockIMiddleware) Handler(handler http.Handler) http.Handler { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Handler", handler) + ret0, _ := ret[0].(http.Handler) + return ret0 +} + +// Handler indicates an expected call of Handler. +func (mr *MockIMiddlewareMockRecorder) Handler(handler interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Handler", reflect.TypeOf((*MockIMiddleware)(nil).Handler), handler) +}