Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pkg/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ type Proxy struct {
config *Config
}

func filterHeaders(r *http.Response) error {
func filterHeaders(r *http.Response, headers []string) error {
badHeaders := []string{"Connection", "Keep-Alive", "Proxy-Connection", "Transfer-Encoding", "Upgrade"}
badHeaders = append(badHeaders, headers...)
for _, h := range badHeaders {
r.Header.Del(h)
}
Expand All @@ -54,7 +55,9 @@ func NewProxy(cfg *Config) *Proxy {
reverseProxy := httputil.NewSingleHostReverseProxy(cfg.Endpoint)
reverseProxy.FlushInterval = time.Millisecond * 100
reverseProxy.Transport = transport
reverseProxy.ModifyResponse = filterHeaders
reverseProxy.ModifyResponse = func(r *http.Response) error {
Copy link
Member

@spadgett spadgett Jul 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HeaderBlacklist is intended to remove request headers, not response headers.

return filterHeaders(r, cfg.HeaderBlacklist)
}

proxy := &Proxy{
reverseProxy: reverseProxy,
Expand Down
23 changes: 17 additions & 6 deletions pkg/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import (
)

func TestProxyWebsocket(t *testing.T) {
proxyURL, closer, err := startProxyServer(t)
config := &Config{}
proxyURL, closer, err := startProxyServer(t, config)
if err != nil {
t.Fatalf("problem setting up proxy server: %v", err)
}
Expand Down Expand Up @@ -43,7 +44,10 @@ func TestProxyWebsocket(t *testing.T) {
}

func TestProxyHTTP(t *testing.T) {
proxyURL, closer, err := startProxyServer(t)
config := &Config{
HeaderBlacklist: []string{"Unwanted-Header"},
}
proxyURL, closer, err := startProxyServer(t, config)
if err != nil {
t.Fatalf("problem setting up proxy server: %v", err)
}
Expand All @@ -60,6 +64,13 @@ func TestProxyHTTP(t *testing.T) {
if string(body) != "static" {
t.Errorf("string(body) == %q, want %q", string(body), "static")
}
for k := range res.Header {
for _, h := range config.HeaderBlacklist {
if k == h {
t.Errorf("Blacklisted header %s should have been deleted", k)
}
}
}

}

Expand Down Expand Up @@ -99,7 +110,7 @@ func TestProxyDecodeSubprotocol(t *testing.T) {
// endppint which receives strings and responds with lowercased versions of
// those strings.
// The proxy server proxies requests to the underlying server on the endpoint "/proxy".
func startProxyServer(t *testing.T) (string, func(), error) {
func startProxyServer(t *testing.T, config *Config) (string, func(), error) {
// Setup the server we want to proxy.
mux := http.NewServeMux()
mux.HandleFunc("/lower", lowercaseServer(t))
Expand All @@ -112,9 +123,8 @@ func startProxyServer(t *testing.T) (string, func(), error) {
return "", nil, err
}
targetURL.Path = ""
p := NewProxy(&Config{
Endpoint: targetURL,
})
config.Endpoint = targetURL
p := NewProxy(config)
proxyMux := http.NewServeMux()
proxyMux.Handle("/proxy/", http.StripPrefix("/proxy/", p))
proxyServer := httptest.NewServer(proxyMux)
Expand Down Expand Up @@ -154,6 +164,7 @@ func lowercaseServer(t *testing.T) func(w http.ResponseWriter, r *http.Request)
}

func staticServer(res http.ResponseWriter, req *http.Request) {
res.Header().Set("Unwanted-Header", "test-value")
res.Write([]byte("static"))
}

Expand Down