Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 66 additions & 48 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ var log = capnslog.NewPackageLogger("github.com/openshift/console", "auth")
type Authenticator struct {
tokenVerifier func(string) (*loginState, error)

oauth2Client *oauth2.Config
authFunc func() (*oauth2.Config, loginMethod)

clientFunc func() *http.Client

Expand All @@ -53,8 +53,6 @@ type Authenticator struct {
cookiePath string
refererURL *url.URL
secureCookies bool

loginMethod loginMethod
}

// loginMethod is used to handle OAuth2 responses and associate bearer tokens
Expand Down Expand Up @@ -146,44 +144,43 @@ func NewAuthenticator(ctx context.Context, c *Config) (*Authenticator, error) {
steps := 0

for {
var (
a *Authenticator
lm loginMethod
endpoint oauth2.Endpoint
err error
)

a, err = newUnstartedAuthenticator(c)
a, err := newUnstartedAuthenticator(c)
if err != nil {
return nil, err
}

var authSourceFunc func() (oauth2.Endpoint, loginMethod, error)
switch c.AuthSource {
case AuthSourceOpenShift:
// Use the k8s CA for OAuth metadata discovery.
var k8sClient *http.Client
// Don't include system roots when talking to the API server.
k8sClient, err = newHTTPClient(c.K8sCA, false)
if err != nil {
return nil, err
authSourceFunc = func() (oauth2.Endpoint, loginMethod, error) {
// Use the k8s CA for OAuth metadata discovery.
// Don't include system roots when talking to the API server.
k8sClient, errK8Client := newHTTPClient(c.K8sCA, false)
if errK8Client != nil {
return oauth2.Endpoint{}, nil, errK8Client
}

return newOpenShiftAuth(ctx, &openShiftConfig{
k8sClient: k8sClient,
oauthClient: a.clientFunc(),
issuerURL: c.IssuerURL,
cookiePath: c.CookiePath,
secureCookies: c.SecureCookies,
})
}

endpoint, lm, err = newOpenShiftAuth(ctx, &openShiftConfig{
k8sClient: k8sClient,
oauthClient: a.clientFunc(),
issuerURL: c.IssuerURL,
cookiePath: c.CookiePath,
secureCookies: c.SecureCookies,
})
default:
endpoint, lm, err = newOIDCAuth(ctx, &oidcConfig{
client: a.clientFunc(),
issuerURL: c.IssuerURL,
clientID: c.ClientID,
cookiePath: c.CookiePath,
secureCookies: c.SecureCookies,
})
authSourceFunc = func() (oauth2.Endpoint, loginMethod, error) {
return newOIDCAuth(ctx, &oidcConfig{
client: a.clientFunc(),
issuerURL: c.IssuerURL,
clientID: c.ClientID,
cookiePath: c.CookiePath,
secureCookies: c.SecureCookies,
})
}
}

fallbackEndpoint, fallbackLoginMethod, err := authSourceFunc()
if err != nil {
steps++
if steps > maxSteps {
Expand All @@ -198,8 +195,26 @@ func NewAuthenticator(ctx context.Context, c *Config) (*Authenticator, error) {
continue
}

a.loginMethod = lm
a.oauth2Client.Endpoint = endpoint
a.authFunc = func() (*oauth2.Config, loginMethod) {
// rebuild non-pointer struct each time to prevent any mutation
baseOAuth2Config := oauth2.Config{
ClientID: c.ClientID,
ClientSecret: c.ClientSecret,
RedirectURL: c.RedirectURL,
Scopes: c.Scope,
Endpoint: fallbackEndpoint,
}

currentEndpoint, currentLoginMethod, errAuthSource := authSourceFunc()
if errAuthSource != nil {
log.Errorf("failed to get latest auth source data: %v", errAuthSource)
return &baseOAuth2Config, fallbackLoginMethod
}

baseOAuth2Config.Endpoint = currentEndpoint
return &baseOAuth2Config, currentLoginMethod
}

return a, nil
}
}
Expand All @@ -220,13 +235,6 @@ func newUnstartedAuthenticator(c *Config) (*Authenticator, error) {
return currentClient
}

oauth2Client := &oauth2.Config{
ClientID: c.ClientID,
ClientSecret: c.ClientSecret,
RedirectURL: c.RedirectURL,
Scopes: c.Scope,
}

errURL := "/"
if c.ErrorURL != "" {
errURL = c.ErrorURL
Expand All @@ -247,7 +255,6 @@ func newUnstartedAuthenticator(c *Config) (*Authenticator, error) {
}

return &Authenticator{
oauth2Client: oauth2Client,
clientFunc: clientFunc,
errorURL: errURL,
successURL: sucURL,
Expand All @@ -265,7 +272,7 @@ type User struct {
}

func (a *Authenticator) Authenticate(r *http.Request) (*User, error) {
return a.loginMethod.authenticate(r)
return a.getLoginMethod().authenticate(r)
}

// LoginFunc redirects to the OIDC provider for user login.
Expand All @@ -283,17 +290,17 @@ func (a *Authenticator) LoginFunc(w http.ResponseWriter, r *http.Request) {
Secure: a.secureCookies,
}
http.SetCookie(w, &cookie)
http.Redirect(w, r, a.oauth2Client.AuthCodeURL(state), http.StatusSeeOther)
http.Redirect(w, r, a.getOAuth2Config().AuthCodeURL(state), http.StatusSeeOther)
}

// LogoutFunc cleans up session cookies.
func (a *Authenticator) LogoutFunc(w http.ResponseWriter, r *http.Request) {
a.loginMethod.logout(w, r)
a.getLoginMethod().logout(w, r)
}

// GetKubeAdminLogoutURL returns the logout URL for the special kube:admin user in OpenShift
func (a *Authenticator) GetKubeAdminLogoutURL() string {
return a.loginMethod.getKubeAdminLogoutURL()
return a.getLoginMethod().getKubeAdminLogoutURL()
}

// CallbackFunc handles OAuth2 callbacks and code/token exchange.
Expand Down Expand Up @@ -330,14 +337,15 @@ func (a *Authenticator) CallbackFunc(fn func(loginInfo LoginJSON, successURL str
return
}
ctx := oidc.ClientContext(context.TODO(), a.clientFunc())
token, err := a.oauth2Client.Exchange(ctx, code)
oauthConfig, lm := a.authFunc()
token, err := oauthConfig.Exchange(ctx, code)
if err != nil {
log.Infof("unable to verify auth code with issuer: %v", err)
a.redirectAuthError(w, errorInvalidCode, err)
return
}

ls, err := a.loginMethod.login(w, token)
ls, err := lm.login(w, token)
if err != nil {
log.Errorf("error constructing login state: %v", err)
a.redirectAuthError(w, errorInternal, nil)
Expand All @@ -349,6 +357,16 @@ func (a *Authenticator) CallbackFunc(fn func(loginInfo LoginJSON, successURL str
}
}

func (a *Authenticator) getOAuth2Config() *oauth2.Config {
oauthConfig, _ := a.authFunc()
return oauthConfig
}

func (a *Authenticator) getLoginMethod() loginMethod {
_, lm := a.authFunc()
return lm
}

func (a *Authenticator) redirectAuthError(w http.ResponseWriter, authErr string, err error) {
var u url.URL
up, err := url.Parse(a.errorURL)
Expand Down