Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/aws/copilot-cli

go 1.20
go 1.21

require (
github.com/AlecAivazis/survey/v2 v2.3.2
Expand Down
61 changes: 51 additions & 10 deletions internal/pkg/aws/ecs/ecs.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type api interface {
StopTask(input *ecs.StopTaskInput) (*ecs.StopTaskOutput, error)
UpdateService(input *ecs.UpdateServiceInput) (*ecs.UpdateServiceOutput, error)
WaitUntilTasksRunning(input *ecs.DescribeTasksInput) error
ListServicesByNamespacePages(input *ecs.ListServicesByNamespaceInput, fn func(*ecs.ListServicesByNamespaceOutput, bool) bool) error
}

type ssmSessionStarter interface {
Expand Down Expand Up @@ -100,20 +101,60 @@ func (e *ECS) TaskDefinition(taskDefName string) (*TaskDefinition, error) {

// Service calls ECS API and returns the specified service running in the cluster.
func (e *ECS) Service(clusterName, serviceName string) (*Service, error) {
resp, err := e.client.DescribeServices(&ecs.DescribeServicesInput{
Cluster: aws.String(clusterName),
Services: aws.StringSlice([]string{serviceName}),
})
svcs, err := e.Services(clusterName, serviceName)
if err != nil {
return nil, fmt.Errorf("describe service %s: %w", serviceName, err)
return nil, err
}
for _, service := range resp.Services {
if aws.StringValue(service.ServiceName) == serviceName {
svc := Service(*service)
return &svc, nil
if aws.StringValue(svcs[0].ServiceName) != serviceName {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check the length just in case 🥺

return nil, fmt.Errorf("cannot find service %s", serviceName)
}

return svcs[0], nil
}

// Services calls the ECS API and returns all of the specified services running in cluster.
func (e *ECS) Services(cluster string, services ...string) ([]*Service, error) {
var svcs []*Service

for i := 0; i < len(services); i += 10 {
split := services[i:min(10+i, len(services))]

resp, err := e.client.DescribeServices(&ecs.DescribeServicesInput{
Cluster: aws.String(cluster),
Services: aws.StringSlice(split),
})
switch {
case err != nil:
return nil, fmt.Errorf("describe services: %w", err)
case len(resp.Failures) > 0:
return nil, fmt.Errorf("describe services: %s", resp.Failures[0].String())
case len(resp.Services) != len(split):
return nil, fmt.Errorf("describe services: got %v services, but expected %v", len(resp.Services), len(split))
}

for j := range resp.Services {
svc := Service(*resp.Services[j])
svcs = append(svcs, &svc)
}
}
return nil, fmt.Errorf("cannot find service %s", serviceName)

return svcs, nil
}

// ListServicesByNamespace returns a list of service ARNs of services that
// are in the given namespace.
func (e *ECS) ListServicesByNamespace(namespace string) ([]string, error) {
var arns []string
err := e.client.ListServicesByNamespacePages(&ecs.ListServicesByNamespaceInput{
Namespace: aws.String(namespace),
}, func(resp *ecs.ListServicesByNamespaceOutput, b bool) bool {
arns = append(arns, aws.StringValueSlice(resp.ServiceArns)...)
return true
})
if err != nil {
return nil, err
}
return arns, nil
}

// UpdateServiceOpts sets the optional parameter for UpdateService.
Expand Down
238 changes: 236 additions & 2 deletions internal/pkg/aws/ecs/ecs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func TestECS_Service(t *testing.T) {
Services: aws.StringSlice([]string{"mockService"}),
}).Return(nil, errors.New("some error"))
},
wantErr: fmt.Errorf("describe service mockService: some error"),
wantErr: fmt.Errorf("describe services: some error"),
},
"errors if failed to find the service": {
clusterName: "mockCluster",
Expand Down Expand Up @@ -187,6 +187,240 @@ func TestECS_Service(t *testing.T) {
}
}

func TestECS_Services(t *testing.T) {
testCases := map[string]struct {
clusterName string
services []string
mockECSClient func(m *mocks.Mockapi)

wantErr string
wantSvcs []*Service
}{
"error if api call error": {
clusterName: "mockCluster",
services: []string{"1"},
mockECSClient: func(m *mocks.Mockapi) {
m.EXPECT().DescribeServices(&ecs.DescribeServicesInput{
Cluster: aws.String("mockCluster"),
Services: aws.StringSlice([]string{"1"}),
}).Return(nil, errors.New("some error"))
},
wantErr: "describe services: some error",
},
"error if api returns failure": {
clusterName: "mockCluster",
services: []string{"1"},
mockECSClient: func(m *mocks.Mockapi) {
m.EXPECT().DescribeServices(&ecs.DescribeServicesInput{
Cluster: aws.String("mockCluster"),
Services: aws.StringSlice([]string{"1"}),
}).Return(&ecs.DescribeServicesOutput{
Failures: []*ecs.Failure{
{
Arn: aws.String("arn:1"),
Reason: aws.String("some error"),
},
},
}, nil)
},
wantErr: `describe services: {
Arn: "arn:1",
Reason: "some error"
}`,
},
"error if api returns incorrect count": {
clusterName: "mockCluster",
services: []string{"1", "2"},
mockECSClient: func(m *mocks.Mockapi) {
m.EXPECT().DescribeServices(&ecs.DescribeServicesInput{
Cluster: aws.String("mockCluster"),
Services: aws.StringSlice([]string{"1", "2"}),
}).Return(&ecs.DescribeServicesOutput{
Services: []*ecs.Service{
{
ServiceName: aws.String("1"),
},
},
}, nil)
},
wantErr: "describe services: got 1 services, but expected 2",
},
"success with > 10": {
clusterName: "mockCluster",
services: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11",
},
mockECSClient: func(m *mocks.Mockapi) {
m.EXPECT().DescribeServices(&ecs.DescribeServicesInput{
Cluster: aws.String("mockCluster"),
Services: aws.StringSlice([]string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}),
}).Return(&ecs.DescribeServicesOutput{
Services: []*ecs.Service{
{
ServiceName: aws.String("1"),
},
{
ServiceName: aws.String("2"),
},
{
ServiceName: aws.String("3"),
},
{
ServiceName: aws.String("4"),
},
{
ServiceName: aws.String("5"),
},
{
ServiceName: aws.String("6"),
},
{
ServiceName: aws.String("7"),
},
{
ServiceName: aws.String("8"),
},
{
ServiceName: aws.String("9"),
},
{
ServiceName: aws.String("10"),
},
},
}, nil)
m.EXPECT().DescribeServices(&ecs.DescribeServicesInput{
Cluster: aws.String("mockCluster"),
Services: aws.StringSlice([]string{"11"}),
}).Return(&ecs.DescribeServicesOutput{
Services: []*ecs.Service{
{
ServiceName: aws.String("11"),
},
},
}, nil)
},
wantSvcs: []*Service{
{
ServiceName: aws.String("1"),
},
{
ServiceName: aws.String("2"),
},
{
ServiceName: aws.String("3"),
},
{
ServiceName: aws.String("4"),
},
{
ServiceName: aws.String("5"),
},
{
ServiceName: aws.String("6"),
},
{
ServiceName: aws.String("7"),
},
{
ServiceName: aws.String("8"),
},
{
ServiceName: aws.String("9"),
},
{
ServiceName: aws.String("10"),
},
{
ServiceName: aws.String("11"),
},
},
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
// GIVEN
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockECSClient := mocks.NewMockapi(ctrl)
tc.mockECSClient(mockECSClient)

service := ECS{
client: mockECSClient,
}

gotSvcs, gotErr := service.Services(tc.clusterName, tc.services...)

if tc.wantErr != "" {
require.EqualError(t, gotErr, tc.wantErr)
} else {
require.Equal(t, tc.wantSvcs, gotSvcs)
require.NoError(t, gotErr)
}
})
}
}

func TestECS_ListServicesByNamespace(t *testing.T) {
testCases := map[string]struct {
namespace string
mockECSClient func(m *mocks.Mockapi)

wantErr string
wantARNs []string
}{
"error if api call error": {
namespace: "mockNamespace",
mockECSClient: func(m *mocks.Mockapi) {
m.EXPECT().ListServicesByNamespacePages(&ecs.ListServicesByNamespaceInput{
Namespace: aws.String("mockNamespace"),
}, gomock.Any()).Return(errors.New("some error"))
},
wantErr: "some error",
},
"success": {
namespace: "mockNamespace",
mockECSClient: func(m *mocks.Mockapi) {
m.EXPECT().ListServicesByNamespacePages(&ecs.ListServicesByNamespaceInput{
Namespace: aws.String("mockNamespace"),
}, gomock.Any()).DoAndReturn(func(in *ecs.ListServicesByNamespaceInput, fn func(*ecs.ListServicesByNamespaceOutput, bool) bool) error {
fn(&ecs.ListServicesByNamespaceOutput{
ServiceArns: []*string{aws.String("svc1"), aws.String("svc2"), aws.String("svc3")},
}, true)
return nil
})
},
wantARNs: []string{"svc1", "svc2", "svc3"},
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
// GIVEN
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockECSClient := mocks.NewMockapi(ctrl)
tc.mockECSClient(mockECSClient)

service := ECS{
client: mockECSClient,
}

gotARNs, gotErr := service.ListServicesByNamespace(tc.namespace)

if tc.wantErr != "" {
require.EqualError(t, gotErr, tc.wantErr)
} else {
require.Equal(t, tc.wantARNs, gotARNs)
require.NoError(t, gotErr)
}
})
}
}

func TestECS_UpdateService(t *testing.T) {
const (
clusterName = "mockCluster"
Expand Down Expand Up @@ -260,7 +494,7 @@ func TestECS_UpdateService(t *testing.T) {
Services: aws.StringSlice([]string{serviceName}),
}).Return(nil, errors.New("some error"))
},
wantErr: fmt.Errorf("wait until service mockService becomes stable: describe service mockService: some error"),
wantErr: fmt.Errorf("wait until service mockService becomes stable: describe services: some error"),
},
"success": {
forceUpdate: true,
Expand Down
14 changes: 14 additions & 0 deletions internal/pkg/aws/ecs/mocks/mock_ecs.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions internal/pkg/cli/flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ const (
// Run local flags
portOverrideFlag = "port-override"
envVarOverrideFlag = "env-var-override"
proxyFlag = "proxy"

// Flags for CI/CD.
githubURLFlag = "github-url"
Expand Down Expand Up @@ -320,6 +321,7 @@ Defaults to all logs. Only one of end-time / follow may be used.`
Format: [container]:KEY=VALUE. Omit container name to apply to all containers.`
portOverridesFlagDescription = `Optional. Override ports exposed by service. Format: <host port>:<service port>.
Example: --port-override 5000:80 binds localhost:5000 to the service's port 80.`
proxyFlagDescription = `Optional. Proxy outbound requests to your environment's VPC.`

svcManifestFlagDescription = `Optional. Name of the environment in which the service was deployed;
output the manifest file used for that deployment.`
Expand Down
3 changes: 2 additions & 1 deletion internal/pkg/cli/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,9 @@ type repositoryService interface {
imageBuilderPusher
}

type ecsLocalClient interface {
type ecsClient interface {
TaskDefinition(app, env, svc string) (*awsecs.TaskDefinition, error)
ServiceConnectServices(app, env, svc string) ([]*awsecs.Service, error)
}

type logEventsWriter interface {
Expand Down
Loading