From e47bd0acdf3bce706ddbed1ab8022894391a3cc2 Mon Sep 17 00:00:00 2001 From: Adam Harwayne Date: Wed, 26 Sep 2018 13:12:48 -0700 Subject: [PATCH 1/9] Unit tests for passthrough headers. --- pkg/buses/bus.go | 3 +- pkg/buses/message_dispatcher.go | 29 +++++--- pkg/buses/message_dispatcher_test.go | 100 +++++++++++++++++++++++++++ 3 files changed, 121 insertions(+), 11 deletions(-) create mode 100644 pkg/buses/message_dispatcher_test.go diff --git a/pkg/buses/bus.go b/pkg/buses/bus.go index ccc2596eec8..0ee44279a26 100644 --- a/pkg/buses/bus.go +++ b/pkg/buses/bus.go @@ -193,7 +193,6 @@ func (b *bus) dispatchMessage(subscription *channelsv1alpha1.Subscription, messa subscriber := subscription.Spec.Subscriber defaults := DispatchDefaults{ Namespace: subscription.Namespace, - ReplyTo: subscription.Spec.ReplyTo, } - return b.dispatcher.DispatchMessage(message, subscriber, defaults) + return b.dispatcher.DispatchMessage(message, subscriber, subscription.Spec.ReplyTo, defaults) } diff --git a/pkg/buses/message_dispatcher.go b/pkg/buses/message_dispatcher.go index ffb055dd2bb..29d62c2ef73 100644 --- a/pkg/buses/message_dispatcher.go +++ b/pkg/buses/message_dispatcher.go @@ -29,9 +29,14 @@ import ( const correlationIDHeaderName = "Knative-Correlation-Id" +// httpDoer is an interface for making HTTP requests. +type httpDoer interface { + Do(*http.Request) (*http.Response, error) +} + // MessageDispatcher dispatches messages to a destination over HTTP. type MessageDispatcher struct { - httpClient *http.Client + httpClient httpDoer forwardHeaders map[string]bool forwardPrefixes []string supportedSchemes map[string]bool @@ -42,7 +47,6 @@ type MessageDispatcher struct { // DispatchDefaults provides default parameter values used when dispatching a message. type DispatchDefaults struct { Namespace string - ReplyTo string } // NewMessageDispatcher creates a new message dispatcher that can dispatch @@ -66,14 +70,21 @@ func NewMessageDispatcher(logger *zap.SugaredLogger) *MessageDispatcher { // The destination and replyTo are DNS names. For names with a single label, // the default namespace is used to expand it into a fully qualified name // within the cluster. -func (d *MessageDispatcher) DispatchMessage(message *Message, destination string, defaults DispatchDefaults) error { - destinationURL := d.resolveURL(destination, defaults.Namespace) - reply, err := d.executeRequest(destinationURL, message) - if err != nil { - return fmt.Errorf("Unable to complete request %v", err) +func (d *MessageDispatcher) DispatchMessage(message *Message, destination, replyTo string, defaults DispatchDefaults) error { + var err error + // Default to replying with the original message. If there is a destination, then replace it + // with the response from the call to the destination instead. + reply := message + if destination != "" { + destinationURL := d.resolveURL(destination, defaults.Namespace) + reply, err = d.executeRequest(destinationURL, message) + if err != nil { + return fmt.Errorf("Unable to complete request %v", err) + } } - if defaults.ReplyTo != "" && reply != nil { - replyToURL := d.resolveURL(defaults.ReplyTo, defaults.Namespace) + + if replyTo != "" && reply != nil { + replyToURL := d.resolveURL(replyTo, defaults.Namespace) _, err = d.executeRequest(replyToURL, reply) if err != nil { return fmt.Errorf("Failed to forward reply %v", err) diff --git a/pkg/buses/message_dispatcher_test.go b/pkg/buses/message_dispatcher_test.go new file mode 100644 index 00000000000..110a90e5f1d --- /dev/null +++ b/pkg/buses/message_dispatcher_test.go @@ -0,0 +1,100 @@ +/* +Copyright 2018 The Knative Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package buses + +import ( + "bytes" + "github.com/google/go-cmp/cmp" + "go.uber.org/zap" + "io/ioutil" + "net/http" + "strings" + "testing" +) + +func TestDispatchMessage(t *testing.T) { + testCases := map[string]struct { + message *Message + expectedHeaders http.Header + }{ + "filter unwanted headers": { + message: &Message{ + Headers: map[string]string{ + "do-not-forward": "header", + }, + Payload: []byte("{}"), + }, + expectedHeaders: map[string][]string{}, + }, + "multiple forward prefixes": { + message: &Message{ + Headers: map[string]string{ + "x-request-id": "id123", + "knative-1": "knative-1-value", + "knative-2": "knative-2-value", + "ce-abc": "ce-abc-value", + }, + Payload: []byte("{}"), + }, + expectedHeaders: map[string][]string{ + "x-request-id": {"id123"}, + "knative-1": {"knative-1-value"}, + "knative-2": {"knative-2-value"}, + "ce-abc": {"ce-abc-value"}, + }, + }, + } + for n, tc := range testCases { + t.Run(n, func(t *testing.T) { + md := NewMessageDispatcher(zap.NewNop().Sugar()) + fc := &fakeHttpClient{} + md.httpClient = fc + err := md.DispatchMessage(tc.message, "destination", "", DispatchDefaults{}) + if err != nil { + t.Errorf("Unexpected error dispatching message: %v", err) + } + if diff := headerDiff(tc.expectedHeaders, fc.requestHeaders); diff != "" { + t.Errorf("Unexpected request headers (-wanted, +got): %v", diff) + } + }) + } +} + +type fakeHttpClient struct { + requestHeaders http.Header +} + +var _ httpDoer = &fakeHttpClient{} + +func (f *fakeHttpClient) Do(r *http.Request) (*http.Response, error) { + f.requestHeaders = r.Header + return &http.Response{ + StatusCode: http.StatusAccepted, + Body: ioutil.NopCloser(bytes.NewBufferString("body")), + }, nil +} + +func headerDiff(expected http.Header, actual http.Header) string { + // HTTP header names are case-insensitive, so normalize them to lower case for comparison. + for _, headers := range []http.Header{expected, actual} { + for n, v := range headers { + delete(headers, n) + headers[strings.ToLower(n)] = v + } + } + return cmp.Diff(expected, actual) +} From 87dda5d54b29fc814e52022c9704edfd8ed9512e Mon Sep 17 00:00:00 2001 From: Adam Harwayne Date: Wed, 26 Sep 2018 14:59:40 -0700 Subject: [PATCH 2/9] Unit tests. --- pkg/buses/message_dispatcher.go | 33 +-- pkg/buses/message_dispatcher_test.go | 292 ++++++++++++++++++++++++--- 2 files changed, 285 insertions(+), 40 deletions(-) diff --git a/pkg/buses/message_dispatcher.go b/pkg/buses/message_dispatcher.go index 29d62c2ef73..c9b3afc7896 100644 --- a/pkg/buses/message_dispatcher.go +++ b/pkg/buses/message_dispatcher.go @@ -104,23 +104,24 @@ func (d *MessageDispatcher) executeRequest(url *url.URL, message *Message) (*Mes if err != nil { return nil, err } - if res != nil { - if res.StatusCode < 200 || res.StatusCode >= 300 { - // reject non-successful (2xx) responses - return nil, fmt.Errorf("unexpected HTTP response, expected 2xx, got %d", res.StatusCode) - } - headers := d.fromHTTPHeaders(res.Header) - // TODO: add configurable whitelisting of propagated headers/prefixes (configmap?) - if correlationID, ok := message.Headers[correlationIDHeaderName]; ok { - headers[correlationIDHeaderName] = correlationID - } - payload, err := ioutil.ReadAll(res.Body) - if err != nil { - return nil, fmt.Errorf("Unable to read response %v", err) - } - return &Message{headers, payload}, nil + if res.StatusCode < 200 || res.StatusCode >= 300 { + // reject non-successful (2xx) responses + return nil, fmt.Errorf("unexpected HTTP response, expected 2xx, got %d", res.StatusCode) + } + headers := d.fromHTTPHeaders(res.Header) + // TODO: add configurable whitelisting of propagated headers/prefixes (configmap?) + if correlationID, ok := message.Headers[correlationIDHeaderName]; ok { + headers[correlationIDHeaderName] = correlationID + } + payload, err := ioutil.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("Unable to read response %v", err) + } + if len(payload) == 0 { + // The response body is empty, the event has 'finished'. + return nil, nil } - return nil, nil + return &Message{headers, payload}, nil } // toHTTPHeaders converts message headers to HTTP headers. diff --git a/pkg/buses/message_dispatcher_test.go b/pkg/buses/message_dispatcher_test.go index 110a90e5f1d..7ee8aa81a65 100644 --- a/pkg/buses/message_dispatcher_test.go +++ b/pkg/buses/message_dispatcher_test.go @@ -28,66 +28,310 @@ import ( func TestDispatchMessage(t *testing.T) { testCases := map[string]struct { - message *Message - expectedHeaders http.Header + destination string + replyTo string + message *Message + fakeResponse *http.Response + expectedErr bool + expectedDestRequest *requestValidation + expectedReplyRequest *requestValidation }{ - "filter unwanted headers": { + "destination - only": { + destination: "test-destination-svc.test-namespace.svc.cluster.local", message: &Message{ Headers: map[string]string{ + // do-not-forward should not get forwarded. "do-not-forward": "header", + "x-request-id": "id123", + "knative-1": "knative-1-value", + "knative-2": "knative-2-value", + "ce-abc": "ce-abc-value", }, - Payload: []byte("{}"), + Payload: []byte("destination"), }, - expectedHeaders: map[string][]string{}, + expectedDestRequest: &requestValidation{ + url: "http://test-destination-svc.test-namespace.svc.cluster.local/", + headers: map[string][]string{ + "x-request-id": {"id123"}, + "knative-1": {"knative-1-value"}, + "knative-2": {"knative-2-value"}, + "ce-abc": {"ce-abc-value"}, + }, + body: "destination", + }, + }, + "destination - only -- error": { + destination: "test-destination-svc.test-namespace.svc.cluster.local", + message: &Message{ + Headers: map[string]string{ + // do-not-forward should not get forwarded. + "do-not-forward": "header", + "x-request-id": "id123", + "knative-1": "knative-1-value", + "knative-2": "knative-2-value", + "ce-abc": "ce-abc-value", + }, + Payload: []byte("destination"), + }, + expectedDestRequest: &requestValidation{ + url: "http://test-destination-svc.test-namespace.svc.cluster.local/", + headers: map[string][]string{ + "x-request-id": {"id123"}, + "knative-1": {"knative-1-value"}, + "knative-2": {"knative-2-value"}, + "ce-abc": {"ce-abc-value"}, + }, + body: "destination", + }, + fakeResponse: &http.Response{ + StatusCode: http.StatusNotFound, + Body: ioutil.NopCloser(bytes.NewBufferString("destination-response")), + }, + expectedErr: true, + }, + "reply - only": { + replyTo: "test-reply-svc.test-namespace.svc.cluster.local", + message: &Message{ + Headers: map[string]string{ + // do-not-forward should not get forwarded. + "do-not-forward": "header", + "x-request-id": "id123", + "knative-1": "knative-1-value", + "knative-2": "knative-2-value", + "ce-abc": "ce-abc-value", + }, + Payload: []byte("replyTo"), + }, + expectedReplyRequest: &requestValidation{ + url: "http://test-reply-svc.test-namespace.svc.cluster.local/", + headers: map[string][]string{ + "x-request-id": {"id123"}, + "knative-1": {"knative-1-value"}, + "knative-2": {"knative-2-value"}, + "ce-abc": {"ce-abc-value"}, + }, + body: "replyTo", + }, + }, + "reply - only -- error": { + replyTo: "test-reply-svc.test-namespace.svc.cluster.local", + message: &Message{ + Headers: map[string]string{ + // do-not-forward should not get forwarded. + "do-not-forward": "header", + "x-request-id": "id123", + "knative-1": "knative-1-value", + "knative-2": "knative-2-value", + "ce-abc": "ce-abc-value", + }, + Payload: []byte("replyTo"), + }, + expectedReplyRequest: &requestValidation{ + url: "http://test-reply-svc.test-namespace.svc.cluster.local/", + headers: map[string][]string{ + "x-request-id": {"id123"}, + "knative-1": {"knative-1-value"}, + "knative-2": {"knative-2-value"}, + "ce-abc": {"ce-abc-value"}, + }, + body: "replyTo", + }, + fakeResponse: &http.Response{ + StatusCode: http.StatusNotFound, + Body: ioutil.NopCloser(bytes.NewBufferString("destination-response")), + }, + expectedErr: true, + }, + "destination and reply - dest returns bad status code": { + destination: "test-destination-svc.test-namespace.svc.cluster.local", + replyTo: "test-reply-svc.test-namespace.svc.cluster.local", + message: &Message{ + Headers: map[string]string{ + // do-not-forward should not get forwarded. + "do-not-forward": "header", + "x-request-id": "id123", + "knative-1": "knative-1-value", + "knative-2": "knative-2-value", + "ce-abc": "ce-abc-value", + }, + Payload: []byte("destination"), + }, + expectedDestRequest: &requestValidation{ + url: "http://test-destination-svc.test-namespace.svc.cluster.local/", + headers: map[string][]string{ + "x-request-id": {"id123"}, + "knative-1": {"knative-1-value"}, + "knative-2": {"knative-2-value"}, + "ce-abc": {"ce-abc-value"}, + }, + body: "destination", + }, + fakeResponse: &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: ioutil.NopCloser(bytes.NewBufferString("destination-response")), + }, + expectedErr: true, }, - "multiple forward prefixes": { + "destination and reply - dest returns empty body": { + destination: "test-destination-svc.test-namespace.svc.cluster.local", + replyTo: "test-reply-svc.test-namespace.svc.cluster.local", message: &Message{ Headers: map[string]string{ - "x-request-id": "id123", - "knative-1": "knative-1-value", - "knative-2": "knative-2-value", - "ce-abc": "ce-abc-value", + // do-not-forward should not get forwarded. + "do-not-forward": "header", + "x-request-id": "id123", + "knative-1": "knative-1-value", + "knative-2": "knative-2-value", + "ce-abc": "ce-abc-value", + }, + Payload: []byte("destination"), + }, + expectedDestRequest: &requestValidation{ + url: "http://test-destination-svc.test-namespace.svc.cluster.local/", + headers: map[string][]string{ + "x-request-id": {"id123"}, + "knative-1": {"knative-1-value"}, + "knative-2": {"knative-2-value"}, + "ce-abc": {"ce-abc-value"}, + }, + body: "destination", + }, + fakeResponse: &http.Response{ + StatusCode: http.StatusAccepted, + Header: map[string][]string{ + "do-not-passthrough": {"no"}, + "x-request-id": {"altered-id"}, + "knative-1": {"new-knative-1-value"}, + "ce-abc": {"new-ce-abc-value"}, + }, + Body: ioutil.NopCloser(bytes.NewBufferString("")), + }, + }, + "destination and reply": { + destination: "test-destination-svc.test-namespace.svc.cluster.local", + replyTo: "test-reply-svc.test-namespace.svc.cluster.local", + message: &Message{ + Headers: map[string]string{ + // do-not-forward should not get forwarded. + "do-not-forward": "header", + "x-request-id": "id123", + "knative-1": "knative-1-value", + "knative-2": "knative-2-value", + "ce-abc": "ce-abc-value", + }, + Payload: []byte("destination"), + }, + expectedDestRequest: &requestValidation{ + url: "http://test-destination-svc.test-namespace.svc.cluster.local/", + headers: map[string][]string{ + "x-request-id": {"id123"}, + "knative-1": {"knative-1-value"}, + "knative-2": {"knative-2-value"}, + "ce-abc": {"ce-abc-value"}, + }, + body: "destination", + }, + fakeResponse: &http.Response{ + StatusCode: http.StatusAccepted, + Header: map[string][]string{ + "do-not-passthrough": {"no"}, + "x-request-id": {"altered-id"}, + "knative-1": {"new-knative-1-value"}, + "ce-abc": {"new-ce-abc-value"}, }, - Payload: []byte("{}"), + Body: ioutil.NopCloser(bytes.NewBufferString("destination-response")), }, - expectedHeaders: map[string][]string{ - "x-request-id": {"id123"}, - "knative-1": {"knative-1-value"}, - "knative-2": {"knative-2-value"}, - "ce-abc": {"ce-abc-value"}, + expectedReplyRequest: &requestValidation{ + url: "http://test-reply-svc.test-namespace.svc.cluster.local/", + headers: map[string][]string{ + "x-request-id": {"altered-id"}, + "knative-1": {"new-knative-1-value"}, + "ce-abc": {"new-ce-abc-value"}, + }, + body: "destination-response", }, }, } for n, tc := range testCases { t.Run(n, func(t *testing.T) { md := NewMessageDispatcher(zap.NewNop().Sugar()) - fc := &fakeHttpClient{} + fc := &fakeHttpClient{ + response: tc.fakeResponse, + } md.httpClient = fc - err := md.DispatchMessage(tc.message, "destination", "", DispatchDefaults{}) - if err != nil { - t.Errorf("Unexpected error dispatching message: %v", err) + err := md.DispatchMessage(tc.message, tc.destination, tc.replyTo, DispatchDefaults{}) + if tc.expectedErr != (err != nil) { + t.Errorf("Unexpected error from DispatchRequest. Expected %v. Actual: %v", tc.expectedErr, err) + } + if tc.expectedDestRequest != nil { + rv := fc.popRequest(t) + assertEquality(t, *tc.expectedDestRequest, rv) + } + if tc.expectedReplyRequest != nil { + rv := fc.popRequest(t) + assertEquality(t, *tc.expectedReplyRequest, rv) } - if diff := headerDiff(tc.expectedHeaders, fc.requestHeaders); diff != "" { - t.Errorf("Unexpected request headers (-wanted, +got): %v", diff) + if len(fc.requests) != 0 { + t.Errorf("Unexpected requests: %+v", fc.requests) } }) } } +type requestValidation struct { + url string + headers http.Header + body string +} + type fakeHttpClient struct { - requestHeaders http.Header + t *testing.T + response *http.Response + requests []requestValidation } var _ httpDoer = &fakeHttpClient{} func (f *fakeHttpClient) Do(r *http.Request) (*http.Response, error) { - f.requestHeaders = r.Header + body, err := ioutil.ReadAll(r.Body) + if err != nil { + f.t.Error("Failed to read the request body") + } + f.requests = append(f.requests, requestValidation{ + url: r.URL.String(), + headers: r.Header, + body: string(body), + }) + if f.response != nil { + return f.response, nil + } return &http.Response{ StatusCode: http.StatusAccepted, Body: ioutil.NopCloser(bytes.NewBufferString("body")), }, nil } +func (f *fakeHttpClient) popRequest(t *testing.T) requestValidation { + if len(f.requests) == 0 { + t.Error("Unable to pop request") + } + rv := f.requests[0] + f.requests = f.requests[1:] + return rv +} + +func assertEquality(t *testing.T, expected, actual requestValidation) { + if diff := cmp.Diff(expected.url, actual.url); diff != "" { + t.Errorf("Unexpected URL (-wanted, +got): %v", diff) + } + if diff := headerDiff(expected.headers, actual.headers); diff != "" { + t.Errorf("Unexpected request headers (-wanted, +got): %v", diff) + } + if diff := cmp.Diff(expected.body, actual.body); diff != "" { + t.Errorf("Unexpected body (-want, +got): %v", diff) + } +} + func headerDiff(expected http.Header, actual http.Header) string { // HTTP header names are case-insensitive, so normalize them to lower case for comparison. for _, headers := range []http.Header{expected, actual} { From 1c4ee9a76a9244ff3e4ca5c98e5b93664a4ebbe2 Mon Sep 17 00:00:00 2001 From: Adam Harwayne Date: Thu, 27 Sep 2018 11:12:19 -0700 Subject: [PATCH 3/9] Respond to PR comments. --- pkg/buses/message_dispatcher_test.go | 77 +++++++++++++--------------- 1 file changed, 36 insertions(+), 41 deletions(-) diff --git a/pkg/buses/message_dispatcher_test.go b/pkg/buses/message_dispatcher_test.go index 7ee8aa81a65..f627d211842 100644 --- a/pkg/buses/message_dispatcher_test.go +++ b/pkg/buses/message_dispatcher_test.go @@ -50,14 +50,14 @@ func TestDispatchMessage(t *testing.T) { Payload: []byte("destination"), }, expectedDestRequest: &requestValidation{ - url: "http://test-destination-svc.test-namespace.svc.cluster.local/", - headers: map[string][]string{ + Url: "http://test-destination-svc.test-namespace.svc.cluster.local/", + Headers: map[string][]string{ "x-request-id": {"id123"}, "knative-1": {"knative-1-value"}, "knative-2": {"knative-2-value"}, "ce-abc": {"ce-abc-value"}, }, - body: "destination", + Body: "destination", }, }, "destination - only -- error": { @@ -74,14 +74,14 @@ func TestDispatchMessage(t *testing.T) { Payload: []byte("destination"), }, expectedDestRequest: &requestValidation{ - url: "http://test-destination-svc.test-namespace.svc.cluster.local/", - headers: map[string][]string{ + Url: "http://test-destination-svc.test-namespace.svc.cluster.local/", + Headers: map[string][]string{ "x-request-id": {"id123"}, "knative-1": {"knative-1-value"}, "knative-2": {"knative-2-value"}, "ce-abc": {"ce-abc-value"}, }, - body: "destination", + Body: "destination", }, fakeResponse: &http.Response{ StatusCode: http.StatusNotFound, @@ -103,14 +103,14 @@ func TestDispatchMessage(t *testing.T) { Payload: []byte("replyTo"), }, expectedReplyRequest: &requestValidation{ - url: "http://test-reply-svc.test-namespace.svc.cluster.local/", - headers: map[string][]string{ + Url: "http://test-reply-svc.test-namespace.svc.cluster.local/", + Headers: map[string][]string{ "x-request-id": {"id123"}, "knative-1": {"knative-1-value"}, "knative-2": {"knative-2-value"}, "ce-abc": {"ce-abc-value"}, }, - body: "replyTo", + Body: "replyTo", }, }, "reply - only -- error": { @@ -127,14 +127,14 @@ func TestDispatchMessage(t *testing.T) { Payload: []byte("replyTo"), }, expectedReplyRequest: &requestValidation{ - url: "http://test-reply-svc.test-namespace.svc.cluster.local/", - headers: map[string][]string{ + Url: "http://test-reply-svc.test-namespace.svc.cluster.local/", + Headers: map[string][]string{ "x-request-id": {"id123"}, "knative-1": {"knative-1-value"}, "knative-2": {"knative-2-value"}, "ce-abc": {"ce-abc-value"}, }, - body: "replyTo", + Body: "replyTo", }, fakeResponse: &http.Response{ StatusCode: http.StatusNotFound, @@ -157,14 +157,14 @@ func TestDispatchMessage(t *testing.T) { Payload: []byte("destination"), }, expectedDestRequest: &requestValidation{ - url: "http://test-destination-svc.test-namespace.svc.cluster.local/", - headers: map[string][]string{ + Url: "http://test-destination-svc.test-namespace.svc.cluster.local/", + Headers: map[string][]string{ "x-request-id": {"id123"}, "knative-1": {"knative-1-value"}, "knative-2": {"knative-2-value"}, "ce-abc": {"ce-abc-value"}, }, - body: "destination", + Body: "destination", }, fakeResponse: &http.Response{ StatusCode: http.StatusInternalServerError, @@ -187,14 +187,14 @@ func TestDispatchMessage(t *testing.T) { Payload: []byte("destination"), }, expectedDestRequest: &requestValidation{ - url: "http://test-destination-svc.test-namespace.svc.cluster.local/", - headers: map[string][]string{ + Url: "http://test-destination-svc.test-namespace.svc.cluster.local/", + Headers: map[string][]string{ "x-request-id": {"id123"}, "knative-1": {"knative-1-value"}, "knative-2": {"knative-2-value"}, "ce-abc": {"ce-abc-value"}, }, - body: "destination", + Body: "destination", }, fakeResponse: &http.Response{ StatusCode: http.StatusAccepted, @@ -222,14 +222,14 @@ func TestDispatchMessage(t *testing.T) { Payload: []byte("destination"), }, expectedDestRequest: &requestValidation{ - url: "http://test-destination-svc.test-namespace.svc.cluster.local/", - headers: map[string][]string{ + Url: "http://test-destination-svc.test-namespace.svc.cluster.local/", + Headers: map[string][]string{ "x-request-id": {"id123"}, "knative-1": {"knative-1-value"}, "knative-2": {"knative-2-value"}, "ce-abc": {"ce-abc-value"}, }, - body: "destination", + Body: "destination", }, fakeResponse: &http.Response{ StatusCode: http.StatusAccepted, @@ -242,13 +242,13 @@ func TestDispatchMessage(t *testing.T) { Body: ioutil.NopCloser(bytes.NewBufferString("destination-response")), }, expectedReplyRequest: &requestValidation{ - url: "http://test-reply-svc.test-namespace.svc.cluster.local/", - headers: map[string][]string{ + Url: "http://test-reply-svc.test-namespace.svc.cluster.local/", + Headers: map[string][]string{ "x-request-id": {"altered-id"}, "knative-1": {"new-knative-1-value"}, "ce-abc": {"new-ce-abc-value"}, }, - body: "destination-response", + Body: "destination-response", }, }, } @@ -279,9 +279,9 @@ func TestDispatchMessage(t *testing.T) { } type requestValidation struct { - url string - headers http.Header - body string + Url string + Headers http.Header + Body string } type fakeHttpClient struct { @@ -298,9 +298,9 @@ func (f *fakeHttpClient) Do(r *http.Request) (*http.Response, error) { f.t.Error("Failed to read the request body") } f.requests = append(f.requests, requestValidation{ - url: r.URL.String(), - headers: r.Header, - body: string(body), + Url: r.URL.String(), + Headers: r.Header, + Body: string(body), }) if f.response != nil { return f.response, nil @@ -321,24 +321,19 @@ func (f *fakeHttpClient) popRequest(t *testing.T) requestValidation { } func assertEquality(t *testing.T, expected, actual requestValidation) { - if diff := cmp.Diff(expected.url, actual.url); diff != "" { - t.Errorf("Unexpected URL (-wanted, +got): %v", diff) - } - if diff := headerDiff(expected.headers, actual.headers); diff != "" { - t.Errorf("Unexpected request headers (-wanted, +got): %v", diff) - } - if diff := cmp.Diff(expected.body, actual.body); diff != "" { - t.Errorf("Unexpected body (-want, +got): %v", diff) + canonicalizeHeaders(expected, actual) + if diff := cmp.Diff(expected, actual); diff != "" { + t.Errorf("Unexpected difference (-want, +got): %v", diff) } } -func headerDiff(expected http.Header, actual http.Header) string { +func canonicalizeHeaders(rvs ...requestValidation) { // HTTP header names are case-insensitive, so normalize them to lower case for comparison. - for _, headers := range []http.Header{expected, actual} { + for _, rv := range rvs { + headers := rv.Headers for n, v := range headers { delete(headers, n) headers[strings.ToLower(n)] = v } } - return cmp.Diff(expected, actual) } From ab5c51d6973e1643f8b0f88d2f264c59bccd9287 Mon Sep 17 00:00:00 2001 From: Adam Harwayne Date: Thu, 27 Sep 2018 16:17:28 -0700 Subject: [PATCH 4/9] Remove httpDoer and use an in-memory server instead. --- pkg/buses/message_dispatcher.go | 7 +- pkg/buses/message_dispatcher_test.go | 127 ++++++++++++++++++--------- 2 files changed, 85 insertions(+), 49 deletions(-) diff --git a/pkg/buses/message_dispatcher.go b/pkg/buses/message_dispatcher.go index c9b3afc7896..917087b63b0 100644 --- a/pkg/buses/message_dispatcher.go +++ b/pkg/buses/message_dispatcher.go @@ -29,14 +29,9 @@ import ( const correlationIDHeaderName = "Knative-Correlation-Id" -// httpDoer is an interface for making HTTP requests. -type httpDoer interface { - Do(*http.Request) (*http.Response, error) -} - // MessageDispatcher dispatches messages to a destination over HTTP. type MessageDispatcher struct { - httpClient httpDoer + httpClient *http.Client forwardHeaders map[string]bool forwardPrefixes []string supportedSchemes map[string]bool diff --git a/pkg/buses/message_dispatcher_test.go b/pkg/buses/message_dispatcher_test.go index f627d211842..62293bc26f3 100644 --- a/pkg/buses/message_dispatcher_test.go +++ b/pkg/buses/message_dispatcher_test.go @@ -22,14 +22,25 @@ import ( "go.uber.org/zap" "io/ioutil" "net/http" + "net/http/httptest" "strings" "testing" ) +var ( + // Headers that are added to the response, but we don't want to check in our assertions. + unimportantHeaders = map[string]struct{}{ + "accept-encoding": {}, + "content-length": {}, + "content-type": {}, + "user-agent": {}, + } +) + func TestDispatchMessage(t *testing.T) { testCases := map[string]struct { - destination string - replyTo string + sendToDestination bool + sendToReply bool message *Message fakeResponse *http.Response expectedErr bool @@ -37,7 +48,7 @@ func TestDispatchMessage(t *testing.T) { expectedReplyRequest *requestValidation }{ "destination - only": { - destination: "test-destination-svc.test-namespace.svc.cluster.local", + sendToDestination: true, message: &Message{ Headers: map[string]string{ // do-not-forward should not get forwarded. @@ -50,7 +61,6 @@ func TestDispatchMessage(t *testing.T) { Payload: []byte("destination"), }, expectedDestRequest: &requestValidation{ - Url: "http://test-destination-svc.test-namespace.svc.cluster.local/", Headers: map[string][]string{ "x-request-id": {"id123"}, "knative-1": {"knative-1-value"}, @@ -61,7 +71,7 @@ func TestDispatchMessage(t *testing.T) { }, }, "destination - only -- error": { - destination: "test-destination-svc.test-namespace.svc.cluster.local", + sendToDestination: true, message: &Message{ Headers: map[string]string{ // do-not-forward should not get forwarded. @@ -74,7 +84,6 @@ func TestDispatchMessage(t *testing.T) { Payload: []byte("destination"), }, expectedDestRequest: &requestValidation{ - Url: "http://test-destination-svc.test-namespace.svc.cluster.local/", Headers: map[string][]string{ "x-request-id": {"id123"}, "knative-1": {"knative-1-value"}, @@ -90,7 +99,7 @@ func TestDispatchMessage(t *testing.T) { expectedErr: true, }, "reply - only": { - replyTo: "test-reply-svc.test-namespace.svc.cluster.local", + sendToReply: true, message: &Message{ Headers: map[string]string{ // do-not-forward should not get forwarded. @@ -103,7 +112,6 @@ func TestDispatchMessage(t *testing.T) { Payload: []byte("replyTo"), }, expectedReplyRequest: &requestValidation{ - Url: "http://test-reply-svc.test-namespace.svc.cluster.local/", Headers: map[string][]string{ "x-request-id": {"id123"}, "knative-1": {"knative-1-value"}, @@ -114,7 +122,7 @@ func TestDispatchMessage(t *testing.T) { }, }, "reply - only -- error": { - replyTo: "test-reply-svc.test-namespace.svc.cluster.local", + sendToReply: true, message: &Message{ Headers: map[string]string{ // do-not-forward should not get forwarded. @@ -127,7 +135,6 @@ func TestDispatchMessage(t *testing.T) { Payload: []byte("replyTo"), }, expectedReplyRequest: &requestValidation{ - Url: "http://test-reply-svc.test-namespace.svc.cluster.local/", Headers: map[string][]string{ "x-request-id": {"id123"}, "knative-1": {"knative-1-value"}, @@ -143,8 +150,8 @@ func TestDispatchMessage(t *testing.T) { expectedErr: true, }, "destination and reply - dest returns bad status code": { - destination: "test-destination-svc.test-namespace.svc.cluster.local", - replyTo: "test-reply-svc.test-namespace.svc.cluster.local", + sendToDestination: true, + sendToReply: true, message: &Message{ Headers: map[string]string{ // do-not-forward should not get forwarded. @@ -157,7 +164,6 @@ func TestDispatchMessage(t *testing.T) { Payload: []byte("destination"), }, expectedDestRequest: &requestValidation{ - Url: "http://test-destination-svc.test-namespace.svc.cluster.local/", Headers: map[string][]string{ "x-request-id": {"id123"}, "knative-1": {"knative-1-value"}, @@ -173,8 +179,8 @@ func TestDispatchMessage(t *testing.T) { expectedErr: true, }, "destination and reply - dest returns empty body": { - destination: "test-destination-svc.test-namespace.svc.cluster.local", - replyTo: "test-reply-svc.test-namespace.svc.cluster.local", + sendToDestination: true, + sendToReply: true, message: &Message{ Headers: map[string]string{ // do-not-forward should not get forwarded. @@ -187,7 +193,6 @@ func TestDispatchMessage(t *testing.T) { Payload: []byte("destination"), }, expectedDestRequest: &requestValidation{ - Url: "http://test-destination-svc.test-namespace.svc.cluster.local/", Headers: map[string][]string{ "x-request-id": {"id123"}, "knative-1": {"knative-1-value"}, @@ -208,8 +213,8 @@ func TestDispatchMessage(t *testing.T) { }, }, "destination and reply": { - destination: "test-destination-svc.test-namespace.svc.cluster.local", - replyTo: "test-reply-svc.test-namespace.svc.cluster.local", + sendToDestination: true, + sendToReply: true, message: &Message{ Headers: map[string]string{ // do-not-forward should not get forwarded. @@ -222,7 +227,6 @@ func TestDispatchMessage(t *testing.T) { Payload: []byte("destination"), }, expectedDestRequest: &requestValidation{ - Url: "http://test-destination-svc.test-namespace.svc.cluster.local/", Headers: map[string][]string{ "x-request-id": {"id123"}, "knative-1": {"knative-1-value"}, @@ -242,7 +246,6 @@ func TestDispatchMessage(t *testing.T) { Body: ioutil.NopCloser(bytes.NewBufferString("destination-response")), }, expectedReplyRequest: &requestValidation{ - Url: "http://test-reply-svc.test-namespace.svc.cluster.local/", Headers: map[string][]string{ "x-request-id": {"altered-id"}, "knative-1": {"new-knative-1-value"}, @@ -254,64 +257,98 @@ func TestDispatchMessage(t *testing.T) { } for n, tc := range testCases { t.Run(n, func(t *testing.T) { - md := NewMessageDispatcher(zap.NewNop().Sugar()) - fc := &fakeHttpClient{ + destHandler := &fakeHandler{ + t: t, response: tc.fakeResponse, + requests: make([]requestValidation, 0), } - md.httpClient = fc - err := md.DispatchMessage(tc.message, tc.destination, tc.replyTo, DispatchDefaults{}) + destServer := httptest.NewServer(destHandler) + defer destServer.Close() + replyHandler := &fakeHandler{ + t: t, + response: tc.fakeResponse, + requests: make([]requestValidation, 0), + } + replyServer := httptest.NewServer(replyHandler) + defer replyServer.Close() + + md := NewMessageDispatcher(zap.NewNop().Sugar()) + err := md.DispatchMessage(tc.message, + getDomain(tc.sendToDestination, destServer.URL[7:]), + getDomain(tc.sendToReply, replyServer.URL[7:]), + DispatchDefaults{}) if tc.expectedErr != (err != nil) { t.Errorf("Unexpected error from DispatchRequest. Expected %v. Actual: %v", tc.expectedErr, err) } if tc.expectedDestRequest != nil { - rv := fc.popRequest(t) - assertEquality(t, *tc.expectedDestRequest, rv) + rv := destHandler.popRequest(t) + assertEquality(t, destServer.URL, *tc.expectedDestRequest, rv) } if tc.expectedReplyRequest != nil { - rv := fc.popRequest(t) - assertEquality(t, *tc.expectedReplyRequest, rv) + rv := replyHandler.popRequest(t) + assertEquality(t, replyServer.URL, *tc.expectedReplyRequest, rv) } - if len(fc.requests) != 0 { - t.Errorf("Unexpected requests: %+v", fc.requests) + if len(destHandler.requests) != 0 { + t.Errorf("Unexpected destination requests: %+v", destHandler.requests) + } + if len(replyHandler.requests) != 0 { + t.Errorf("Unexpected reply requests: %+v", replyHandler.requests) } }) } } +func getDomain(shouldSend bool, domain string) string { + if shouldSend { + return domain + } + return "" +} + type requestValidation struct { - Url string + Host string Headers http.Header Body string } -type fakeHttpClient struct { +type fakeHandler struct { t *testing.T response *http.Response requests []requestValidation } -var _ httpDoer = &fakeHttpClient{} +func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() -func (f *fakeHttpClient) Do(r *http.Request) (*http.Response, error) { + // Make a copy of the request. body, err := ioutil.ReadAll(r.Body) if err != nil { f.t.Error("Failed to read the request body") } f.requests = append(f.requests, requestValidation{ - Url: r.URL.String(), + Host: r.Host, Headers: r.Header, Body: string(body), }) + + // Write the response. if f.response != nil { - return f.response, nil + for h, vs := range f.response.Header { + for _, v := range vs { + w.Header().Add(h, v) + } + } + w.WriteHeader(f.response.StatusCode) + var buf bytes.Buffer + buf.ReadFrom(f.response.Body) + w.Write(buf.Bytes()) + } else { + w.WriteHeader(http.StatusOK) + w.Write([]byte("")) } - return &http.Response{ - StatusCode: http.StatusAccepted, - Body: ioutil.NopCloser(bytes.NewBufferString("body")), - }, nil } -func (f *fakeHttpClient) popRequest(t *testing.T) requestValidation { +func (f *fakeHandler) popRequest(t *testing.T) requestValidation { if len(f.requests) == 0 { t.Error("Unable to pop request") } @@ -320,7 +357,8 @@ func (f *fakeHttpClient) popRequest(t *testing.T) requestValidation { return rv } -func assertEquality(t *testing.T, expected, actual requestValidation) { +func assertEquality(t *testing.T, replacementURL string, expected, actual requestValidation) { + expected.Host = replacementURL[7:] canonicalizeHeaders(expected, actual) if diff := cmp.Diff(expected, actual); diff != "" { t.Errorf("Unexpected difference (-want, +got): %v", diff) @@ -333,7 +371,10 @@ func canonicalizeHeaders(rvs ...requestValidation) { headers := rv.Headers for n, v := range headers { delete(headers, n) - headers[strings.ToLower(n)] = v + ln := strings.ToLower(n) + if _, present := unimportantHeaders[ln]; !present { + headers[ln] = v + } } } } From fdd684252f9e31cb75c59b92ae15f611b0b1f6ee Mon Sep 17 00:00:00 2001 From: Adam Harwayne Date: Mon, 1 Oct 2018 08:44:27 -0700 Subject: [PATCH 5/9] Respond to PR comments. --- pkg/buses/message_dispatcher.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pkg/buses/message_dispatcher.go b/pkg/buses/message_dispatcher.go index 917087b63b0..001a095a690 100644 --- a/pkg/buses/message_dispatcher.go +++ b/pkg/buses/message_dispatcher.go @@ -92,14 +92,17 @@ func (d *MessageDispatcher) executeRequest(url *url.URL, message *Message) (*Mes d.logger.Infof("Dispatching message to %s", url.String()) req, err := http.NewRequest(http.MethodPost, url.String(), bytes.NewReader(message.Payload)) if err != nil { - return nil, fmt.Errorf("Unable to create request %v", err) + return nil, fmt.Errorf("unable to create request %v", err) } req.Header = d.toHTTPHeaders(message.Headers) res, err := d.httpClient.Do(req) if err != nil { return nil, err } - if res.StatusCode < 200 || res.StatusCode >= 300 { + if res == nil { + return nil, nil + } + if res.StatusCode < http.StatusOK /* 200 */ || res.StatusCode >= http.StatusMultipleChoices /* 300 */ { // reject non-successful (2xx) responses return nil, fmt.Errorf("unexpected HTTP response, expected 2xx, got %d", res.StatusCode) } From 1e3c7cf3c2ee0c56701bbefdaa4e8eeca565ade3 Mon Sep 17 00:00:00 2001 From: Adam Harwayne Date: Mon, 1 Oct 2018 08:54:32 -0700 Subject: [PATCH 6/9] Close the response body. --- pkg/buses/message_dispatcher.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pkg/buses/message_dispatcher.go b/pkg/buses/message_dispatcher.go index 001a095a690..1d0c9995de8 100644 --- a/pkg/buses/message_dispatcher.go +++ b/pkg/buses/message_dispatcher.go @@ -100,8 +100,11 @@ func (d *MessageDispatcher) executeRequest(url *url.URL, message *Message) (*Mes return nil, err } if res == nil { + // I don't think this is actually rechable with http.Client.Do(), but just to be sure we + // check anyway. return nil, nil } + defer res.Body.Close() if res.StatusCode < http.StatusOK /* 200 */ || res.StatusCode >= http.StatusMultipleChoices /* 300 */ { // reject non-successful (2xx) responses return nil, fmt.Errorf("unexpected HTTP response, expected 2xx, got %d", res.StatusCode) From 8c45581bdf7fbe736e3188f9e6ab49e4ac14bdfb Mon Sep 17 00:00:00 2001 From: Adam Harwayne Date: Mon, 1 Oct 2018 11:40:21 -0700 Subject: [PATCH 7/9] Respond to PR comments. --- pkg/buses/message_dispatcher.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/pkg/buses/message_dispatcher.go b/pkg/buses/message_dispatcher.go index 1d0c9995de8..f0ad5e65fc2 100644 --- a/pkg/buses/message_dispatcher.go +++ b/pkg/buses/message_dispatcher.go @@ -100,13 +100,13 @@ func (d *MessageDispatcher) executeRequest(url *url.URL, message *Message) (*Mes return nil, err } if res == nil { - // I don't think this is actually rechable with http.Client.Do(), but just to be sure we + // I don't think this is actually reachable with http.Client.Do(), but just to be sure we // check anyway. return nil, nil } defer res.Body.Close() - if res.StatusCode < http.StatusOK /* 200 */ || res.StatusCode >= http.StatusMultipleChoices /* 300 */ { - // reject non-successful (2xx) responses + if isFailure(res.StatusCode) { + // reject non-successful responses return nil, fmt.Errorf("unexpected HTTP response, expected 2xx, got %d", res.StatusCode) } headers := d.fromHTTPHeaders(res.Header) @@ -125,6 +125,12 @@ func (d *MessageDispatcher) executeRequest(url *url.URL, message *Message) (*Mes return &Message{headers, payload}, nil } +// isFailure returns true if the status code is not a successful HTTP status. +func isFailure(statusCode int) bool { + return statusCode < http.StatusOK /* 200 */ || + statusCode >= http.StatusMultipleChoices /* 300 */ +} + // toHTTPHeaders converts message headers to HTTP headers. // // Only headers whitelisted as safe are copied. From 286dc722b122b074bdaf8055fac134b3ee2ae480 Mon Sep 17 00:00:00 2001 From: Adam Harwayne Date: Mon, 1 Oct 2018 11:49:10 -0700 Subject: [PATCH 8/9] Nil response is an error. --- pkg/buses/message_dispatcher.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/buses/message_dispatcher.go b/pkg/buses/message_dispatcher.go index f0ad5e65fc2..c04b99508dc 100644 --- a/pkg/buses/message_dispatcher.go +++ b/pkg/buses/message_dispatcher.go @@ -18,6 +18,7 @@ package buses import ( "bytes" + "errors" "fmt" "io/ioutil" "net/http" @@ -102,7 +103,7 @@ func (d *MessageDispatcher) executeRequest(url *url.URL, message *Message) (*Mes if res == nil { // I don't think this is actually reachable with http.Client.Do(), but just to be sure we // check anyway. - return nil, nil + return nil, errors.New("non-error nil result from http.Client.Do()") } defer res.Body.Close() if isFailure(res.StatusCode) { From 757f14b1c59630b6ae00808bde76ad703a4a2408 Mon Sep 17 00:00:00 2001 From: Adam Harwayne Date: Mon, 1 Oct 2018 11:53:47 -0700 Subject: [PATCH 9/9] Unit test improvements suggested by PR comments. --- pkg/buses/message_dispatcher_test.go | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/pkg/buses/message_dispatcher_test.go b/pkg/buses/message_dispatcher_test.go index 62293bc26f3..4ccee0d462b 100644 --- a/pkg/buses/message_dispatcher_test.go +++ b/pkg/buses/message_dispatcher_test.go @@ -23,6 +23,7 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "net/url" "strings" "testing" ) @@ -274,8 +275,8 @@ func TestDispatchMessage(t *testing.T) { md := NewMessageDispatcher(zap.NewNop().Sugar()) err := md.DispatchMessage(tc.message, - getDomain(tc.sendToDestination, destServer.URL[7:]), - getDomain(tc.sendToReply, replyServer.URL[7:]), + getDomain(t, tc.sendToDestination, destServer.URL), + getDomain(t, tc.sendToReply, replyServer.URL), DispatchDefaults{}) if tc.expectedErr != (err != nil) { t.Errorf("Unexpected error from DispatchRequest. Expected %v. Actual: %v", tc.expectedErr, err) @@ -298,9 +299,13 @@ func TestDispatchMessage(t *testing.T) { } } -func getDomain(shouldSend bool, domain string) string { +func getDomain(t *testing.T, shouldSend bool, serverURL string) string { if shouldSend { - return domain + server, err := url.Parse(serverURL) + if err != nil { + t.Errorf("Bad serverURL: %q", serverURL) + } + return server.Host } return "" } @@ -358,7 +363,11 @@ func (f *fakeHandler) popRequest(t *testing.T) requestValidation { } func assertEquality(t *testing.T, replacementURL string, expected, actual requestValidation) { - expected.Host = replacementURL[7:] + server, err := url.Parse(replacementURL) + if err != nil { + t.Errorf("Bad replacement URL: %q", replacementURL) + } + expected.Host = server.Host canonicalizeHeaders(expected, actual) if diff := cmp.Diff(expected, actual); diff != "" { t.Errorf("Unexpected difference (-want, +got): %v", diff)