Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 69 additions & 16 deletions shortcuts/mail/mail_watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"sort"
"strings"
"sync"
"sync/atomic"
"syscall"

larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
Expand Down Expand Up @@ -76,6 +77,33 @@ func detectPromptInjection(content string) bool {
return false
}

func waitForMailWatchShutdown(startErrCh <-chan error, shutdownBySignal <-chan struct{}) error {
select {
case <-shutdownBySignal:
return nil
case err := <-startErrCh:
select {
case <-shutdownBySignal:
return nil
default:
}
if err == nil || errors.Is(err, context.Canceled) {
return nil
}
return err
}
}

func finalizeMailWatchCleanup(runErr, cleanupErr error) error {
if cleanupErr == nil {
return runErr
}
if runErr != nil {
return runErr
}
return cleanupErr
}

var MailWatch = common.Shortcut{
Service: "mail",
Command: "+watch",
Expand Down Expand Up @@ -162,7 +190,7 @@ var MailWatch = common.Shortcut{
}
return d
},
Execute: func(ctx context.Context, runtime *common.RuntimeContext) error {
Execute: func(ctx context.Context, runtime *common.RuntimeContext) (retErr error) {
if runtime.Bool("print-output-schema") {
printWatchOutputSchema(runtime)
return nil
Expand Down Expand Up @@ -259,19 +287,34 @@ var MailWatch = common.Shortcut{
})
return unsubErr
}

var unsubscribeLogOnce sync.Once
unsubscribeWithLog := func(primaryErr error) error {
unsubscribeLogOnce.Do(func() {
info("Unsubscribing mailbox events...")
if err := unsubscribe(); err != nil {
if primaryErr != nil {
fmt.Fprintf(errOut, "Warning: unsubscribe failed during cleanup: %v\n", err)
}
} else {
info("Mailbox unsubscribed.")
}
})
return unsubErr
}
defer func() {
retErr = finalizeMailWatchCleanup(retErr, unsubscribeWithLog(retErr))
}()
// Resolve "me" to the actual email address so we can filter events.
mailboxFilter := mailbox
if mailbox == "me" {
resolved, profileErr := fetchMailboxPrimaryEmail(runtime, "me")
if profileErr != nil {
unsubscribe() //nolint:errcheck // best-effort cleanup; primary error is profileErr
return enhanceProfileError(profileErr)
}
mailboxFilter = resolved
}

eventCount := 0
var eventCount int64

handleEvent := func(data map[string]interface{}) {
// Extract event body
Expand Down Expand Up @@ -337,7 +380,7 @@ var MailWatch = common.Shortcut{
}
}

eventCount++
atomic.AddInt64(&eventCount, 1)

// Prompt injection detection: warn when email body contains known injection patterns.
// Body fields may be base64url-encoded; decode before scanning.
Expand Down Expand Up @@ -426,27 +469,37 @@ var MailWatch = common.Shortcut{

sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
defer signal.Stop(sigCh)

watchCtx, cancelWatch := context.WithCancel(ctx)
defer cancelWatch()

shutdownBySignal := make(chan struct{})
go func() {
defer func() {
if r := recover(); r != nil {
fmt.Fprintf(errOut, "panic in signal handler: %v\n", r)
}
}()
<-sigCh
info(fmt.Sprintf("\nShutting down... (received %d events)", eventCount))
info("Unsubscribing mailbox events...")
if unsubErr := unsubscribe(); unsubErr != nil {
fmt.Fprintf(errOut, "Warning: unsubscribe failed: %v\n", unsubErr)
} else {
info("Mailbox unsubscribed.")
select {
case <-sigCh:
// Restore default signal behavior so a second Ctrl+C can force terminate.
signal.Stop(sigCh)
signal.Reset(os.Interrupt, syscall.SIGTERM)
info(fmt.Sprintf("\nShutting down... (received %d events)", atomic.LoadInt64(&eventCount)))
close(shutdownBySignal)
cancelWatch()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
case <-watchCtx.Done():
return
}
signal.Stop(sigCh)
os.Exit(0)
}()

info("Connected. Waiting for mail events... (Ctrl+C to stop)")
if err := cli.Start(ctx); err != nil {
unsubscribe() //nolint:errcheck // best-effort cleanup
startErrCh := make(chan error, 1)
go func() {
startErrCh <- cli.Start(watchCtx)
}()
if err := waitForMailWatchShutdown(startErrCh, shutdownBySignal); err != nil {
return output.ErrNetwork("WebSocket connection failed: %v", err)
}
return nil
Expand Down
60 changes: 60 additions & 0 deletions shortcuts/mail/mail_watch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"encoding/json"
"strings"
"testing"
"time"

"github.com/larksuite/cli/internal/core"
"github.com/larksuite/cli/internal/output"
Expand Down Expand Up @@ -264,6 +265,65 @@ func TestMailWatchLoggerSuppressesDebugAlways(t *testing.T) {
}
}

func TestWaitForMailWatchShutdownReturnsOnSignalWithoutStartResult(t *testing.T) {
startErrCh := make(chan error)
shutdownBySignal := make(chan struct{})
close(shutdownBySignal)

done := make(chan error, 1)
go func() {
done <- waitForMailWatchShutdown(startErrCh, shutdownBySignal)
}()

select {
case err := <-done:
if err != nil {
t.Fatalf("expected nil on signal shutdown, got %v", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatal("waitForMailWatchShutdown blocked after signal shutdown")
}
}

func TestWaitForMailWatchShutdownReturnsNilOnContextCanceled(t *testing.T) {
startErrCh := make(chan error, 1)
shutdownBySignal := make(chan struct{})
startErrCh <- context.Canceled

if err := waitForMailWatchShutdown(startErrCh, shutdownBySignal); err != nil {
t.Fatalf("expected nil for context.Canceled, got %v", err)
}
}

func TestWaitForMailWatchShutdownReturnsStartError(t *testing.T) {
startErrCh := make(chan error, 1)
shutdownBySignal := make(chan struct{})
want := assertErr("boom")
startErrCh <- want

if err := waitForMailWatchShutdown(startErrCh, shutdownBySignal); err != want {
t.Fatalf("expected original error, got %v", err)
}
}

func TestFinalizeMailWatchCleanup(t *testing.T) {
runErr := assertErr("run failed")
cleanupErr := assertErr("cleanup failed")

if err := finalizeMailWatchCleanup(nil, nil); err != nil {
t.Fatalf("expected nil, got %v", err)
}
if err := finalizeMailWatchCleanup(runErr, nil); err != runErr {
t.Fatalf("expected runErr when cleanup succeeds, got %v", err)
}
if err := finalizeMailWatchCleanup(nil, cleanupErr); err != cleanupErr {
t.Fatalf("expected cleanupErr when run succeeds, got %v", err)
}
if err := finalizeMailWatchCleanup(runErr, cleanupErr); err != runErr {
t.Fatalf("expected primary runErr to win, got %v", err)
}
}

func TestDecodeBodyFieldsForFileDecodesMessageWrapper(t *testing.T) {
htmlEncoded := base64.URLEncoding.EncodeToString([]byte("<h1>Hello</h1>"))
plainEncoded := base64.URLEncoding.EncodeToString([]byte("Hello plain"))
Expand Down
Loading