diff --git a/internal/pkg/cli/errors.go b/internal/pkg/cli/errors.go index 58561972c52..eb89caa4017 100644 --- a/internal/pkg/cli/errors.go +++ b/internal/pkg/cli/errors.go @@ -173,9 +173,11 @@ func (e *errTaskRoleRetrievalFailed) Error() string { } func (e *errTaskRoleRetrievalFailed) RecommendActions() string { - return fmt.Sprintf(`TaskRole retrieval failed. You can manually add permissions for your account to assume TaskRole by adding the following YAML override to your service: + return fmt.Sprintf(`TaskRole retrieval failed. If your containers don't require the TaskRole for local testing, you can use %s to disable this feature. +If you require the TaskRole, you can manually add permissions for your account to assume TaskRole by adding the following YAML override to your service: %s For more information on YAML overrides see %s`, + color.HighlightCode(`copilot run local --use-task-role=false`), color.HighlightCodeBlock(`- op: add path: /Resources/TaskRole/Properties/AssumeRolePolicyDocument/Statement/- value: diff --git a/internal/pkg/cli/run_local.go b/internal/pkg/cli/run_local.go index 900c4e80667..3fd82a913e2 100644 --- a/internal/pkg/cli/run_local.go +++ b/internal/pkg/cli/run_local.go @@ -4,9 +4,12 @@ package cli import ( + "bytes" "context" + "encoding/json" "errors" "fmt" + "io" "net" "os" "os/signal" @@ -65,6 +68,12 @@ const ( workloadAskPrompt = "Which workload would you like to run locally?" ) +const ( + // Command to retrieve container credentials with ecs exec. See more at https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-iam-roles.html. + // Example output: {"AccessKeyId":"ACCESS_KEY_ID","Expiration":"EXPIRATION_DATE","RoleArn":"TASK_ROLE_ARN","SecretAccessKey":"SECRET_ACCESS_KEY","Token":"SECURITY_TOKEN_STRING"} + curlContainerCredentialsCmd = "curl 169.254.170.2$AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" +) + type containerOrchestrator interface { Start() <-chan error RunTask(orchestrator.Task, ...orchestrator.RunTaskOption) @@ -109,6 +118,7 @@ type runLocalOpts struct { sel deploySelector ecsClient ecsClient + ecsExecutor ecsCommandExecutor ssm secretGetter secretsManager secretGetter sessProvider sessionProvider @@ -133,6 +143,9 @@ type runLocalOpts struct { labeledTermPrinter func(fw syncbuffer.FileWriter, bufs []*syncbuffer.LabeledSyncBuffer, opts ...syncbuffer.LabeledTermPrinterOption) clideploy.LabeledTermPrinter unmarshal func([]byte) (manifest.DynamicWorkload, error) newInterpolator func(app, env string) interpolator + + captureStdout func() (io.Reader, error) + releaseStdout func() } func newRunLocalOpts(vars runLocalVars) (*runLocalOpts, error) { @@ -184,6 +197,7 @@ func newRunLocalOpts(vars runLocalVars) (*runLocalOpts, error) { // so use the default sess and *hope* they have permissions. o.ecsClient = ecs.New(o.envManagerSess) o.ssm = ssm.New(o.envManagerSess) + o.ecsExecutor = awsecs.New(o.envManagerSess) o.secretsManager = secretsmanager.New(defaultSessEnvRegion) resources, err := cloudformation.New(o.sess, cloudformation.WithProgressTracker(os.Stderr)).GetAppResourcesByRegion(o.targetApp, o.targetEnv.Region) @@ -256,6 +270,32 @@ func newRunLocalOpts(vars runLocalVars) (*runLocalOpts, error) { o.newRecursiveWatcher = func() (recursiveWatcher, error) { return file.NewRecursiveWatcher(0) } + + // Capture stdout by replacing it with a piped writer and returning an attached io.Reader. + // Functions are concurrency safe and idempotent. + var mu sync.Mutex + var savedWriter, savedStdout *os.File + savedStdout = os.Stdout + o.captureStdout = func() (io.Reader, error) { + if savedWriter != nil { + savedWriter.Close() + } + pipeReader, pipeWriter, err := os.Pipe() + if err != nil { + return nil, err + } + mu.Lock() + defer mu.Unlock() + savedWriter = pipeWriter + os.Stdout = savedWriter + return (io.Reader)(pipeReader), nil + } + o.releaseStdout = func() { + mu.Lock() + defer mu.Unlock() + os.Stdout = savedStdout + savedWriter.Close() + } return o, nil } @@ -666,7 +706,102 @@ func (o *runLocalOpts) taskRoleCredentials(ctx context.Context) (map[string]stri // ecsExecMethod tries to use ECS Exec to retrive credentials from running container ecsExecMethod := func() (map[string]string, error) { - return nil, errors.New("ecs exec method not implemented") + svcDesc, err := o.ecsClient.DescribeService(o.appName, o.envName, o.wkldName) + if err != nil { + return nil, fmt.Errorf("describe ECS service for %s in environment %s: %w", o.wkldName, o.envName, err) + } + + stdoutReader, err := o.captureStdout() + if err != nil { + return nil, err + } + defer o.releaseStdout() + + // try exec on each container within the service + var wg sync.WaitGroup + containerErr := make(chan error) + for _, task := range svcDesc.Tasks { + taskID, err := awsecs.TaskID(aws.StringValue(task.TaskArn)) + if err != nil { + return nil, err + } + + for _, container := range task.Containers { + wg.Add(1) + containerName := aws.StringValue(container.Name) + go func() { + defer wg.Done() + err := o.ecsExecutor.ExecuteCommand(awsecs.ExecuteCommandInput{ + Cluster: svcDesc.ClusterName, + Command: fmt.Sprintf("/bin/sh -c %q\n", curlContainerCredentialsCmd), + Task: taskID, + Container: containerName, + }) + if err != nil { + containerErr <- fmt.Errorf("container %s in task %s: %w", containerName, taskID, err) + } + }() + } + } + + // wait for containers to finish and reset stdout + containersFinished := make(chan struct{}) + go func() { + wg.Wait() + o.releaseStdout() + close(containersFinished) + }() + + type containerCredentialsOutput struct { + AccessKeyId string + SecretAccessKey string + Token string + } + + // parse stdout to try and find credentials + credsResult := make(chan map[string]string) + parseErr := make(chan error) + go func() { + select { + case <-containersFinished: + buf, err := io.ReadAll(stdoutReader) + if err != nil { + parseErr <- err + return + } + lines := bytes.Split(buf, []byte("\n")) + var creds containerCredentialsOutput + for _, line := range lines { + err := json.Unmarshal(line, &creds) + if err != nil { + continue + } + credsResult <- map[string]string{ + "AWS_ACCESS_KEY_ID": creds.AccessKeyId, + "AWS_SECRET_ACCESS_KEY": creds.SecretAccessKey, + "AWS_SESSION_TOKEN": creds.Token, + } + return + } + parseErr <- errors.New("all containers failed to retrieve credentials") + case <-ctx.Done(): + return + } + }() + + var containerErrs []error + for { + select { + case creds := <-credsResult: + return creds, nil + case <-ctx.Done(): + return nil, ctx.Err() + case err := <-parseErr: + return nil, errors.Join(append([]error{err}, containerErrs...)...) + case err := <-containerErr: + containerErrs = append(containerErrs, err) + } + } } credentialsChain := []func() (map[string]string, error){ @@ -674,14 +809,19 @@ func (o *runLocalOpts) taskRoleCredentials(ctx context.Context) (map[string]stri ecsExecMethod, } + credentialsChainWrappedErrs := []string{ + "assume role", + "ecs exec", + } + // return TaskRole credentials from first successful method var errs []error - for _, method := range credentialsChain { + for errIndex, method := range credentialsChain { vars, err := method() if err == nil { return vars, nil } - errs = append(errs, err) + errs = append(errs, fmt.Errorf("%s: %w", credentialsChainWrappedErrs[errIndex], err)) } return nil, &errTaskRoleRetrievalFailed{errs} @@ -1056,6 +1196,7 @@ func BuildRunLocalCmd() *cobra.Command { cmd.Flags().StringVarP(&vars.envName, envFlag, envFlagShort, "", envFlagDescription) cmd.Flags().StringVarP(&vars.appName, appFlag, appFlagShort, tryReadingAppName(), appFlagDescription) cmd.Flags().BoolVar(&vars.watch, watchFlag, false, watchFlagDescription) + cmd.Flags().BoolVar(&vars.useTaskRole, useTaskRoleFlag, true, useTaskRoleFlagDescription) cmd.Flags().Var(&vars.portOverrides, portOverrideFlag, portOverridesFlagDescription) cmd.Flags().StringToStringVar(&vars.envOverrides, envVarOverrideFlag, nil, envVarOverrideFlagDescription) cmd.Flags().BoolVar(&vars.proxy, proxyFlag, false, proxyFlagDescription) diff --git a/internal/pkg/cli/run_local_test.go b/internal/pkg/cli/run_local_test.go index 9e7313cf46b..fba711e60ce 100644 --- a/internal/pkg/cli/run_local_test.go +++ b/internal/pkg/cli/run_local_test.go @@ -7,6 +7,8 @@ import ( "context" "errors" "fmt" + "io" + "strings" "syscall" "testing" @@ -201,6 +203,7 @@ func TestRunLocalOpts_Ask(t *testing.T) { type runLocalExecuteMocks struct { ecsClient *mocks.MockecsClient + ecsExecutor *mocks.MockecsCommandExecutor store *mocks.Mockstore sessCreds credentials.Provider sessProvider *mocks.MocksessionProvider @@ -432,7 +435,7 @@ func TestRunLocalOpts_Execute(t *testing.T) { "AWS_SESSION_TOKEN": "myEnvToken", }, } - expectedTaskRoleTask := orchestrator.Task{ + expectedTaskWithRegion := orchestrator.Task{ Containers: map[string]orchestrator.ContainerDefinition{ "foo": { ImageURI: "image1", @@ -441,9 +444,9 @@ func TestRunLocalOpts_Execute(t *testing.T) { }, Secrets: map[string]string{ "SHARED_SECRET": "secretvalue", - "AWS_ACCESS_KEY_ID": "taskRoleID", - "AWS_SECRET_ACCESS_KEY": "taskRoleSecret", - "AWS_SESSION_TOKEN": "taskRoleToken", + "AWS_ACCESS_KEY_ID": "myID", + "AWS_SECRET_ACCESS_KEY": "mySecret", + "AWS_SESSION_TOKEN": "myToken", "AWS_DEFAULT_REGION": testRegion, "AWS_REGION": testRegion, }, @@ -463,9 +466,9 @@ func TestRunLocalOpts_Execute(t *testing.T) { }, Secrets: map[string]string{ "SHARED_SECRET": "secretvalue", - "AWS_ACCESS_KEY_ID": "taskRoleID", - "AWS_SECRET_ACCESS_KEY": "taskRoleSecret", - "AWS_SESSION_TOKEN": "taskRoleToken", + "AWS_ACCESS_KEY_ID": "myID", + "AWS_SECRET_ACCESS_KEY": "mySecret", + "AWS_SESSION_TOKEN": "myToken", "AWS_DEFAULT_REGION": testRegion, "AWS_REGION": testRegion, }, @@ -488,6 +491,7 @@ func TestRunLocalOpts_Execute(t *testing.T) { inputWatch bool inputTaskRole bool inputProxy bool + inputReader io.Reader buildImagesError error setupMocks func(t *testing.T, m *runLocalExecuteMocks) @@ -522,14 +526,35 @@ func TestRunLocalOpts_Execute(t *testing.T) { inputWkldName: testWkldName, inputEnvName: testEnvName, inputTaskRole: true, + inputReader: strings.NewReader("some error"), setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) m.sessProvider.EXPECT().FromRole("mock-arn", testRegion).Return(nil, errors.New("some error")) + m.ecsClient.EXPECT().DescribeService(testAppName, testEnvName, testWkldName).Return(&ecs.ServiceDesc{ + Tasks: []*awsecs.Task{ + { + TaskArn: aws.String("arn:aws:ecs:us-west-2:123456789:task/clusterName/taskName"), + Containers: []*sdkecs.Container{ + { + RuntimeId: aws.String("runtime-id"), + LastStatus: aws.String("RUNNING"), + ManagedAgents: []*sdkecs.ManagedAgent{ + { + Name: aws.String("ExecuteCommandAgent"), + LastStatus: aws.String("RUNNING"), + }, + }, + }, + }, + }, + }, + }, nil) + m.ecsExecutor.EXPECT().ExecuteCommand(gomock.Any()).Return(nil) }, - wantedError: errors.New(`get task: retrieve task role credentials: some error -ecs exec method not implemented`), + wantedError: errors.New(`get task: retrieve task role credentials: assume role: some error +ecs exec: all containers failed to retrieve credentials`), }, "error reading workload manifest": { inputAppName: testAppName, @@ -807,7 +832,7 @@ ecs exec method not implemented`), } }, }, - "success, one run task call, taskrole assumerole method": { + "success, one run task call, task role assume role method": { inputAppName: testAppName, inputWkldName: testWkldName, inputEnvName: testEnvName, @@ -818,7 +843,7 @@ ecs exec method not implemented`), m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) taskRoleSess := &session.Session{ Config: &aws.Config{ - Credentials: credentials.NewStaticCredentials("taskRoleID", "taskRoleSecret", "taskRoleToken"), + Credentials: credentials.NewStaticCredentials("myID", "mySecret", "myToken"), Region: aws.String(testRegion), }, } @@ -832,7 +857,55 @@ ecs exec method not implemented`), return errCh } m.orchestrator.RunTaskFn = func(task orchestrator.Task, opts ...orchestrator.RunTaskOption) { - require.Equal(t, expectedTaskRoleTask, task) + require.Equal(t, expectedTaskWithRegion, task) + } + m.orchestrator.StopFn = func() { + require.Len(t, errCh, 0) + close(errCh) + } + }, + }, + "success, one run task call, task role ecs exec method": { + inputAppName: testAppName, + inputWkldName: testWkldName, + inputEnvName: testEnvName, + inputTaskRole: true, + inputReader: strings.NewReader(`{"AccessKeyId":"myID","SecretAccessKey":"mySecret","Token":"myToken"}`), + setupMocks: func(t *testing.T, m *runLocalExecuteMocks) { + m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) + m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil) + m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil) + m.sessProvider.EXPECT().FromRole("mock-arn", testRegion).Return(nil, errors.New("some error")) + m.ecsClient.EXPECT().DescribeService(testAppName, testEnvName, testWkldName).Return(&ecs.ServiceDesc{ + Tasks: []*awsecs.Task{ + { + TaskArn: aws.String("arn:aws:ecs:us-west-2:123456789:task/clusterName/taskName"), + Containers: []*sdkecs.Container{ + { + RuntimeId: aws.String("runtime-id"), + LastStatus: aws.String("RUNNING"), + ManagedAgents: []*sdkecs.ManagedAgent{ + { + Name: aws.String("ExecuteCommandAgent"), + LastStatus: aws.String("RUNNING"), + }, + }, + }, + }, + }, + }, + }, nil) + m.ecsExecutor.EXPECT().ExecuteCommand(gomock.Any()).Return(nil) + m.ws.EXPECT().ReadWorkloadManifest(testWkldName).Return([]byte(""), nil) + m.interpolator.EXPECT().Interpolate("").Return("", nil) + + errCh := make(chan error, 1) + m.orchestrator.StartFn = func() <-chan error { + errCh <- errors.New("some error") + return errCh + } + m.orchestrator.RunTaskFn = func(task orchestrator.Task, opts ...orchestrator.RunTaskOption) { + require.Equal(t, expectedTask, task) } m.orchestrator.StopFn = func() { require.Len(t, errCh, 0) @@ -1053,6 +1126,7 @@ ecs exec method not implemented`), defer ctrl.Finish() m := &runLocalExecuteMocks{ ecsClient: mocks.NewMockecsClient(ctrl), + ecsExecutor: mocks.NewMockecsCommandExecutor(ctrl), ssm: mocks.NewMocksecretGetter(ctrl), secretsManager: mocks.NewMocksecretGetter(ctrl), store: mocks.NewMockstore(ctrl), @@ -1103,6 +1177,7 @@ ecs exec method not implemented`), }, ws: m.ws, ecsClient: m.ecsClient, + ecsExecutor: m.ecsExecutor, ssm: m.ssm, secretsManager: m.secretsManager, store: m.store, @@ -1130,6 +1205,10 @@ ecs exec method not implemented`), newRecursiveWatcher: func() (recursiveWatcher, error) { return m.watcher, nil }, + captureStdout: func() (io.Reader, error) { + return tc.inputReader, nil + }, + releaseStdout: func() {}, } // WHEN err := opts.Execute()