Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions internal/scheduler/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ import (
"github.com/LeanerCloud/CUDly/internal/oidc"
"github.com/LeanerCloud/CUDly/internal/purchase"
"github.com/LeanerCloud/CUDly/pkg/common"
"github.com/LeanerCloud/CUDly/pkg/concurrency"
"github.com/LeanerCloud/CUDly/pkg/logging"
"github.com/LeanerCloud/CUDly/pkg/provider"
azureprovider "github.com/LeanerCloud/CUDly/providers/azure"
gcpprovider "github.com/LeanerCloud/CUDly/providers/gcp"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
)

// SchedulerConfig holds configuration for the scheduler
Expand Down Expand Up @@ -168,6 +170,22 @@ func (s *Scheduler) CollectRecommendations(ctx context.Context) (*CollectResult,
return nil, err
}

// Attach a shared semaphore to ctx so every leaf goroutine in the
// recommendations-collection fan-out tree (AWS service, Azure service,
// GCP region×service) acquires one slot before issuing its cloud-API
// call and releases it after. This bounds aggregate concurrent IO across
// the whole tree at CUDLY_MAX_PARALLELISM (default 20) regardless of how
// nested the dispatch is — without it, peak concurrency multiplies
// through the nested fan-outs and can exhaust Lambda memory before work
// completes (observed with a 512 MB function in dev). Intermediate
// dispatchers (provider, account, GCP region) do NOT acquire — they only
// launch sub-goroutines — so no goroutine can deadlock by holding a
// permit while waiting for sub-permits.
maxParallelism := concurrency.MaxParallelismFromEnv()
sem := semaphore.NewWeighted(int64(maxParallelism))
ctx = concurrency.WithSharedSemaphore(ctx, sem)
logging.Infof("Recommendations collection: aggregate parallelism cap = %d", maxParallelism)

// Collect recommendations from each enabled provider concurrently, tracking
// per-provider outcomes so persistence can scope stale-row eviction to
// (provider, account) pairs that actually ran. A partial collection (e.g.
Expand Down
16 changes: 12 additions & 4 deletions internal/scheduler/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1176,7 +1176,11 @@ func TestScheduler_ListRecommendations_ColdStartSync(t *testing.T) {
// global config. Return no enabled providers so the collect is a
// no-op but still runs the persistence path.
mockStore.On("GetGlobalConfig", ctx).Return(&config.GlobalConfig{EnabledProviders: []string{}}, nil)
mockStore.On("UpsertRecommendations", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil)
// UpsertRecommendations runs inside CollectRecommendations, after the
// shared-semaphore is attached to ctx; the wrapped ctx is what reaches
// the persistence layer. mock.Anything keeps the assertion resilient
// to that wrap.
mockStore.On("UpsertRecommendations", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
mockStore.On("ListStoredRecommendations", ctx, mock.Anything).
Return([]config.RecommendationRecord{}, nil)

Expand All @@ -1189,7 +1193,7 @@ func TestScheduler_ListRecommendations_ColdStartSync(t *testing.T) {
// CollectRecommendations, so seeing it prove the sync collect fired
// before the store read.
mockStore.AssertCalled(t, "GetGlobalConfig", ctx)
mockStore.AssertCalled(t, "UpsertRecommendations", ctx, mock.Anything, mock.Anything, mock.Anything)
mockStore.AssertCalled(t, "UpsertRecommendations", mock.Anything, mock.Anything, mock.Anything, mock.Anything)
}

// fanOutPerAccount bounds parallel in-flight calls to
Expand Down Expand Up @@ -1718,7 +1722,11 @@ func TestScheduler_CollectRecommendations_WithSuccessfulRecs(t *testing.T) {
// contract test cover the cancellation path).
mockProvider.On("GetRecommendationsClient", mock.Anything).Return(mockRecClient, nil)
mockRecClient.On("GetAllRecommendations", mock.Anything).Return(recommendations, nil)
mockEmail.On("SendNewRecommendationsNotification", ctx, mock.AnythingOfType("email.NotificationData")).Return(nil)
// SendNewRecommendationsNotification fires inside CollectRecommendations,
// after ctx has been wrapped via concurrency.WithSharedSemaphore — the
// wrapped ctx is what reaches the email sender. mock.Anything keeps the
// assertion resilient to that wrap.
mockEmail.On("SendNewRecommendationsNotification", mock.Anything, mock.AnythingOfType("email.NotificationData")).Return(nil)

scheduler := &Scheduler{
config: mockStore,
Expand All @@ -1733,7 +1741,7 @@ func TestScheduler_CollectRecommendations_WithSuccessfulRecs(t *testing.T) {
assert.Equal(t, 1, result.Recommendations)
assert.Equal(t, 500.0, result.TotalSavings)

mockEmail.AssertCalled(t, "SendNewRecommendationsNotification", ctx, mock.AnythingOfType("email.NotificationData"))
mockEmail.AssertCalled(t, "SendNewRecommendationsNotification", mock.Anything, mock.AnythingOfType("email.NotificationData"))
}

// Test AWS recommendations fallback to GetRecommendations when GetAllRecommendations returns empty
Expand Down
87 changes: 87 additions & 0 deletions pkg/concurrency/concurrency.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Package concurrency provides a shared global parallelism cap for the
// recommendations-collection fan-out tree.
//
// The fan-out has up to four nested levels (provider → account → service|region
// → per-region service). Each level was independently capped, so peak goroutine
// counts multiplied through the tree (3 providers × 20 accounts × 30 regions ×
// 2 services = thousands of in-flight gRPC/HTTP clients). On a 512 MB Lambda
// that exhausted memory before the work could finish.
//
// A single semaphore stashed on the context lets every leaf goroutine — the
// goroutine that issues the actual cloud-API call — acquire one slot before
// doing IO and release it after, so the aggregate concurrent IO count is hard-
// bounded regardless of nesting depth. Intermediate dispatchers (provider,
// account, GCP region) do NOT acquire — they only launch sub-goroutines — so
// no goroutine can deadlock by holding a permit while waiting for sub-permits.
//
// If no semaphore is attached to the context (e.g. unit tests, ambient calls
// from CLI tools), Acquire and Release are no-ops; callers don't need to
// branch on whether the semaphore is set.
package concurrency

import (
"context"
"os"
"strconv"

"golang.org/x/sync/semaphore"
)

// DefaultMaxParallelism is the default cap on aggregate concurrent leaf
// goroutines across the recommendations-collection fan-out tree. Override at
// runtime with CUDLY_MAX_PARALLELISM.
const DefaultMaxParallelism = 20

// MaxParallelismFromEnv reads CUDLY_MAX_PARALLELISM and returns its
// positive-integer value, falling back to DefaultMaxParallelism on unset /
// invalid / non-positive values.
func MaxParallelismFromEnv() int {
if v := os.Getenv("CUDLY_MAX_PARALLELISM"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
return DefaultMaxParallelism
}

type ctxKey struct{}

// WithSharedSemaphore returns a context carrying sem. Goroutines spawned from
// this context (or any descendant) can acquire/release slots via Acquire and
// Release. If sem is nil the context is returned unchanged.
func WithSharedSemaphore(ctx context.Context, sem *semaphore.Weighted) context.Context {
if sem == nil {
return ctx
}
return context.WithValue(ctx, ctxKey{}, sem)
}

// SharedSemaphore returns the semaphore stashed in ctx, or nil if none.
func SharedSemaphore(ctx context.Context) *semaphore.Weighted {
sem, _ := ctx.Value(ctxKey{}).(*semaphore.Weighted)
return sem
}

// Acquire blocks until a slot is available on the shared semaphore in ctx and
// returns nil. Returns ctx.Err() if the wait is cancelled. If no semaphore is
// attached to ctx, Acquire is a no-op and returns nil immediately — leaf
// callers can use it unconditionally without checking.
func Acquire(ctx context.Context) error {
sem := SharedSemaphore(ctx)
if sem == nil {
return nil
}
return sem.Acquire(ctx, 1)
}

// Release returns one slot to the shared semaphore in ctx. Always pair with a
// successful Acquire (return value nil); calling Release after a cancelled
// Acquire would corrupt the slot count. If no semaphore is attached to ctx,
// Release is a no-op.
func Release(ctx context.Context) {
sem := SharedSemaphore(ctx)
if sem == nil {
return
}
sem.Release(1)
}
136 changes: 136 additions & 0 deletions pkg/concurrency/concurrency_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package concurrency

import (
"context"
"os"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/semaphore"
)

// TestMaxParallelismFromEnv pins the env-knob parser semantics for
// CUDLY_MAX_PARALLELISM.
func TestMaxParallelismFromEnv(t *testing.T) {
cases := []struct {
name string
env string
want int
}{
{"unset returns default", "", DefaultMaxParallelism},
{"positive integer overrides", "50", 50},
{"non-numeric falls back to default", "many", DefaultMaxParallelism},
{"zero falls back to default", "0", DefaultMaxParallelism},
{"negative falls back to default", "-3", DefaultMaxParallelism},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Setenv("CUDLY_MAX_PARALLELISM", tc.env)
assert.Equal(t, tc.want, MaxParallelismFromEnv())
})
}

t.Run("explicit unset returns default", func(t *testing.T) {
os.Unsetenv("CUDLY_MAX_PARALLELISM")
assert.Equal(t, DefaultMaxParallelism, MaxParallelismFromEnv())
})
}

// TestSharedSemaphore_NoSemaphoreOnContext verifies Acquire/Release are
// no-ops when no semaphore is attached — the documented contract that lets
// CLI tools and unit tests skip the semaphore entirely without per-call
// branching.
func TestSharedSemaphore_NoSemaphoreOnContext(t *testing.T) {
ctx := context.Background()
assert.Nil(t, SharedSemaphore(ctx))
require.NoError(t, Acquire(ctx))
Release(ctx) // must not panic
}

// TestSharedSemaphore_WithNilSemaphore verifies WithSharedSemaphore returns
// the input ctx unchanged when sem is nil — defensive against accidental
// nil passes.
func TestSharedSemaphore_WithNilSemaphore(t *testing.T) {
ctx := context.Background()
assert.Equal(t, ctx, WithSharedSemaphore(ctx, nil))
}

// TestSharedSemaphore_BoundsConcurrency is the load-bearing contract test:
// with a cap of 3, 20 goroutines all calling Acquire/work/Release must
// never see more than 3 in-flight concurrently. Asserts peak concurrency
// observed via atomics.
func TestSharedSemaphore_BoundsConcurrency(t *testing.T) {
const cap = 3
const goroutines = 20
sem := semaphore.NewWeighted(cap)
ctx := WithSharedSemaphore(context.Background(), sem)

var inflight, peak atomic.Int32
updatePeak := func(cur int32) {
for {
p := peak.Load()
if cur <= p || peak.CompareAndSwap(p, cur) {
return
}
}
}

// Workers must never call require.* / FailNow on a non-test goroutine —
// testify's contract is that those land on the test's own goroutine
// (otherwise the failure mechanism uses runtime.Goexit on the worker
// instead of stopping the test, which can hang or skip cleanup). Each
// worker captures its Acquire result on a buffered channel and the main
// goroutine asserts after wg.Wait(). Release is only deferred on a
// successful Acquire — the documented pairing contract.
var wg sync.WaitGroup
errCh := make(chan error, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if err := Acquire(ctx); err != nil {
errCh <- err
return
}
defer Release(ctx)
errCh <- nil
cur := inflight.Add(1)
updatePeak(cur)
time.Sleep(2 * time.Millisecond) // make overlap observable
inflight.Add(-1)
}()
}
wg.Wait()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
close(errCh)
for err := range errCh {
require.NoError(t, err)
}

assert.LessOrEqual(t, peak.Load(), int32(cap),
"peak concurrent in-flight goroutines must not exceed semaphore cap")
assert.GreaterOrEqual(t, peak.Load(), int32(2),
"with %d goroutines and cap %d, peak should reach at least 2 (proves goroutines genuinely overlapped)",
goroutines, cap)
}

// TestSharedSemaphore_AcquireRespectsCancellation verifies Acquire returns
// ctx.Err() when the parent ctx is cancelled while waiting for a slot.
// Without this, a cancelled refresh would leak a goroutine parked
// indefinitely on Acquire.
func TestSharedSemaphore_AcquireRespectsCancellation(t *testing.T) {
sem := semaphore.NewWeighted(1)
// Pre-occupy the only slot so the second Acquire must wait.
require.NoError(t, sem.Acquire(context.Background(), 1))
defer sem.Release(1)

ctx, cancel := context.WithCancel(WithSharedSemaphore(context.Background(), sem))
cancel() // cancel before Acquire even starts

err := Acquire(ctx)
require.Error(t, err)
assert.ErrorIs(t, err, context.Canceled)
}
10 changes: 4 additions & 6 deletions pkg/go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
module github.com/LeanerCloud/CUDly/pkg

go 1.23

toolchain go1.24.4
go 1.25.0

// This module contains shared types, provider interfaces, and the exchange package.
// The exchange package has AWS SDK dependencies for RI exchange operations.
Expand All @@ -12,7 +10,10 @@ require (
github.com/aws/aws-sdk-go-v2/config v1.26.2
github.com/aws/aws-sdk-go-v2/service/ec2 v1.251.2
github.com/aws/aws-sdk-go-v2/service/sts v1.26.6
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.11.1
golang.org/x/sync v0.20.0
gopkg.in/yaml.v3 v3.0.1
)

require (
Expand All @@ -27,8 +28,5 @@ require (
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.5 // indirect
github.com/aws/smithy-go v1.24.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
4 changes: 2 additions & 2 deletions pkg/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down
Loading