Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 102 additions & 42 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@ import (
"net/http"
"net/url"
"strings"
"time"

"github.com/mwitkow/go-conntrack"
"gopkg.in/yaml.v2"
)

// BasicAuth contains basic HTTP authentication credentials.
type BasicAuth struct {
Username string `yaml:"username"`
Password Secret `yaml:"password"`
Username string `yaml:"username"`
Password Secret `yaml:"password,omitempty"`
PasswordFile string `yaml:"password_file,omitempty"`

// Catches all undefined fields and must be empty after parsing.
XXX map[string]interface{} `yaml:",inline"`
Expand Down Expand Up @@ -88,6 +91,12 @@ func (c *HTTPClientConfig) Validate() error {
if c.BasicAuth != nil && (len(c.BearerToken) > 0 || len(c.BearerTokenFile) > 0) {
return fmt.Errorf("at most one of basic_auth, bearer_token & bearer_token_file must be configured")
}
if c.BasicAuth != nil && c.BasicAuth.Username == "" {
return fmt.Errorf("basic_auth requires a username")
}
if c.BasicAuth != nil && (string(c.BasicAuth.Password) != "" && c.BasicAuth.PasswordFile != "") {
return fmt.Errorf("at most one of basic_auth password & password_file must be configured")
}
return nil
}

Expand Down Expand Up @@ -115,82 +124,134 @@ func (a *BasicAuth) UnmarshalYAML(unmarshal func(interface{}) error) error {
return checkOverflow(a.XXX, "basic_auth")
}

// NewHTTPClientFromConfig returns a new HTTP client configured for the
// given config.HTTPClientConfig.
func NewHTTPClientFromConfig(cfg *HTTPClientConfig) (*http.Client, error) {
tlsConfig, err := NewTLSConfig(&cfg.TLSConfig)
// NewClient returns a http.Client using the specified http.RoundTripper.
func newClient(rt http.RoundTripper) *http.Client {
return &http.Client{Transport: rt}
}

// NewClientFromConfig returns a new HTTP client configured for the
// given config.HTTPClientConfig. The name is used as go-conntrack metric label.
func NewClientFromConfig(cfg HTTPClientConfig, name string) (*http.Client, error) {
rt, err := NewRoundTripperFromConfig(cfg, name)
if err != nil {
return nil, err
}
return newClient(rt), nil
}

// It's the caller's job to handle timeouts
// NewRoundTripperFromConfig returns a new HTTP RoundTripper configured for the
// given config.HTTPClientConfig. The name is used as go-conntrack metric label.
func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string) (http.RoundTripper, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy all this stuff wholesale from Prometheus, this code currently isn't handling bearer tokens correctly. The new code you add should follow how Prometheus currently does bearer tokens.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tlsConfig, err := NewTLSConfig(&cfg.TLSConfig)
if err != nil {
return nil, err
}
// The only timeout we care about is the configured scrape timeout.
// It is applied on request. So we leave out any timings here.
var rt http.RoundTripper = &http.Transport{
Proxy: http.ProxyURL(cfg.ProxyURL.URL),
DisableKeepAlives: true,
TLSClientConfig: tlsConfig,
Proxy: http.ProxyURL(cfg.ProxyURL.URL),
MaxIdleConns: 20000,
MaxIdleConnsPerHost: 1000, // see https://github.com/golang/go/issues/13801
DisableKeepAlives: false,
TLSClientConfig: tlsConfig,
DisableCompression: true,
// 5 minutes is typically above the maximum sane scrape interval. So we can
// use keepalive for all configurations.
IdleConnTimeout: 5 * time.Minute,
DialContext: conntrack.NewDialContextFunc(
conntrack.DialWithTracing(),
conntrack.DialWithName(name),
),
}

// If a bearer token is provided, create a round tripper that will set the
// Authorization header correctly on each request.
bearerToken := cfg.BearerToken
if len(bearerToken) == 0 && len(cfg.BearerTokenFile) > 0 {
b, err := ioutil.ReadFile(cfg.BearerTokenFile)
if err != nil {
return nil, fmt.Errorf("unable to read bearer token file %s: %s", cfg.BearerTokenFile, err)
}
bearerToken = Secret(strings.TrimSpace(string(b)))
}

if len(bearerToken) > 0 {
rt = NewBearerAuthRoundTripper(bearerToken, rt)
if len(cfg.BearerToken) > 0 {
rt = NewBearerAuthRoundTripper(cfg.BearerToken, rt)
} else if len(cfg.BearerTokenFile) > 0 {
rt = NewBearerAuthFileRoundTripper(cfg.BearerTokenFile, rt)
}

if cfg.BasicAuth != nil {
rt = NewBasicAuthRoundTripper(cfg.BasicAuth.Username, Secret(cfg.BasicAuth.Password), rt)
rt = NewBasicAuthRoundTripper(cfg.BasicAuth.Username, cfg.BasicAuth.Password, cfg.BasicAuth.PasswordFile, rt)
}

// Return a new client with the configured round tripper.
return &http.Client{Transport: rt}, nil
// Return a new configured RoundTripper.
return rt, nil
}

type bearerAuthRoundTripper struct {
bearerToken Secret
rt http.RoundTripper
}

type basicAuthRoundTripper struct {
username string
password Secret
rt http.RoundTripper
// NewBearerAuthRoundTripper adds the provided bearer token to a request unless the authorization
// header has already been set.
func NewBearerAuthRoundTripper(token Secret, rt http.RoundTripper) http.RoundTripper {
return &bearerAuthRoundTripper{token, rt}
}

// NewBasicAuthRoundTripper will apply a BASIC auth authorization header to a request unless it has
// already been set.
func NewBasicAuthRoundTripper(username string, password Secret, rt http.RoundTripper) http.RoundTripper {
return &basicAuthRoundTripper{username, password, rt}
func (rt *bearerAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if len(req.Header.Get("Authorization")) == 0 {
req = cloneRequest(req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", string(rt.bearerToken)))
}
return rt.rt.RoundTrip(req)
}

func (rt *bearerAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
type bearerAuthFileRoundTripper struct {
bearerFile string
rt http.RoundTripper
}

// NewBearerAuthFileRoundTripper adds the bearer token read from the provided file to a request unless
// the authorization header has already been set. This file is read for every request.
func NewBearerAuthFileRoundTripper(bearerFile string, rt http.RoundTripper) http.RoundTripper {
return &bearerAuthFileRoundTripper{bearerFile, rt}
}

func (rt *bearerAuthFileRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if len(req.Header.Get("Authorization")) == 0 {
b, err := ioutil.ReadFile(rt.bearerFile)
if err != nil {
return nil, fmt.Errorf("unable to read bearer token file %s: %s", rt.bearerFile, err)
}
bearerToken := strings.TrimSpace(string(b))

req = cloneRequest(req)
req.Header.Set("Authorization", "Bearer "+string(rt.bearerToken))
req.Header.Set("Authorization", "Bearer "+bearerToken)
}

return rt.rt.RoundTrip(req)
}

// NewBearerAuthRoundTripper adds the provided bearer token to a request unless the authorization
// header has already been set.
func NewBearerAuthRoundTripper(bearer Secret, rt http.RoundTripper) http.RoundTripper {
return &bearerAuthRoundTripper{bearer, rt}
type basicAuthRoundTripper struct {
username string
password Secret
passwordFile string
rt http.RoundTripper
}

// NewBasicAuthRoundTripper will apply a BASIC auth authorization header to a request unless it has
// already been set.
func NewBasicAuthRoundTripper(username string, password Secret, passwordFile string, rt http.RoundTripper) http.RoundTripper {
return &basicAuthRoundTripper{username, password, passwordFile, rt}
}

func (rt *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if len(req.Header.Get("Authorization")) != 0 {
return rt.RoundTrip(req)
return rt.rt.RoundTrip(req)
}
req = cloneRequest(req)
req.SetBasicAuth(rt.username, string(rt.password))
if rt.passwordFile != "" {
bs, err := ioutil.ReadFile(rt.passwordFile)
if err != nil {
return nil, fmt.Errorf("unable to read basic auth password file %s: %s", rt.passwordFile, err)
}
req.SetBasicAuth(rt.username, strings.TrimSpace(string(bs)))
} else {
req.SetBasicAuth(rt.username, strings.TrimSpace(string(rt.password)))
}
return rt.rt.RoundTrip(req)
}

Expand All @@ -208,7 +269,7 @@ func cloneRequest(r *http.Request) *http.Request {
return r2
}

// NewTLSConfig creates a new tls.Config from the given config.TLSConfig.
// NewTLSConfig creates a new tls.Config from the given TLSConfig.
func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {
tlsConfig := &tls.Config{InsecureSkipVerify: cfg.InsecureSkipVerify}

Expand All @@ -228,7 +289,6 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {
if len(cfg.ServerName) > 0 {
tlsConfig.ServerName = cfg.ServerName
}

// If a client cert & key is provided then configure TLS config accordingly.
if len(cfg.CertFile) > 0 && len(cfg.KeyFile) == 0 {
return nil, fmt.Errorf("client cert file %q specified without client key file", cfg.CertFile)
Expand Down
Loading