Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions pkg/broker/filter/filter_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"errors"
"fmt"
"net/http"
"strings"
"time"

opencensusclient "github.com/cloudevents/sdk-go/observability/opencensus/v2/client"
Expand Down Expand Up @@ -56,6 +57,12 @@ const (
defaultMaxIdleConnectionsPerHost = 100
)

// HeaderProxyAllowList contains the headers that are proxied from the reply; other than the CloudEvents headers.
// Other headers are not proxied because of security concerns.
var HeaderProxyAllowList = map[string]struct{}{
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine. As an FYI, Kubernetes has some helpers here: https://pkg.go.dev/k8s.io/apimachinery/pkg/util/sets#String

There will probably be an equivalent for go 1.19 or so with generics, I expect.

strings.ToLower("Retry-After"): {},
}

// Handler parses Cloud Events, determines if they pass a filter, and sends them to a subscriber.
type Handler struct {
// receiver receives incoming HTTP requests
Expand Down Expand Up @@ -435,8 +442,16 @@ func triggerFilterAttribute(filter *eventingv1.TriggerFilter, attributeName stri
// proxyHeaders adds the specified HTTP Headers to the ResponseWriter.
func proxyHeaders(httpHeader http.Header, writer http.ResponseWriter) {
for headerKey, headerValues := range httpHeader {
for _, headerValue := range headerValues {
writer.Header().Add(headerKey, headerValue)
// *Only* proxy some headers because of security reasons
if isInProxyHeaderAllowList(headerKey) {
for _, headerValue := range headerValues {
writer.Header().Add(headerKey, headerValue)
}
}
}
}

func isInProxyHeaderAllowList(headerKey string) bool {
_, exists := HeaderProxyAllowList[strings.ToLower(headerKey)]
return exists
}
138 changes: 72 additions & 66 deletions pkg/broker/filter/filter_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,24 @@ type TriggerOption func(trigger *eventingv1.Trigger)

func TestReceiver(t *testing.T) {
testCases := map[string]struct {
triggers []*eventingv1.Trigger
request *http.Request
event *cloudevents.Event
requestFails bool
failureStatus int
returnedEvent *cloudevents.Event
expectNewToFail bool
// input
triggers []*eventingv1.Trigger
request *http.Request
event *cloudevents.Event
requestFails bool
failureStatus int
additionalReplyHeaders http.Header

// expectations
expectedResponseEvent *cloudevents.Event
expectedResponse *http.Response
expectedDispatch bool
expectedStatus int
expectedHeaders http.Header
expectedEventCount bool
expectedEventDispatchTime bool
expectedEventProcessingTime bool
response *http.Response
responseHeaders http.Header
expectedResponseHeaders http.Header
}{
"Not POST": {
request: httptest.NewRequest(http.MethodGet, validPath, nil),
Expand Down Expand Up @@ -231,7 +234,7 @@ func TestReceiver(t *testing.T) {
expectedDispatch: true,
expectedEventCount: true,
expectedEventDispatchTime: true,
returnedEvent: makeDifferentEvent(),
expectedResponseEvent: makeDifferentEvent(),
},
"Error From Trigger": {
triggers: []*eventingv1.Trigger{
Expand Down Expand Up @@ -282,7 +285,7 @@ func TestReceiver(t *testing.T) {
expectedDispatch: true,
expectedEventCount: true,
expectedEventDispatchTime: true,
returnedEvent: makeDifferentEvent(),
expectedResponseEvent: makeDifferentEvent(),
},
"Maintain `Prefer: reply` header when it is provided in the original request": {
triggers: []*eventingv1.Trigger{
Expand All @@ -307,7 +310,7 @@ func TestReceiver(t *testing.T) {
expectedDispatch: true,
expectedEventCount: true,
expectedEventDispatchTime: true,
returnedEvent: makeDifferentEvent(),
expectedResponseEvent: makeDifferentEvent(),
},
"Add `Prefer: reply` header when it isn't provided in the original request": {
triggers: []*eventingv1.Trigger{
Expand All @@ -328,17 +331,17 @@ func TestReceiver(t *testing.T) {
expectedDispatch: true,
expectedEventCount: true,
expectedEventDispatchTime: true,
returnedEvent: makeDifferentEvent(),
expectedResponseEvent: makeDifferentEvent(),
},
"Returned non empty non event response": {
"Returned non empty non event expectedResponse": {
triggers: []*eventingv1.Trigger{
makeTrigger(withAttributesFilter(&eventingv1.TriggerFilter{})),
},
expectedDispatch: true,
expectedEventCount: true,
expectedEventDispatchTime: true,
expectedStatus: http.StatusBadGateway,
response: makeNonEmptyResponse(),
expectedResponse: makeNonEmptyResponse(),
},
"Returned malformed Cloud Event": {
triggers: []*eventingv1.Trigger{
Expand All @@ -348,7 +351,7 @@ func TestReceiver(t *testing.T) {
expectedEventCount: true,
expectedEventDispatchTime: true,
expectedStatus: http.StatusOK,
response: makeMalformedEventResponse(),
expectedResponse: makeMalformedEventResponse(),
},
"Returned malformed structured Cloud Event": {
triggers: []*eventingv1.Trigger{
Expand All @@ -358,7 +361,7 @@ func TestReceiver(t *testing.T) {
expectedEventCount: true,
expectedEventDispatchTime: true,
expectedStatus: http.StatusBadGateway,
response: makeMalformedStructuredEventResponse(),
expectedResponse: makeMalformedStructuredEventResponse(),
},
"Returned empty body 200": {
triggers: []*eventingv1.Trigger{
Expand All @@ -368,7 +371,7 @@ func TestReceiver(t *testing.T) {
expectedEventCount: true,
expectedEventDispatchTime: true,
expectedStatus: http.StatusOK,
response: makeEmptyResponse(200),
expectedResponse: makeEmptyResponse(200),
},
"Returned empty body 202": {
triggers: []*eventingv1.Trigger{
Expand All @@ -378,41 +381,43 @@ func TestReceiver(t *testing.T) {
expectedEventCount: true,
expectedEventDispatchTime: true,
expectedStatus: http.StatusAccepted,
response: makeEmptyResponse(202),
expectedResponse: makeEmptyResponse(202),
},
"Proxy CloudEvent response headers": {
"Proxy allowed empty non event response headers": {
triggers: []*eventingv1.Trigger{
makeTrigger(withAttributesFilter(&eventingv1.TriggerFilter{})),
},
expectedDispatch: true,
expectedEventCount: true,
expectedEventDispatchTime: true,
returnedEvent: makeDifferentEvent(),
responseHeaders: http.Header{"Test-Header": []string{"TestValue"}},
expectedStatus: http.StatusTooManyRequests,
expectedResponse: makeEmptyResponse(http.StatusTooManyRequests),
additionalReplyHeaders: http.Header{"Retry-After": []string{"10"}},
expectedResponseHeaders: http.Header{"Retry-After": []string{"10"}},
},
"Proxy empty non event response headers": {
"Do not proxy disallowed response headers": {
triggers: []*eventingv1.Trigger{
makeTrigger(withAttributesFilter(&eventingv1.TriggerFilter{})),
},
expectedDispatch: true,
expectedEventCount: true,
expectedEventDispatchTime: true,
expectedStatus: http.StatusTooManyRequests,
response: makeEmptyResponse(http.StatusTooManyRequests),
responseHeaders: http.Header{"Retry-After": []string{"10"}},
expectedResponseEvent: makeDifferentEvent(),
additionalReplyHeaders: http.Header{"Retry-After": []string{"10"}, "Test-Header": []string{"TestValue"}},
expectedResponseHeaders: http.Header{"Retry-After": []string{"10"}},
},
}
for n, tc := range testCases {
t.Run(n, func(t *testing.T) {

fh := fakeHandler{
failRequest: tc.requestFails,
failStatus: tc.failureStatus,
returnedEvent: tc.returnedEvent,
headers: tc.expectedHeaders,
t: t,
response: tc.response,
responseHeaders: tc.responseHeaders,
failRequest: tc.requestFails,
failStatus: tc.failureStatus,
expectedResponseEvent: tc.expectedResponseEvent,
expectedRequestHeaders: tc.expectedHeaders,
t: t,
expectedResponse: tc.expectedResponse,
additionalReplyHeaders: tc.additionalReplyHeaders,
}
s := httptest.NewServer(&fh)
defer s.Close()
Expand Down Expand Up @@ -441,12 +446,7 @@ func TestReceiver(t *testing.T) {
return ctx
},
)
if tc.expectNewToFail {
if err == nil {
t.Fatal("Expected New to fail, it didn't")
}
return
} else if err != nil {
if err != nil {
t.Fatal("Unable to create receiver:", err)
}

Expand All @@ -472,7 +472,7 @@ func TestReceiver(t *testing.T) {
response := responseWriter.Result()

if tc.expectedStatus != http.StatusInternalServerError && tc.expectedStatus != http.StatusBadGateway {
for expectedHeaderKey, expectedHeaderValues := range tc.responseHeaders {
for expectedHeaderKey, expectedHeaderValues := range tc.expectedResponseHeaders {
if response.Header[expectedHeaderKey] == nil || response.Header[expectedHeaderKey][0] != expectedHeaderValues[0] {
t.Errorf("Response header proxy failed for header '%v'. Expected %v, Actual %v", expectedHeaderKey, expectedHeaderValues[0], response.Header[expectedHeaderKey])
}
Expand All @@ -494,15 +494,15 @@ func TestReceiver(t *testing.T) {
if tc.expectedEventProcessingTime != reporter.eventProcessingTimeReported {
t.Errorf("Incorrect event processing time reported metric. Expected %v, Actual %v", tc.expectedEventProcessingTime, reporter.eventProcessingTimeReported)
}
if tc.returnedEvent != nil {
if tc.returnedEvent.SpecVersion() != event.CloudEventsVersionV1 {
t.Errorf("Incorrect spec version. Expected %v, Actual %v", tc.returnedEvent.SpecVersion(), event.CloudEventsVersionV1)
if tc.expectedResponseEvent != nil {
if tc.expectedResponseEvent.SpecVersion() != event.CloudEventsVersionV1 {
t.Errorf("Incorrect spec version. Expected %v, Actual %v", tc.expectedResponseEvent.SpecVersion(), event.CloudEventsVersionV1)
}
}
// Compare the returned event.
message := cehttp.NewMessageFromHttpResponse(response)
event, err := binding.ToEvent(context.Background(), message)
if tc.returnedEvent == nil {
if tc.expectedResponseEvent == nil {
if err == nil || event != nil {
t.Fatal("Unexpected response event:", event)
}
Expand All @@ -513,7 +513,7 @@ func TestReceiver(t *testing.T) {
}

// The TTL will be added again.
expectedResponseEvent := addTTLToEvent(*tc.returnedEvent)
expectedResponseEvent := addTTLToEvent(*tc.expectedResponseEvent)

// cloudevents/sdk-go doesn't preserve the extension type, so get TTL and set it back again.
// https://github.com/cloudevents/sdk-go/blob/97abfeb3da0bed09e395bff2c5bcf35b6435cb5f/v2/types/value.go#L57
Expand Down Expand Up @@ -704,23 +704,29 @@ func (r *mockReporter) ReportEventProcessingTime(args *ReportArgs, d time.Durati
}

type fakeHandler struct {
failRequest bool
failStatus int
t *testing.T

// input
failRequest bool
failStatus int
additionalReplyHeaders http.Header

// expectations
expectedRequestHeaders http.Header
expectedResponseEvent *cloudevents.Event
expectedResponse *http.Response

// results
requestReceived bool
headers http.Header
returnedEvent *cloudevents.Event
t *testing.T
response *http.Response
responseHeaders http.Header
}

func (h *fakeHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
if h.returnedEvent != nil && h.response != nil {
h.t.Errorf("Can not specify both returnedEvent and response.")
if h.expectedResponseEvent != nil && h.expectedResponse != nil {
h.t.Errorf("Can not specify both expectedResponseEvent and expectedResponse.")
}
h.requestReceived = true

for n, v := range h.headers {
for n, v := range h.expectedRequestHeaders {
if strings.Contains(strings.ToLower(n), strings.ToLower(broker.TTLAttribute)) {
h.t.Errorf("Broker TTL should not be seen by the subscriber: %s", n)
}
Expand All @@ -737,33 +743,33 @@ func (h *fakeHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
}
return
}
if h.returnedEvent == nil && h.response == nil {
if h.expectedResponseEvent == nil && h.expectedResponse == nil {
resp.WriteHeader(http.StatusAccepted)
return
}

if h.returnedEvent != nil {
message := binding.ToMessage(h.returnedEvent)
if h.expectedResponseEvent != nil {
message := binding.ToMessage(h.expectedResponseEvent)
defer message.Finish(nil)
for k, v := range h.responseHeaders {
for k, v := range h.additionalReplyHeaders {
resp.Header().Set(k, v[0])
}
err := cehttp.WriteResponseWriter(context.Background(), message, http.StatusAccepted, resp)
if err != nil {
h.t.Fatalf("Unable to write body: %v", err)
}
}
if h.response != nil {
for k, v := range h.response.Header {
if h.expectedResponse != nil {
for k, v := range h.expectedResponse.Header {
resp.Header().Set(k, v[0])
}
for k, v := range h.responseHeaders {
for k, v := range h.additionalReplyHeaders {
resp.Header().Add(k, v[0])
}
resp.WriteHeader(h.response.StatusCode)
if h.response.Body != nil {
defer h.response.Body.Close()
body, err := ioutil.ReadAll(h.response.Body)
resp.WriteHeader(h.expectedResponse.StatusCode)
if h.expectedResponse.Body != nil {
defer h.expectedResponse.Body.Close()
body, err := ioutil.ReadAll(h.expectedResponse.Body)
if err != nil {
h.t.Fatal("Unable to read body: ", err)
}
Expand Down