diff --git a/shortcuts/mail/mail_watch.go b/shortcuts/mail/mail_watch.go index d06cb47fc..4b11ff5e0 100644 --- a/shortcuts/mail/mail_watch.go +++ b/shortcuts/mail/mail_watch.go @@ -18,6 +18,7 @@ import ( "sort" "strings" "sync" + "sync/atomic" "syscall" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" @@ -50,6 +51,50 @@ func (l *mailWatchLogger) Error(_ context.Context, args ...interface{}) { var _ larkcore.Logger = (*mailWatchLogger)(nil) +func handleMailWatchSignal(errOut io.Writer, sig os.Signal, eventCount int, unsubscribe func() error, stopSignals func(), cancel context.CancelFunc) { + fmt.Fprintf(errOut, "\nShutting down (signal: %v)... (received %d events)\n", sig, eventCount) + fmt.Fprintln(errOut, "Unsubscribing mailbox events...") + if unsubErr := unsubscribe(); unsubErr != nil { + fmt.Fprintf(errOut, "Warning: unsubscribe failed: %v\n", unsubErr) + } else { + fmt.Fprintln(errOut, "Mailbox unsubscribed.") + } + if stopSignals != nil { + stopSignals() + } + if cancel != nil { + cancel() + } +} + +func newMailWatchUnsubscribeOnce(unsubscribe func() error) func() error { + var once sync.Once + var unsubscribeErr error + return func() error { + once.Do(func() { + if unsubscribe != nil { + unsubscribeErr = unsubscribe() + } + }) + return unsubscribeErr + } +} + +func waitForMailWatchStart(startErrCh <-chan error, shutdownRequested <-chan struct{}, shutdownComplete <-chan struct{}) error { + select { + case <-shutdownComplete: + return nil + case err := <-startErrCh: + select { + case <-shutdownRequested: + <-shutdownComplete + return nil + default: + return err + } + } +} + const mailEventType = "mail.user_mailbox.event.message_received_v1" // promptInjectionPatterns lists known prompt injection trigger phrases. @@ -272,7 +317,7 @@ var MailWatch = common.Shortcut{ mailboxFilter = resolved } - eventCount := 0 + var eventCount atomic.Int64 handleEvent := func(data map[string]interface{}) { // Extract event body @@ -338,7 +383,7 @@ var MailWatch = common.Shortcut{ } } - eventCount++ + eventCount.Add(1) // Prompt injection detection: warn when email body contains known injection patterns. // Body fields may be base64url-encoded; decode before scanning. @@ -425,29 +470,41 @@ var MailWatch = common.Shortcut{ larkws.WithLogger(sdkLogger), ) + unsubscribeOnce := newMailWatchUnsubscribeOnce(unsubscribe) + + watchCtx, cancel := context.WithCancel(ctx) + defer cancel() + defer unsubscribeOnce() //nolint:errcheck // best-effort cleanup on all exit paths + sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + stopSignals := func() { signal.Stop(sigCh) } + defer stopSignals() + + startErrCh := make(chan error, 1) + go func() { + startErrCh <- cli.Start(watchCtx) + }() + + shutdownRequested := make(chan struct{}) + shutdownComplete := make(chan struct{}) go func() { defer func() { if r := recover(); r != nil { fmt.Fprintf(errOut, "panic in signal handler: %v\n", r) } }() - <-sigCh - info(fmt.Sprintf("\nShutting down... (received %d events)", eventCount)) - info("Unsubscribing mailbox events...") - if unsubErr := unsubscribe(); unsubErr != nil { - fmt.Fprintf(errOut, "Warning: unsubscribe failed: %v\n", unsubErr) - } else { - info("Mailbox unsubscribed.") + sig, ok := <-sigCh + if !ok { + return } - signal.Stop(sigCh) - os.Exit(0) + close(shutdownRequested) + handleMailWatchSignal(errOut, sig, int(eventCount.Load()), unsubscribeOnce, stopSignals, cancel) + close(shutdownComplete) }() info("Connected. Waiting for mail events... (Ctrl+C to stop)") - if err := cli.Start(ctx); err != nil { - unsubscribe() //nolint:errcheck // best-effort cleanup + if err := waitForMailWatchStart(startErrCh, shutdownRequested, shutdownComplete); err != nil { return output.ErrNetwork("WebSocket connection failed: %v", err) } return nil diff --git a/shortcuts/mail/mail_watch_test.go b/shortcuts/mail/mail_watch_test.go index 4dcade091..56040d793 100644 --- a/shortcuts/mail/mail_watch_test.go +++ b/shortcuts/mail/mail_watch_test.go @@ -8,6 +8,8 @@ import ( "context" "encoding/base64" "encoding/json" + "errors" + "os" "strings" "testing" @@ -264,6 +266,107 @@ func TestMailWatchLoggerSuppressesDebugAlways(t *testing.T) { } } +func TestHandleMailWatchSignal_UnsubscribesAndCancels(t *testing.T) { + var buf bytes.Buffer + unsubscribed := false + stopped := false + canceled := false + + handleMailWatchSignal(&buf, os.Interrupt, 3, func() error { + unsubscribed = true + return nil + }, func() { + stopped = true + }, func() { + canceled = true + }) + + if !unsubscribed { + t.Fatal("expected unsubscribe to be called") + } + if !stopped { + t.Fatal("expected signal stop to be called") + } + if !canceled { + t.Fatal("expected cancel to be called") + } + out := buf.String() + if !strings.Contains(out, "Shutting down (signal: interrupt)... (received 3 events)") { + t.Fatalf("missing shutdown message, got: %q", out) + } + if !strings.Contains(out, "Mailbox unsubscribed.") { + t.Fatalf("missing unsubscribe success message, got: %q", out) + } +} + +func TestHandleMailWatchSignal_ReportsUnsubscribeFailure(t *testing.T) { + var buf bytes.Buffer + + handleMailWatchSignal(&buf, os.Interrupt, 1, func() error { + return errors.New("boom") + }, nil, nil) + + if got := buf.String(); !strings.Contains(got, "Warning: unsubscribe failed: boom") { + t.Fatalf("expected unsubscribe warning, got: %q", got) + } +} + +func TestNewMailWatchUnsubscribeOnce_OnlyRunsOnce(t *testing.T) { + calls := 0 + unsubscribeOnce := newMailWatchUnsubscribeOnce(func() error { + calls++ + return errors.New("boom") + }) + + if err := unsubscribeOnce(); err == nil || err.Error() != "boom" { + t.Fatalf("expected first unsubscribe error, got %v", err) + } + if err := unsubscribeOnce(); err == nil || err.Error() != "boom" { + t.Fatalf("expected cached unsubscribe error, got %v", err) + } + if calls != 1 { + t.Fatalf("expected unsubscribe to run once, got %d", calls) + } +} + +func TestWaitForMailWatchStart_ReturnsNilWhenSignalShutdownWinsRace(t *testing.T) { + startErrCh := make(chan error, 1) + shutdownRequested := make(chan struct{}) + shutdownComplete := make(chan struct{}) + + close(shutdownRequested) + startErrCh <- context.Canceled + + done := make(chan error, 1) + go func() { + done <- waitForMailWatchStart(startErrCh, shutdownRequested, shutdownComplete) + }() + + select { + case err := <-done: + t.Fatalf("expected waitForMailWatchStart to block until shutdown completes, got %v", err) + default: + } + + close(shutdownComplete) + + if err := <-done; err != nil { + t.Fatalf("expected nil after shutdown completion, got %v", err) + } +} + +func TestWaitForMailWatchStart_ReturnsStartErrorWithoutSignal(t *testing.T) { + startErrCh := make(chan error, 1) + shutdownRequested := make(chan struct{}) + shutdownComplete := make(chan struct{}) + + startErrCh <- context.Canceled + + if err := waitForMailWatchStart(startErrCh, shutdownRequested, shutdownComplete); !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } +} + func TestDecodeBodyFieldsForFileDecodesMessageWrapper(t *testing.T) { htmlEncoded := base64.URLEncoding.EncodeToString([]byte("