From f576db2dfbe78c9f67341f8c7b024f3cd58dd5ab Mon Sep 17 00:00:00 2001 From: Cristian Magherusan-Stanciu Date: Sun, 3 May 2026 13:22:12 +0200 Subject: [PATCH 1/5] feat(api): validate plan account provider matches plan provider on assignment (closes #209) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Backend gate for PUT /api/plans/:id/accounts. Every account assigned to a plan must have its provider match one of the providers derived from the plan's services map (key format "provider:service", e.g. "aws:ec2"). Mismatches return a single HTTP 400 listing every offender; the underlying SetPlanAccounts store write is never invoked on failure. Per spec acceptance criterion E-4 in specs/multi-account-execution/acceptance.md, this is the backend hardening of the existing frontend single-provider-per-plan rule. Implementation -------------- - New `derivePlanProviders(plan)` helper extracts the distinct provider set from `plan.Services` keys (sorted slice for stable error messages). - New `Handler.validatePlanAccountProviders(ctx, planID, accountIDs)` helper holds the validation block — extracted from setPlanAccounts to stay under the gocyclo budget (limit 10). - `setPlanAccounts` loads the plan, derives providers, then for each account_id in the request loads the account and checks its provider against the derived set. All offenders are collected and reported in a single error rather than failing fast — clients fix everything in one round-trip. - Plan not found → 404 with the plan ID. - Account not found → 404 with the account ID. - Mismatch(es) → 400 "plan provider mismatch: account "" has provider="", expected one of []; ..." (single line, parseable). - Empty services map → defensive skip of validation; production plans always have ≥1 service (frontend enforces this), and the test pins the behaviour so a future change is conscious. Mocks ----- - `MockConfigStore.GetPurchasePlan` now resolves to (in order): `GetPurchasePlanFn` override, registered testify expectation, or a default minimal `{ID: planID}` plan with empty Services. The default fallback lets pre-existing tests like `TestSetPlanAccounts_Success` keep working without setting up the new mock call — empty Services trips the defensive skip-validation branch. - `MockConfigStore.SetPlanAccounts` now uses `SetPlanAccountsFn` when set, so tests can capture and assert on the call (mismatch tests verify the underlying store write is NOT invoked on failure). Tests ----- Seven new tests in `handler_accounts_test.go`: - TestSetPlanAccounts_SingleMismatch — one Azure account vs aws plan → 400 - TestSetPlanAccounts_MultipleMismatches — Azure + GCP vs aws plan → 400 listing both - TestSetPlanAccounts_ValidHappyPath — AWS account vs aws plan → 200, store write captured - TestSetPlanAccounts_PlanNotFound — GetPurchasePlan returns (nil, nil) → 404 - TestSetPlanAccounts_AccountNotFound — GetCloudAccount returns (nil, nil) → 404 referencing the account ID - TestSetPlanAccounts_MixedValidAndMismatch — only the offender named in error; store write NOT called - TestSetPlanAccounts_EmptyServicesSkipsValidation — defensive behaviour pinned --- internal/api/handler_accounts.go | 101 +++++++++++ internal/api/handler_accounts_test.go | 247 ++++++++++++++++++++++++++ internal/api/mocks_test.go | 33 ++++ 3 files changed, 381 insertions(+) diff --git a/internal/api/handler_accounts.go b/internal/api/handler_accounts.go index 98eecf67..c1c4aa67 100644 --- a/internal/api/handler_accounts.go +++ b/internal/api/handler_accounts.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "fmt" + "slices" + "sort" "strings" "time" @@ -1037,7 +1039,100 @@ func (h *Handler) deleteAccountServiceOverride(ctx context.Context, req *events. return nil, nil } +// derivePlanProviders extracts the distinct set of providers a plan +// targets by parsing the keys of plan.Services (format "provider:service", +// e.g. "aws:ec2"). Returns a sorted slice for stable error messages. +// An empty result means the plan has no parseable services — production +// plans always carry at least one (frontend enforces this), so an empty +// return is a defensive case that signals to skip provider validation. +func derivePlanProviders(plan *config.PurchasePlan) []string { + if plan == nil { + return nil + } + seen := make(map[string]struct{}, len(plan.Services)) + for k := range plan.Services { + // Keys are "provider:service"; skip malformed keys rather than guess. + parts := strings.SplitN(k, ":", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + continue + } + seen[parts[0]] = struct{}{} + } + out := make([]string, 0, len(seen)) + for p := range seen { + out = append(out, p) + } + sort.Strings(out) + return out +} + +// validatePlanAccountProviders enforces the issue-#209 / spec E-4 rule +// that every account assigned to a plan must have its provider match +// one of the plan's derived providers. Returns: +// - 404 ClientError when the plan does not exist +// - 404 ClientError when an account_id does not exist (referencing +// the offending ID) +// - 400 ClientError listing every provider mismatch in one message +// (so clients fix everything in one round-trip rather than +// resubmitting to discover the next) +// - nil when all accounts match (or when the plan has no parseable +// services — defensive skip; production plans always carry at least +// one service, frontend enforces this) +// +// Pulled out of setPlanAccounts to keep that function under the gocyclo +// budget (limit 10). No business logic lives here that isn't otherwise +// described in setPlanAccounts' doc comment. +func (h *Handler) validatePlanAccountProviders(ctx context.Context, planID string, accountIDs []string) error { + plan, err := h.config.GetPurchasePlan(ctx, planID) + if err != nil { + return fmt.Errorf("accounts: failed to get plan: %w", err) + } + if plan == nil { + return NewClientError(404, fmt.Sprintf("plan not found: %s", planID)) + } + + expected := derivePlanProviders(plan) + if len(expected) == 0 { + return nil + } + + type mismatch struct { + ID string + Name string + Provider string + } + var mismatches []mismatch + for _, aid := range accountIDs { + acct, getErr := h.config.GetCloudAccount(ctx, aid) + if getErr != nil { + return fmt.Errorf("accounts: failed to get account %s: %w", aid, getErr) + } + if acct == nil { + return NewClientError(404, fmt.Sprintf("account not found: %s", aid)) + } + if !slices.Contains(expected, acct.Provider) { + mismatches = append(mismatches, mismatch{ID: aid, Name: acct.Name, Provider: acct.Provider}) + } + } + if len(mismatches) == 0 { + return nil + } + + parts := make([]string, len(mismatches)) + for i, m := range mismatches { + parts[i] = fmt.Sprintf("account %q has provider=%q, expected one of %v", + m.Name, m.Provider, expected) + } + return NewClientError(400, "plan provider mismatch: "+strings.Join(parts, "; ")) +} + // setPlanAccounts handles PUT /api/plans/:id/accounts. +// +// Per issue #209 / spec acceptance criterion E-4, every account assigned +// to a plan must have its provider match one of the plan's derived +// providers (extracted from plan.Services keys). Mismatches return 400 +// listing every offender; the assignment is rejected atomically (no +// partial writes). func (h *Handler) setPlanAccounts(ctx context.Context, httpReq *events.LambdaFunctionURLRequest, id string) (any, error) { if err := validateUUID(id); err != nil { return nil, err @@ -1060,6 +1155,12 @@ func (h *Handler) setPlanAccounts(ctx context.Context, httpReq *events.LambdaFun } } + // Provider-match validation (issue #209). Extracted to keep + // setPlanAccounts under the gocyclo budget (10). + if err := h.validatePlanAccountProviders(ctx, id, body.AccountIDs); err != nil { + return nil, err + } + if err := h.config.SetPlanAccounts(ctx, id, body.AccountIDs); err != nil { return nil, fmt.Errorf("accounts: %w", err) } diff --git a/internal/api/handler_accounts_test.go b/internal/api/handler_accounts_test.go index adb0c238..ac540447 100644 --- a/internal/api/handler_accounts_test.go +++ b/internal/api/handler_accounts_test.go @@ -485,6 +485,253 @@ func TestSetPlanAccounts_Success(t *testing.T) { assert.Nil(t, result) } +// ── Provider-validation tests for setPlanAccounts (issue #209) +// Every account assigned to a plan must have its provider match one of +// the providers derived from the plan's services map (key format +// "provider:service"). Mismatches return a single 400 listing every +// offender; the underlying store write is never invoked on failure. + +const ( + planID209 = "22222222-2222-2222-2222-222222222222" + awsAcct209 = "11111111-1111-1111-1111-111111111111" + azureAcct1 = "33333333-3333-3333-3333-333333333333" + azureAcct2 = "44444444-4444-4444-4444-444444444444" + missingAcct = "55555555-5555-5555-5555-555555555555" +) + +// awsPlan209 returns a plan whose services map yields a single derived +// provider ("aws"). Used as the default fixture across the mismatch +// tests below. +func awsPlan209() *config.PurchasePlan { + return &config.PurchasePlan{ + ID: planID209, + Name: "AWS-only plan", + Services: map[string]config.ServiceConfig{"aws:ec2": {}}, + } +} + +func TestSetPlanAccounts_SingleMismatch(t *testing.T) { + ctx := context.Background() + mockAuth := new(MockAuthService) + setupAdminAuth(ctx, mockAuth) + + setCalled := false + store := setupAdminMock(ctx) + store.GetPurchasePlanFn = func(_ context.Context, _ string) (*config.PurchasePlan, error) { + return awsPlan209(), nil + } + store.GetCloudAccountFn = func(_ context.Context, id string) (*config.CloudAccount, error) { + return &config.CloudAccount{ID: id, Name: "prod-azure", Provider: "azure"}, nil + } + store.SetPlanAccountsFn = func(_ context.Context, _ string, _ []string) error { + setCalled = true + return nil + } + handler := &Handler{auth: mockAuth, config: store} + + body := `{"account_ids":["` + azureAcct1 + `"]}` + _, err := handler.setPlanAccounts(ctx, adminRequest(body), planID209) + require.Error(t, err) + ce, ok := IsClientError(err) + require.True(t, ok, "expected a clientError, got %T", err) + assert.Equal(t, 400, ce.code) + assert.Contains(t, ce.Error(), "prod-azure") + assert.Contains(t, ce.Error(), "azure") + assert.Contains(t, ce.Error(), "aws") + assert.False(t, setCalled, "SetPlanAccounts must NOT be called when validation fails") +} + +func TestSetPlanAccounts_MultipleMismatches(t *testing.T) { + ctx := context.Background() + mockAuth := new(MockAuthService) + setupAdminAuth(ctx, mockAuth) + + setCalled := false + store := setupAdminMock(ctx) + store.GetPurchasePlanFn = func(_ context.Context, _ string) (*config.PurchasePlan, error) { + return awsPlan209(), nil + } + store.GetCloudAccountFn = func(_ context.Context, id string) (*config.CloudAccount, error) { + switch id { + case azureAcct1: + return &config.CloudAccount{ID: id, Name: "prod-azure", Provider: "azure"}, nil + case azureAcct2: + return &config.CloudAccount{ID: id, Name: "stage-gcp", Provider: "gcp"}, nil + } + return nil, nil + } + store.SetPlanAccountsFn = func(_ context.Context, _ string, _ []string) error { + setCalled = true + return nil + } + handler := &Handler{auth: mockAuth, config: store} + + body := `{"account_ids":["` + azureAcct1 + `","` + azureAcct2 + `"]}` + _, err := handler.setPlanAccounts(ctx, adminRequest(body), planID209) + require.Error(t, err) + ce, ok := IsClientError(err) + require.True(t, ok) + assert.Equal(t, 400, ce.code) + // Both offenders named in a single error so the client gets the full + // picture in one round-trip. + assert.Contains(t, ce.Error(), "prod-azure") + assert.Contains(t, ce.Error(), "stage-gcp") + assert.Contains(t, ce.Error(), "azure") + assert.Contains(t, ce.Error(), "gcp") + assert.False(t, setCalled, "SetPlanAccounts must NOT be called when validation fails") +} + +func TestSetPlanAccounts_ValidHappyPath(t *testing.T) { + ctx := context.Background() + mockAuth := new(MockAuthService) + setupAdminAuth(ctx, mockAuth) + + var capturedIDs []string + store := setupAdminMock(ctx) + store.GetPurchasePlanFn = func(_ context.Context, _ string) (*config.PurchasePlan, error) { + return awsPlan209(), nil + } + store.GetCloudAccountFn = func(_ context.Context, id string) (*config.CloudAccount, error) { + return &config.CloudAccount{ID: id, Name: "prod-aws", Provider: "aws"}, nil + } + store.SetPlanAccountsFn = func(_ context.Context, _ string, ids []string) error { + capturedIDs = ids + return nil + } + handler := &Handler{auth: mockAuth, config: store} + + body := `{"account_ids":["` + awsAcct209 + `"]}` + result, err := handler.setPlanAccounts(ctx, adminRequest(body), planID209) + require.NoError(t, err) + assert.Nil(t, result) + assert.Equal(t, []string{awsAcct209}, capturedIDs, "SetPlanAccounts should be called with the validated IDs") +} + +func TestSetPlanAccounts_PlanNotFound(t *testing.T) { + ctx := context.Background() + mockAuth := new(MockAuthService) + setupAdminAuth(ctx, mockAuth) + + setCalled := false + store := setupAdminMock(ctx) + store.GetPurchasePlanFn = func(_ context.Context, _ string) (*config.PurchasePlan, error) { + return nil, nil // store-style "not found": (nil, nil) + } + store.SetPlanAccountsFn = func(_ context.Context, _ string, _ []string) error { + setCalled = true + return nil + } + handler := &Handler{auth: mockAuth, config: store} + + body := `{"account_ids":["` + awsAcct209 + `"]}` + _, err := handler.setPlanAccounts(ctx, adminRequest(body), planID209) + require.Error(t, err) + ce, ok := IsClientError(err) + require.True(t, ok) + assert.Equal(t, 404, ce.code) + assert.Contains(t, ce.Error(), planID209) + assert.False(t, setCalled, "SetPlanAccounts must NOT be called when the plan is not found") +} + +func TestSetPlanAccounts_AccountNotFound(t *testing.T) { + ctx := context.Background() + mockAuth := new(MockAuthService) + setupAdminAuth(ctx, mockAuth) + + setCalled := false + store := setupAdminMock(ctx) + store.GetPurchasePlanFn = func(_ context.Context, _ string) (*config.PurchasePlan, error) { + return awsPlan209(), nil + } + store.GetCloudAccountFn = func(_ context.Context, id string) (*config.CloudAccount, error) { + if id == missingAcct { + return nil, nil // store-style not-found + } + return &config.CloudAccount{ID: id, Name: "prod-aws", Provider: "aws"}, nil + } + store.SetPlanAccountsFn = func(_ context.Context, _ string, _ []string) error { + setCalled = true + return nil + } + handler := &Handler{auth: mockAuth, config: store} + + body := `{"account_ids":["` + missingAcct + `"]}` + _, err := handler.setPlanAccounts(ctx, adminRequest(body), planID209) + require.Error(t, err) + ce, ok := IsClientError(err) + require.True(t, ok) + assert.Equal(t, 404, ce.code) + assert.Contains(t, ce.Error(), missingAcct, "404 should reference the missing account ID") + assert.False(t, setCalled, "SetPlanAccounts must NOT be called when an account is not found") +} + +func TestSetPlanAccounts_MixedValidAndMismatch(t *testing.T) { + ctx := context.Background() + mockAuth := new(MockAuthService) + setupAdminAuth(ctx, mockAuth) + + setCalled := false + store := setupAdminMock(ctx) + store.GetPurchasePlanFn = func(_ context.Context, _ string) (*config.PurchasePlan, error) { + return awsPlan209(), nil + } + store.GetCloudAccountFn = func(_ context.Context, id string) (*config.CloudAccount, error) { + switch id { + case awsAcct209: + return &config.CloudAccount{ID: id, Name: "prod-aws", Provider: "aws"}, nil + case azureAcct1: + return &config.CloudAccount{ID: id, Name: "prod-azure", Provider: "azure"}, nil + } + return nil, nil + } + store.SetPlanAccountsFn = func(_ context.Context, _ string, _ []string) error { + setCalled = true + return nil + } + handler := &Handler{auth: mockAuth, config: store} + + body := `{"account_ids":["` + awsAcct209 + `","` + azureAcct1 + `"]}` + _, err := handler.setPlanAccounts(ctx, adminRequest(body), planID209) + require.Error(t, err) + ce, ok := IsClientError(err) + require.True(t, ok) + assert.Equal(t, 400, ce.code) + // Only the Azure account is the offender; the AWS account does not + // appear in the error. + assert.Contains(t, ce.Error(), "prod-azure") + assert.NotContains(t, ce.Error(), "prod-aws") + assert.False(t, setCalled, "SetPlanAccounts must NOT be called when even one account fails validation") +} + +func TestSetPlanAccounts_EmptyServicesSkipsValidation(t *testing.T) { + ctx := context.Background() + mockAuth := new(MockAuthService) + setupAdminAuth(ctx, mockAuth) + + var capturedIDs []string + store := setupAdminMock(ctx) + // Plan with an empty services map — derived provider set is empty; + // the validation block skips and the assignment passes through. + // Pins the defensive behaviour so a future change is conscious. + store.GetPurchasePlanFn = func(_ context.Context, _ string) (*config.PurchasePlan, error) { + return &config.PurchasePlan{ID: planID209, Name: "no-services"}, nil + } + store.GetCloudAccountFn = func(_ context.Context, id string) (*config.CloudAccount, error) { + return &config.CloudAccount{ID: id, Name: "prod-azure", Provider: "azure"}, nil + } + store.SetPlanAccountsFn = func(_ context.Context, _ string, ids []string) error { + capturedIDs = ids + return nil + } + handler := &Handler{auth: mockAuth, config: store} + + body := `{"account_ids":["` + azureAcct1 + `"]}` + result, err := handler.setPlanAccounts(ctx, adminRequest(body), planID209) + require.NoError(t, err) + assert.Nil(t, result) + assert.Equal(t, []string{azureAcct1}, capturedIDs, "empty services map → validation skipped → write proceeds") +} + func TestListPlanAccounts_Success(t *testing.T) { ctx := context.Background() mockAuth := new(MockAuthService) diff --git a/internal/api/mocks_test.go b/internal/api/mocks_test.go index 37d29da5..60c819d3 100644 --- a/internal/api/mocks_test.go +++ b/internal/api/mocks_test.go @@ -27,6 +27,16 @@ type MockConfigStore struct { // CreateCloudAccountFn overrides CreateCloudAccount when non-nil (used by // org-discovery tests to capture the new rows the handler persists). CreateCloudAccountFn func(ctx context.Context, account *config.CloudAccount) error + // GetPurchasePlanFn overrides GetPurchasePlan when non-nil. Used by the + // setPlanAccounts provider-validation tests (issue #209) to seed the + // plan without registering a testify expectation — see the fall-through + // comment in GetPurchasePlan below for why the default no-expectation + // path returns a minimal stub instead of panicking via m.Called. + GetPurchasePlanFn func(ctx context.Context, planID string) (*config.PurchasePlan, error) + // SetPlanAccountsFn overrides SetPlanAccounts when non-nil. The + // provider-validation tests use it to assert whether the underlying + // store write was invoked (mismatched assignments must NOT call it). + SetPlanAccountsFn func(ctx context.Context, planID string, accountIDs []string) error } func (m *MockConfigStore) GetGlobalConfig(ctx context.Context) (*config.GlobalConfig, error) { @@ -68,7 +78,21 @@ func (m *MockConfigStore) CreatePurchasePlan(ctx context.Context, plan *config.P return args.Error(0) } +// GetPurchasePlan resolves to (in order): an explicit GetPurchasePlanFn +// override, a registered testify expectation, or a default minimal plan +// (`{ID: planID}` with empty Services). The default-fallback path lets +// tests written before the issue-#209 provider-validation block (e.g. +// TestSetPlanAccounts_Success) keep working without setting up the +// new mock call — the empty Services map trips the defensive "no +// parseable services, skip provider validation" branch in +// setPlanAccounts so behaviour is unchanged for those legacy tests. func (m *MockConfigStore) GetPurchasePlan(ctx context.Context, planID string) (*config.PurchasePlan, error) { + if m.GetPurchasePlanFn != nil { + return m.GetPurchasePlanFn(ctx, planID) + } + if !m.isExpected("GetPurchasePlan") { + return &config.PurchasePlan{ID: planID}, nil + } args := m.Called(ctx, planID) if args.Get(0) == nil { return nil, args.Error(1) @@ -292,7 +316,16 @@ func (m *MockConfigStore) DeleteAccountServiceOverride(ctx context.Context, acco func (m *MockConfigStore) ListAccountServiceOverrides(ctx context.Context, accountID string) ([]config.AccountServiceOverride, error) { return nil, nil } + +// SetPlanAccounts uses SetPlanAccountsFn when non-nil so tests can +// capture and assert on the call (the issue-#209 mismatch tests verify +// the underlying store write is NOT invoked when validation fails). +// The default no-op preserves the previous behaviour for tests that +// don't care. func (m *MockConfigStore) SetPlanAccounts(ctx context.Context, planID string, accountIDs []string) error { + if m.SetPlanAccountsFn != nil { + return m.SetPlanAccountsFn(ctx, planID, accountIDs) + } return nil } func (m *MockConfigStore) GetPlanAccounts(ctx context.Context, planID string) ([]config.CloudAccount, error) { From 70b8c54a831954007de4af11405716e167150545 Mon Sep 17 00:00:00 2001 From: Cristian Magherusan-Stanciu Date: Sun, 3 May 2026 13:53:39 +0200 Subject: [PATCH 2/5] =?UTF-8?q?fix(api):=20address=20CR=20feedback=20on=20?= =?UTF-8?q?PR#228=20=E2=80=94=20delimiter=20bug=20+=20mock=20fallback?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - derivePlanProviders: fix key delimiter ":" → "/" to match buildServiceConfig (CR actionable: provider extraction was silently skipping all services, disabling validation; keys are "provider/service" not "provider:service") - awsPlan209() test fixture: update key to "aws/ec2" to match the corrected delimiter; previously the test passed only because both code and fixture used the wrong format - MockConfigStore.SetPlanAccounts: add isExpected/m.Called fallback (CR nitpick: testify .On("SetPlanAccounts",...) expectations were ignored) --- internal/api/handler_accounts.go | 13 +++++++------ internal/api/handler_accounts_test.go | 5 +++-- internal/api/mocks_test.go | 8 ++++++-- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/internal/api/handler_accounts.go b/internal/api/handler_accounts.go index c1c4aa67..51691b53 100644 --- a/internal/api/handler_accounts.go +++ b/internal/api/handler_accounts.go @@ -1040,8 +1040,9 @@ func (h *Handler) deleteAccountServiceOverride(ctx context.Context, req *events. } // derivePlanProviders extracts the distinct set of providers a plan -// targets by parsing the keys of plan.Services (format "provider:service", -// e.g. "aws:ec2"). Returns a sorted slice for stable error messages. +// targets by parsing the keys of plan.Services (format "provider/service", +// e.g. "aws/ec2" — produced by buildServiceConfig). Returns a sorted slice +// for stable error messages. // An empty result means the plan has no parseable services — production // plans always carry at least one (frontend enforces this), so an empty // return is a defensive case that signals to skip provider validation. @@ -1051,12 +1052,12 @@ func derivePlanProviders(plan *config.PurchasePlan) []string { } seen := make(map[string]struct{}, len(plan.Services)) for k := range plan.Services { - // Keys are "provider:service"; skip malformed keys rather than guess. - parts := strings.SplitN(k, ":", 2) - if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + // Keys are "provider/service"; skip malformed keys rather than guess. + p, _, ok := strings.Cut(k, "/") + if !ok || p == "" { continue } - seen[parts[0]] = struct{}{} + seen[p] = struct{}{} } out := make([]string, 0, len(seen)) for p := range seen { diff --git a/internal/api/handler_accounts_test.go b/internal/api/handler_accounts_test.go index ac540447..4f2d43fc 100644 --- a/internal/api/handler_accounts_test.go +++ b/internal/api/handler_accounts_test.go @@ -500,13 +500,14 @@ const ( ) // awsPlan209 returns a plan whose services map yields a single derived -// provider ("aws"). Used as the default fixture across the mismatch +// provider ("aws"). Key format is "provider/service" as produced by +// buildServiceConfig. Used as the default fixture across the mismatch // tests below. func awsPlan209() *config.PurchasePlan { return &config.PurchasePlan{ ID: planID209, Name: "AWS-only plan", - Services: map[string]config.ServiceConfig{"aws:ec2": {}}, + Services: map[string]config.ServiceConfig{"aws/ec2": {}}, } } diff --git a/internal/api/mocks_test.go b/internal/api/mocks_test.go index 60c819d3..ce569313 100644 --- a/internal/api/mocks_test.go +++ b/internal/api/mocks_test.go @@ -320,12 +320,16 @@ func (m *MockConfigStore) ListAccountServiceOverrides(ctx context.Context, accou // SetPlanAccounts uses SetPlanAccountsFn when non-nil so tests can // capture and assert on the call (the issue-#209 mismatch tests verify // the underlying store write is NOT invoked when validation fails). -// The default no-op preserves the previous behaviour for tests that -// don't care. +// Falls back to m.Called when a testify expectation is registered so +// .On("SetPlanAccounts", ...) works correctly. The no-op is preserved +// only when neither path applies (tests that don't care). func (m *MockConfigStore) SetPlanAccounts(ctx context.Context, planID string, accountIDs []string) error { if m.SetPlanAccountsFn != nil { return m.SetPlanAccountsFn(ctx, planID, accountIDs) } + if m.isExpected("SetPlanAccounts") { + return m.Called(ctx, planID, accountIDs).Error(0) + } return nil } func (m *MockConfigStore) GetPlanAccounts(ctx context.Context, planID string) ([]config.CloudAccount, error) { From 951707bdfc4e7d011f75b6fd8850d18801fc1678 Mon Sep 17 00:00:00 2001 From: Cristian Magherusan-Stanciu Date: Sun, 3 May 2026 16:54:37 +0200 Subject: [PATCH 3/5] fix(config): validate plan accounts atomically Move plan provider derivation into config so the API and store share the same provider/service parsing rule. Recheck account providers inside the Postgres SetPlanAccounts transaction before deleting or inserting plan_accounts, locking the plan and account rows that determine validity. Add pgxmock coverage that a provider mismatch rolls back before any plan_accounts mutation. --- internal/api/handler_accounts.go | 31 +------ internal/config/purchase_plan_providers.go | 29 +++++++ .../config/purchase_plan_providers_test.go | 22 +++++ internal/config/store_postgres.go | 87 +++++++++++++++++++ .../config/store_postgres_pgxmock_test.go | 40 ++++++++- 5 files changed, 178 insertions(+), 31 deletions(-) create mode 100644 internal/config/purchase_plan_providers.go create mode 100644 internal/config/purchase_plan_providers_test.go diff --git a/internal/api/handler_accounts.go b/internal/api/handler_accounts.go index 51691b53..534bb66d 100644 --- a/internal/api/handler_accounts.go +++ b/internal/api/handler_accounts.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "slices" - "sort" "strings" "time" @@ -1039,34 +1038,6 @@ func (h *Handler) deleteAccountServiceOverride(ctx context.Context, req *events. return nil, nil } -// derivePlanProviders extracts the distinct set of providers a plan -// targets by parsing the keys of plan.Services (format "provider/service", -// e.g. "aws/ec2" — produced by buildServiceConfig). Returns a sorted slice -// for stable error messages. -// An empty result means the plan has no parseable services — production -// plans always carry at least one (frontend enforces this), so an empty -// return is a defensive case that signals to skip provider validation. -func derivePlanProviders(plan *config.PurchasePlan) []string { - if plan == nil { - return nil - } - seen := make(map[string]struct{}, len(plan.Services)) - for k := range plan.Services { - // Keys are "provider/service"; skip malformed keys rather than guess. - p, _, ok := strings.Cut(k, "/") - if !ok || p == "" { - continue - } - seen[p] = struct{}{} - } - out := make([]string, 0, len(seen)) - for p := range seen { - out = append(out, p) - } - sort.Strings(out) - return out -} - // validatePlanAccountProviders enforces the issue-#209 / spec E-4 rule // that every account assigned to a plan must have its provider match // one of the plan's derived providers. Returns: @@ -1092,7 +1063,7 @@ func (h *Handler) validatePlanAccountProviders(ctx context.Context, planID strin return NewClientError(404, fmt.Sprintf("plan not found: %s", planID)) } - expected := derivePlanProviders(plan) + expected := config.DerivePlanProviders(plan) if len(expected) == 0 { return nil } diff --git a/internal/config/purchase_plan_providers.go b/internal/config/purchase_plan_providers.go new file mode 100644 index 00000000..455867bc --- /dev/null +++ b/internal/config/purchase_plan_providers.go @@ -0,0 +1,29 @@ +package config + +import ( + "sort" + "strings" +) + +// DerivePlanProviders extracts the distinct set of providers a plan targets +// by parsing the keys of plan.Services. Keys are expected to use the +// "provider/service" format produced by buildServiceConfig. +func DerivePlanProviders(plan *PurchasePlan) []string { + if plan == nil { + return nil + } + seen := make(map[string]struct{}, len(plan.Services)) + for k := range plan.Services { + provider, _, ok := strings.Cut(k, "/") + if !ok || provider == "" { + continue + } + seen[provider] = struct{}{} + } + providers := make([]string, 0, len(seen)) + for provider := range seen { + providers = append(providers, provider) + } + sort.Strings(providers) + return providers +} diff --git a/internal/config/purchase_plan_providers_test.go b/internal/config/purchase_plan_providers_test.go new file mode 100644 index 00000000..ed5f0fd3 --- /dev/null +++ b/internal/config/purchase_plan_providers_test.go @@ -0,0 +1,22 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDerivePlanProviders(t *testing.T) { + plan := &PurchasePlan{ + Services: map[string]ServiceConfig{ + "aws/ec2": {}, + "aws/rds": {}, + "azure/compute": {}, + "gcp:compute": {}, + "malformed": {}, + }, + } + + assert.Equal(t, []string{"aws", "azure"}, DerivePlanProviders(plan)) + assert.Nil(t, DerivePlanProviders(nil)) +} diff --git a/internal/config/store_postgres.go b/internal/config/store_postgres.go index 36282bfc..11bcdcfc 100644 --- a/internal/config/store_postgres.go +++ b/internal/config/store_postgres.go @@ -1928,6 +1928,10 @@ func (s *PostgresStore) SetPlanAccounts(ctx context.Context, planID string, acco } defer tx.Rollback(ctx) //nolint:errcheck + if err = s.validatePlanAccountProvidersTx(ctx, tx, planID, accountIDs); err != nil { + return err + } + if _, err = tx.Exec(ctx, `DELETE FROM plan_accounts WHERE plan_id = $1`, planID); err != nil { return fmt.Errorf("failed to clear plan accounts: %w", err) } @@ -1947,6 +1951,89 @@ func (s *PostgresStore) SetPlanAccounts(ctx context.Context, planID string, acco return nil } +func (s *PostgresStore) validatePlanAccountProvidersTx(ctx context.Context, tx pgx.Tx, planID string, accountIDs []string) error { + if len(accountIDs) == 0 { + return nil + } + + services, err := s.getPlanServicesForShareTx(ctx, tx, planID) + if err != nil { + return err + } + expected := DerivePlanProviders(&PurchasePlan{Services: services}) + if len(expected) == 0 { + return nil + } + + mismatches, err := s.findPlanAccountProviderMismatchesTx(ctx, tx, accountIDs, expected) + if err != nil { + return err + } + if len(mismatches) == 0 { + return nil + } + + parts := make([]string, len(mismatches)) + for i, mismatch := range mismatches { + parts[i] = fmt.Sprintf("account %q has provider=%q, expected one of %v", + mismatch.Name, mismatch.Provider, expected) + } + return fmt.Errorf("plan provider mismatch: %s", strings.Join(parts, "; ")) +} + +func (s *PostgresStore) getPlanServicesForShareTx(ctx context.Context, tx pgx.Tx, planID string) (map[string]ServiceConfig, error) { + var servicesJSON []byte + if err := tx.QueryRow(ctx, ` + SELECT services + FROM purchase_plans + WHERE id = $1 + FOR SHARE + `, planID).Scan(&servicesJSON); err != nil { + if err == pgx.ErrNoRows { + return nil, fmt.Errorf("plan not found: %s", planID) + } + return nil, fmt.Errorf("failed to get plan services: %w", err) + } + + services := make(map[string]ServiceConfig) + if err := json.Unmarshal(servicesJSON, &services); err != nil { + return nil, fmt.Errorf("failed to decode plan services: %w", err) + } + return services, nil +} + +type planAccountProviderMismatch struct { + Name string + Provider string +} + +func (s *PostgresStore) findPlanAccountProviderMismatchesTx(ctx context.Context, tx pgx.Tx, accountIDs []string, expected []string) ([]planAccountProviderMismatch, error) { + expectedSet := make(map[string]struct{}, len(expected)) + for _, provider := range expected { + expectedSet[provider] = struct{}{} + } + + var mismatches []planAccountProviderMismatch + for _, accountID := range accountIDs { + var name, provider string + if err := tx.QueryRow(ctx, ` + SELECT name, provider + FROM cloud_accounts + WHERE id = $1 + FOR SHARE + `, accountID).Scan(&name, &provider); err != nil { + if err == pgx.ErrNoRows { + return nil, fmt.Errorf("account not found: %s", accountID) + } + return nil, fmt.Errorf("failed to get account %s: %w", accountID, err) + } + if _, ok := expectedSet[provider]; !ok { + mismatches = append(mismatches, planAccountProviderMismatch{Name: name, Provider: provider}) + } + } + return mismatches, nil +} + // GetPlanAccounts returns all cloud accounts associated with a plan. func (s *PostgresStore) GetPlanAccounts(ctx context.Context, planID string) ([]CloudAccount, error) { query := ` diff --git a/internal/config/store_postgres_pgxmock_test.go b/internal/config/store_postgres_pgxmock_test.go index f0038583..95f2270a 100644 --- a/internal/config/store_postgres_pgxmock_test.go +++ b/internal/config/store_postgres_pgxmock_test.go @@ -1121,8 +1121,21 @@ func TestPGXMock_SetPlanAccounts_Success(t *testing.T) { mock := newMock(t) store := storeWith(mock) ctx := context.Background() + servicesJSON, err := json.Marshal(map[string]ServiceConfig{ + "aws/ec2": {Provider: "aws", Service: "ec2"}, + }) + require.NoError(t, err) mock.ExpectBegin() + mock.ExpectQuery("SELECT services").WithArgs("plan-1").WillReturnRows( + pgxmock.NewRows([]string{"services"}).AddRow(servicesJSON), + ) + mock.ExpectQuery("SELECT name, provider").WithArgs("acct-1").WillReturnRows( + pgxmock.NewRows([]string{"name", "provider"}).AddRow("Account 1", "aws"), + ) + mock.ExpectQuery("SELECT name, provider").WithArgs("acct-2").WillReturnRows( + pgxmock.NewRows([]string{"name", "provider"}).AddRow("Account 2", "aws"), + ) mock.ExpectExec("DELETE FROM plan_accounts").WithArgs(pgxmock.AnyArg()). WillReturnResult(pgxmock.NewResult("DELETE", 0)) mock.ExpectExec("INSERT INTO plan_accounts"). @@ -1133,8 +1146,33 @@ func TestPGXMock_SetPlanAccounts_Success(t *testing.T) { WillReturnResult(pgxmock.NewResult("INSERT", 1)) mock.ExpectCommit() - err := store.SetPlanAccounts(ctx, "plan-1", []string{"acct-1", "acct-2"}) + err = store.SetPlanAccounts(ctx, "plan-1", []string{"acct-1", "acct-2"}) + require.NoError(t, err) + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPGXMock_SetPlanAccounts_ProviderMismatchRollsBackBeforeDelete(t *testing.T) { + mock := newMock(t) + store := storeWith(mock) + ctx := context.Background() + servicesJSON, err := json.Marshal(map[string]ServiceConfig{ + "aws/ec2": {Provider: "aws", Service: "ec2"}, + }) require.NoError(t, err) + + mock.ExpectBegin() + mock.ExpectQuery("SELECT services").WithArgs("plan-1").WillReturnRows( + pgxmock.NewRows([]string{"services"}).AddRow(servicesJSON), + ) + mock.ExpectQuery("SELECT name, provider").WithArgs("acct-1").WillReturnRows( + pgxmock.NewRows([]string{"name", "provider"}).AddRow("Azure Account", "azure"), + ) + mock.ExpectRollback() + + err = store.SetPlanAccounts(ctx, "plan-1", []string{"acct-1"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "plan provider mismatch") + assert.Contains(t, err.Error(), "Azure Account") assert.NoError(t, mock.ExpectationsWereMet()) } From c566d9504da9b9ca9e423beffe32d2d5deefb8c4 Mon Sep 17 00:00:00 2001 From: Cristian Magherusan-Stanciu Date: Sun, 3 May 2026 17:07:02 +0200 Subject: [PATCH 4/5] fix(api): map atomic account validation misses to 404 Add a config-store ErrNotFound sentinel and use it for Postgres plan and account lookups in the atomic SetPlanAccounts validation path. The accounts handler now maps those store-level not-found races to 404s. Also validate the plan row before accepting empty account assignments so clearing a nonexistent plan cannot silently succeed. --- internal/api/handler_accounts.go | 25 +++++++++++++---- internal/api/handler_accounts_test.go | 27 +++++++++++++++++++ internal/config/errors.go | 6 +++++ internal/config/store_postgres.go | 18 ++++++------- .../config/store_postgres_pgxmock_test.go | 24 ++++++++++++++++- 5 files changed, 85 insertions(+), 15 deletions(-) create mode 100644 internal/config/errors.go diff --git a/internal/api/handler_accounts.go b/internal/api/handler_accounts.go index 534bb66d..255130e6 100644 --- a/internal/api/handler_accounts.go +++ b/internal/api/handler_accounts.go @@ -3,6 +3,7 @@ package api import ( "context" "encoding/json" + "errors" "fmt" "slices" "strings" @@ -1055,12 +1056,9 @@ func (h *Handler) deleteAccountServiceOverride(ctx context.Context, req *events. // budget (limit 10). No business logic lives here that isn't otherwise // described in setPlanAccounts' doc comment. func (h *Handler) validatePlanAccountProviders(ctx context.Context, planID string, accountIDs []string) error { - plan, err := h.config.GetPurchasePlan(ctx, planID) + plan, err := h.getPlanForAccountProviderValidation(ctx, planID) if err != nil { - return fmt.Errorf("accounts: failed to get plan: %w", err) - } - if plan == nil { - return NewClientError(404, fmt.Sprintf("plan not found: %s", planID)) + return err } expected := config.DerivePlanProviders(plan) @@ -1098,6 +1096,20 @@ func (h *Handler) validatePlanAccountProviders(ctx context.Context, planID strin return NewClientError(400, "plan provider mismatch: "+strings.Join(parts, "; ")) } +func (h *Handler) getPlanForAccountProviderValidation(ctx context.Context, planID string) (*config.PurchasePlan, error) { + plan, err := h.config.GetPurchasePlan(ctx, planID) + if err != nil { + if errors.Is(err, config.ErrNotFound) { + return nil, NewClientError(404, fmt.Sprintf("plan not found: %s", planID)) + } + return nil, fmt.Errorf("accounts: failed to get plan: %w", err) + } + if plan == nil { + return nil, NewClientError(404, fmt.Sprintf("plan not found: %s", planID)) + } + return plan, nil +} + // setPlanAccounts handles PUT /api/plans/:id/accounts. // // Per issue #209 / spec acceptance criterion E-4, every account assigned @@ -1134,6 +1146,9 @@ func (h *Handler) setPlanAccounts(ctx context.Context, httpReq *events.LambdaFun } if err := h.config.SetPlanAccounts(ctx, id, body.AccountIDs); err != nil { + if errors.Is(err, config.ErrNotFound) { + return nil, NewClientError(404, err.Error()) + } return nil, fmt.Errorf("accounts: %w", err) } diff --git a/internal/api/handler_accounts_test.go b/internal/api/handler_accounts_test.go index 4f2d43fc..9336dc20 100644 --- a/internal/api/handler_accounts_test.go +++ b/internal/api/handler_accounts_test.go @@ -3,6 +3,7 @@ package api import ( "context" "errors" + "fmt" "testing" "github.com/LeanerCloud/CUDly/internal/accounts" @@ -608,6 +609,32 @@ func TestSetPlanAccounts_ValidHappyPath(t *testing.T) { assert.Equal(t, []string{awsAcct209}, capturedIDs, "SetPlanAccounts should be called with the validated IDs") } +func TestSetPlanAccounts_StoreNotFoundAfterValidationMapsTo404(t *testing.T) { + ctx := context.Background() + mockAuth := new(MockAuthService) + setupAdminAuth(ctx, mockAuth) + + store := setupAdminMock(ctx) + store.GetPurchasePlanFn = func(_ context.Context, _ string) (*config.PurchasePlan, error) { + return awsPlan209(), nil + } + store.GetCloudAccountFn = func(_ context.Context, id string) (*config.CloudAccount, error) { + return &config.CloudAccount{ID: id, Name: "prod-aws", Provider: "aws"}, nil + } + store.SetPlanAccountsFn = func(_ context.Context, _ string, _ []string) error { + return fmt.Errorf("%w: account %s", config.ErrNotFound, awsAcct209) + } + handler := &Handler{auth: mockAuth, config: store} + + body := `{"account_ids":["` + awsAcct209 + `"]}` + _, err := handler.setPlanAccounts(ctx, adminRequest(body), planID209) + require.Error(t, err) + ce, ok := IsClientError(err) + require.True(t, ok) + assert.Equal(t, 404, ce.code) + assert.Contains(t, ce.Error(), awsAcct209) +} + func TestSetPlanAccounts_PlanNotFound(t *testing.T) { ctx := context.Background() mockAuth := new(MockAuthService) diff --git a/internal/config/errors.go b/internal/config/errors.go new file mode 100644 index 00000000..8ce0cf27 --- /dev/null +++ b/internal/config/errors.go @@ -0,0 +1,6 @@ +package config + +import "errors" + +// ErrNotFound is returned when a requested config-store row does not exist. +var ErrNotFound = errors.New("not found") diff --git a/internal/config/store_postgres.go b/internal/config/store_postgres.go index 11bcdcfc..5144a6dc 100644 --- a/internal/config/store_postgres.go +++ b/internal/config/store_postgres.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "strings" "time" @@ -446,7 +447,7 @@ func (s *PostgresStore) GetPurchasePlan(ctx context.Context, planID string) (*Pu if err != nil { if err == pgx.ErrNoRows { - return nil, fmt.Errorf("purchase plan not found: %s", planID) + return nil, fmt.Errorf("%w: purchase plan %s", ErrNotFound, planID) } return nil, fmt.Errorf("failed to get purchase plan: %w", err) } @@ -1952,14 +1953,13 @@ func (s *PostgresStore) SetPlanAccounts(ctx context.Context, planID string, acco } func (s *PostgresStore) validatePlanAccountProvidersTx(ctx context.Context, tx pgx.Tx, planID string, accountIDs []string) error { - if len(accountIDs) == 0 { - return nil - } - services, err := s.getPlanServicesForShareTx(ctx, tx, planID) if err != nil { return err } + if len(accountIDs) == 0 { + return nil + } expected := DerivePlanProviders(&PurchasePlan{Services: services}) if len(expected) == 0 { return nil @@ -1989,8 +1989,8 @@ func (s *PostgresStore) getPlanServicesForShareTx(ctx context.Context, tx pgx.Tx WHERE id = $1 FOR SHARE `, planID).Scan(&servicesJSON); err != nil { - if err == pgx.ErrNoRows { - return nil, fmt.Errorf("plan not found: %s", planID) + if errors.Is(err, pgx.ErrNoRows) { + return nil, fmt.Errorf("%w: plan %s", ErrNotFound, planID) } return nil, fmt.Errorf("failed to get plan services: %w", err) } @@ -2022,8 +2022,8 @@ func (s *PostgresStore) findPlanAccountProviderMismatchesTx(ctx context.Context, WHERE id = $1 FOR SHARE `, accountID).Scan(&name, &provider); err != nil { - if err == pgx.ErrNoRows { - return nil, fmt.Errorf("account not found: %s", accountID) + if errors.Is(err, pgx.ErrNoRows) { + return nil, fmt.Errorf("%w: account %s", ErrNotFound, accountID) } return nil, fmt.Errorf("failed to get account %s: %w", accountID, err) } diff --git a/internal/config/store_postgres_pgxmock_test.go b/internal/config/store_postgres_pgxmock_test.go index 95f2270a..175aa635 100644 --- a/internal/config/store_postgres_pgxmock_test.go +++ b/internal/config/store_postgres_pgxmock_test.go @@ -1180,16 +1180,38 @@ func TestPGXMock_SetPlanAccounts_Empty(t *testing.T) { mock := newMock(t) store := storeWith(mock) ctx := context.Background() + servicesJSON, err := json.Marshal(map[string]ServiceConfig{ + "aws/ec2": {Provider: "aws", Service: "ec2"}, + }) + require.NoError(t, err) mock.ExpectBegin() + mock.ExpectQuery("SELECT services").WithArgs("plan-1").WillReturnRows( + pgxmock.NewRows([]string{"services"}).AddRow(servicesJSON), + ) mock.ExpectExec("DELETE FROM plan_accounts").WithArgs(pgxmock.AnyArg()). WillReturnResult(pgxmock.NewResult("DELETE", 0)) mock.ExpectCommit() - err := store.SetPlanAccounts(ctx, "plan-1", []string{}) + err = store.SetPlanAccounts(ctx, "plan-1", []string{}) require.NoError(t, err) } +func TestPGXMock_SetPlanAccounts_EmptyMissingPlanReturnsNotFound(t *testing.T) { + mock := newMock(t) + store := storeWith(mock) + ctx := context.Background() + + mock.ExpectBegin() + mock.ExpectQuery("SELECT services").WithArgs("missing-plan").WillReturnError(pgx.ErrNoRows) + mock.ExpectRollback() + + err := store.SetPlanAccounts(ctx, "missing-plan", []string{}) + require.ErrorIs(t, err, ErrNotFound) + assert.Contains(t, err.Error(), "missing-plan") + assert.NoError(t, mock.ExpectationsWereMet()) +} + func TestPGXMock_SetPlanAccounts_BeginError(t *testing.T) { mock := newMock(t) store := storeWith(mock) From 50f1af9165ec43993d3f520a00ba3080dd25852f Mon Sep 17 00:00:00 2001 From: Cristian Magherusan-Stanciu Date: Sun, 3 May 2026 17:16:34 +0200 Subject: [PATCH 5/5] test(config): assert empty plan account expectations Add the missing pgxmock expectation assertion to the empty SetPlanAccounts test so unmet SQL expectations are surfaced. --- internal/config/store_postgres_pgxmock_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/config/store_postgres_pgxmock_test.go b/internal/config/store_postgres_pgxmock_test.go index 175aa635..78139b5d 100644 --- a/internal/config/store_postgres_pgxmock_test.go +++ b/internal/config/store_postgres_pgxmock_test.go @@ -1195,6 +1195,7 @@ func TestPGXMock_SetPlanAccounts_Empty(t *testing.T) { err = store.SetPlanAccounts(ctx, "plan-1", []string{}) require.NoError(t, err) + assert.NoError(t, mock.ExpectationsWereMet()) } func TestPGXMock_SetPlanAccounts_EmptyMissingPlanReturnsNotFound(t *testing.T) {