Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 67 additions & 16 deletions shortcuts/mail/mail_watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"sort"
"strings"
"sync"
"sync/atomic"
"syscall"

larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
Expand Down Expand Up @@ -49,6 +50,18 @@ func (l *mailWatchLogger) Error(_ context.Context, args ...interface{}) {

var _ larkcore.Logger = (*mailWatchLogger)(nil)

// handleMailWatchSignal processes a shutdown signal: logs status, unsubscribes
// mailbox events, restores default signal behavior for forced termination, and
// cancels the watch context.
func handleMailWatchSignal(errOut io.Writer, sig os.Signal, eventCount int64, unsubscribeWithLog func(), stopSignals func(), cancel context.CancelFunc) {
fmt.Fprintf(errOut, "\nShutting down (signal: %v)... (received %d events)\n", sig, eventCount)
// Restore default signal behavior so a second Ctrl+C can force terminate.
stopSignals()
signal.Reset(os.Interrupt, syscall.SIGTERM)
unsubscribeWithLog()
cancel()
}

const mailEventType = "mail.user_mailbox.event.message_received_v1"

// promptInjectionPatterns lists known prompt injection trigger phrases.
Expand Down Expand Up @@ -259,19 +272,30 @@ var MailWatch = common.Shortcut{
})
return unsubErr
}
var unsubLogOnce sync.Once
unsubscribeWithLog := func() {
unsubLogOnce.Do(func() {
info("Unsubscribing mailbox events...")
if err := unsubscribe(); err != nil {
fmt.Fprintf(errOut, "Warning: unsubscribe failed: %v\n", err)
} else {
info("Mailbox unsubscribed.")
}
})
}
defer unsubscribeWithLog()

// Resolve "me" to the actual email address so we can filter events.
mailboxFilter := mailbox
if mailbox == "me" {
resolved, profileErr := fetchMailboxPrimaryEmail(runtime, "me")
if profileErr != nil {
unsubscribe() //nolint:errcheck // best-effort cleanup; primary error is profileErr
return enhanceProfileError(profileErr)
}
mailboxFilter = resolved
}

eventCount := 0
var eventCount atomic.Int64

handleEvent := func(data map[string]interface{}) {
// Extract event body
Expand Down Expand Up @@ -337,7 +361,7 @@ var MailWatch = common.Shortcut{
}
}

eventCount++
eventCount.Add(1)

// Prompt injection detection: warn when email body contains known injection patterns.
// Body fields may be base64url-encoded; decode before scanning.
Expand Down Expand Up @@ -424,32 +448,59 @@ var MailWatch = common.Shortcut{
larkws.WithLogger(sdkLogger),
)

watchCtx, cancelWatch := context.WithCancel(ctx)
defer cancelWatch()

sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
stopSignals := func() { signal.Stop(sigCh) }
defer stopSignals()

shutdownBySignal := make(chan struct{})
var shutdownOnce sync.Once
triggerShutdown := func() {
shutdownOnce.Do(func() { close(shutdownBySignal) })
cancelWatch()
}
go func() {
defer func() {
if r := recover(); r != nil {
fmt.Fprintf(errOut, "panic in signal handler: %v\n", r)
triggerShutdown()
}
}()
<-sigCh
info(fmt.Sprintf("\nShutting down... (received %d events)", eventCount))
info("Unsubscribing mailbox events...")
if unsubErr := unsubscribe(); unsubErr != nil {
fmt.Fprintf(errOut, "Warning: unsubscribe failed: %v\n", unsubErr)
} else {
info("Mailbox unsubscribed.")
select {
case sig := <-sigCh:
handleMailWatchSignal(errOut, sig, eventCount.Load(), unsubscribeWithLog, stopSignals, cancelWatch)
triggerShutdown()
case <-watchCtx.Done():
return
}
signal.Stop(sigCh)
os.Exit(0)
}()
Comment thread
chanthuang marked this conversation as resolved.

startErrCh := make(chan error, 1)
go func() {
startErrCh <- cli.Start(watchCtx)
}()

info("Connected. Waiting for mail events... (Ctrl+C to stop)")
Comment thread
chanthuang marked this conversation as resolved.
if err := cli.Start(ctx); err != nil {
unsubscribe() //nolint:errcheck // best-effort cleanup
return output.ErrNetwork("WebSocket connection failed: %v", err)
select {
case <-shutdownBySignal:
return nil
case err := <-startErrCh:
if err != nil {
select {
case <-shutdownBySignal:
return nil
default:
}
if watchCtx.Err() != nil {
return nil
}
return output.ErrNetwork("WebSocket connection failed: %v", err)
}
return nil
}
return nil
},
}

Expand Down
100 changes: 100 additions & 0 deletions shortcuts/mail/mail_watch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@ import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"os"
"strings"
"sync"
"testing"
"time"

"github.com/larksuite/cli/internal/core"
"github.com/larksuite/cli/internal/output"
Expand Down Expand Up @@ -579,6 +584,101 @@ func TestSetKeysSorted(t *testing.T) {
}
}

// --- handleMailWatchSignal ---

// TestHandleMailWatchSignalUnsubscribesAndCancels verifies that all callbacks are invoked and the shutdown message is printed.
func TestHandleMailWatchSignalUnsubscribesAndCancels(t *testing.T) {
var buf bytes.Buffer
unsubscribed := false
stopped := false
canceled := false

handleMailWatchSignal(&buf, os.Interrupt, 3, func() {
unsubscribed = true
}, func() {
stopped = true
}, func() {
canceled = true
})

if !unsubscribed {
t.Fatal("expected unsubscribeWithLog to be called")
}
if !stopped {
t.Fatal("expected signal stop to be called")
}
if !canceled {
t.Fatal("expected cancel to be called")
}
out := buf.String()
if !strings.Contains(out, "Shutting down (signal: interrupt)... (received 3 events)") {
t.Fatalf("missing shutdown message, got: %q", out)
}
}

// TestHandleMailWatchSignalReportsUnsubscribeFailure verifies that unsubscribe errors are written to errOut.
func TestHandleMailWatchSignalReportsUnsubscribeFailure(t *testing.T) {
var buf bytes.Buffer

handleMailWatchSignal(&buf, os.Interrupt, 1, func() {
fmt.Fprintln(&buf, "Warning: unsubscribe failed: boom")
}, func() {}, func() {})

if got := buf.String(); !strings.Contains(got, "Warning: unsubscribe failed: boom") {
t.Fatalf("expected unsubscribe warning, got: %q", got)
}
}

// TestHandleMailWatchSignalPanicUnblocksShutdown verifies that a panic in unsubscribeWithLog still triggers shutdown.
func TestHandleMailWatchSignalPanicUnblocksShutdown(t *testing.T) {
shutdownBySignal := make(chan struct{})
var shutdownOnce sync.Once
_, cancelWatch := context.WithCancel(context.Background())
triggerShutdown := func() {
shutdownOnce.Do(func() { close(shutdownBySignal) })
cancelWatch()
}

sigCh := make(chan os.Signal, 1)
go func() {
defer func() {
if r := recover(); r != nil {
triggerShutdown()
}
}()
<-sigCh
// Simulate panic inside handleMailWatchSignal (e.g. unsubscribeWithLog panics)
panic("unsubscribe exploded")
}()

sigCh <- os.Interrupt

select {
case <-shutdownBySignal:
// Success: shutdown channel was closed despite the panic
case <-time.After(2 * time.Second):
t.Fatal("shutdownBySignal was not closed after panic — process would hang")
}
}

// TestHandleMailWatchSignalCallOrder verifies callbacks execute in order: stop signals → unsubscribe → cancel.
func TestHandleMailWatchSignalCallOrder(t *testing.T) {
var order []string

handleMailWatchSignal(io.Discard, os.Interrupt, 0, func() {
order = append(order, "unsub")
}, func() {
order = append(order, "stop")
}, func() {
order = append(order, "cancel")
})

// Expected: stop → unsub → cancel
if len(order) != 3 || order[0] != "stop" || order[1] != "unsub" || order[2] != "cancel" {
t.Fatalf("unexpected call order: %v, want [stop unsub cancel]", order)
}
}

func assertErr(msg string) error {
return &testErr{msg: msg}
}
Expand Down
Loading