From 575e41212290c9d278ad7777785db29abe3b37cc Mon Sep 17 00:00:00 2001 From: Jordan Liggitt Date: Mon, 6 Apr 2015 13:48:39 -0400 Subject: [PATCH] Validate CSRF in external oauth flow --- pkg/auth/oauth/external/handler.go | 76 ++++++++++++++++++++++-------- pkg/cmd/server/origin/auth.go | 4 +- 2 files changed, 59 insertions(+), 21 deletions(-) diff --git a/pkg/auth/oauth/external/handler.go b/pkg/auth/oauth/external/handler.go index 3e8f47981d03..3a8e72ec766d 100644 --- a/pkg/auth/oauth/external/handler.go +++ b/pkg/auth/oauth/external/handler.go @@ -1,6 +1,7 @@ package external import ( + "encoding/base64" "errors" "fmt" "net/http" @@ -12,6 +13,7 @@ import ( authapi "github.com/openshift/origin/pkg/auth/api" "github.com/openshift/origin/pkg/auth/oauth/handlers" + "github.com/openshift/origin/pkg/auth/server/csrf" ) // Handler exposes an external oauth provider flow (including the call back) as an oauth.handlers.AuthenticationHandler to allow our internal oauth @@ -78,7 +80,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { authData, err := authReq.HandleRequest(req) if err != nil { glog.V(4).Infof("Error handling request: %v", err) - h.errorHandler.AuthenticationError(err, w, req) + h.handleError(err, w, req) return } @@ -89,7 +91,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { accessData, err := accessReq.GetToken() if err != nil { glog.V(4).Infof("Error getting access token:", err) - h.errorHandler.AuthenticationError(err, w, req) + h.handleError(err, w, req) return } @@ -98,12 +100,13 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { identity, ok, err := h.provider.GetUserIdentity(accessData) if err != nil { glog.V(4).Infof("Error getting userIdentityInfo info: %v", err) - h.errorHandler.AuthenticationError(err, w, req) + h.handleError(err, w, req) return } if !ok { glog.V(4).Infof("Could not get userIdentityInfo info from access token") - h.errorHandler.AuthenticationError(errors.New("Could not get userIdentityInfo info from access token"), w, req) + err := errors.New("Could not get userIdentityInfo info from access token") + h.handleError(err, w, req) return } @@ -111,53 +114,75 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { glog.V(4).Infof("Got userIdentityMapping: %#v", user) if err != nil { glog.V(4).Infof("Error creating or updating mapping for: %#v due to %v", identity, err) - h.errorHandler.AuthenticationError(err, w, req) + h.handleError(err, w, req) return } ok, err = h.state.Check(authData.State, w, req) if !ok { glog.V(4).Infof("State is invalid") - h.errorHandler.AuthenticationError(errors.New("State is invalid"), w, req) + err := errors.New("State is invalid") + h.handleError(err, w, req) return } if err != nil { glog.V(4).Infof("Error verifying state: %v", err) - h.errorHandler.AuthenticationError(err, w, req) + h.handleError(err, w, req) return } _, err = h.success.AuthenticationSucceeded(user, authData.State, w, req) if err != nil { glog.V(4).Infof("Error calling success handler: %v", err) - h.errorHandler.AuthenticationError(err, w, req) + h.handleError(err, w, req) return } } +func (h *Handler) handleError(err error, w http.ResponseWriter, req *http.Request) { + handled, err := h.errorHandler.AuthenticationError(err, w, req) + if handled { + return + } + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`An error occurred`)) +} + // defaultState provides default state-building, validation, and parsing to contain CSRF and "then" redirection -type defaultState struct{} +type defaultState struct { + csrf csrf.CSRF +} -func DefaultState() State { - return defaultState{} +func DefaultState(csrf csrf.CSRF) State { + return &defaultState{csrf} } -func (defaultState) Generate(w http.ResponseWriter, req *http.Request) (string, error) { +func (d *defaultState) Generate(w http.ResponseWriter, req *http.Request) (string, error) { + csrfToken, err := d.csrf.Generate(w, req) + if err != nil { + return "", err + } + state := url.Values{ - "csrf": {"..."}, // TODO: get csrf + "csrf": {csrfToken}, "then": {req.URL.String()}, } - return state.Encode(), nil + return encodeState(state) } -func (defaultState) Check(state string, w http.ResponseWriter, req *http.Request) (bool, error) { - values, err := url.ParseQuery(state) +func (d *defaultState) Check(state string, w http.ResponseWriter, req *http.Request) (bool, error) { + values, err := decodeState(state) if err != nil { return false, err } csrf := values.Get("csrf") - if csrf != "..." { - return false, fmt.Errorf("State did not contain valid CSRF token (expected %s, got %s)", "...", csrf) + + ok, err := d.csrf.Check(req, csrf) + if err != nil { + return false, err + } + if !ok { + return false, fmt.Errorf("State did not contain a valid CSRF token") } then := values.Get("then") @@ -169,7 +194,7 @@ func (defaultState) Check(state string, w http.ResponseWriter, req *http.Request } func (defaultState) AuthenticationSucceeded(user user.Info, state string, w http.ResponseWriter, req *http.Request) (bool, error) { - values, err := url.ParseQuery(state) + values, err := decodeState(state) if err != nil { return false, err } @@ -182,3 +207,16 @@ func (defaultState) AuthenticationSucceeded(user user.Info, state string, w http http.Redirect(w, req, then, http.StatusFound) return true, nil } + +// URL-encode, then base-64 encode for OAuth providers that don't do a good job of treating the state param like an opaque value +func encodeState(values url.Values) (string, error) { + return base64.URLEncoding.EncodeToString([]byte(values.Encode())), nil +} + +func decodeState(state string) (url.Values, error) { + decodedState, err := base64.URLEncoding.DecodeString(state) + if err != nil { + return nil, err + } + return url.ParseQuery(string(decodedState)) +} diff --git a/pkg/cmd/server/origin/auth.go b/pkg/cmd/server/origin/auth.go index 7a93626db68d..80f1ab9ea87d 100644 --- a/pkg/cmd/server/origin/auth.go +++ b/pkg/cmd/server/origin/auth.go @@ -332,7 +332,7 @@ func (c *AuthConfig) getAuthenticationHandler(mux cmdutil.Mux, errorHandler hand return nil, fmt.Errorf("unexpected oauth provider %#v", provider) } - state := external.DefaultState() + state := external.DefaultState(getCSRF()) oauthHandler, err := external.NewExternalOAuthRedirector(oauthProvider, state, c.Options.MasterPublicURL+callbackPath, successHandler, errorHandler, identityMapper) if err != nil { return nil, fmt.Errorf("unexpected error: %v", err) @@ -402,7 +402,7 @@ func (c *AuthConfig) getAuthenticationSuccessHandler() handlers.AuthenticationSu switch identityProvider.Provider.Object.(type) { case (*configapi.OAuthRedirectingIdentityProvider): - successHandlers = append(successHandlers, external.DefaultState().(handlers.AuthenticationSuccessHandler)) + successHandlers = append(successHandlers, external.DefaultState(getCSRF()).(handlers.AuthenticationSuccessHandler)) } if !addedRedirectSuccessHandler && configapi.IsPasswordAuthenticator(identityProvider) {