diff --git a/pkg/buses/bus.go b/pkg/buses/bus.go index ccc2596eec8..0ee44279a26 100644 --- a/pkg/buses/bus.go +++ b/pkg/buses/bus.go @@ -193,7 +193,6 @@ func (b *bus) dispatchMessage(subscription *channelsv1alpha1.Subscription, messa subscriber := subscription.Spec.Subscriber defaults := DispatchDefaults{ Namespace: subscription.Namespace, - ReplyTo: subscription.Spec.ReplyTo, } - return b.dispatcher.DispatchMessage(message, subscriber, defaults) + return b.dispatcher.DispatchMessage(message, subscriber, subscription.Spec.ReplyTo, defaults) } diff --git a/pkg/buses/message_dispatcher.go b/pkg/buses/message_dispatcher.go index ffb055dd2bb..c04b99508dc 100644 --- a/pkg/buses/message_dispatcher.go +++ b/pkg/buses/message_dispatcher.go @@ -18,6 +18,7 @@ package buses import ( "bytes" + "errors" "fmt" "io/ioutil" "net/http" @@ -42,7 +43,6 @@ type MessageDispatcher struct { // DispatchDefaults provides default parameter values used when dispatching a message. type DispatchDefaults struct { Namespace string - ReplyTo string } // NewMessageDispatcher creates a new message dispatcher that can dispatch @@ -66,14 +66,21 @@ func NewMessageDispatcher(logger *zap.SugaredLogger) *MessageDispatcher { // The destination and replyTo are DNS names. For names with a single label, // the default namespace is used to expand it into a fully qualified name // within the cluster. -func (d *MessageDispatcher) DispatchMessage(message *Message, destination string, defaults DispatchDefaults) error { - destinationURL := d.resolveURL(destination, defaults.Namespace) - reply, err := d.executeRequest(destinationURL, message) - if err != nil { - return fmt.Errorf("Unable to complete request %v", err) +func (d *MessageDispatcher) DispatchMessage(message *Message, destination, replyTo string, defaults DispatchDefaults) error { + var err error + // Default to replying with the original message. If there is a destination, then replace it + // with the response from the call to the destination instead. + reply := message + if destination != "" { + destinationURL := d.resolveURL(destination, defaults.Namespace) + reply, err = d.executeRequest(destinationURL, message) + if err != nil { + return fmt.Errorf("Unable to complete request %v", err) + } } - if defaults.ReplyTo != "" && reply != nil { - replyToURL := d.resolveURL(defaults.ReplyTo, defaults.Namespace) + + if replyTo != "" && reply != nil { + replyToURL := d.resolveURL(replyTo, defaults.Namespace) _, err = d.executeRequest(replyToURL, reply) if err != nil { return fmt.Errorf("Failed to forward reply %v", err) @@ -86,30 +93,43 @@ func (d *MessageDispatcher) executeRequest(url *url.URL, message *Message) (*Mes d.logger.Infof("Dispatching message to %s", url.String()) req, err := http.NewRequest(http.MethodPost, url.String(), bytes.NewReader(message.Payload)) if err != nil { - return nil, fmt.Errorf("Unable to create request %v", err) + return nil, fmt.Errorf("unable to create request %v", err) } req.Header = d.toHTTPHeaders(message.Headers) res, err := d.httpClient.Do(req) if err != nil { return nil, err } - if res != nil { - if res.StatusCode < 200 || res.StatusCode >= 300 { - // reject non-successful (2xx) responses - return nil, fmt.Errorf("unexpected HTTP response, expected 2xx, got %d", res.StatusCode) - } - headers := d.fromHTTPHeaders(res.Header) - // TODO: add configurable whitelisting of propagated headers/prefixes (configmap?) - if correlationID, ok := message.Headers[correlationIDHeaderName]; ok { - headers[correlationIDHeaderName] = correlationID - } - payload, err := ioutil.ReadAll(res.Body) - if err != nil { - return nil, fmt.Errorf("Unable to read response %v", err) - } - return &Message{headers, payload}, nil + if res == nil { + // I don't think this is actually reachable with http.Client.Do(), but just to be sure we + // check anyway. + return nil, errors.New("non-error nil result from http.Client.Do()") } - return nil, nil + defer res.Body.Close() + if isFailure(res.StatusCode) { + // reject non-successful responses + return nil, fmt.Errorf("unexpected HTTP response, expected 2xx, got %d", res.StatusCode) + } + headers := d.fromHTTPHeaders(res.Header) + // TODO: add configurable whitelisting of propagated headers/prefixes (configmap?) + if correlationID, ok := message.Headers[correlationIDHeaderName]; ok { + headers[correlationIDHeaderName] = correlationID + } + payload, err := ioutil.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("Unable to read response %v", err) + } + if len(payload) == 0 { + // The response body is empty, the event has 'finished'. + return nil, nil + } + return &Message{headers, payload}, nil +} + +// isFailure returns true if the status code is not a successful HTTP status. +func isFailure(statusCode int) bool { + return statusCode < http.StatusOK /* 200 */ || + statusCode >= http.StatusMultipleChoices /* 300 */ } // toHTTPHeaders converts message headers to HTTP headers. diff --git a/pkg/buses/message_dispatcher_test.go b/pkg/buses/message_dispatcher_test.go new file mode 100644 index 00000000000..4ccee0d462b --- /dev/null +++ b/pkg/buses/message_dispatcher_test.go @@ -0,0 +1,389 @@ +/* +Copyright 2018 The Knative Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package buses + +import ( + "bytes" + "github.com/google/go-cmp/cmp" + "go.uber.org/zap" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +var ( + // Headers that are added to the response, but we don't want to check in our assertions. + unimportantHeaders = map[string]struct{}{ + "accept-encoding": {}, + "content-length": {}, + "content-type": {}, + "user-agent": {}, + } +) + +func TestDispatchMessage(t *testing.T) { + testCases := map[string]struct { + sendToDestination bool + sendToReply bool + message *Message + fakeResponse *http.Response + expectedErr bool + expectedDestRequest *requestValidation + expectedReplyRequest *requestValidation + }{ + "destination - only": { + sendToDestination: true, + message: &Message{ + Headers: map[string]string{ + // do-not-forward should not get forwarded. + "do-not-forward": "header", + "x-request-id": "id123", + "knative-1": "knative-1-value", + "knative-2": "knative-2-value", + "ce-abc": "ce-abc-value", + }, + Payload: []byte("destination"), + }, + expectedDestRequest: &requestValidation{ + Headers: map[string][]string{ + "x-request-id": {"id123"}, + "knative-1": {"knative-1-value"}, + "knative-2": {"knative-2-value"}, + "ce-abc": {"ce-abc-value"}, + }, + Body: "destination", + }, + }, + "destination - only -- error": { + sendToDestination: true, + message: &Message{ + Headers: map[string]string{ + // do-not-forward should not get forwarded. + "do-not-forward": "header", + "x-request-id": "id123", + "knative-1": "knative-1-value", + "knative-2": "knative-2-value", + "ce-abc": "ce-abc-value", + }, + Payload: []byte("destination"), + }, + expectedDestRequest: &requestValidation{ + Headers: map[string][]string{ + "x-request-id": {"id123"}, + "knative-1": {"knative-1-value"}, + "knative-2": {"knative-2-value"}, + "ce-abc": {"ce-abc-value"}, + }, + Body: "destination", + }, + fakeResponse: &http.Response{ + StatusCode: http.StatusNotFound, + Body: ioutil.NopCloser(bytes.NewBufferString("destination-response")), + }, + expectedErr: true, + }, + "reply - only": { + sendToReply: true, + message: &Message{ + Headers: map[string]string{ + // do-not-forward should not get forwarded. + "do-not-forward": "header", + "x-request-id": "id123", + "knative-1": "knative-1-value", + "knative-2": "knative-2-value", + "ce-abc": "ce-abc-value", + }, + Payload: []byte("replyTo"), + }, + expectedReplyRequest: &requestValidation{ + Headers: map[string][]string{ + "x-request-id": {"id123"}, + "knative-1": {"knative-1-value"}, + "knative-2": {"knative-2-value"}, + "ce-abc": {"ce-abc-value"}, + }, + Body: "replyTo", + }, + }, + "reply - only -- error": { + sendToReply: true, + message: &Message{ + Headers: map[string]string{ + // do-not-forward should not get forwarded. + "do-not-forward": "header", + "x-request-id": "id123", + "knative-1": "knative-1-value", + "knative-2": "knative-2-value", + "ce-abc": "ce-abc-value", + }, + Payload: []byte("replyTo"), + }, + expectedReplyRequest: &requestValidation{ + Headers: map[string][]string{ + "x-request-id": {"id123"}, + "knative-1": {"knative-1-value"}, + "knative-2": {"knative-2-value"}, + "ce-abc": {"ce-abc-value"}, + }, + Body: "replyTo", + }, + fakeResponse: &http.Response{ + StatusCode: http.StatusNotFound, + Body: ioutil.NopCloser(bytes.NewBufferString("destination-response")), + }, + expectedErr: true, + }, + "destination and reply - dest returns bad status code": { + sendToDestination: true, + sendToReply: true, + message: &Message{ + Headers: map[string]string{ + // do-not-forward should not get forwarded. + "do-not-forward": "header", + "x-request-id": "id123", + "knative-1": "knative-1-value", + "knative-2": "knative-2-value", + "ce-abc": "ce-abc-value", + }, + Payload: []byte("destination"), + }, + expectedDestRequest: &requestValidation{ + Headers: map[string][]string{ + "x-request-id": {"id123"}, + "knative-1": {"knative-1-value"}, + "knative-2": {"knative-2-value"}, + "ce-abc": {"ce-abc-value"}, + }, + Body: "destination", + }, + fakeResponse: &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: ioutil.NopCloser(bytes.NewBufferString("destination-response")), + }, + expectedErr: true, + }, + "destination and reply - dest returns empty body": { + sendToDestination: true, + sendToReply: true, + message: &Message{ + Headers: map[string]string{ + // do-not-forward should not get forwarded. + "do-not-forward": "header", + "x-request-id": "id123", + "knative-1": "knative-1-value", + "knative-2": "knative-2-value", + "ce-abc": "ce-abc-value", + }, + Payload: []byte("destination"), + }, + expectedDestRequest: &requestValidation{ + Headers: map[string][]string{ + "x-request-id": {"id123"}, + "knative-1": {"knative-1-value"}, + "knative-2": {"knative-2-value"}, + "ce-abc": {"ce-abc-value"}, + }, + Body: "destination", + }, + fakeResponse: &http.Response{ + StatusCode: http.StatusAccepted, + Header: map[string][]string{ + "do-not-passthrough": {"no"}, + "x-request-id": {"altered-id"}, + "knative-1": {"new-knative-1-value"}, + "ce-abc": {"new-ce-abc-value"}, + }, + Body: ioutil.NopCloser(bytes.NewBufferString("")), + }, + }, + "destination and reply": { + sendToDestination: true, + sendToReply: true, + message: &Message{ + Headers: map[string]string{ + // do-not-forward should not get forwarded. + "do-not-forward": "header", + "x-request-id": "id123", + "knative-1": "knative-1-value", + "knative-2": "knative-2-value", + "ce-abc": "ce-abc-value", + }, + Payload: []byte("destination"), + }, + expectedDestRequest: &requestValidation{ + Headers: map[string][]string{ + "x-request-id": {"id123"}, + "knative-1": {"knative-1-value"}, + "knative-2": {"knative-2-value"}, + "ce-abc": {"ce-abc-value"}, + }, + Body: "destination", + }, + fakeResponse: &http.Response{ + StatusCode: http.StatusAccepted, + Header: map[string][]string{ + "do-not-passthrough": {"no"}, + "x-request-id": {"altered-id"}, + "knative-1": {"new-knative-1-value"}, + "ce-abc": {"new-ce-abc-value"}, + }, + Body: ioutil.NopCloser(bytes.NewBufferString("destination-response")), + }, + expectedReplyRequest: &requestValidation{ + Headers: map[string][]string{ + "x-request-id": {"altered-id"}, + "knative-1": {"new-knative-1-value"}, + "ce-abc": {"new-ce-abc-value"}, + }, + Body: "destination-response", + }, + }, + } + for n, tc := range testCases { + t.Run(n, func(t *testing.T) { + destHandler := &fakeHandler{ + t: t, + response: tc.fakeResponse, + requests: make([]requestValidation, 0), + } + destServer := httptest.NewServer(destHandler) + defer destServer.Close() + replyHandler := &fakeHandler{ + t: t, + response: tc.fakeResponse, + requests: make([]requestValidation, 0), + } + replyServer := httptest.NewServer(replyHandler) + defer replyServer.Close() + + md := NewMessageDispatcher(zap.NewNop().Sugar()) + err := md.DispatchMessage(tc.message, + getDomain(t, tc.sendToDestination, destServer.URL), + getDomain(t, tc.sendToReply, replyServer.URL), + DispatchDefaults{}) + if tc.expectedErr != (err != nil) { + t.Errorf("Unexpected error from DispatchRequest. Expected %v. Actual: %v", tc.expectedErr, err) + } + if tc.expectedDestRequest != nil { + rv := destHandler.popRequest(t) + assertEquality(t, destServer.URL, *tc.expectedDestRequest, rv) + } + if tc.expectedReplyRequest != nil { + rv := replyHandler.popRequest(t) + assertEquality(t, replyServer.URL, *tc.expectedReplyRequest, rv) + } + if len(destHandler.requests) != 0 { + t.Errorf("Unexpected destination requests: %+v", destHandler.requests) + } + if len(replyHandler.requests) != 0 { + t.Errorf("Unexpected reply requests: %+v", replyHandler.requests) + } + }) + } +} + +func getDomain(t *testing.T, shouldSend bool, serverURL string) string { + if shouldSend { + server, err := url.Parse(serverURL) + if err != nil { + t.Errorf("Bad serverURL: %q", serverURL) + } + return server.Host + } + return "" +} + +type requestValidation struct { + Host string + Headers http.Header + Body string +} + +type fakeHandler struct { + t *testing.T + response *http.Response + requests []requestValidation +} + +func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + + // Make a copy of the request. + body, err := ioutil.ReadAll(r.Body) + if err != nil { + f.t.Error("Failed to read the request body") + } + f.requests = append(f.requests, requestValidation{ + Host: r.Host, + Headers: r.Header, + Body: string(body), + }) + + // Write the response. + if f.response != nil { + for h, vs := range f.response.Header { + for _, v := range vs { + w.Header().Add(h, v) + } + } + w.WriteHeader(f.response.StatusCode) + var buf bytes.Buffer + buf.ReadFrom(f.response.Body) + w.Write(buf.Bytes()) + } else { + w.WriteHeader(http.StatusOK) + w.Write([]byte("")) + } +} + +func (f *fakeHandler) popRequest(t *testing.T) requestValidation { + if len(f.requests) == 0 { + t.Error("Unable to pop request") + } + rv := f.requests[0] + f.requests = f.requests[1:] + return rv +} + +func assertEquality(t *testing.T, replacementURL string, expected, actual requestValidation) { + server, err := url.Parse(replacementURL) + if err != nil { + t.Errorf("Bad replacement URL: %q", replacementURL) + } + expected.Host = server.Host + canonicalizeHeaders(expected, actual) + if diff := cmp.Diff(expected, actual); diff != "" { + t.Errorf("Unexpected difference (-want, +got): %v", diff) + } +} + +func canonicalizeHeaders(rvs ...requestValidation) { + // HTTP header names are case-insensitive, so normalize them to lower case for comparison. + for _, rv := range rvs { + headers := rv.Headers + for n, v := range headers { + delete(headers, n) + ln := strings.ToLower(n) + if _, present := unimportantHeaders[ln]; !present { + headers[ln] = v + } + } + } +}