From 4114d537336441318d2e1acdee3707cb3c0ab462 Mon Sep 17 00:00:00 2001 From: Richard Wall Date: Wed, 27 Aug 2025 11:37:33 +0100 Subject: [PATCH] Refactor the CyberArk identity client to take an HTTP client So that we can more easily pass in an HTTP client which is configured with CA certificates to connect to an httptest TLS server. Signed-off-by: Richard Wall --- .../cyberark/dataupload/dataupload_test.go | 29 +++++++---- .../identity/advance_authentication_test.go | 16 ++---- .../identity/cmd/testidentity/main.go | 39 ++++++++++---- pkg/internal/cyberark/identity/identity.go | 51 +++++-------------- pkg/internal/cyberark/identity/mock.go | 26 +++++----- .../identity/start_authentication_test.go | 14 +---- .../cyberark/servicediscovery/discovery.go | 12 +---- 7 files changed, 81 insertions(+), 106 deletions(-) diff --git a/pkg/internal/cyberark/dataupload/dataupload_test.go b/pkg/internal/cyberark/dataupload/dataupload_test.go index cadb296a..2de68ad4 100644 --- a/pkg/internal/cyberark/dataupload/dataupload_test.go +++ b/pkg/internal/cyberark/dataupload/dataupload_test.go @@ -9,7 +9,9 @@ import ( "testing" "time" + "github.com/jetstack/venafi-connection-lib/http_client" "github.com/stretchr/testify/require" + "k8s.io/client-go/transport" "k8s.io/klog/v2" "k8s.io/klog/v2/ktesting" @@ -17,6 +19,7 @@ import ( "github.com/jetstack/preflight/pkg/internal/cyberark/dataupload" "github.com/jetstack/preflight/pkg/internal/cyberark/identity" "github.com/jetstack/preflight/pkg/internal/cyberark/servicediscovery" + "github.com/jetstack/preflight/pkg/version" _ "k8s.io/klog/v2/ktesting/init" ) @@ -129,6 +132,9 @@ func TestCyberArkClient_PostDataReadingsWithOptions(t *testing.T) { // ARK_SUBDOMAIN should be your tenant subdomain. // ARK_PLATFORM_DOMAIN should be either integration-cyberark.cloud or cyberark.cloud // +// To test against a tenant on the integration platform, also set: +// ARK_DISCOVERY_API=https://platform-discovery.integration-cyberark.cloud/api/v2 +// // To enable verbose request logging: // // go test ./pkg/internal/cyberark/dataupload/... \ @@ -138,6 +144,10 @@ func TestPostDataReadingsWithOptionsWithRealAPI(t *testing.T) { subdomain := os.Getenv("ARK_SUBDOMAIN") username := os.Getenv("ARK_USERNAME") secret := os.Getenv("ARK_SECRET") + serviceDiscoveryAPI := os.Getenv("ARK_DISCOVERY_API") + if serviceDiscoveryAPI == "" { + serviceDiscoveryAPI = servicediscovery.ProdDiscoveryEndpoint + } if platformDomain == "" || subdomain == "" || username == "" || secret == "" { t.Skip("Skipping because one of the following environment variables is unset or empty: ARK_PLATFORM_DOMAIN, ARK_SUBDOMAIN, ARK_USERNAME, ARK_SECRET") @@ -154,18 +164,19 @@ func TestPostDataReadingsWithOptionsWithRealAPI(t *testing.T) { serviceURL := fmt.Sprintf("https://%s%s%s.%s", subdomain, separator, discoveryContextServiceName, platformDomain) - var ( - identityClient *identity.Client - err error + var rootCAs *x509.CertPool + httpClient := http_client.NewDefaultClient(version.UserAgent(), rootCAs) + httpClient.Transport = transport.NewDebuggingRoundTripper(httpClient.Transport, transport.DebugByContext) + + discoveryClient := servicediscovery.New( + servicediscovery.WithHTTPClient(httpClient), + servicediscovery.WithCustomEndpoint(serviceDiscoveryAPI), ) - if platformDomain == "cyberark.cloud" { - identityClient, err = identity.New(ctx, subdomain) - } else { - discoveryClient := servicediscovery.New(servicediscovery.WithIntegrationEndpoint()) - identityClient, err = identity.NewWithDiscoveryClient(ctx, discoveryClient, subdomain) - } + + identityAPI, err := discoveryClient.DiscoverIdentityAPIURL(ctx, subdomain) require.NoError(t, err) + identityClient := identity.New(httpClient, identityAPI, subdomain) err = identityClient.LoginUsernamePassword(ctx, username, []byte(secret)) require.NoError(t, err) diff --git a/pkg/internal/cyberark/identity/advance_authentication_test.go b/pkg/internal/cyberark/identity/advance_authentication_test.go index 96749835..eb2c2b37 100644 --- a/pkg/internal/cyberark/identity/advance_authentication_test.go +++ b/pkg/internal/cyberark/identity/advance_authentication_test.go @@ -99,21 +99,11 @@ func Test_IdentityAdvanceAuthentication(t *testing.T) { t.Run(name, func(t *testing.T) { ctx := t.Context() - identityServer := MockIdentityServer() - defer identityServer.Close() + identityAPI, httpClient := MockIdentityServer(t) - mockDiscoveryServer := servicediscovery.MockDiscoveryServerWithCustomAPIURL(identityServer.Server.URL) - defer mockDiscoveryServer.Close() + client := New(httpClient, identityAPI, servicediscovery.MockDiscoverySubdomain) - discoveryClient := servicediscovery.New(servicediscovery.WithCustomEndpoint(mockDiscoveryServer.Server.URL)) - - client, err := NewWithDiscoveryClient(ctx, discoveryClient, servicediscovery.MockDiscoverySubdomain) - if err != nil { - t.Errorf("failed to create identity client: %s", err) - return - } - - err = client.doAdvanceAuthentication(ctx, testSpec.username, &testSpec.password, testSpec.advanceBody) + err := client.doAdvanceAuthentication(ctx, testSpec.username, &testSpec.password, testSpec.advanceBody) if testSpec.expectedError != err { if testSpec.expectedError == nil { t.Errorf("didn't expect an error but got %v", err) diff --git a/pkg/internal/cyberark/identity/cmd/testidentity/main.go b/pkg/internal/cyberark/identity/cmd/testidentity/main.go index b7df3562..45e9b3bf 100644 --- a/pkg/internal/cyberark/identity/cmd/testidentity/main.go +++ b/pkg/internal/cyberark/identity/cmd/testidentity/main.go @@ -2,25 +2,32 @@ package main import ( "context" + "crypto/x509" "flag" "fmt" "os" "os/signal" + "github.com/jetstack/venafi-connection-lib/http_client" + "k8s.io/client-go/transport" "k8s.io/klog/v2" "github.com/jetstack/preflight/pkg/internal/cyberark/identity" "github.com/jetstack/preflight/pkg/internal/cyberark/servicediscovery" + "github.com/jetstack/preflight/pkg/version" ) // This is a trivial CLI application for testing our identity client end-to-end. // It's not intended for distribution; it simply allows us to run our client and check // the login is successful. - +// +// To test against a tenant on the integration platform, set: +// ARK_DISCOVERY_API=https://platform-discovery.integration-cyberark.cloud/api/v2 const ( - subdomainFlag = "subdomain" - usernameFlag = "username" - passwordEnv = "TESTIDENTITY_PASSWORD" + subdomainFlag = "subdomain" + usernameFlag = "username" + passwordEnv = "ARK_SECRET" + serviceDiscoveryAPIEnv = "ARK_DISCOVERY_API" ) var ( @@ -41,16 +48,30 @@ func run(ctx context.Context) error { if password == "" { return fmt.Errorf("no password provided in %s", passwordEnv) } - sdClient := servicediscovery.New(servicediscovery.WithIntegrationEndpoint()) - client, err := identity.NewWithDiscoveryClient(ctx, sdClient, subdomain) + serviceDiscoveryAPI := os.Getenv(serviceDiscoveryAPIEnv) + if serviceDiscoveryAPI == "" { + serviceDiscoveryAPI = servicediscovery.ProdDiscoveryEndpoint + } + + var rootCAs *x509.CertPool + httpClient := http_client.NewDefaultClient(version.UserAgent(), rootCAs) + httpClient.Transport = transport.NewDebuggingRoundTripper(httpClient.Transport, transport.DebugByContext) + + sdClient := servicediscovery.New( + servicediscovery.WithHTTPClient(httpClient), + servicediscovery.WithCustomEndpoint(serviceDiscoveryAPI), + ) + identityAPI, err := sdClient.DiscoverIdentityAPIURL(ctx, subdomain) if err != nil { - return err + return fmt.Errorf("while performing service discovery: %s", err) } + client := identity.New(httpClient, identityAPI, subdomain) + err = client.LoginUsernamePassword(ctx, username, []byte(password)) if err != nil { - return err + return fmt.Errorf("while performing login with username and password: %s", err) } return nil @@ -61,7 +82,7 @@ func main() { flagSet := flag.NewFlagSet("test", flag.ExitOnError) klog.InitFlags(flagSet) - _ = flagSet.Parse([]string{"--v", "5"}) + _ = flagSet.Parse([]string{"--v", "6"}) logger := klog.Background() diff --git a/pkg/internal/cyberark/identity/identity.go b/pkg/internal/cyberark/identity/identity.go index 430fcc1f..1e48eeb1 100644 --- a/pkg/internal/cyberark/identity/identity.go +++ b/pkg/internal/cyberark/identity/identity.go @@ -12,10 +12,8 @@ import ( "time" "github.com/cenkalti/backoff/v5" - "k8s.io/client-go/transport" "k8s.io/klog/v2" - "github.com/jetstack/preflight/pkg/internal/cyberark/servicediscovery" "github.com/jetstack/preflight/pkg/logs" "github.com/jetstack/preflight/pkg/version" ) @@ -177,10 +175,9 @@ type advanceAuthenticationResponseResult struct { // Client is an client for interacting with the CyberArk Identity API and performing a login using a username and password. // For context on the behaviour of this client, see the Python SDK: https://github.com/cyberark/ark-sdk-python/blob/3be12c3f2d3a2d0407025028943e584b6edc5996/ark_sdk_python/auth/identity/ark_identity.py type Client struct { - client *http.Client - - endpoint string - subdomain string + httpClient *http.Client + baseURL string + subdomain string tokenCached token tokenCachedMutex sync.Mutex @@ -190,39 +187,15 @@ type Client struct { type token string // New returns an initialized CyberArk Identity client using a default service discovery client. -// NB: This function performs service discovery when called, in order to ensure that all Identity -// clients are created with a valid Identity API URL. This function blocks on the network call to -// the discovery service. -func New(ctx context.Context, subdomain string) (*Client, error) { - return NewWithDiscoveryClient(ctx, servicediscovery.New(), subdomain) -} - -// NewWithDiscoveryClient returns an initialized CyberArk Identity client using the given service discovery client. -// NB: This function performs service discovery when called, in order to ensure that all Identity -// clients are created with a valid Identity API URL. This function blocks on the network call to -// the discovery service. -func NewWithDiscoveryClient(ctx context.Context, discoveryClient *servicediscovery.Client, subdomain string) (*Client, error) { - if discoveryClient == nil { - return nil, fmt.Errorf("must provide a non-nil discovery client to the Identity Client") - } - - endpoint, err := discoveryClient.DiscoverIdentityAPIURL(ctx, subdomain) - if err != nil { - return nil, err - } - +func New(httpClient *http.Client, baseURL string, subdomain string) *Client { return &Client{ - client: &http.Client{ - Timeout: 10 * time.Second, - Transport: transport.NewDebuggingRoundTripper(http.DefaultTransport, transport.DebugByContext), - }, - - endpoint: endpoint, - subdomain: subdomain, + httpClient: httpClient, + baseURL: baseURL, + subdomain: subdomain, tokenCached: "", tokenCachedMutex: sync.Mutex{}, - }, nil + } } // LoginUsernamePassword performs a blocking call to fetch an auth token from CyberArk Identity using the given username and password. @@ -282,7 +255,7 @@ func (c *Client) doStartAuthentication(ctx context.Context, username string) (ad return response, fmt.Errorf("failed to marshal JSON for request to StartAuthentication endpoint: %s", err) } - endpoint, err := url.JoinPath(c.endpoint, "Security", "StartAuthentication") + endpoint, err := url.JoinPath(c.baseURL, "Security", "StartAuthentication") if err != nil { return response, fmt.Errorf("failed to create URL for request to CyberArk Identity StartAuthentication: %s", err) } @@ -294,7 +267,7 @@ func (c *Client) doStartAuthentication(ctx context.Context, username string) (ad setIdentityHeaders(request) - httpResponse, err := c.client.Do(request) + httpResponse, err := c.httpClient.Do(request) if err != nil { return response, fmt.Errorf("failed to perform HTTP request to start authentication: %s", err) } @@ -391,7 +364,7 @@ func (c *Client) doAdvanceAuthentication(ctx context.Context, username string, p return backoff.Permanent(fmt.Errorf("failed to marshal JSON for request to AdvanceAuthentication endpoint: %s", err)) } - endpoint, err := url.JoinPath(c.endpoint, "Security", "AdvanceAuthentication") + endpoint, err := url.JoinPath(c.baseURL, "Security", "AdvanceAuthentication") if err != nil { return backoff.Permanent(fmt.Errorf("failed to create URL for request to CyberArk Identity AdvanceAuthentication: %s", err)) } @@ -403,7 +376,7 @@ func (c *Client) doAdvanceAuthentication(ctx context.Context, username string, p setIdentityHeaders(request) - httpResponse, err := c.client.Do(request) + httpResponse, err := c.httpClient.Do(request) if err != nil { return fmt.Errorf("failed to perform HTTP request to advance authentication: %s", err) } diff --git a/pkg/internal/cyberark/identity/mock.go b/pkg/internal/cyberark/identity/mock.go index d57d2c38..991b6733 100644 --- a/pkg/internal/cyberark/identity/mock.go +++ b/pkg/internal/cyberark/identity/mock.go @@ -6,6 +6,9 @@ import ( "fmt" "net/http" "net/http/httptest" + "testing" + + "k8s.io/client-go/transport" "github.com/jetstack/preflight/pkg/version" @@ -51,22 +54,17 @@ var ( advanceAuthenticationFailureResponse string ) -type mockIdentityServer struct { - Server *httptest.Server -} +type mockIdentityServer struct{} -// MockIdentityServer returns a mocked CyberArk Identity server. -// The returned server should be Closed by the caller after use. -func MockIdentityServer() *mockIdentityServer { +// MockIdentityServer returns a URL of a mocked CyberArk identity server and an +// HTTP client with the CA certs needed to connect to it.. +func MockIdentityServer(t *testing.T) (string, *http.Client) { mis := &mockIdentityServer{} - - mis.Server = httptest.NewServer(mis) - - return mis -} - -func (mis *mockIdentityServer) Close() { - mis.Server.Close() + server := httptest.NewTLSServer(mis) + t.Cleanup(server.Close) + httpClient := server.Client() + httpClient.Transport = transport.NewDebuggingRoundTripper(httpClient.Transport, transport.DebugByContext) + return server.URL, httpClient } func (mis *mockIdentityServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/internal/cyberark/identity/start_authentication_test.go b/pkg/internal/cyberark/identity/start_authentication_test.go index a25ad139..17dd85eb 100644 --- a/pkg/internal/cyberark/identity/start_authentication_test.go +++ b/pkg/internal/cyberark/identity/start_authentication_test.go @@ -40,19 +40,9 @@ func Test_IdentityStartAuthentication(t *testing.T) { t.Run(name, func(t *testing.T) { ctx := t.Context() - identityServer := MockIdentityServer() - defer identityServer.Close() + identityServer, httpClient := MockIdentityServer(t) - mockDiscoveryServer := servicediscovery.MockDiscoveryServerWithCustomAPIURL(identityServer.Server.URL) - defer mockDiscoveryServer.Close() - - discoveryClient := servicediscovery.New(servicediscovery.WithCustomEndpoint(mockDiscoveryServer.Server.URL)) - - client, err := NewWithDiscoveryClient(ctx, discoveryClient, servicediscovery.MockDiscoverySubdomain) - if err != nil { - t.Errorf("failed to create identity client: %s", err) - return - } + client := New(httpClient, identityServer, servicediscovery.MockDiscoverySubdomain) advanceBody, err := client.doStartAuthentication(ctx, testSpec.username) if err != nil { diff --git a/pkg/internal/cyberark/servicediscovery/discovery.go b/pkg/internal/cyberark/servicediscovery/discovery.go index 5c2cf98a..d99ddc79 100644 --- a/pkg/internal/cyberark/servicediscovery/discovery.go +++ b/pkg/internal/cyberark/servicediscovery/discovery.go @@ -15,8 +15,7 @@ import ( ) const ( - prodDiscoveryEndpoint = "https://platform-discovery.cyberark.cloud/api/v2/" - integrationDiscoveryEndpoint = "https://platform-discovery.integration-cyberark.cloud/api/v2/" + ProdDiscoveryEndpoint = "https://platform-discovery.cyberark.cloud/api/v2/" // identityServiceName is the name of the identity service we're looking for in responses from the Service Discovery API // We were told to use the identity_administration field, not the identity_user_portal field. @@ -45,13 +44,6 @@ func WithHTTPClient(httpClient *http.Client) ClientOpt { } } -// WithIntegrationEndpoint sets the discovery client to use the integration testing endpoint rather than production -func WithIntegrationEndpoint() ClientOpt { - return func(c *Client) { - c.endpoint = integrationDiscoveryEndpoint - } -} - // WithCustomEndpoint sets the endpoint to a custom URL without checking that the URL is a CyberArk Service Discovery // server. func WithCustomEndpoint(endpoint string) ClientOpt { @@ -67,7 +59,7 @@ func New(clientOpts ...ClientOpt) *Client { Timeout: 10 * time.Second, Transport: transport.NewDebuggingRoundTripper(http.DefaultTransport, transport.DebugByContext), }, - endpoint: prodDiscoveryEndpoint, + endpoint: ProdDiscoveryEndpoint, } for _, opt := range clientOpts {