diff --git a/internal/api/handler_accounts.go b/internal/api/handler_accounts.go index 98eecf67..255130e6 100644 --- a/internal/api/handler_accounts.go +++ b/internal/api/handler_accounts.go @@ -3,7 +3,9 @@ package api import ( "context" "encoding/json" + "errors" "fmt" + "slices" "strings" "time" @@ -1037,7 +1039,84 @@ func (h *Handler) deleteAccountServiceOverride(ctx context.Context, req *events. return nil, nil } +// validatePlanAccountProviders enforces the issue-#209 / spec E-4 rule +// that every account assigned to a plan must have its provider match +// one of the plan's derived providers. Returns: +// - 404 ClientError when the plan does not exist +// - 404 ClientError when an account_id does not exist (referencing +// the offending ID) +// - 400 ClientError listing every provider mismatch in one message +// (so clients fix everything in one round-trip rather than +// resubmitting to discover the next) +// - nil when all accounts match (or when the plan has no parseable +// services — defensive skip; production plans always carry at least +// one service, frontend enforces this) +// +// Pulled out of setPlanAccounts to keep that function under the gocyclo +// budget (limit 10). No business logic lives here that isn't otherwise +// described in setPlanAccounts' doc comment. +func (h *Handler) validatePlanAccountProviders(ctx context.Context, planID string, accountIDs []string) error { + plan, err := h.getPlanForAccountProviderValidation(ctx, planID) + if err != nil { + return err + } + + expected := config.DerivePlanProviders(plan) + if len(expected) == 0 { + return nil + } + + type mismatch struct { + ID string + Name string + Provider string + } + var mismatches []mismatch + for _, aid := range accountIDs { + acct, getErr := h.config.GetCloudAccount(ctx, aid) + if getErr != nil { + return fmt.Errorf("accounts: failed to get account %s: %w", aid, getErr) + } + if acct == nil { + return NewClientError(404, fmt.Sprintf("account not found: %s", aid)) + } + if !slices.Contains(expected, acct.Provider) { + mismatches = append(mismatches, mismatch{ID: aid, Name: acct.Name, Provider: acct.Provider}) + } + } + if len(mismatches) == 0 { + return nil + } + + parts := make([]string, len(mismatches)) + for i, m := range mismatches { + parts[i] = fmt.Sprintf("account %q has provider=%q, expected one of %v", + m.Name, m.Provider, expected) + } + return NewClientError(400, "plan provider mismatch: "+strings.Join(parts, "; ")) +} + +func (h *Handler) getPlanForAccountProviderValidation(ctx context.Context, planID string) (*config.PurchasePlan, error) { + plan, err := h.config.GetPurchasePlan(ctx, planID) + if err != nil { + if errors.Is(err, config.ErrNotFound) { + return nil, NewClientError(404, fmt.Sprintf("plan not found: %s", planID)) + } + return nil, fmt.Errorf("accounts: failed to get plan: %w", err) + } + if plan == nil { + return nil, NewClientError(404, fmt.Sprintf("plan not found: %s", planID)) + } + return plan, nil +} + // setPlanAccounts handles PUT /api/plans/:id/accounts. +// +// Per issue #209 / spec acceptance criterion E-4, every account assigned +// to a plan must have its provider match one of the plan's derived +// providers (extracted from plan.Services keys). Mismatches return 400 +// listing every offender; the assignment is rejected atomically (no +// partial writes). func (h *Handler) setPlanAccounts(ctx context.Context, httpReq *events.LambdaFunctionURLRequest, id string) (any, error) { if err := validateUUID(id); err != nil { return nil, err @@ -1060,7 +1139,16 @@ func (h *Handler) setPlanAccounts(ctx context.Context, httpReq *events.LambdaFun } } + // Provider-match validation (issue #209). Extracted to keep + // setPlanAccounts under the gocyclo budget (10). + if err := h.validatePlanAccountProviders(ctx, id, body.AccountIDs); err != nil { + return nil, err + } + if err := h.config.SetPlanAccounts(ctx, id, body.AccountIDs); err != nil { + if errors.Is(err, config.ErrNotFound) { + return nil, NewClientError(404, err.Error()) + } return nil, fmt.Errorf("accounts: %w", err) } diff --git a/internal/api/handler_accounts_test.go b/internal/api/handler_accounts_test.go index adb0c238..9336dc20 100644 --- a/internal/api/handler_accounts_test.go +++ b/internal/api/handler_accounts_test.go @@ -3,6 +3,7 @@ package api import ( "context" "errors" + "fmt" "testing" "github.com/LeanerCloud/CUDly/internal/accounts" @@ -485,6 +486,280 @@ func TestSetPlanAccounts_Success(t *testing.T) { assert.Nil(t, result) } +// ── Provider-validation tests for setPlanAccounts (issue #209) +// Every account assigned to a plan must have its provider match one of +// the providers derived from the plan's services map (key format +// "provider:service"). Mismatches return a single 400 listing every +// offender; the underlying store write is never invoked on failure. + +const ( + planID209 = "22222222-2222-2222-2222-222222222222" + awsAcct209 = "11111111-1111-1111-1111-111111111111" + azureAcct1 = "33333333-3333-3333-3333-333333333333" + azureAcct2 = "44444444-4444-4444-4444-444444444444" + missingAcct = "55555555-5555-5555-5555-555555555555" +) + +// awsPlan209 returns a plan whose services map yields a single derived +// provider ("aws"). Key format is "provider/service" as produced by +// buildServiceConfig. Used as the default fixture across the mismatch +// tests below. +func awsPlan209() *config.PurchasePlan { + return &config.PurchasePlan{ + ID: planID209, + Name: "AWS-only plan", + Services: map[string]config.ServiceConfig{"aws/ec2": {}}, + } +} + +func TestSetPlanAccounts_SingleMismatch(t *testing.T) { + ctx := context.Background() + mockAuth := new(MockAuthService) + setupAdminAuth(ctx, mockAuth) + + setCalled := false + store := setupAdminMock(ctx) + store.GetPurchasePlanFn = func(_ context.Context, _ string) (*config.PurchasePlan, error) { + return awsPlan209(), nil + } + store.GetCloudAccountFn = func(_ context.Context, id string) (*config.CloudAccount, error) { + return &config.CloudAccount{ID: id, Name: "prod-azure", Provider: "azure"}, nil + } + store.SetPlanAccountsFn = func(_ context.Context, _ string, _ []string) error { + setCalled = true + return nil + } + handler := &Handler{auth: mockAuth, config: store} + + body := `{"account_ids":["` + azureAcct1 + `"]}` + _, err := handler.setPlanAccounts(ctx, adminRequest(body), planID209) + require.Error(t, err) + ce, ok := IsClientError(err) + require.True(t, ok, "expected a clientError, got %T", err) + assert.Equal(t, 400, ce.code) + assert.Contains(t, ce.Error(), "prod-azure") + assert.Contains(t, ce.Error(), "azure") + assert.Contains(t, ce.Error(), "aws") + assert.False(t, setCalled, "SetPlanAccounts must NOT be called when validation fails") +} + +func TestSetPlanAccounts_MultipleMismatches(t *testing.T) { + ctx := context.Background() + mockAuth := new(MockAuthService) + setupAdminAuth(ctx, mockAuth) + + setCalled := false + store := setupAdminMock(ctx) + store.GetPurchasePlanFn = func(_ context.Context, _ string) (*config.PurchasePlan, error) { + return awsPlan209(), nil + } + store.GetCloudAccountFn = func(_ context.Context, id string) (*config.CloudAccount, error) { + switch id { + case azureAcct1: + return &config.CloudAccount{ID: id, Name: "prod-azure", Provider: "azure"}, nil + case azureAcct2: + return &config.CloudAccount{ID: id, Name: "stage-gcp", Provider: "gcp"}, nil + } + return nil, nil + } + store.SetPlanAccountsFn = func(_ context.Context, _ string, _ []string) error { + setCalled = true + return nil + } + handler := &Handler{auth: mockAuth, config: store} + + body := `{"account_ids":["` + azureAcct1 + `","` + azureAcct2 + `"]}` + _, err := handler.setPlanAccounts(ctx, adminRequest(body), planID209) + require.Error(t, err) + ce, ok := IsClientError(err) + require.True(t, ok) + assert.Equal(t, 400, ce.code) + // Both offenders named in a single error so the client gets the full + // picture in one round-trip. + assert.Contains(t, ce.Error(), "prod-azure") + assert.Contains(t, ce.Error(), "stage-gcp") + assert.Contains(t, ce.Error(), "azure") + assert.Contains(t, ce.Error(), "gcp") + assert.False(t, setCalled, "SetPlanAccounts must NOT be called when validation fails") +} + +func TestSetPlanAccounts_ValidHappyPath(t *testing.T) { + ctx := context.Background() + mockAuth := new(MockAuthService) + setupAdminAuth(ctx, mockAuth) + + var capturedIDs []string + store := setupAdminMock(ctx) + store.GetPurchasePlanFn = func(_ context.Context, _ string) (*config.PurchasePlan, error) { + return awsPlan209(), nil + } + store.GetCloudAccountFn = func(_ context.Context, id string) (*config.CloudAccount, error) { + return &config.CloudAccount{ID: id, Name: "prod-aws", Provider: "aws"}, nil + } + store.SetPlanAccountsFn = func(_ context.Context, _ string, ids []string) error { + capturedIDs = ids + return nil + } + handler := &Handler{auth: mockAuth, config: store} + + body := `{"account_ids":["` + awsAcct209 + `"]}` + result, err := handler.setPlanAccounts(ctx, adminRequest(body), planID209) + require.NoError(t, err) + assert.Nil(t, result) + assert.Equal(t, []string{awsAcct209}, capturedIDs, "SetPlanAccounts should be called with the validated IDs") +} + +func TestSetPlanAccounts_StoreNotFoundAfterValidationMapsTo404(t *testing.T) { + ctx := context.Background() + mockAuth := new(MockAuthService) + setupAdminAuth(ctx, mockAuth) + + store := setupAdminMock(ctx) + store.GetPurchasePlanFn = func(_ context.Context, _ string) (*config.PurchasePlan, error) { + return awsPlan209(), nil + } + store.GetCloudAccountFn = func(_ context.Context, id string) (*config.CloudAccount, error) { + return &config.CloudAccount{ID: id, Name: "prod-aws", Provider: "aws"}, nil + } + store.SetPlanAccountsFn = func(_ context.Context, _ string, _ []string) error { + return fmt.Errorf("%w: account %s", config.ErrNotFound, awsAcct209) + } + handler := &Handler{auth: mockAuth, config: store} + + body := `{"account_ids":["` + awsAcct209 + `"]}` + _, err := handler.setPlanAccounts(ctx, adminRequest(body), planID209) + require.Error(t, err) + ce, ok := IsClientError(err) + require.True(t, ok) + assert.Equal(t, 404, ce.code) + assert.Contains(t, ce.Error(), awsAcct209) +} + +func TestSetPlanAccounts_PlanNotFound(t *testing.T) { + ctx := context.Background() + mockAuth := new(MockAuthService) + setupAdminAuth(ctx, mockAuth) + + setCalled := false + store := setupAdminMock(ctx) + store.GetPurchasePlanFn = func(_ context.Context, _ string) (*config.PurchasePlan, error) { + return nil, nil // store-style "not found": (nil, nil) + } + store.SetPlanAccountsFn = func(_ context.Context, _ string, _ []string) error { + setCalled = true + return nil + } + handler := &Handler{auth: mockAuth, config: store} + + body := `{"account_ids":["` + awsAcct209 + `"]}` + _, err := handler.setPlanAccounts(ctx, adminRequest(body), planID209) + require.Error(t, err) + ce, ok := IsClientError(err) + require.True(t, ok) + assert.Equal(t, 404, ce.code) + assert.Contains(t, ce.Error(), planID209) + assert.False(t, setCalled, "SetPlanAccounts must NOT be called when the plan is not found") +} + +func TestSetPlanAccounts_AccountNotFound(t *testing.T) { + ctx := context.Background() + mockAuth := new(MockAuthService) + setupAdminAuth(ctx, mockAuth) + + setCalled := false + store := setupAdminMock(ctx) + store.GetPurchasePlanFn = func(_ context.Context, _ string) (*config.PurchasePlan, error) { + return awsPlan209(), nil + } + store.GetCloudAccountFn = func(_ context.Context, id string) (*config.CloudAccount, error) { + if id == missingAcct { + return nil, nil // store-style not-found + } + return &config.CloudAccount{ID: id, Name: "prod-aws", Provider: "aws"}, nil + } + store.SetPlanAccountsFn = func(_ context.Context, _ string, _ []string) error { + setCalled = true + return nil + } + handler := &Handler{auth: mockAuth, config: store} + + body := `{"account_ids":["` + missingAcct + `"]}` + _, err := handler.setPlanAccounts(ctx, adminRequest(body), planID209) + require.Error(t, err) + ce, ok := IsClientError(err) + require.True(t, ok) + assert.Equal(t, 404, ce.code) + assert.Contains(t, ce.Error(), missingAcct, "404 should reference the missing account ID") + assert.False(t, setCalled, "SetPlanAccounts must NOT be called when an account is not found") +} + +func TestSetPlanAccounts_MixedValidAndMismatch(t *testing.T) { + ctx := context.Background() + mockAuth := new(MockAuthService) + setupAdminAuth(ctx, mockAuth) + + setCalled := false + store := setupAdminMock(ctx) + store.GetPurchasePlanFn = func(_ context.Context, _ string) (*config.PurchasePlan, error) { + return awsPlan209(), nil + } + store.GetCloudAccountFn = func(_ context.Context, id string) (*config.CloudAccount, error) { + switch id { + case awsAcct209: + return &config.CloudAccount{ID: id, Name: "prod-aws", Provider: "aws"}, nil + case azureAcct1: + return &config.CloudAccount{ID: id, Name: "prod-azure", Provider: "azure"}, nil + } + return nil, nil + } + store.SetPlanAccountsFn = func(_ context.Context, _ string, _ []string) error { + setCalled = true + return nil + } + handler := &Handler{auth: mockAuth, config: store} + + body := `{"account_ids":["` + awsAcct209 + `","` + azureAcct1 + `"]}` + _, err := handler.setPlanAccounts(ctx, adminRequest(body), planID209) + require.Error(t, err) + ce, ok := IsClientError(err) + require.True(t, ok) + assert.Equal(t, 400, ce.code) + // Only the Azure account is the offender; the AWS account does not + // appear in the error. + assert.Contains(t, ce.Error(), "prod-azure") + assert.NotContains(t, ce.Error(), "prod-aws") + assert.False(t, setCalled, "SetPlanAccounts must NOT be called when even one account fails validation") +} + +func TestSetPlanAccounts_EmptyServicesSkipsValidation(t *testing.T) { + ctx := context.Background() + mockAuth := new(MockAuthService) + setupAdminAuth(ctx, mockAuth) + + var capturedIDs []string + store := setupAdminMock(ctx) + // Plan with an empty services map — derived provider set is empty; + // the validation block skips and the assignment passes through. + // Pins the defensive behaviour so a future change is conscious. + store.GetPurchasePlanFn = func(_ context.Context, _ string) (*config.PurchasePlan, error) { + return &config.PurchasePlan{ID: planID209, Name: "no-services"}, nil + } + store.GetCloudAccountFn = func(_ context.Context, id string) (*config.CloudAccount, error) { + return &config.CloudAccount{ID: id, Name: "prod-azure", Provider: "azure"}, nil + } + store.SetPlanAccountsFn = func(_ context.Context, _ string, ids []string) error { + capturedIDs = ids + return nil + } + handler := &Handler{auth: mockAuth, config: store} + + body := `{"account_ids":["` + azureAcct1 + `"]}` + result, err := handler.setPlanAccounts(ctx, adminRequest(body), planID209) + require.NoError(t, err) + assert.Nil(t, result) + assert.Equal(t, []string{azureAcct1}, capturedIDs, "empty services map → validation skipped → write proceeds") +} + func TestListPlanAccounts_Success(t *testing.T) { ctx := context.Background() mockAuth := new(MockAuthService) diff --git a/internal/api/mocks_test.go b/internal/api/mocks_test.go index 37d29da5..ce569313 100644 --- a/internal/api/mocks_test.go +++ b/internal/api/mocks_test.go @@ -27,6 +27,16 @@ type MockConfigStore struct { // CreateCloudAccountFn overrides CreateCloudAccount when non-nil (used by // org-discovery tests to capture the new rows the handler persists). CreateCloudAccountFn func(ctx context.Context, account *config.CloudAccount) error + // GetPurchasePlanFn overrides GetPurchasePlan when non-nil. Used by the + // setPlanAccounts provider-validation tests (issue #209) to seed the + // plan without registering a testify expectation — see the fall-through + // comment in GetPurchasePlan below for why the default no-expectation + // path returns a minimal stub instead of panicking via m.Called. + GetPurchasePlanFn func(ctx context.Context, planID string) (*config.PurchasePlan, error) + // SetPlanAccountsFn overrides SetPlanAccounts when non-nil. The + // provider-validation tests use it to assert whether the underlying + // store write was invoked (mismatched assignments must NOT call it). + SetPlanAccountsFn func(ctx context.Context, planID string, accountIDs []string) error } func (m *MockConfigStore) GetGlobalConfig(ctx context.Context) (*config.GlobalConfig, error) { @@ -68,7 +78,21 @@ func (m *MockConfigStore) CreatePurchasePlan(ctx context.Context, plan *config.P return args.Error(0) } +// GetPurchasePlan resolves to (in order): an explicit GetPurchasePlanFn +// override, a registered testify expectation, or a default minimal plan +// (`{ID: planID}` with empty Services). The default-fallback path lets +// tests written before the issue-#209 provider-validation block (e.g. +// TestSetPlanAccounts_Success) keep working without setting up the +// new mock call — the empty Services map trips the defensive "no +// parseable services, skip provider validation" branch in +// setPlanAccounts so behaviour is unchanged for those legacy tests. func (m *MockConfigStore) GetPurchasePlan(ctx context.Context, planID string) (*config.PurchasePlan, error) { + if m.GetPurchasePlanFn != nil { + return m.GetPurchasePlanFn(ctx, planID) + } + if !m.isExpected("GetPurchasePlan") { + return &config.PurchasePlan{ID: planID}, nil + } args := m.Called(ctx, planID) if args.Get(0) == nil { return nil, args.Error(1) @@ -292,7 +316,20 @@ func (m *MockConfigStore) DeleteAccountServiceOverride(ctx context.Context, acco func (m *MockConfigStore) ListAccountServiceOverrides(ctx context.Context, accountID string) ([]config.AccountServiceOverride, error) { return nil, nil } + +// SetPlanAccounts uses SetPlanAccountsFn when non-nil so tests can +// capture and assert on the call (the issue-#209 mismatch tests verify +// the underlying store write is NOT invoked when validation fails). +// Falls back to m.Called when a testify expectation is registered so +// .On("SetPlanAccounts", ...) works correctly. The no-op is preserved +// only when neither path applies (tests that don't care). func (m *MockConfigStore) SetPlanAccounts(ctx context.Context, planID string, accountIDs []string) error { + if m.SetPlanAccountsFn != nil { + return m.SetPlanAccountsFn(ctx, planID, accountIDs) + } + if m.isExpected("SetPlanAccounts") { + return m.Called(ctx, planID, accountIDs).Error(0) + } return nil } func (m *MockConfigStore) GetPlanAccounts(ctx context.Context, planID string) ([]config.CloudAccount, error) { diff --git a/internal/config/errors.go b/internal/config/errors.go new file mode 100644 index 00000000..8ce0cf27 --- /dev/null +++ b/internal/config/errors.go @@ -0,0 +1,6 @@ +package config + +import "errors" + +// ErrNotFound is returned when a requested config-store row does not exist. +var ErrNotFound = errors.New("not found") diff --git a/internal/config/purchase_plan_providers.go b/internal/config/purchase_plan_providers.go new file mode 100644 index 00000000..455867bc --- /dev/null +++ b/internal/config/purchase_plan_providers.go @@ -0,0 +1,29 @@ +package config + +import ( + "sort" + "strings" +) + +// DerivePlanProviders extracts the distinct set of providers a plan targets +// by parsing the keys of plan.Services. Keys are expected to use the +// "provider/service" format produced by buildServiceConfig. +func DerivePlanProviders(plan *PurchasePlan) []string { + if plan == nil { + return nil + } + seen := make(map[string]struct{}, len(plan.Services)) + for k := range plan.Services { + provider, _, ok := strings.Cut(k, "/") + if !ok || provider == "" { + continue + } + seen[provider] = struct{}{} + } + providers := make([]string, 0, len(seen)) + for provider := range seen { + providers = append(providers, provider) + } + sort.Strings(providers) + return providers +} diff --git a/internal/config/purchase_plan_providers_test.go b/internal/config/purchase_plan_providers_test.go new file mode 100644 index 00000000..ed5f0fd3 --- /dev/null +++ b/internal/config/purchase_plan_providers_test.go @@ -0,0 +1,22 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDerivePlanProviders(t *testing.T) { + plan := &PurchasePlan{ + Services: map[string]ServiceConfig{ + "aws/ec2": {}, + "aws/rds": {}, + "azure/compute": {}, + "gcp:compute": {}, + "malformed": {}, + }, + } + + assert.Equal(t, []string{"aws", "azure"}, DerivePlanProviders(plan)) + assert.Nil(t, DerivePlanProviders(nil)) +} diff --git a/internal/config/store_postgres.go b/internal/config/store_postgres.go index 36282bfc..5144a6dc 100644 --- a/internal/config/store_postgres.go +++ b/internal/config/store_postgres.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "strings" "time" @@ -446,7 +447,7 @@ func (s *PostgresStore) GetPurchasePlan(ctx context.Context, planID string) (*Pu if err != nil { if err == pgx.ErrNoRows { - return nil, fmt.Errorf("purchase plan not found: %s", planID) + return nil, fmt.Errorf("%w: purchase plan %s", ErrNotFound, planID) } return nil, fmt.Errorf("failed to get purchase plan: %w", err) } @@ -1928,6 +1929,10 @@ func (s *PostgresStore) SetPlanAccounts(ctx context.Context, planID string, acco } defer tx.Rollback(ctx) //nolint:errcheck + if err = s.validatePlanAccountProvidersTx(ctx, tx, planID, accountIDs); err != nil { + return err + } + if _, err = tx.Exec(ctx, `DELETE FROM plan_accounts WHERE plan_id = $1`, planID); err != nil { return fmt.Errorf("failed to clear plan accounts: %w", err) } @@ -1947,6 +1952,88 @@ func (s *PostgresStore) SetPlanAccounts(ctx context.Context, planID string, acco return nil } +func (s *PostgresStore) validatePlanAccountProvidersTx(ctx context.Context, tx pgx.Tx, planID string, accountIDs []string) error { + services, err := s.getPlanServicesForShareTx(ctx, tx, planID) + if err != nil { + return err + } + if len(accountIDs) == 0 { + return nil + } + expected := DerivePlanProviders(&PurchasePlan{Services: services}) + if len(expected) == 0 { + return nil + } + + mismatches, err := s.findPlanAccountProviderMismatchesTx(ctx, tx, accountIDs, expected) + if err != nil { + return err + } + if len(mismatches) == 0 { + return nil + } + + parts := make([]string, len(mismatches)) + for i, mismatch := range mismatches { + parts[i] = fmt.Sprintf("account %q has provider=%q, expected one of %v", + mismatch.Name, mismatch.Provider, expected) + } + return fmt.Errorf("plan provider mismatch: %s", strings.Join(parts, "; ")) +} + +func (s *PostgresStore) getPlanServicesForShareTx(ctx context.Context, tx pgx.Tx, planID string) (map[string]ServiceConfig, error) { + var servicesJSON []byte + if err := tx.QueryRow(ctx, ` + SELECT services + FROM purchase_plans + WHERE id = $1 + FOR SHARE + `, planID).Scan(&servicesJSON); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, fmt.Errorf("%w: plan %s", ErrNotFound, planID) + } + return nil, fmt.Errorf("failed to get plan services: %w", err) + } + + services := make(map[string]ServiceConfig) + if err := json.Unmarshal(servicesJSON, &services); err != nil { + return nil, fmt.Errorf("failed to decode plan services: %w", err) + } + return services, nil +} + +type planAccountProviderMismatch struct { + Name string + Provider string +} + +func (s *PostgresStore) findPlanAccountProviderMismatchesTx(ctx context.Context, tx pgx.Tx, accountIDs []string, expected []string) ([]planAccountProviderMismatch, error) { + expectedSet := make(map[string]struct{}, len(expected)) + for _, provider := range expected { + expectedSet[provider] = struct{}{} + } + + var mismatches []planAccountProviderMismatch + for _, accountID := range accountIDs { + var name, provider string + if err := tx.QueryRow(ctx, ` + SELECT name, provider + FROM cloud_accounts + WHERE id = $1 + FOR SHARE + `, accountID).Scan(&name, &provider); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, fmt.Errorf("%w: account %s", ErrNotFound, accountID) + } + return nil, fmt.Errorf("failed to get account %s: %w", accountID, err) + } + if _, ok := expectedSet[provider]; !ok { + mismatches = append(mismatches, planAccountProviderMismatch{Name: name, Provider: provider}) + } + } + return mismatches, nil +} + // GetPlanAccounts returns all cloud accounts associated with a plan. func (s *PostgresStore) GetPlanAccounts(ctx context.Context, planID string) ([]CloudAccount, error) { query := ` diff --git a/internal/config/store_postgres_pgxmock_test.go b/internal/config/store_postgres_pgxmock_test.go index f0038583..78139b5d 100644 --- a/internal/config/store_postgres_pgxmock_test.go +++ b/internal/config/store_postgres_pgxmock_test.go @@ -1121,8 +1121,21 @@ func TestPGXMock_SetPlanAccounts_Success(t *testing.T) { mock := newMock(t) store := storeWith(mock) ctx := context.Background() + servicesJSON, err := json.Marshal(map[string]ServiceConfig{ + "aws/ec2": {Provider: "aws", Service: "ec2"}, + }) + require.NoError(t, err) mock.ExpectBegin() + mock.ExpectQuery("SELECT services").WithArgs("plan-1").WillReturnRows( + pgxmock.NewRows([]string{"services"}).AddRow(servicesJSON), + ) + mock.ExpectQuery("SELECT name, provider").WithArgs("acct-1").WillReturnRows( + pgxmock.NewRows([]string{"name", "provider"}).AddRow("Account 1", "aws"), + ) + mock.ExpectQuery("SELECT name, provider").WithArgs("acct-2").WillReturnRows( + pgxmock.NewRows([]string{"name", "provider"}).AddRow("Account 2", "aws"), + ) mock.ExpectExec("DELETE FROM plan_accounts").WithArgs(pgxmock.AnyArg()). WillReturnResult(pgxmock.NewResult("DELETE", 0)) mock.ExpectExec("INSERT INTO plan_accounts"). @@ -1133,23 +1146,71 @@ func TestPGXMock_SetPlanAccounts_Success(t *testing.T) { WillReturnResult(pgxmock.NewResult("INSERT", 1)) mock.ExpectCommit() - err := store.SetPlanAccounts(ctx, "plan-1", []string{"acct-1", "acct-2"}) + err = store.SetPlanAccounts(ctx, "plan-1", []string{"acct-1", "acct-2"}) require.NoError(t, err) assert.NoError(t, mock.ExpectationsWereMet()) } +func TestPGXMock_SetPlanAccounts_ProviderMismatchRollsBackBeforeDelete(t *testing.T) { + mock := newMock(t) + store := storeWith(mock) + ctx := context.Background() + servicesJSON, err := json.Marshal(map[string]ServiceConfig{ + "aws/ec2": {Provider: "aws", Service: "ec2"}, + }) + require.NoError(t, err) + + mock.ExpectBegin() + mock.ExpectQuery("SELECT services").WithArgs("plan-1").WillReturnRows( + pgxmock.NewRows([]string{"services"}).AddRow(servicesJSON), + ) + mock.ExpectQuery("SELECT name, provider").WithArgs("acct-1").WillReturnRows( + pgxmock.NewRows([]string{"name", "provider"}).AddRow("Azure Account", "azure"), + ) + mock.ExpectRollback() + + err = store.SetPlanAccounts(ctx, "plan-1", []string{"acct-1"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "plan provider mismatch") + assert.Contains(t, err.Error(), "Azure Account") + assert.NoError(t, mock.ExpectationsWereMet()) +} + func TestPGXMock_SetPlanAccounts_Empty(t *testing.T) { mock := newMock(t) store := storeWith(mock) ctx := context.Background() + servicesJSON, err := json.Marshal(map[string]ServiceConfig{ + "aws/ec2": {Provider: "aws", Service: "ec2"}, + }) + require.NoError(t, err) mock.ExpectBegin() + mock.ExpectQuery("SELECT services").WithArgs("plan-1").WillReturnRows( + pgxmock.NewRows([]string{"services"}).AddRow(servicesJSON), + ) mock.ExpectExec("DELETE FROM plan_accounts").WithArgs(pgxmock.AnyArg()). WillReturnResult(pgxmock.NewResult("DELETE", 0)) mock.ExpectCommit() - err := store.SetPlanAccounts(ctx, "plan-1", []string{}) + err = store.SetPlanAccounts(ctx, "plan-1", []string{}) require.NoError(t, err) + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPGXMock_SetPlanAccounts_EmptyMissingPlanReturnsNotFound(t *testing.T) { + mock := newMock(t) + store := storeWith(mock) + ctx := context.Background() + + mock.ExpectBegin() + mock.ExpectQuery("SELECT services").WithArgs("missing-plan").WillReturnError(pgx.ErrNoRows) + mock.ExpectRollback() + + err := store.SetPlanAccounts(ctx, "missing-plan", []string{}) + require.ErrorIs(t, err, ErrNotFound) + assert.Contains(t, err.Error(), "missing-plan") + assert.NoError(t, mock.ExpectationsWereMet()) } func TestPGXMock_SetPlanAccounts_BeginError(t *testing.T) {