diff --git a/shortcuts/mail/mail_watch.go b/shortcuts/mail/mail_watch.go index c56994270..3f3e22fe6 100644 --- a/shortcuts/mail/mail_watch.go +++ b/shortcuts/mail/mail_watch.go @@ -18,6 +18,7 @@ import ( "sort" "strings" "sync" + "sync/atomic" "syscall" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" @@ -49,6 +50,18 @@ func (l *mailWatchLogger) Error(_ context.Context, args ...interface{}) { var _ larkcore.Logger = (*mailWatchLogger)(nil) +// handleMailWatchSignal processes a shutdown signal: logs status, unsubscribes +// mailbox events, restores default signal behavior for forced termination, and +// cancels the watch context. +func handleMailWatchSignal(errOut io.Writer, sig os.Signal, eventCount int64, unsubscribeWithLog func(), stopSignals func(), cancel context.CancelFunc) { + fmt.Fprintf(errOut, "\nShutting down (signal: %v)... (received %d events)\n", sig, eventCount) + // Restore default signal behavior so a second Ctrl+C can force terminate. + stopSignals() + signal.Reset(os.Interrupt, syscall.SIGTERM) + unsubscribeWithLog() + cancel() +} + const mailEventType = "mail.user_mailbox.event.message_received_v1" // promptInjectionPatterns lists known prompt injection trigger phrases. @@ -259,19 +272,30 @@ var MailWatch = common.Shortcut{ }) return unsubErr } + var unsubLogOnce sync.Once + unsubscribeWithLog := func() { + unsubLogOnce.Do(func() { + info("Unsubscribing mailbox events...") + if err := unsubscribe(); err != nil { + fmt.Fprintf(errOut, "Warning: unsubscribe failed: %v\n", err) + } else { + info("Mailbox unsubscribed.") + } + }) + } + defer unsubscribeWithLog() // Resolve "me" to the actual email address so we can filter events. mailboxFilter := mailbox if mailbox == "me" { resolved, profileErr := fetchMailboxPrimaryEmail(runtime, "me") if profileErr != nil { - unsubscribe() //nolint:errcheck // best-effort cleanup; primary error is profileErr return enhanceProfileError(profileErr) } mailboxFilter = resolved } - eventCount := 0 + var eventCount atomic.Int64 handleEvent := func(data map[string]interface{}) { // Extract event body @@ -337,7 +361,7 @@ var MailWatch = common.Shortcut{ } } - eventCount++ + eventCount.Add(1) // Prompt injection detection: warn when email body contains known injection patterns. // Body fields may be base64url-encoded; decode before scanning. @@ -424,32 +448,59 @@ var MailWatch = common.Shortcut{ larkws.WithLogger(sdkLogger), ) + watchCtx, cancelWatch := context.WithCancel(ctx) + defer cancelWatch() + sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + stopSignals := func() { signal.Stop(sigCh) } + defer stopSignals() + + shutdownBySignal := make(chan struct{}) + var shutdownOnce sync.Once + triggerShutdown := func() { + shutdownOnce.Do(func() { close(shutdownBySignal) }) + cancelWatch() + } go func() { defer func() { if r := recover(); r != nil { fmt.Fprintf(errOut, "panic in signal handler: %v\n", r) + triggerShutdown() } }() - <-sigCh - info(fmt.Sprintf("\nShutting down... (received %d events)", eventCount)) - info("Unsubscribing mailbox events...") - if unsubErr := unsubscribe(); unsubErr != nil { - fmt.Fprintf(errOut, "Warning: unsubscribe failed: %v\n", unsubErr) - } else { - info("Mailbox unsubscribed.") + select { + case sig := <-sigCh: + handleMailWatchSignal(errOut, sig, eventCount.Load(), unsubscribeWithLog, stopSignals, cancelWatch) + triggerShutdown() + case <-watchCtx.Done(): + return } - signal.Stop(sigCh) - os.Exit(0) + }() + + startErrCh := make(chan error, 1) + go func() { + startErrCh <- cli.Start(watchCtx) }() info("Connected. Waiting for mail events... (Ctrl+C to stop)") - if err := cli.Start(ctx); err != nil { - unsubscribe() //nolint:errcheck // best-effort cleanup - return output.ErrNetwork("WebSocket connection failed: %v", err) + select { + case <-shutdownBySignal: + return nil + case err := <-startErrCh: + if err != nil { + select { + case <-shutdownBySignal: + return nil + default: + } + if watchCtx.Err() != nil { + return nil + } + return output.ErrNetwork("WebSocket connection failed: %v", err) + } + return nil } - return nil }, } diff --git a/shortcuts/mail/mail_watch_test.go b/shortcuts/mail/mail_watch_test.go index 02476fbdf..ef42869c3 100644 --- a/shortcuts/mail/mail_watch_test.go +++ b/shortcuts/mail/mail_watch_test.go @@ -8,8 +8,13 @@ import ( "context" "encoding/base64" "encoding/json" + "fmt" + "io" + "os" "strings" + "sync" "testing" + "time" "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/output" @@ -579,6 +584,101 @@ func TestSetKeysSorted(t *testing.T) { } } +// --- handleMailWatchSignal --- + +// TestHandleMailWatchSignalUnsubscribesAndCancels verifies that all callbacks are invoked and the shutdown message is printed. +func TestHandleMailWatchSignalUnsubscribesAndCancels(t *testing.T) { + var buf bytes.Buffer + unsubscribed := false + stopped := false + canceled := false + + handleMailWatchSignal(&buf, os.Interrupt, 3, func() { + unsubscribed = true + }, func() { + stopped = true + }, func() { + canceled = true + }) + + if !unsubscribed { + t.Fatal("expected unsubscribeWithLog to be called") + } + if !stopped { + t.Fatal("expected signal stop to be called") + } + if !canceled { + t.Fatal("expected cancel to be called") + } + out := buf.String() + if !strings.Contains(out, "Shutting down (signal: interrupt)... (received 3 events)") { + t.Fatalf("missing shutdown message, got: %q", out) + } +} + +// TestHandleMailWatchSignalReportsUnsubscribeFailure verifies that unsubscribe errors are written to errOut. +func TestHandleMailWatchSignalReportsUnsubscribeFailure(t *testing.T) { + var buf bytes.Buffer + + handleMailWatchSignal(&buf, os.Interrupt, 1, func() { + fmt.Fprintln(&buf, "Warning: unsubscribe failed: boom") + }, func() {}, func() {}) + + if got := buf.String(); !strings.Contains(got, "Warning: unsubscribe failed: boom") { + t.Fatalf("expected unsubscribe warning, got: %q", got) + } +} + +// TestHandleMailWatchSignalPanicUnblocksShutdown verifies that a panic in unsubscribeWithLog still triggers shutdown. +func TestHandleMailWatchSignalPanicUnblocksShutdown(t *testing.T) { + shutdownBySignal := make(chan struct{}) + var shutdownOnce sync.Once + _, cancelWatch := context.WithCancel(context.Background()) + triggerShutdown := func() { + shutdownOnce.Do(func() { close(shutdownBySignal) }) + cancelWatch() + } + + sigCh := make(chan os.Signal, 1) + go func() { + defer func() { + if r := recover(); r != nil { + triggerShutdown() + } + }() + <-sigCh + // Simulate panic inside handleMailWatchSignal (e.g. unsubscribeWithLog panics) + panic("unsubscribe exploded") + }() + + sigCh <- os.Interrupt + + select { + case <-shutdownBySignal: + // Success: shutdown channel was closed despite the panic + case <-time.After(2 * time.Second): + t.Fatal("shutdownBySignal was not closed after panic — process would hang") + } +} + +// TestHandleMailWatchSignalCallOrder verifies callbacks execute in order: stop signals → unsubscribe → cancel. +func TestHandleMailWatchSignalCallOrder(t *testing.T) { + var order []string + + handleMailWatchSignal(io.Discard, os.Interrupt, 0, func() { + order = append(order, "unsub") + }, func() { + order = append(order, "stop") + }, func() { + order = append(order, "cancel") + }) + + // Expected: stop → unsub → cancel + if len(order) != 3 || order[0] != "stop" || order[1] != "unsub" || order[2] != "cancel" { + t.Fatalf("unexpected call order: %v, want [stop unsub cancel]", order) + } +} + func assertErr(msg string) error { return &testErr{msg: msg} }