Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 92 additions & 78 deletions pkg/docker/creds/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func TestCheckAuth(t *testing.T) {
incorrectPwd = "badpwd"
)

localhost, localhostTLS := startServer(t, uname, pwd)
localhost, localhostTLS, cert := startServer(t, uname, pwd)

_, portTLS, err := net.SplitHostPort(localhostTLS)
if err != nil {
Expand Down Expand Up @@ -132,7 +132,6 @@ func TestCheckAuth(t *testing.T) {
},
wantErr: false,
},

{
name: "correct credentials non-localhost",
args: args{
Expand Down Expand Up @@ -170,7 +169,30 @@ func TestCheckAuth(t *testing.T) {
Username: tt.args.username,
Password: tt.args.password,
}
if err := creds.CheckAuth(tt.args.ctx, tt.args.registry+"/someorg/someimage:sometag", c, http.DefaultTransport); (err != nil) != tt.wantErr {
// create trusted certificates pool and add our certificate
certPool := x509.NewCertPool()
certPool.AddCert(cert)

// client transport with the certificate
transport := &http.Transport{
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the startServer whichi previously used to handle the http transport mutation gets used by one more function TestCheckAuthEmptyCreds which still uses the default http transport instead of this one but doesnt seem to cause any problems?
Perhaps TestCheckAuthEmptyCreds needs to have the Transport changed as well

TLSClientConfig: &tls.Config{
RootCAs: certPool,
},
}

dialer := &net.Dialer{}

transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
h, p, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
if h == "test.io" {
h = "localhost"
}
return dialer.DialContext(ctx, network, net.JoinHostPort(h, p))
Comment on lines +172 to +193
Copy link
Copy Markdown
Contributor Author

@gauron99 gauron99 Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create custom transport and ca pool each run instead of mutating the http.defaultTransport on server creation

}
if err := creds.CheckAuth(tt.args.ctx, tt.args.registry+"/someorg/someimage:sometag", c, transport); (err != nil) != tt.wantErr {
t.Errorf("CheckAuth() error = %v, wantErr %v", err, tt.wantErr)
}
})
Expand All @@ -179,141 +201,133 @@ func TestCheckAuth(t *testing.T) {

func TestCheckAuthEmptyCreds(t *testing.T) {

localhost, _ := startServer(t, "", "")
localhost, _, _ := startServer(t, "", "")
err := creds.CheckAuth(context.Background(), localhost+"/someorg/someimage:sometag", docker.Credentials{}, http.DefaultTransport)
if err != nil {
t.Error(err)
}
}

func startServer(t *testing.T, uname, pwd string) (addr, addrTLS string) {
// TODO: this should be refactored to use OS-chosen ports so as not to
// fail when a user is running a function on the default port.)
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}
addr = listener.Addr().String()

listenerTLS, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}
addrTLS = listenerTLS.Addr().String()

handler := http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
if uname == "" || pwd == "" {
if req.Method == http.MethodPost {
resp.WriteHeader(http.StatusCreated)
} else {
resp.WriteHeader(http.StatusOK)
}
return
}
// TODO add also test for token based auth
resp.Header().Add("WWW-Authenticate", "basic")
if u, p, ok := req.BasicAuth(); ok {
if u == uname && p == pwd {
if req.Method == http.MethodPost {
resp.WriteHeader(http.StatusCreated)
} else {
resp.WriteHeader(http.StatusOK)
}
return
}
}
resp.WriteHeader(http.StatusUnauthorized)
})

// generate Certificates
func generateCert(t *testing.T) (tls.Certificate, *x509.Certificate) {
Copy link
Copy Markdown
Contributor Author

@gauron99 gauron99 Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move certificate gen to separate function

var randReader io.Reader = rand.Reader

caPublicKey, caPrivateKey, err := ed25519.GenerateKey(randReader)
if err != nil {
t.Fatal(err)
}

ca := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "localhost",
},
caTemplate := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "localhost"},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
DNSNames: []string{"localhost", "test.io"},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
NotAfter: time.Now().AddDate(1, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
ExtraExtensions: []pkix.Extension{},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}

caBytes, err := x509.CreateCertificate(randReader, ca, ca, caPublicKey, caPrivateKey)
caBytes, err := x509.CreateCertificate(randReader, caTemplate, caTemplate, caPublicKey, caPrivateKey)
if err != nil {
t.Fatal(err)
}

ca, err = x509.ParseCertificate(caBytes)
ca, err := x509.ParseCertificate(caBytes)
if err != nil {
t.Fatal(err)
}

cert := tls.Certificate{
tls := tls.Certificate{
Certificate: [][]byte{caBytes},
PrivateKey: caPrivateKey,
Leaf: ca,
}
return tls, ca
}

func startServer(t *testing.T, uname, pwd string) (addr, addrTLS string, ca *x509.Certificate) {
// create a custom handler function
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// no authentication required, empty creds
if uname == "" || pwd == "" {
if r.Method == http.MethodPost {
w.WriteHeader(http.StatusCreated)
} else {
w.WriteHeader(http.StatusOK)
}
return
}

w.Header().Add("WWW-Authenticate", "basic")
if u, p, ok := r.BasicAuth(); ok {
if u == uname && p == pwd {
if r.Method == http.MethodPost {
w.WriteHeader(http.StatusCreated)
} else {
w.WriteHeader(http.StatusOK)
}
return
}
}
w.WriteHeader(http.StatusUnauthorized)
})

// Setup certificates
// tls Cert for the TLS server (has ca as Leaf)
// x509 certificate which is its own CA for client
tlsCert, ca := generateCert(t)

// create Server config
server := http.Server{
Handler: handler,
TLSConfig: &tls.Config{
ServerName: "localhost",
Certificates: []tls.Certificate{cert},
ServerName: "localhost",
// with the TLS certificate
Certificates: []tls.Certificate{tlsCert},
},
}

// non-TLS listener
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}

// TLS listener
listenerTLS, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}
addr = listener.Addr().String()
addrTLS = listenerTLS.Addr().String()

// listen for requests
go func() {
err := server.ServeTLS(listenerTLS, "", "")
if err != nil && !strings.Contains(err.Error(), "Server closed") {
if err != nil && err != http.ErrServerClosed {
panic(err)
}
}()

go func() {
err := server.Serve(listener)
if err != nil && !strings.Contains(err.Error(), "Server closed") {
if err != nil && err != http.ErrServerClosed {
panic(err)
}
}()
// make the testing CA trusted by default HTTP transport/client
oldDefaultTransport := http.DefaultTransport
newDefaultTransport := http.DefaultTransport.(*http.Transport).Clone()
http.DefaultTransport = newDefaultTransport
caPool := x509.NewCertPool()
caPool.AddCert(ca)
newDefaultTransport.TLSClientConfig.RootCAs = caPool
dc := newDefaultTransport.DialContext
newDefaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
h, p, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
if h == "test.io" {
h = "localhost"
}
addr = net.JoinHostPort(h, p)
return dc(ctx, network, addr)
}

// shutdown servers at cleanup
t.Cleanup(func() {
err := server.Shutdown(context.Background())
if err != nil {
t.Fatal(err)
}
http.DefaultTransport = oldDefaultTransport
})

return addr, addrTLS
return
}

const (
Expand Down
Loading