diff --git a/pkg/activator/throttler.go b/pkg/activator/throttler.go index 35412deb8b7d..c0910e2ebb09 100644 --- a/pkg/activator/throttler.go +++ b/pkg/activator/throttler.go @@ -76,8 +76,7 @@ func (t *Throttler) UpdateCapacity(rev RevisionID, size int32) error { return err } breaker, _ := t.getOrCreateBreaker(rev) - err = t.updateCapacity(revision, breaker, size) - return err + return t.updateCapacity(revision, breaker, size) } // Try potentially registers a new breaker in our bookkeeping @@ -92,7 +91,7 @@ func (t *Throttler) Try(rev RevisionID, function func()) error { return err } } - if ok := breaker.Maybe(function); !ok { + if !breaker.Maybe(function) { return errors.New(OverloadMessage) } return nil diff --git a/pkg/activator/throttler_test.go b/pkg/activator/throttler_test.go index 49cd2a24c8a2..00a9b85cf5ab 100644 --- a/pkg/activator/throttler_test.go +++ b/pkg/activator/throttler_test.go @@ -19,23 +19,23 @@ package activator import ( "errors" "testing" + "time" + + "golang.org/x/sync/errgroup" . "github.com/knative/pkg/logging/testing" testinghelper "github.com/knative/serving/pkg/activator/testing" "github.com/knative/serving/pkg/apis/serving/v1alpha1" v1alpha12 "github.com/knative/serving/pkg/apis/serving/v1alpha1" "github.com/knative/serving/pkg/queue" + "go.uber.org/zap" corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "go.uber.org/zap" ) var ( revID = RevisionID{"good-namespace", "good-name"} - sampleError = "some error" - existingRevisionGetter = func(concurrency v1alpha1.RevisionContainerConcurrencyType) func(RevisionID) (*v1alpha12.Revision, error) { return func(RevisionID) (*v1alpha12.Revision, error) { revision := &v1alpha12.Revision{Spec: v1alpha12.RevisionSpec{ContainerConcurrency: concurrency}} @@ -43,10 +43,8 @@ var ( } } nonExistingRevisionGetter = func(RevisionID) (*v1alpha12.Revision, error) { - revision := &v1alpha12.Revision{} - return revision, errors.New(sampleError) + return nil, errors.New(sampleError) } - initCapacity = int32(0) existingEndpointsGetter = func(RevisionID) (int32, error) { return initCapacity, nil } @@ -57,6 +55,8 @@ var ( const ( defaultMaxConcurrency = int32(10) + initCapacity = int32(0) + sampleError = "some error" ) func TestThrottler_UpdateCapacity(t *testing.T) { @@ -89,20 +89,20 @@ func TestThrottler_UpdateCapacity(t *testing.T) { }} for _, s := range samples { t.Run(s.label, func(t *testing.T) { - want := s.want throttler := getThrottler(s.maxConcurrency, s.revisionGetter, s.endpointsGetter, TestLogger(t), initCapacity) err := throttler.UpdateCapacity(revID, 1) if s.wantError != "" { - received := err.Error() - if received != s.wantError { - t.Errorf("Expected error in Update capacity. Want %s, got %s", s.wantError, err.Error()) + if err == nil { + t.Fatal("Expected error, got nil") + } + if got := err.Error(); got != s.wantError { + t.Errorf("Update capacity error message = %s, want: %s", got, s.wantError) } } - if want > 0 { + if s.want > 0 { breaker, _ := throttler.breakers[revID] - got := breaker.Capacity() - if got != want { - t.Errorf("Unexpected capacity of the breaker. Want %d, got %d", want, got) + if got := breaker.Capacity(); got != s.want { + t.Errorf("Breaker Capacity = %d, want: %d", got, s.want) } } }) @@ -121,7 +121,6 @@ func TestThrottler_Try(t *testing.T) { label: "all good", addCapacity: true, wantBreakers: int32(1), - wantError: "", revisionGetter: existingRevisionGetter(10), endpointsGetter: existingEndpointsGetter, }, { @@ -143,7 +142,8 @@ func TestThrottler_Try(t *testing.T) { t.Run(s.label, func(t *testing.T) { var got int32 want := s.wantBreakers - throttler := getThrottler(defaultMaxConcurrency, s.revisionGetter, s.endpointsGetter, TestLogger(t), initCapacity) + throttler := getThrottler( + defaultMaxConcurrency, s.revisionGetter, s.endpointsGetter, TestLogger(t), initCapacity) if s.addCapacity { throttler.UpdateCapacity(revID, 1) } @@ -151,18 +151,54 @@ func TestThrottler_Try(t *testing.T) { got++ }) if s.wantError != "" { - received := err.Error() - if received != s.wantError { - t.Errorf("Expected error in the Try. Want %s, got %s", s.wantError, received) + if err == nil { + t.Fatal("Expected error got nil") + } + + if got := err.Error(); got != s.wantError { + t.Errorf("Try error = %s, want: %s", got, s.wantError) } } if got != want { - t.Errorf("Unexpected number of function runs in Try. Want %d, got %d", want, got) + t.Errorf("Unexpected number of function runs in Try = %d, want: %d", got, want) } }) } } +func TestThrottler_TryOverload(t *testing.T) { + th := getThrottler( + 1 /*maxConcurrency*/, existingRevisionGetter(10), existingEndpointsGetter, TestLogger(t), + 1 /*initial capacity*/) + done := make(chan struct{}) + + // We have two slots to fill. + var g errgroup.Group + for i := 0; i < 2; i++ { + g.Go(func() error { + return th.Try(revID, func() { + select { + case <-done: + } + }) + }) + } + // Give the chance for the goroutines to launch. + time.Sleep(150 * time.Millisecond) + err := th.Try(revID, func() { + t.Fatal("This should not have executed") + }) + // `err` must be non-nil here, since `t.Fatal()` above would ensure we + // don't reach here on success. + if got := err.Error(); got != OverloadMessage { + t.Errorf("Error message = %q, want: %q", got, OverloadMessage) + } + close(done) + if err := g.Wait(); err != nil { + t.Errorf("Error in the parallel requests: %v", err) + } +} + func TestUpdateEndpoints(t *testing.T) { samples := []struct { label string @@ -206,14 +242,14 @@ func TestUpdateEndpoints(t *testing.T) { throttler.breakers[revID] = queue.NewBreaker(throttler.breakerParams) updater := UpdateEndpoints(throttler) - endpointsBefore := corev1.Endpoints{ObjectMeta: v1.ObjectMeta{Name: revID.Name + "-service", Namespace: revID.Namespace}, Subsets: testinghelper.GetTestEndpointsSubset(s.endpointBefore, 1)} - endpointsAfter := corev1.Endpoints{ObjectMeta: v1.ObjectMeta{Name: revID.Name + "-service", Namespace: revID.Namespace}, Subsets: testinghelper.GetTestEndpointsSubset(s.endpointsAfter, 1)} + endpointsBefore := corev1.Endpoints{ObjectMeta: metav1.ObjectMeta{Name: revID.Name + "-service", Namespace: revID.Namespace}, Subsets: testinghelper.GetTestEndpointsSubset(s.endpointBefore, 1)} + endpointsAfter := corev1.Endpoints{ObjectMeta: metav1.ObjectMeta{Name: revID.Name + "-service", Namespace: revID.Namespace}, Subsets: testinghelper.GetTestEndpointsSubset(s.endpointsAfter, 1)} updater(&endpointsBefore, &endpointsAfter) breaker, _ := throttler.breakers[revID] got := breaker.Capacity() if got != s.wantCapacity { - t.Errorf("Unexpected Breaker capacity received. Want %d, got %d", s.wantCapacity, got) + t.Errorf("Breaker capacity = %d, want: %d", got, s.wantCapacity) } } } @@ -221,14 +257,14 @@ func TestUpdateEndpoints(t *testing.T) { func TestThrottler_Remove(t *testing.T) { throttler := getThrottler(defaultMaxConcurrency, existingRevisionGetter(10), existingEndpointsGetter, TestLogger(t), initCapacity) throttler.breakers[revID] = queue.NewBreaker(throttler.breakerParams) - got := len(throttler.breakers) - if got != 1 { - t.Errorf("Unexpected number of Breakers was created. Want %d, got %d", 1, got) + + if got := len(throttler.breakers); got != 1 { + t.Errorf("Number of Breakers created = %d, want: 1", got) } throttler.Remove(revID) - got = len(throttler.breakers) - if got != 0 { - t.Errorf("Unexpected number of Breakers was created. Want %d, got %d", 0, got) + + if got := len(throttler.breakers); got != 0 { + t.Errorf("Number of Breakers created = %d, want: %d", got, 0) } } @@ -242,18 +278,20 @@ func TestHelper_DeleteBreaker(t *testing.T) { } revID := RevisionID{Namespace: revID.Namespace, Name: revID.Name} throttler.breakers[revID] = queue.NewBreaker(throttler.breakerParams) - if len(throttler.breakers) != 1 { - t.Errorf("Breaker map size didn't change. Wanted %d, got %d", 1, len(throttler.breakers)) + if got := len(throttler.breakers); got != 1 { + t.Errorf("Breaker map size got %d, want: 1", got) } - deleter := DeleteBreaker(throttler) - deleter(revision) + DeleteBreaker(throttler)(revision) if len(throttler.breakers) != 0 { t.Error("Breaker map is not empty") } } -func getThrottler(maxConcurrency int32, revisionGetter func(RevisionID) (*v1alpha12.Revision, error), endpointsGetter func(RevisionID) (int32, error), logger *zap.SugaredLogger, initCapacity int32) *Throttler { - params := queue.BreakerParams{QueueDepth: 10, MaxConcurrency: maxConcurrency, InitialCapacity: initCapacity} +func getThrottler( + maxConcurrency int32, revisionGetter func(RevisionID) (*v1alpha12.Revision, error), + endpointsGetter func(RevisionID) (int32, error), logger *zap.SugaredLogger, + initCapacity int32) *Throttler { + params := queue.BreakerParams{QueueDepth: 1, MaxConcurrency: maxConcurrency, InitialCapacity: initCapacity} throttlerParams := ThrottlerParams{BreakerParams: params, Logger: logger, GetRevision: revisionGetter, GetEndpoints: endpointsGetter} return NewThrottler(throttlerParams) } diff --git a/pkg/queue/breaker.go b/pkg/queue/breaker.go index 3276ab3befbf..86da17e4fc7d 100644 --- a/pkg/queue/breaker.go +++ b/pkg/queue/breaker.go @@ -78,7 +78,6 @@ func NewBreaker(params BreakerParams) *Breaker { // already consumed, Maybe returns immediately without calling thunk. If // the thunk was executed, Maybe returns true, else false. func (b *Breaker) Maybe(thunk func()) bool { - var t token select { default: