diff --git a/shortcuts/mail/mail_watch.go b/shortcuts/mail/mail_watch.go index c56994270..a318b1990 100644 --- a/shortcuts/mail/mail_watch.go +++ b/shortcuts/mail/mail_watch.go @@ -18,6 +18,7 @@ import ( "sort" "strings" "sync" + "sync/atomic" "syscall" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" @@ -76,6 +77,33 @@ func detectPromptInjection(content string) bool { return false } +func waitForMailWatchShutdown(startErrCh <-chan error, shutdownBySignal <-chan struct{}) error { + select { + case <-shutdownBySignal: + return nil + case err := <-startErrCh: + select { + case <-shutdownBySignal: + return nil + default: + } + if err == nil || errors.Is(err, context.Canceled) { + return nil + } + return err + } +} + +func finalizeMailWatchCleanup(runErr, cleanupErr error) error { + if cleanupErr == nil { + return runErr + } + if runErr != nil { + return runErr + } + return cleanupErr +} + var MailWatch = common.Shortcut{ Service: "mail", Command: "+watch", @@ -162,7 +190,7 @@ var MailWatch = common.Shortcut{ } return d }, - Execute: func(ctx context.Context, runtime *common.RuntimeContext) error { + Execute: func(ctx context.Context, runtime *common.RuntimeContext) (retErr error) { if runtime.Bool("print-output-schema") { printWatchOutputSchema(runtime) return nil @@ -259,19 +287,34 @@ var MailWatch = common.Shortcut{ }) return unsubErr } - + var unsubscribeLogOnce sync.Once + unsubscribeWithLog := func(primaryErr error) error { + unsubscribeLogOnce.Do(func() { + info("Unsubscribing mailbox events...") + if err := unsubscribe(); err != nil { + if primaryErr != nil { + fmt.Fprintf(errOut, "Warning: unsubscribe failed during cleanup: %v\n", err) + } + } else { + info("Mailbox unsubscribed.") + } + }) + return unsubErr + } + defer func() { + retErr = finalizeMailWatchCleanup(retErr, unsubscribeWithLog(retErr)) + }() // Resolve "me" to the actual email address so we can filter events. mailboxFilter := mailbox if mailbox == "me" { resolved, profileErr := fetchMailboxPrimaryEmail(runtime, "me") if profileErr != nil { - unsubscribe() //nolint:errcheck // best-effort cleanup; primary error is profileErr return enhanceProfileError(profileErr) } mailboxFilter = resolved } - eventCount := 0 + var eventCount int64 handleEvent := func(data map[string]interface{}) { // Extract event body @@ -337,7 +380,7 @@ var MailWatch = common.Shortcut{ } } - eventCount++ + atomic.AddInt64(&eventCount, 1) // Prompt injection detection: warn when email body contains known injection patterns. // Body fields may be base64url-encoded; decode before scanning. @@ -426,27 +469,37 @@ var MailWatch = common.Shortcut{ sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + defer signal.Stop(sigCh) + + watchCtx, cancelWatch := context.WithCancel(ctx) + defer cancelWatch() + + shutdownBySignal := make(chan struct{}) go func() { defer func() { if r := recover(); r != nil { fmt.Fprintf(errOut, "panic in signal handler: %v\n", r) } }() - <-sigCh - info(fmt.Sprintf("\nShutting down... (received %d events)", eventCount)) - info("Unsubscribing mailbox events...") - if unsubErr := unsubscribe(); unsubErr != nil { - fmt.Fprintf(errOut, "Warning: unsubscribe failed: %v\n", unsubErr) - } else { - info("Mailbox unsubscribed.") + select { + case <-sigCh: + // Restore default signal behavior so a second Ctrl+C can force terminate. + signal.Stop(sigCh) + signal.Reset(os.Interrupt, syscall.SIGTERM) + info(fmt.Sprintf("\nShutting down... (received %d events)", atomic.LoadInt64(&eventCount))) + close(shutdownBySignal) + cancelWatch() + case <-watchCtx.Done(): + return } - signal.Stop(sigCh) - os.Exit(0) }() info("Connected. Waiting for mail events... (Ctrl+C to stop)") - if err := cli.Start(ctx); err != nil { - unsubscribe() //nolint:errcheck // best-effort cleanup + startErrCh := make(chan error, 1) + go func() { + startErrCh <- cli.Start(watchCtx) + }() + if err := waitForMailWatchShutdown(startErrCh, shutdownBySignal); err != nil { return output.ErrNetwork("WebSocket connection failed: %v", err) } return nil diff --git a/shortcuts/mail/mail_watch_test.go b/shortcuts/mail/mail_watch_test.go index 02476fbdf..8139b3c22 100644 --- a/shortcuts/mail/mail_watch_test.go +++ b/shortcuts/mail/mail_watch_test.go @@ -10,6 +10,7 @@ import ( "encoding/json" "strings" "testing" + "time" "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/output" @@ -264,6 +265,65 @@ func TestMailWatchLoggerSuppressesDebugAlways(t *testing.T) { } } +func TestWaitForMailWatchShutdownReturnsOnSignalWithoutStartResult(t *testing.T) { + startErrCh := make(chan error) + shutdownBySignal := make(chan struct{}) + close(shutdownBySignal) + + done := make(chan error, 1) + go func() { + done <- waitForMailWatchShutdown(startErrCh, shutdownBySignal) + }() + + select { + case err := <-done: + if err != nil { + t.Fatalf("expected nil on signal shutdown, got %v", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatal("waitForMailWatchShutdown blocked after signal shutdown") + } +} + +func TestWaitForMailWatchShutdownReturnsNilOnContextCanceled(t *testing.T) { + startErrCh := make(chan error, 1) + shutdownBySignal := make(chan struct{}) + startErrCh <- context.Canceled + + if err := waitForMailWatchShutdown(startErrCh, shutdownBySignal); err != nil { + t.Fatalf("expected nil for context.Canceled, got %v", err) + } +} + +func TestWaitForMailWatchShutdownReturnsStartError(t *testing.T) { + startErrCh := make(chan error, 1) + shutdownBySignal := make(chan struct{}) + want := assertErr("boom") + startErrCh <- want + + if err := waitForMailWatchShutdown(startErrCh, shutdownBySignal); err != want { + t.Fatalf("expected original error, got %v", err) + } +} + +func TestFinalizeMailWatchCleanup(t *testing.T) { + runErr := assertErr("run failed") + cleanupErr := assertErr("cleanup failed") + + if err := finalizeMailWatchCleanup(nil, nil); err != nil { + t.Fatalf("expected nil, got %v", err) + } + if err := finalizeMailWatchCleanup(runErr, nil); err != runErr { + t.Fatalf("expected runErr when cleanup succeeds, got %v", err) + } + if err := finalizeMailWatchCleanup(nil, cleanupErr); err != cleanupErr { + t.Fatalf("expected cleanupErr when run succeeds, got %v", err) + } + if err := finalizeMailWatchCleanup(runErr, cleanupErr); err != runErr { + t.Fatalf("expected primary runErr to win, got %v", err) + } +} + func TestDecodeBodyFieldsForFileDecodesMessageWrapper(t *testing.T) { htmlEncoded := base64.URLEncoding.EncodeToString([]byte("