Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 57 additions & 19 deletions pkg/auth/oauth/external/handler.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package external

import (
"encoding/base64"
"errors"
"fmt"
"net/http"
Expand All @@ -12,6 +13,7 @@ import (

authapi "github.com/openshift/origin/pkg/auth/api"
"github.com/openshift/origin/pkg/auth/oauth/handlers"
"github.com/openshift/origin/pkg/auth/server/csrf"
)

// Handler exposes an external oauth provider flow (including the call back) as an oauth.handlers.AuthenticationHandler to allow our internal oauth
Expand Down Expand Up @@ -78,7 +80,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
authData, err := authReq.HandleRequest(req)
if err != nil {
glog.V(4).Infof("Error handling request: %v", err)
h.errorHandler.AuthenticationError(err, w, req)
h.handleError(err, w, req)
return
}

Expand All @@ -89,7 +91,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
accessData, err := accessReq.GetToken()
if err != nil {
glog.V(4).Infof("Error getting access token:", err)
h.errorHandler.AuthenticationError(err, w, req)
h.handleError(err, w, req)
return
}

Expand All @@ -98,66 +100,89 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
identity, ok, err := h.provider.GetUserIdentity(accessData)
if err != nil {
glog.V(4).Infof("Error getting userIdentityInfo info: %v", err)
h.errorHandler.AuthenticationError(err, w, req)
h.handleError(err, w, req)
return
}
if !ok {
glog.V(4).Infof("Could not get userIdentityInfo info from access token")
h.errorHandler.AuthenticationError(errors.New("Could not get userIdentityInfo info from access token"), w, req)
err := errors.New("Could not get userIdentityInfo info from access token")
h.handleError(err, w, req)
return
}

user, err := h.mapper.UserFor(identity)
glog.V(4).Infof("Got userIdentityMapping: %#v", user)
if err != nil {
glog.V(4).Infof("Error creating or updating mapping for: %#v due to %v", identity, err)
h.errorHandler.AuthenticationError(err, w, req)
h.handleError(err, w, req)
return
}

ok, err = h.state.Check(authData.State, w, req)
if !ok {
glog.V(4).Infof("State is invalid")
h.errorHandler.AuthenticationError(errors.New("State is invalid"), w, req)
err := errors.New("State is invalid")
h.handleError(err, w, req)
return
}
if err != nil {
glog.V(4).Infof("Error verifying state: %v", err)
h.errorHandler.AuthenticationError(err, w, req)
h.handleError(err, w, req)
return
}

_, err = h.success.AuthenticationSucceeded(user, authData.State, w, req)
if err != nil {
glog.V(4).Infof("Error calling success handler: %v", err)
h.errorHandler.AuthenticationError(err, w, req)
h.handleError(err, w, req)
return
}
}

func (h *Handler) handleError(err error, w http.ResponseWriter, req *http.Request) {
handled, err := h.errorHandler.AuthenticationError(err, w, req)
if handled {
return
}
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`An error occurred`))
}

// defaultState provides default state-building, validation, and parsing to contain CSRF and "then" redirection
type defaultState struct{}
type defaultState struct {
csrf csrf.CSRF
}

func DefaultState() State {
return defaultState{}
func DefaultState(csrf csrf.CSRF) State {
return &defaultState{csrf}
}

func (defaultState) Generate(w http.ResponseWriter, req *http.Request) (string, error) {
func (d *defaultState) Generate(w http.ResponseWriter, req *http.Request) (string, error) {
csrfToken, err := d.csrf.Generate(w, req)
if err != nil {
return "", err
}

state := url.Values{
"csrf": {"..."}, // TODO: get csrf
"csrf": {csrfToken},
"then": {req.URL.String()},
}
return state.Encode(), nil
return encodeState(state)
}

func (defaultState) Check(state string, w http.ResponseWriter, req *http.Request) (bool, error) {
values, err := url.ParseQuery(state)
func (d *defaultState) Check(state string, w http.ResponseWriter, req *http.Request) (bool, error) {
values, err := decodeState(state)
if err != nil {
return false, err
}
csrf := values.Get("csrf")
if csrf != "..." {
return false, fmt.Errorf("State did not contain valid CSRF token (expected %s, got %s)", "...", csrf)

ok, err := d.csrf.Check(req, csrf)
if err != nil {
return false, err
}
if !ok {
return false, fmt.Errorf("State did not contain a valid CSRF token")
}

then := values.Get("then")
Expand All @@ -169,7 +194,7 @@ func (defaultState) Check(state string, w http.ResponseWriter, req *http.Request
}

func (defaultState) AuthenticationSucceeded(user user.Info, state string, w http.ResponseWriter, req *http.Request) (bool, error) {
values, err := url.ParseQuery(state)
values, err := decodeState(state)
if err != nil {
return false, err
}
Expand All @@ -182,3 +207,16 @@ func (defaultState) AuthenticationSucceeded(user user.Info, state string, w http
http.Redirect(w, req, then, http.StatusFound)
return true, nil
}

// URL-encode, then base-64 encode for OAuth providers that don't do a good job of treating the state param like an opaque value
func encodeState(values url.Values) (string, error) {
return base64.URLEncoding.EncodeToString([]byte(values.Encode())), nil
}

func decodeState(state string) (url.Values, error) {
decodedState, err := base64.URLEncoding.DecodeString(state)
if err != nil {
return nil, err
}
return url.ParseQuery(string(decodedState))
}
4 changes: 2 additions & 2 deletions pkg/cmd/server/origin/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ func (c *AuthConfig) getAuthenticationHandler(mux cmdutil.Mux, errorHandler hand
return nil, fmt.Errorf("unexpected oauth provider %#v", provider)
}

state := external.DefaultState()
state := external.DefaultState(getCSRF())
oauthHandler, err := external.NewExternalOAuthRedirector(oauthProvider, state, c.Options.MasterPublicURL+callbackPath, successHandler, errorHandler, identityMapper)
if err != nil {
return nil, fmt.Errorf("unexpected error: %v", err)
Expand Down Expand Up @@ -402,7 +402,7 @@ func (c *AuthConfig) getAuthenticationSuccessHandler() handlers.AuthenticationSu

switch identityProvider.Provider.Object.(type) {
case (*configapi.OAuthRedirectingIdentityProvider):
successHandlers = append(successHandlers, external.DefaultState().(handlers.AuthenticationSuccessHandler))
successHandlers = append(successHandlers, external.DefaultState(getCSRF()).(handlers.AuthenticationSuccessHandler))
}

if !addedRedirectSuccessHandler && configapi.IsPasswordAuthenticator(identityProvider) {
Expand Down