Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 193 additions & 40 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.

// +build go1.8

package config

import (
"bytes"
"crypto/md5"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
"sync"
"time"

"github.com/mwitkow/go-conntrack"
"gopkg.in/yaml.v2"
)

type closeIdler interface {
CloseIdleConnections()
}

// BasicAuth contains basic HTTP authentication credentials.
type BasicAuth struct {
Username string `yaml:"username"`
Expand Down Expand Up @@ -124,42 +133,51 @@ func NewClientFromConfig(cfg HTTPClientConfig, name string) (*http.Client, error
// NewRoundTripperFromConfig returns a new HTTP RoundTripper configured for the
// given config.HTTPClientConfig. The name is used as go-conntrack metric label.
func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string) (http.RoundTripper, error) {
newRT := func(tlsConfig *tls.Config) (http.RoundTripper, error) {
// The only timeout we care about is the configured scrape timeout.
// It is applied on request. So we leave out any timings here.
var rt http.RoundTripper = &http.Transport{
Proxy: http.ProxyURL(cfg.ProxyURL.URL),
MaxIdleConns: 20000,
MaxIdleConnsPerHost: 1000, // see https://github.com/golang/go/issues/13801
DisableKeepAlives: false,
TLSClientConfig: tlsConfig,
DisableCompression: true,
// 5 minutes is typically above the maximum sane scrape interval. So we can
// use keepalive for all configurations.
IdleConnTimeout: 5 * time.Minute,
DialContext: conntrack.NewDialContextFunc(
conntrack.DialWithTracing(),
conntrack.DialWithName(name),
),
}

// If a bearer token is provided, create a round tripper that will set the
// Authorization header correctly on each request.
if len(cfg.BearerToken) > 0 {
rt = NewBearerAuthRoundTripper(cfg.BearerToken, rt)
} else if len(cfg.BearerTokenFile) > 0 {
rt = NewBearerAuthFileRoundTripper(cfg.BearerTokenFile, rt)
}

if cfg.BasicAuth != nil {
rt = NewBasicAuthRoundTripper(cfg.BasicAuth.Username, cfg.BasicAuth.Password, cfg.BasicAuth.PasswordFile, rt)
}
// Return a new configured RoundTripper.
return rt, nil
}

tlsConfig, err := NewTLSConfig(&cfg.TLSConfig)
if err != nil {
return nil, err
}
// The only timeout we care about is the configured scrape timeout.
// It is applied on request. So we leave out any timings here.
var rt http.RoundTripper = &http.Transport{
Proxy: http.ProxyURL(cfg.ProxyURL.URL),
MaxIdleConns: 20000,
MaxIdleConnsPerHost: 1000, // see https://github.com/golang/go/issues/13801
DisableKeepAlives: false,
TLSClientConfig: tlsConfig,
DisableCompression: true,
// 5 minutes is typically above the maximum sane scrape interval. So we can
// use keepalive for all configurations.
IdleConnTimeout: 5 * time.Minute,
DialContext: conntrack.NewDialContextFunc(
conntrack.DialWithTracing(),
conntrack.DialWithName(name),
),
}

// If a bearer token is provided, create a round tripper that will set the
// Authorization header correctly on each request.
if len(cfg.BearerToken) > 0 {
rt = NewBearerAuthRoundTripper(cfg.BearerToken, rt)
} else if len(cfg.BearerTokenFile) > 0 {
rt = NewBearerAuthFileRoundTripper(cfg.BearerTokenFile, rt)
}

if cfg.BasicAuth != nil {
rt = NewBasicAuthRoundTripper(cfg.BasicAuth.Username, cfg.BasicAuth.Password, cfg.BasicAuth.PasswordFile, rt)
if len(cfg.TLSConfig.CAFile) == 0 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not getting what you're trying to do here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no CA file is provided, we don't need a round tripper that reloads the CA. So we return a "normal" round-tripper.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the root changes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which root?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the CA (as that's usually the root), but that would miss the client ssl auth changing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// No need for a RoundTripper that reloads the CA file automatically.
return newRT(tlsConfig)
}

// Return a new configured RoundTripper.
return rt, nil
return newTLSRoundTripper(tlsConfig, cfg.TLSConfig.CAFile, newRT)
}

type bearerAuthRoundTripper struct {
Expand All @@ -181,6 +199,12 @@ func (rt *bearerAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response,
return rt.rt.RoundTrip(req)
}

func (rt *bearerAuthRoundTripper) CloseIdleConnections() {
if ci, ok := rt.rt.(closeIdler); ok {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cleaner in Go 1.12 as it exposes this directly, without having to reach into the transport

Copy link
Member Author

@simonpasquier simonpasquier Mar 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You refer to client.CloseIdleConnections()? I don't know how we can avoid going through the nested round-trippers until we call http.Transport.CloseIdleConnections().

https://golang.org/src/net/http/client.go?s=27593:27632#L841

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we even switched over to go 1.12 everywhere? I feel like it might be a little early to use that. I think I'd prefer this approach at this point.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a TODO at least to clean this up?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not clear about what the TODO should be about... AFAICT consumers of the library using Go 1.12 will be able to do this and it will just work:

client, err := config.NewClient(...)
...
client.CloseIdleConnections()

But again I don't see how we could avoid having CloseIdleConnections() for all the custom round-trippers if we want client.CloseIdleConnections() to be effective.

ci.CloseIdleConnections()
}
}

type bearerAuthFileRoundTripper struct {
bearerFile string
rt http.RoundTripper
Expand All @@ -207,6 +231,12 @@ func (rt *bearerAuthFileRoundTripper) RoundTrip(req *http.Request) (*http.Respon
return rt.rt.RoundTrip(req)
}

func (rt *bearerAuthFileRoundTripper) CloseIdleConnections() {
if ci, ok := rt.rt.(closeIdler); ok {
ci.CloseIdleConnections()
}
}

type basicAuthRoundTripper struct {
username string
password Secret
Expand Down Expand Up @@ -237,6 +267,12 @@ func (rt *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, e
return rt.rt.RoundTrip(req)
}

func (rt *basicAuthRoundTripper) CloseIdleConnections() {
if ci, ok := rt.rt.(closeIdler); ok {
ci.CloseIdleConnections()
}
}

// cloneRequest returns a clone of the provided *http.Request.
// The clone is a shallow copy of the struct and its Header map.
func cloneRequest(r *http.Request) *http.Request {
Expand All @@ -258,14 +294,13 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {
// If a CA cert is provided then let's read it in so we can validate the
// scrape target's certificate properly.
if len(cfg.CAFile) > 0 {
caCertPool := x509.NewCertPool()
// Load CA cert.
caCert, err := ioutil.ReadFile(cfg.CAFile)
b, err := readCAFile(cfg.CAFile)
if err != nil {
return nil, fmt.Errorf("unable to use specified CA cert %s: %s", cfg.CAFile, err)
return nil, err
}
if !updateRootCA(tlsConfig, b) {
return nil, fmt.Errorf("unable to use specified CA cert %s", cfg.CAFile)
}
caCertPool.AppendCertsFromPEM(caCert)
tlsConfig.RootCAs = caCertPool
}

if len(cfg.ServerName) > 0 {
Expand All @@ -277,13 +312,12 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {
} else if len(cfg.KeyFile) > 0 && len(cfg.CertFile) == 0 {
return nil, fmt.Errorf("client key file %q specified without client cert file", cfg.KeyFile)
} else if len(cfg.CertFile) > 0 && len(cfg.KeyFile) > 0 {
cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
if err != nil {
return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %s", cfg.CertFile, cfg.KeyFile, err)
// Verify that client cert and key are valid.
if _, err := cfg.getClientCertificate(nil); err != nil {
return nil, err
}
tlsConfig.Certificates = []tls.Certificate{cert}
tlsConfig.GetClientCertificate = cfg.getClientCertificate
}
tlsConfig.BuildNameToCertificate()

return tlsConfig, nil
}
Expand All @@ -308,6 +342,125 @@ func (c *TLSConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
return unmarshal((*plain)(c))
}

// getClientCertificate reads the pair of client cert and key from disk and returns a tls.Certificate.
func (c *TLSConfig) getClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile)
if err != nil {
return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %s", c.CertFile, c.KeyFile, err)
}
return &cert, nil
}

// readCAFile reads the CA cert file from disk.
func readCAFile(f string) ([]byte, error) {
data, err := ioutil.ReadFile(f)
if err != nil {
return nil, fmt.Errorf("unable to load specified CA cert %s: %s", f, err)
}
return data, nil
}

// updateRootCA parses the given byte slice as a series of PEM encoded certificates and updates tls.Config.RootCAs.
func updateRootCA(cfg *tls.Config, b []byte) bool {
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(b) {
return false
}
cfg.RootCAs = caCertPool
return true
}

// tlsRoundTripper is a RoundTripper that updates automatically its TLS
// configuration whenever the content of the CA file changes.
type tlsRoundTripper struct {
caFile string
// newRT returns a new RoundTripper.
newRT func(*tls.Config) (http.RoundTripper, error)

mtx sync.RWMutex
rt http.RoundTripper
hashCAFile []byte
tlsConfig *tls.Config
}

func newTLSRoundTripper(
cfg *tls.Config,
caFile string,
newRT func(*tls.Config) (http.RoundTripper, error),
) (http.RoundTripper, error) {
t := &tlsRoundTripper{
caFile: caFile,
newRT: newRT,
tlsConfig: cfg,
}

rt, err := t.newRT(t.tlsConfig)
if err != nil {
return nil, err
}
t.rt = rt

_, t.hashCAFile, err = t.getCAWithHash()
if err != nil {
return nil, err
}

return t, nil
}

func (t *tlsRoundTripper) getCAWithHash() ([]byte, []byte, error) {
b, err := readCAFile(t.caFile)
if err != nil {
return nil, nil, err
}
h := md5.Sum(b)
return b, h[:], nil

}

// RoundTrip implements the http.RoundTrip interface.
func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
b, h, err := t.getCAWithHash()
if err != nil {
return nil, err
}

t.mtx.RLock()
equal := bytes.Equal(h[:], t.hashCAFile)
rt := t.rt
t.mtx.RUnlock()
if equal {
// The CA cert hasn't changed, use the existing RoundTripper.
return rt.RoundTrip(req)
}

// Create a new RoundTripper.
tlsConfig := t.tlsConfig.Clone()
if !updateRootCA(tlsConfig, b) {
return nil, fmt.Errorf("unable to use specified CA cert %s", t.caFile)
}
rt, err = t.newRT(tlsConfig)
if err != nil {
return nil, err
}
t.CloseIdleConnections()

t.mtx.Lock()
t.rt = rt
t.hashCAFile = h[:]
t.mtx.Unlock()

return rt.RoundTrip(req)
}

func (t *tlsRoundTripper) CloseIdleConnections() {
t.mtx.RLock()
defer t.mtx.RUnlock()
if ci, ok := t.rt.(closeIdler); ok {
ci.CloseIdleConnections()
}
}

func (c HTTPClientConfig) String() string {
b, err := yaml.Marshal(c)
if err != nil {
Expand Down
Loading