diff --git a/cmd/activator/handlers.go b/cmd/activator/handlers.go new file mode 100644 index 000000000000..2d5dba94604f --- /dev/null +++ b/cmd/activator/handlers.go @@ -0,0 +1,130 @@ +/* +Copyright 2018 The Knative Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httputil" + "net/url" + + "github.com/knative/serving/pkg/activator" + "github.com/knative/serving/pkg/controller" + "go.uber.org/zap" +) + +// activationHandler will proxy a request to the active endpoint for the specified revision, +// using the provided transport +type activationHandler struct { + activator activator.Activator + logger *zap.SugaredLogger + transport http.RoundTripper +} + +func newActivationHandler(a activator.Activator, rt http.RoundTripper, l *zap.SugaredLogger) http.Handler { + return &activationHandler{activator: a, transport: rt, logger: l} +} + +func (a *activationHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + namespace := r.Header.Get(controller.GetRevisionHeaderNamespace()) + name := r.Header.Get(controller.GetRevisionHeaderName()) + + endpoint, status, err := a.activator.ActiveEndpoint(namespace, name) + if err != nil { + msg := fmt.Sprintf("Error getting active endpoint: %v", err) + + a.logger.Errorf(msg) + http.Error(w, msg, int(status)) + + return + } + + target := &url.URL{ + Scheme: "http", + Host: fmt.Sprintf("%s:%d", endpoint.FQDN, endpoint.Port), + } + + proxy := httputil.NewSingleHostReverseProxy(target) + proxy.Transport = a.transport + + // TODO: Clear the host to avoid 404's. + // https://github.com/knative/serving/issues/964 + r.Host = "" + + proxy.ServeHTTP(w, r) +} + +// uploadHandler wraps the provided handler with a request body that supports +// re-reading and prevents uploads larger than `maxUploadBytes` +type uploadHandler struct { + http.Handler + MaxUploadBytes int64 +} + +func newUploadHandler(h http.Handler, maxUploadBytes int64) http.Handler { + return uploadHandler{ + Handler: h, + MaxUploadBytes: maxUploadBytes, + } +} + +func (h uploadHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.ContentLength > h.MaxUploadBytes { + w.WriteHeader(http.StatusRequestEntityTooLarge) + return + } + + // The request body cannot be read multiple times for retries. + // The workaround is to clone the request body into a byte reader + // so the body can be read multiple times. + r.Body = newRewinder(r.Body) + + h.Handler.ServeHTTP(w, r) +} + +// rewinder wraps a single-use `ReadCloser` into a `ReadCloser` that can be read multiple times +type rewinder struct { + rc io.ReadCloser + rs io.ReadSeeker +} + +func newRewinder(rc io.ReadCloser) io.ReadCloser { + return &rewinder{rc: rc} +} + +func (r *rewinder) Read(b []byte) (int, error) { + // On the first `Read()`, the contents of `rc` is read into a buffer `rs`. + // This buffer is used for all subsequent reads + if r.rs == nil { + buf, err := ioutil.ReadAll(r.rc) + if err != nil { + return 0, err + } + r.rc.Close() + + r.rs = bytes.NewReader(buf) + } + + return r.rs.Read(b) +} + +func (r *rewinder) Close() error { + // Rewind the buffer on `Close()` for the next call to `Read` + r.rs.Seek(0, io.SeekStart) + + return nil +} diff --git a/cmd/activator/handlers_test.go b/cmd/activator/handlers_test.go new file mode 100644 index 000000000000..2f59925d1395 --- /dev/null +++ b/cmd/activator/handlers_test.go @@ -0,0 +1,205 @@ +/* +Copyright 2018 The Knative Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package main + +import ( + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "testing" + + "github.com/knative/serving/pkg/activator" + "github.com/knative/serving/pkg/controller" + "go.uber.org/zap" +) + +type fakeActivator struct { + endpoint activator.Endpoint + namespace string + name string +} + +func newFakeActivator(namespace string, name string, server *httptest.Server) fakeActivator { + url, _ := url.Parse(server.URL) + host := url.Hostname() + port, _ := strconv.Atoi(url.Port()) + + return fakeActivator{ + endpoint: activator.Endpoint{FQDN: host, Port: int32(port)}, + namespace: namespace, + name: name, + } +} + +func (fa fakeActivator) ActiveEndpoint(namespace, name string) (activator.Endpoint, activator.Status, error) { + if namespace == fa.namespace && name == fa.name { + return fa.endpoint, http.StatusOK, nil + } + + return activator.Endpoint{}, http.StatusNotFound, errors.New("not found!") +} + +func (fa fakeActivator) Shutdown() { +} + +func TestActivationHandler(t *testing.T) { + logger := zap.NewExample().Sugar() + + errMsg := func(msg string) string { + return fmt.Sprintf("Error getting active endpoint: %v\n", msg) + } + + server := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "everything good!") + }), + ) + defer server.Close() + + act := newFakeActivator("real-namespace", "real-name", server) + + examples := []struct { + label string + namespace string + name string + wantBody string + wantCode int + wantErr error + }{ + {"active endpoint", "real-namespace", "real-name", "everything good!", http.StatusOK, nil}, + {"no active endpoint", "fake-namespace", "fake-name", errMsg("not found!"), http.StatusNotFound, nil}, + {"request error", "real-namespace", "real-name", "", http.StatusBadGateway, errors.New("request error!")}, + } + + for _, e := range examples { + t.Run(e.label, func(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + if r.Host != "" { + t.Errorf("Unexpected request host. Want %q, got %q", "", r.Host) + } + + if e.wantErr != nil { + return nil, e.wantErr + } + + return http.DefaultTransport.RoundTrip(r) + }) + + handler := newActivationHandler(act, rt, logger) + + resp := httptest.NewRecorder() + + req := httptest.NewRequest("POST", "http://example.com", nil) + req.Header.Set(controller.GetRevisionHeaderNamespace(), e.namespace) + req.Header.Set(controller.GetRevisionHeaderName(), e.name) + + handler.ServeHTTP(resp, req) + + if resp.Code != e.wantCode { + t.Errorf("Unexpected response status. Want %d, got %d", e.wantCode, resp.Code) + } + + gotBody, _ := ioutil.ReadAll(resp.Body) + if string(gotBody) != e.wantBody { + t.Errorf("Unexpected response body. Want %q, got %q", e.wantBody, gotBody) + } + }) + } +} + +func TestUploadHandler(t *testing.T) { + payload := "SAMPLE PAYLOAD" + + examples := []struct { + label string + maxUpload int + status int + }{ + {"under", len(payload) + 1, http.StatusOK}, + {"equal", len(payload), http.StatusOK}, + {"over", len(payload) - 1, http.StatusRequestEntityTooLarge}, + } + + for _, e := range examples { + t.Run(e.label, func(t *testing.T) { + baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b1, _ := ioutil.ReadAll(r.Body) + r.Body.Close() + + b2, _ := ioutil.ReadAll(r.Body) + r.Body.Close() + + if string(b1) != payload || string(b2) != payload { + t.Errorf("Expected request body to be rereadable. Want %q, got %q and %q.", payload, b1, b2) + } + }) + handler := newUploadHandler(baseHandler, int64(e.maxUpload)) + + resp := httptest.NewRecorder() + req := httptest.NewRequest("POST", "http://example.com", bytes.NewBufferString(payload)) + + handler.ServeHTTP(resp, req) + + if resp.Code != e.status { + t.Errorf("Unexpected response status for payload %q. Want %d, got %d", payload, e.status, resp.Code) + } + }) + } +} + +type readCloser struct { + io.Reader + closed bool +} + +func (rc *readCloser) Close() error { + rc.closed = true + + return nil +} + +func TestRewinder(t *testing.T) { + str := "test string" + rc := &readCloser{bytes.NewBufferString(str), false} + rewinder := newRewinder(rc) + + b1, err := ioutil.ReadAll(rewinder) + if err != nil { + t.Errorf("Unexpected error reading b1: %v", err) + } + rewinder.Close() + + b2, err := ioutil.ReadAll(rewinder) + if err != nil { + t.Errorf("Unexpected error reading b2: %v", err) + } + rewinder.Close() + + if string(b1) != str { + t.Errorf("Unexpected str b1. Want %q, got %q", str, b1) + } + + if string(b2) != str { + t.Errorf("Unexpected str b2. Want %q, got %q", str, b2) + } + + if !rc.closed { + t.Errorf("Expected ReadCloser to be closed") + } +} diff --git a/cmd/activator/main.go b/cmd/activator/main.go index 71f78842a92a..dbdb19c50337 100644 --- a/cmd/activator/main.go +++ b/cmd/activator/main.go @@ -14,22 +14,14 @@ limitations under the License. package main import ( - "bytes" "flag" - "fmt" - "io" - "io/ioutil" "log" "net/http" - "net/http/httputil" - "net/url" "time" "github.com/knative/serving/pkg/activator" clientset "github.com/knative/serving/pkg/client/clientset/versioned" "github.com/knative/serving/pkg/configmap" - "github.com/knative/serving/pkg/controller" - h2cutil "github.com/knative/serving/pkg/h2c" "github.com/knative/serving/pkg/logging" "github.com/knative/serving/pkg/signals" "github.com/knative/serving/third_party/h2c" @@ -39,111 +31,11 @@ import ( ) const ( - maxUploadBytes = 32e6 // 32MB - same as app engine - maxRetry = 60 - retryInterval = 1 * time.Second + defaultMaxUploadBytes = 32e6 // 32MB - same as app engine + defaultMaxRetries = 60 + defaultRetryInterval = 1 * time.Second ) -type activationHandler struct { - act activator.Activator - logger *zap.SugaredLogger -} - -// retryRoundTripper retries on 503's for up to 60 seconds. The reason is there is -// a small delay for k8s to include the ready IP in service. -// https://github.com/knative/serving/issues/660#issuecomment-384062553 -type retryRoundTripper struct { - logger *zap.SugaredLogger -} - -func (rrt retryRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { - var err error - var reqBody *bytes.Reader - - transport := http.DefaultTransport - - if r.ProtoMajor == 2 { - transport = h2cutil.NewTransport() - } - - if r.Body != nil { - reqBytes, err := ioutil.ReadAll(r.Body) - - if err != nil { - rrt.logger.Errorf("Error reading request body: %s", err) - return nil, err - } - - reqBody = bytes.NewReader(reqBytes) - r.Body = ioutil.NopCloser(reqBody) - } - - resp, err := transport.RoundTrip(r) - // TODO: Activator should retry with backoff. - // https://github.com/knative/serving/issues/1229 - i := 1 - for ; i < maxRetry; i++ { - if err == nil && resp != nil && resp.StatusCode != 503 { - break - } - - if err != nil { - rrt.logger.Errorf("Error making a request: %s", err) - } - - if resp != nil { - resp.Body.Close() - } - - time.Sleep(retryInterval) - - // The request body cannot be read multiple times for retries. - // The workaround is to clone the request body into a byte reader - // so the body can be read multiple times. - if r.Body != nil { - reqBody.Seek(0, io.SeekStart) - } - - resp, err = transport.RoundTrip(r) - } - // TODO: add metrics for number of tries and the response code. - if resp != nil { - rrt.logger.Infof("It took %d tries to get response code %d", i, resp.StatusCode) - } - return resp, err -} - -func (a *activationHandler) handler(w http.ResponseWriter, r *http.Request) { - if r.ContentLength > maxUploadBytes { - w.WriteHeader(http.StatusRequestEntityTooLarge) - return - } - - namespace := r.Header.Get(controller.GetRevisionHeaderNamespace()) - name := r.Header.Get(controller.GetRevisionHeaderName()) - endpoint, status, err := a.act.ActiveEndpoint(namespace, name) - if err != nil { - msg := fmt.Sprintf("Error getting active endpoint: %v", err) - a.logger.Errorf(msg) - http.Error(w, msg, int(status)) - return - } - target := &url.URL{ - Scheme: "http", - Host: fmt.Sprintf("%s:%d", endpoint.FQDN, endpoint.Port), - } - proxy := httputil.NewSingleHostReverseProxy(target) - proxy.Transport = retryRoundTripper{ - logger: a.logger, - } - - // TODO: Clear the host to avoid 404's. - // https://github.com/knative/serving/issues/964 - r.Host = "" - - proxy.ServeHTTP(w, r) -} - func main() { flag.Parse() cm, err := configmap.Load("/etc/config-logging") @@ -174,7 +66,16 @@ func main() { a := activator.NewRevisionActivator(kubeClient, servingClient, logger) a = activator.NewDedupingActivator(a) - ah := &activationHandler{a, logger} + + // Retry on 503's for up to 60 seconds. The reason is there is + // a small delay for k8s to include the ready IP in service. + // https://github.com/knative/serving/issues/660#issuecomment-384062553 + rt := baseTransport + rt = newStatusFilterRoundTripper(rt, http.StatusServiceUnavailable) + rt = newRetryRoundTripper(rt, logger, defaultMaxRetries, defaultRetryInterval) + + ah := newActivationHandler(a, rt, logger) + ah = newUploadHandler(ah, defaultMaxUploadBytes) // set up signals so we handle the first shutdown signal gracefully stopCh := signals.SetupSignalHandler() @@ -183,6 +84,6 @@ func main() { a.Shutdown() }() - http.HandleFunc("/", ah.handler) + http.Handle("/", ah) h2c.ListenAndServe(":8080", nil) } diff --git a/cmd/activator/round_trippers.go b/cmd/activator/round_trippers.go new file mode 100644 index 000000000000..9d7175d0ae74 --- /dev/null +++ b/cmd/activator/round_trippers.go @@ -0,0 +1,103 @@ +/* +Copyright 2018 The Knative Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "fmt" + "net/http" + "time" + + h2cutil "github.com/knative/serving/pkg/h2c" + "go.uber.org/zap" +) + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (rt roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return rt(r) +} + +// baseTransport will the appropriate transport for the request's http protocol version +var baseTransport http.RoundTripper = roundTripperFunc(func(r *http.Request) (*http.Response, error) { + var transport http.RoundTripper = http.DefaultTransport + if r.ProtoMajor == 2 { + transport = h2cutil.NewTransport() + } + + return transport.RoundTrip(r) +}) + +// statusFilterRoundTripper returns an error if the response contains one of the filtered statuses. +type statusFilterRoundTripper struct { + transport http.RoundTripper + statuses []int +} + +func newStatusFilterRoundTripper(rt http.RoundTripper, statuses ...int) http.RoundTripper { + return statusFilterRoundTripper{rt, statuses} +} + +func (rt statusFilterRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + resp, err := rt.transport.RoundTrip(r) + if err != nil { + return nil, err + } + + for _, status := range rt.statuses { + if resp.StatusCode == status { + resp.Body.Close() + + return nil, fmt.Errorf("Filtering %d", status) + } + } + + return resp, nil +} + +// retryRoundTripper retries a request on error up to `maxRetries` times, +// waiting `interval` milliseconds between retries +type retryRoundTripper struct { + logger *zap.SugaredLogger + maxRetries int + interval time.Duration + transport http.RoundTripper +} + +func newRetryRoundTripper(rt http.RoundTripper, l *zap.SugaredLogger, mr int, i time.Duration) http.RoundTripper { + return retryRoundTripper{logger: l, maxRetries: mr, interval: i, transport: rt} +} + +func (rrt retryRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + resp, err := rrt.transport.RoundTrip(r) + // TODO: Activator should retry with backoff. + // https://github.com/knative/serving/issues/1229 + i := 1 + for ; i < rrt.maxRetries; i++ { + if err == nil { + break + } + + rrt.logger.Errorf("Error making a request: %s", err) + + time.Sleep(rrt.interval) + + resp, err = rrt.transport.RoundTrip(r) + } + + // TODO: add metrics for number of tries and the response code. + if resp != nil { + rrt.logger.Infof("It took %d tries to get response code %d", i, resp.StatusCode) + } + return resp, err +} diff --git a/cmd/activator/round_trippers_test.go b/cmd/activator/round_trippers_test.go new file mode 100644 index 000000000000..ca01ae73f56a --- /dev/null +++ b/cmd/activator/round_trippers_test.go @@ -0,0 +1,152 @@ +/* +Copyright 2018 The Knative Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package main + +import ( + "errors" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + "time" + + "go.uber.org/zap" +) + +func TestRetryRoundTripper(t *testing.T) { + wantBody := "all good!" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(wantBody)) + })) + l := zap.NewExample().Sugar() + maxRetries := 3 + interval := 10 * time.Millisecond + + examples := []struct { + label string + retries int + wantErr bool + }{ + {"success", maxRetries, false}, + {"failure", maxRetries + 1, true}, + } + + for _, e := range examples { + t.Run(e.label, func(t *testing.T) { + var last time.Time + + gotRetries := 0 + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + gotRetries += 1 + + now := time.Now() + duration := now.Sub(last) + if duration < interval { + t.Errorf("Unexpected retry interval. Want %v, got %v", interval, duration) + } + last = now + + if gotRetries < e.retries { + + if r.Body != nil { + ioutil.ReadAll(r.Body) + r.Body.Close() + } + + return nil, errors.New("some error!") + } + + return http.DefaultTransport.RoundTrip(r) + }) + + rrt := newRetryRoundTripper(rt, l, maxRetries, interval) + req := httptest.NewRequest("", ts.URL, nil) + + resp, err := rrt.RoundTrip(req) + + wantRetries := maxRetries + if e.retries < wantRetries { + wantRetries = e.retries + } + + if gotRetries != wantRetries { + t.Errorf("Unexpected number of retries. Want %d, got %d", wantRetries, gotRetries) + } + + if e.wantErr { + if err == nil { + t.Errorf("Expected error") + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + gotBody, _ := ioutil.ReadAll(resp.Body) + if string(gotBody) != wantBody { + t.Errorf("Unexpected response. Want %q, got %q", wantBody, gotBody) + } + } + }) + } +} + +func TestStatusFilterRoundTripper(t *testing.T) { + testServer := func(status int) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(status) + })) + } + + goodRT := http.DefaultTransport + errorRT := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return nil, errors.New("some error") + }) + + filtered := []int{501, 502} + + examples := []struct { + label string + transport http.RoundTripper + status int + err error + }{ + {"filtered status", goodRT, 502, errors.New("Filtering 502")}, + {"unfiltered status", goodRT, 503, nil}, + {"transport error", errorRT, 200, errors.New("some error")}, + } + + for _, e := range examples { + t.Run(e.label, func(t *testing.T) { + ts := testServer(e.status) + defer ts.Close() + + rt := newStatusFilterRoundTripper(e.transport, filtered...) + + req := httptest.NewRequest("", ts.URL, nil) + resp, err := rt.RoundTrip(req) + + if e.err != nil { + if err.Error() != e.err.Error() { + t.Errorf("Unexpected error. Want %v, got %v", e.err, err) + } + } else { + if err != nil { + t.Errorf("Unexpected error %v", err) + } + if resp.StatusCode != e.status { + t.Errorf("Unexpected response status. Want %d, got %d", e.status, resp.StatusCode) + } + } + }) + } +}