diff --git a/auth/auth.go b/auth/auth.go index 1eb18dac9a8..e7fea6c8d94 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -44,7 +44,7 @@ var log = capnslog.NewPackageLogger("github.com/openshift/console", "auth") type Authenticator struct { tokenVerifier func(string) (*loginState, error) - oauth2Client *oauth2.Config + authFunc func() (*oauth2.Config, loginMethod) clientFunc func() *http.Client @@ -53,8 +53,6 @@ type Authenticator struct { cookiePath string refererURL *url.URL secureCookies bool - - loginMethod loginMethod } // loginMethod is used to handle OAuth2 responses and associate bearer tokens @@ -146,44 +144,43 @@ func NewAuthenticator(ctx context.Context, c *Config) (*Authenticator, error) { steps := 0 for { - var ( - a *Authenticator - lm loginMethod - endpoint oauth2.Endpoint - err error - ) - - a, err = newUnstartedAuthenticator(c) + a, err := newUnstartedAuthenticator(c) if err != nil { return nil, err } + var authSourceFunc func() (oauth2.Endpoint, loginMethod, error) switch c.AuthSource { case AuthSourceOpenShift: - // Use the k8s CA for OAuth metadata discovery. - var k8sClient *http.Client - // Don't include system roots when talking to the API server. - k8sClient, err = newHTTPClient(c.K8sCA, false) - if err != nil { - return nil, err + authSourceFunc = func() (oauth2.Endpoint, loginMethod, error) { + // Use the k8s CA for OAuth metadata discovery. + // Don't include system roots when talking to the API server. + k8sClient, errK8Client := newHTTPClient(c.K8sCA, false) + if errK8Client != nil { + return oauth2.Endpoint{}, nil, errK8Client + } + + return newOpenShiftAuth(ctx, &openShiftConfig{ + k8sClient: k8sClient, + oauthClient: a.clientFunc(), + issuerURL: c.IssuerURL, + cookiePath: c.CookiePath, + secureCookies: c.SecureCookies, + }) } - - endpoint, lm, err = newOpenShiftAuth(ctx, &openShiftConfig{ - k8sClient: k8sClient, - oauthClient: a.clientFunc(), - issuerURL: c.IssuerURL, - cookiePath: c.CookiePath, - secureCookies: c.SecureCookies, - }) default: - endpoint, lm, err = newOIDCAuth(ctx, &oidcConfig{ - client: a.clientFunc(), - issuerURL: c.IssuerURL, - clientID: c.ClientID, - cookiePath: c.CookiePath, - secureCookies: c.SecureCookies, - }) + authSourceFunc = func() (oauth2.Endpoint, loginMethod, error) { + return newOIDCAuth(ctx, &oidcConfig{ + client: a.clientFunc(), + issuerURL: c.IssuerURL, + clientID: c.ClientID, + cookiePath: c.CookiePath, + secureCookies: c.SecureCookies, + }) + } } + + fallbackEndpoint, fallbackLoginMethod, err := authSourceFunc() if err != nil { steps++ if steps > maxSteps { @@ -198,8 +195,26 @@ func NewAuthenticator(ctx context.Context, c *Config) (*Authenticator, error) { continue } - a.loginMethod = lm - a.oauth2Client.Endpoint = endpoint + a.authFunc = func() (*oauth2.Config, loginMethod) { + // rebuild non-pointer struct each time to prevent any mutation + baseOAuth2Config := oauth2.Config{ + ClientID: c.ClientID, + ClientSecret: c.ClientSecret, + RedirectURL: c.RedirectURL, + Scopes: c.Scope, + Endpoint: fallbackEndpoint, + } + + currentEndpoint, currentLoginMethod, errAuthSource := authSourceFunc() + if errAuthSource != nil { + log.Errorf("failed to get latest auth source data: %v", errAuthSource) + return &baseOAuth2Config, fallbackLoginMethod + } + + baseOAuth2Config.Endpoint = currentEndpoint + return &baseOAuth2Config, currentLoginMethod + } + return a, nil } } @@ -220,13 +235,6 @@ func newUnstartedAuthenticator(c *Config) (*Authenticator, error) { return currentClient } - oauth2Client := &oauth2.Config{ - ClientID: c.ClientID, - ClientSecret: c.ClientSecret, - RedirectURL: c.RedirectURL, - Scopes: c.Scope, - } - errURL := "/" if c.ErrorURL != "" { errURL = c.ErrorURL @@ -247,7 +255,6 @@ func newUnstartedAuthenticator(c *Config) (*Authenticator, error) { } return &Authenticator{ - oauth2Client: oauth2Client, clientFunc: clientFunc, errorURL: errURL, successURL: sucURL, @@ -265,7 +272,7 @@ type User struct { } func (a *Authenticator) Authenticate(r *http.Request) (*User, error) { - return a.loginMethod.authenticate(r) + return a.getLoginMethod().authenticate(r) } // LoginFunc redirects to the OIDC provider for user login. @@ -283,17 +290,17 @@ func (a *Authenticator) LoginFunc(w http.ResponseWriter, r *http.Request) { Secure: a.secureCookies, } http.SetCookie(w, &cookie) - http.Redirect(w, r, a.oauth2Client.AuthCodeURL(state), http.StatusSeeOther) + http.Redirect(w, r, a.getOAuth2Config().AuthCodeURL(state), http.StatusSeeOther) } // LogoutFunc cleans up session cookies. func (a *Authenticator) LogoutFunc(w http.ResponseWriter, r *http.Request) { - a.loginMethod.logout(w, r) + a.getLoginMethod().logout(w, r) } // GetKubeAdminLogoutURL returns the logout URL for the special kube:admin user in OpenShift func (a *Authenticator) GetKubeAdminLogoutURL() string { - return a.loginMethod.getKubeAdminLogoutURL() + return a.getLoginMethod().getKubeAdminLogoutURL() } // CallbackFunc handles OAuth2 callbacks and code/token exchange. @@ -330,14 +337,15 @@ func (a *Authenticator) CallbackFunc(fn func(loginInfo LoginJSON, successURL str return } ctx := oidc.ClientContext(context.TODO(), a.clientFunc()) - token, err := a.oauth2Client.Exchange(ctx, code) + oauthConfig, lm := a.authFunc() + token, err := oauthConfig.Exchange(ctx, code) if err != nil { log.Infof("unable to verify auth code with issuer: %v", err) a.redirectAuthError(w, errorInvalidCode, err) return } - ls, err := a.loginMethod.login(w, token) + ls, err := lm.login(w, token) if err != nil { log.Errorf("error constructing login state: %v", err) a.redirectAuthError(w, errorInternal, nil) @@ -349,6 +357,16 @@ func (a *Authenticator) CallbackFunc(fn func(loginInfo LoginJSON, successURL str } } +func (a *Authenticator) getOAuth2Config() *oauth2.Config { + oauthConfig, _ := a.authFunc() + return oauthConfig +} + +func (a *Authenticator) getLoginMethod() loginMethod { + _, lm := a.authFunc() + return lm +} + func (a *Authenticator) redirectAuthError(w http.ResponseWriter, authErr string, err error) { var u url.URL up, err := url.Parse(a.errorURL)