diff --git a/lib/pubsub/grpcclient.go b/lib/pubsub/grpcclient.go index 031cf902..d7178a04 100644 --- a/lib/pubsub/grpcclient.go +++ b/lib/pubsub/grpcclient.go @@ -147,7 +147,7 @@ func (c *PubSubClient) Subscribe(channel string, replayPreset proto.ReplayPreset subscribeClient, err := c.pubSubClient.Subscribe(ctx) if err != nil { - if isAuthError(subscribeClient.Trailer().Get("error-code")) { + if subscribeClient != nil && isAuthError(subscribeClient.Trailer().Get("error-code")) { return replayId, SessionExpiredError } return replayId, err diff --git a/lib/pubsub/grpcclient_test.go b/lib/pubsub/grpcclient_test.go index a104d1b6..6fa39682 100644 --- a/lib/pubsub/grpcclient_test.go +++ b/lib/pubsub/grpcclient_test.go @@ -4,6 +4,9 @@ import ( "encoding/json" "reflect" "testing" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func TestFlattenAvroUnions(t *testing.T) { @@ -214,3 +217,128 @@ func TestFlattenAvroUnions_EdgeCases(t *testing.T) { } }) } + +func TestIsAuthError(t *testing.T) { + tests := []struct { + name string + values []string + expected bool + }{ + { + name: "returns_true_for_auth_error_code", + values: []string{"sfdc.platform.eventbus.grpc.service.auth.error"}, + expected: true, + }, + { + name: "returns_true_when_auth_error_in_list", + values: []string{"other.error", "sfdc.platform.eventbus.grpc.service.auth.error", "another.error"}, + expected: true, + }, + { + name: "returns_false_for_empty_list", + values: []string{}, + expected: false, + }, + { + name: "returns_false_for_nil_list", + values: nil, + expected: false, + }, + { + name: "returns_false_for_other_errors", + values: []string{"some.other.error", "another.error"}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isAuthError(tt.values) + if result != tt.expected { + t.Errorf("isAuthError(%v) = %v, expected %v", tt.values, result, tt.expected) + } + }) + } +} + +func TestIsInvalidReplayIdError(t *testing.T) { + tests := []struct { + name string + values []string + expected bool + }{ + { + name: "returns_true_for_replay_id_error_code", + values: []string{"sfdc.platform.eventbus.grpc.subscription.fetch.replayid.corrupted"}, + expected: true, + }, + { + name: "returns_false_for_empty_list", + values: []string{}, + expected: false, + }, + { + name: "returns_false_for_nil_list", + values: nil, + expected: false, + }, + { + name: "returns_false_for_other_errors", + values: []string{"some.other.error"}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isInvalidReplayIdError(tt.values) + if result != tt.expected { + t.Errorf("isInvalidReplayIdError(%v) = %v, expected %v", tt.values, result, tt.expected) + } + }) + } +} + +func TestGRPCUnavailableErrorDetection(t *testing.T) { + tests := []struct { + name string + err error + isUnavail bool + description string + }{ + { + name: "detects_unavailable_error", + err: status.Error(codes.Unavailable, "connection refused"), + isUnavail: true, + description: "gRPC Unavailable error should be detected", + }, + { + name: "does_not_match_other_grpc_errors", + err: status.Error(codes.Internal, "internal error"), + isUnavail: false, + description: "Other gRPC errors should not match Unavailable", + }, + { + name: "does_not_match_cancelled_error", + err: status.Error(codes.Canceled, "context canceled"), + isUnavail: false, + description: "Canceled error should not match Unavailable", + }, + { + name: "handles_nil_error", + err: nil, + isUnavail: false, + description: "nil error should not match Unavailable", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s, ok := status.FromError(tt.err) + isUnavailable := ok && s.Code() == codes.Unavailable + if isUnavailable != tt.isUnavail { + t.Errorf("%s: got %v, expected %v", tt.description, isUnavailable, tt.isUnavail) + } + }) + } +} diff --git a/lib/pubsub/subscribe.go b/lib/pubsub/subscribe.go index ff7836e9..94d2f2e3 100644 --- a/lib/pubsub/subscribe.go +++ b/lib/pubsub/subscribe.go @@ -9,6 +9,8 @@ import ( . "github.com/ForceCLI/force/lib" "github.com/pkg/errors" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "github.com/ForceCLI/force/lib/pubsub/proto" ) @@ -73,6 +75,9 @@ func Subscribe(f *Force, channel string, replayId string, replayPreset proto.Rep if err == InvalidReplayIdError { return errors.Wrap(err, fmt.Sprintf("could not subscribe starting at replay id: %s", base64.StdEncoding.EncodeToString(curReplayId))) } + if s, ok := status.FromError(err); ok && s.Code() == codes.Unavailable { + return errors.Wrap(err, "server unavailable") + } if err != nil { Log.Info(fmt.Sprintf("error occurred while subscribing to topic: %v", err)) }