Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/imageproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ var allowHosts = flag.String("allowHosts", "", "comma separated list of allowed
var denyHosts = flag.String("denyHosts", "", "comma separated list of denied remote hosts")
var referrers = flag.String("referrers", "", "comma separated list of allowed referring hosts")
var includeReferer = flag.Bool("includeReferer", false, "include referer header in remote requests")
var followRedirects = flag.Bool("followRedirects", true, "follow redirects")
var baseURL = flag.String("baseURL", "", "default base URL for relative remote URLs")
var cache tieredCache
var signatureKeys signatureKeyList
Expand Down Expand Up @@ -90,6 +91,7 @@ func main() {
}

p.IncludeReferer = *includeReferer
p.FollowRedirects = *followRedirects
p.Timeout = *timeout
p.ScaleUp = *scaleUp
p.Verbose = *verbose
Expand Down
21 changes: 20 additions & 1 deletion imageproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ type Proxy struct {
// is included in remote requests.
IncludeReferer bool

// FollowRedirects controls whether imageproxy will follow redirects or not.
FollowRedirects bool

// DefaultBaseURL is the URL that relative remote URLs are resolved in
// reference to. If nil, all remote URLs specified in requests must be
// absolute.
Expand Down Expand Up @@ -186,6 +189,21 @@ func (p *Proxy) serveImage(w http.ResponseWriter, r *http.Request) {
// pass along the referer header from the original request
copyHeader(actualReq.Header, r.Header, "referer")
}
if p.FollowRedirects {
// FollowRedirects is true (default), ensure that the redirected host is allowed
p.Client.CheckRedirect = func(newreq *http.Request, via []*http.Request) error {
if hostMatches(p.DenyHosts, newreq.URL) || (len(p.AllowHosts) > 0 && !hostMatches(p.AllowHosts, newreq.URL)) {
http.Error(w, msgNotAllowedInRedirect, http.StatusForbidden)
return errNotAllowed
}
return nil
}
} else {
// FollowRedirects is false, don't follow redirects
p.Client.CheckRedirect = func(newreq *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
}
resp, err := p.Client.Do(actualReq)

if err != nil {
Expand Down Expand Up @@ -269,7 +287,8 @@ var (
errDeniedHost = errors.New("request contains a denied host")
errNotAllowed = errors.New("request does not contain an allowed host or valid signature")

msgNotAllowed = "requested URL is not allowed"
msgNotAllowed = "requested URL is not allowed"
msgNotAllowedInRedirect = "requested URL in redirect is not allowed"
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm kind of meh on the wording here, happy to change it to something else.

)

// allowed determines whether the specified request contains an allowed
Expand Down