diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 8d7f93e39f0..938c15d1f8f 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -31,8 +31,9 @@ type Proxy struct { config *Config } -func filterHeaders(r *http.Response) error { +func filterHeaders(r *http.Response, headers []string) error { badHeaders := []string{"Connection", "Keep-Alive", "Proxy-Connection", "Transfer-Encoding", "Upgrade"} + badHeaders = append(badHeaders, headers...) for _, h := range badHeaders { r.Header.Del(h) } @@ -54,7 +55,9 @@ func NewProxy(cfg *Config) *Proxy { reverseProxy := httputil.NewSingleHostReverseProxy(cfg.Endpoint) reverseProxy.FlushInterval = time.Millisecond * 100 reverseProxy.Transport = transport - reverseProxy.ModifyResponse = filterHeaders + reverseProxy.ModifyResponse = func(r *http.Response) error { + return filterHeaders(r, cfg.HeaderBlacklist) + } proxy := &Proxy{ reverseProxy: reverseProxy, diff --git a/pkg/proxy/proxy_test.go b/pkg/proxy/proxy_test.go index f6bcbe89e72..2ee6dcd16e2 100644 --- a/pkg/proxy/proxy_test.go +++ b/pkg/proxy/proxy_test.go @@ -12,7 +12,8 @@ import ( ) func TestProxyWebsocket(t *testing.T) { - proxyURL, closer, err := startProxyServer(t) + config := &Config{} + proxyURL, closer, err := startProxyServer(t, config) if err != nil { t.Fatalf("problem setting up proxy server: %v", err) } @@ -43,7 +44,10 @@ func TestProxyWebsocket(t *testing.T) { } func TestProxyHTTP(t *testing.T) { - proxyURL, closer, err := startProxyServer(t) + config := &Config{ + HeaderBlacklist: []string{"Unwanted-Header"}, + } + proxyURL, closer, err := startProxyServer(t, config) if err != nil { t.Fatalf("problem setting up proxy server: %v", err) } @@ -60,6 +64,13 @@ func TestProxyHTTP(t *testing.T) { if string(body) != "static" { t.Errorf("string(body) == %q, want %q", string(body), "static") } + for k := range res.Header { + for _, h := range config.HeaderBlacklist { + if k == h { + t.Errorf("Blacklisted header %s should have been deleted", k) + } + } + } } @@ -99,7 +110,7 @@ func TestProxyDecodeSubprotocol(t *testing.T) { // endppint which receives strings and responds with lowercased versions of // those strings. // The proxy server proxies requests to the underlying server on the endpoint "/proxy". -func startProxyServer(t *testing.T) (string, func(), error) { +func startProxyServer(t *testing.T, config *Config) (string, func(), error) { // Setup the server we want to proxy. mux := http.NewServeMux() mux.HandleFunc("/lower", lowercaseServer(t)) @@ -112,9 +123,8 @@ func startProxyServer(t *testing.T) (string, func(), error) { return "", nil, err } targetURL.Path = "" - p := NewProxy(&Config{ - Endpoint: targetURL, - }) + config.Endpoint = targetURL + p := NewProxy(config) proxyMux := http.NewServeMux() proxyMux.Handle("/proxy/", http.StripPrefix("/proxy/", p)) proxyServer := httptest.NewServer(proxyMux) @@ -154,6 +164,7 @@ func lowercaseServer(t *testing.T) func(w http.ResponseWriter, r *http.Request) } func staticServer(res http.ResponseWriter, req *http.Request) { + res.Header().Set("Unwanted-Header", "test-value") res.Write([]byte("static")) }