Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 70 additions & 13 deletions shortcuts/mail/mail_watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"sort"
"strings"
"sync"
"sync/atomic"
"syscall"

larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
Expand Down Expand Up @@ -50,6 +51,50 @@ func (l *mailWatchLogger) Error(_ context.Context, args ...interface{}) {

var _ larkcore.Logger = (*mailWatchLogger)(nil)

func handleMailWatchSignal(errOut io.Writer, sig os.Signal, eventCount int, unsubscribe func() error, stopSignals func(), cancel context.CancelFunc) {
Comment thread
greptile-apps[bot] marked this conversation as resolved.
fmt.Fprintf(errOut, "\nShutting down (signal: %v)... (received %d events)\n", sig, eventCount)
fmt.Fprintln(errOut, "Unsubscribing mailbox events...")
if unsubErr := unsubscribe(); unsubErr != nil {
fmt.Fprintf(errOut, "Warning: unsubscribe failed: %v\n", unsubErr)
} else {
fmt.Fprintln(errOut, "Mailbox unsubscribed.")
}
if stopSignals != nil {
stopSignals()
}
if cancel != nil {
cancel()
}
}

func newMailWatchUnsubscribeOnce(unsubscribe func() error) func() error {
var once sync.Once
var unsubscribeErr error
return func() error {
once.Do(func() {
if unsubscribe != nil {
unsubscribeErr = unsubscribe()
}
})
return unsubscribeErr
}
}

func waitForMailWatchStart(startErrCh <-chan error, shutdownRequested <-chan struct{}, shutdownComplete <-chan struct{}) error {
select {
case <-shutdownComplete:
return nil
case err := <-startErrCh:
select {
case <-shutdownRequested:
<-shutdownComplete
return nil
default:
return err
}
}
}

const mailEventType = "mail.user_mailbox.event.message_received_v1"

// promptInjectionPatterns lists known prompt injection trigger phrases.
Expand Down Expand Up @@ -272,7 +317,7 @@ var MailWatch = common.Shortcut{
mailboxFilter = resolved
}

eventCount := 0
var eventCount atomic.Int64

handleEvent := func(data map[string]interface{}) {
// Extract event body
Expand Down Expand Up @@ -338,7 +383,7 @@ var MailWatch = common.Shortcut{
}
}

eventCount++
eventCount.Add(1)

// Prompt injection detection: warn when email body contains known injection patterns.
// Body fields may be base64url-encoded; decode before scanning.
Expand Down Expand Up @@ -425,29 +470,41 @@ var MailWatch = common.Shortcut{
larkws.WithLogger(sdkLogger),
)

unsubscribeOnce := newMailWatchUnsubscribeOnce(unsubscribe)

watchCtx, cancel := context.WithCancel(ctx)
defer cancel()
defer unsubscribeOnce() //nolint:errcheck // best-effort cleanup on all exit paths

sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
stopSignals := func() { signal.Stop(sigCh) }
defer stopSignals()

startErrCh := make(chan error, 1)
go func() {
startErrCh <- cli.Start(watchCtx)
}()

shutdownRequested := make(chan struct{})
shutdownComplete := make(chan struct{})
go func() {
defer func() {
if r := recover(); r != nil {
fmt.Fprintf(errOut, "panic in signal handler: %v\n", r)
}
}()
<-sigCh
info(fmt.Sprintf("\nShutting down... (received %d events)", eventCount))
info("Unsubscribing mailbox events...")
if unsubErr := unsubscribe(); unsubErr != nil {
fmt.Fprintf(errOut, "Warning: unsubscribe failed: %v\n", unsubErr)
} else {
info("Mailbox unsubscribed.")
sig, ok := <-sigCh
if !ok {
return
}
signal.Stop(sigCh)
os.Exit(0)
close(shutdownRequested)
handleMailWatchSignal(errOut, sig, int(eventCount.Load()), unsubscribeOnce, stopSignals, cancel)
close(shutdownComplete)
}()

info("Connected. Waiting for mail events... (Ctrl+C to stop)")
if err := cli.Start(ctx); err != nil {
unsubscribe() //nolint:errcheck // best-effort cleanup
if err := waitForMailWatchStart(startErrCh, shutdownRequested, shutdownComplete); err != nil {
return output.ErrNetwork("WebSocket connection failed: %v", err)
}
Comment thread
greptile-apps[bot] marked this conversation as resolved.
return nil
Expand Down
103 changes: 103 additions & 0 deletions shortcuts/mail/mail_watch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"os"
"strings"
"testing"

Expand Down Expand Up @@ -264,6 +266,107 @@ func TestMailWatchLoggerSuppressesDebugAlways(t *testing.T) {
}
}

func TestHandleMailWatchSignal_UnsubscribesAndCancels(t *testing.T) {
var buf bytes.Buffer
unsubscribed := false
stopped := false
canceled := false

handleMailWatchSignal(&buf, os.Interrupt, 3, func() error {
unsubscribed = true
return nil
}, func() {
stopped = true
}, func() {
canceled = true
})

if !unsubscribed {
t.Fatal("expected unsubscribe to be called")
}
if !stopped {
t.Fatal("expected signal stop to be called")
}
if !canceled {
t.Fatal("expected cancel to be called")
}
out := buf.String()
if !strings.Contains(out, "Shutting down (signal: interrupt)... (received 3 events)") {
t.Fatalf("missing shutdown message, got: %q", out)
}
if !strings.Contains(out, "Mailbox unsubscribed.") {
t.Fatalf("missing unsubscribe success message, got: %q", out)
}
}

func TestHandleMailWatchSignal_ReportsUnsubscribeFailure(t *testing.T) {
var buf bytes.Buffer

handleMailWatchSignal(&buf, os.Interrupt, 1, func() error {
return errors.New("boom")
}, nil, nil)

if got := buf.String(); !strings.Contains(got, "Warning: unsubscribe failed: boom") {
t.Fatalf("expected unsubscribe warning, got: %q", got)
}
}

func TestNewMailWatchUnsubscribeOnce_OnlyRunsOnce(t *testing.T) {
calls := 0
unsubscribeOnce := newMailWatchUnsubscribeOnce(func() error {
calls++
return errors.New("boom")
})

if err := unsubscribeOnce(); err == nil || err.Error() != "boom" {
t.Fatalf("expected first unsubscribe error, got %v", err)
}
if err := unsubscribeOnce(); err == nil || err.Error() != "boom" {
t.Fatalf("expected cached unsubscribe error, got %v", err)
}
if calls != 1 {
t.Fatalf("expected unsubscribe to run once, got %d", calls)
}
}

func TestWaitForMailWatchStart_ReturnsNilWhenSignalShutdownWinsRace(t *testing.T) {
startErrCh := make(chan error, 1)
shutdownRequested := make(chan struct{})
shutdownComplete := make(chan struct{})

close(shutdownRequested)
startErrCh <- context.Canceled

done := make(chan error, 1)
go func() {
done <- waitForMailWatchStart(startErrCh, shutdownRequested, shutdownComplete)
}()

select {
case err := <-done:
t.Fatalf("expected waitForMailWatchStart to block until shutdown completes, got %v", err)
default:
}

close(shutdownComplete)

if err := <-done; err != nil {
t.Fatalf("expected nil after shutdown completion, got %v", err)
}
}

func TestWaitForMailWatchStart_ReturnsStartErrorWithoutSignal(t *testing.T) {
startErrCh := make(chan error, 1)
shutdownRequested := make(chan struct{})
shutdownComplete := make(chan struct{})

startErrCh <- context.Canceled

if err := waitForMailWatchStart(startErrCh, shutdownRequested, shutdownComplete); !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled, got %v", err)
}
}

func TestDecodeBodyFieldsForFileDecodesMessageWrapper(t *testing.T) {
htmlEncoded := base64.URLEncoding.EncodeToString([]byte("<h1>Hello</h1>"))
plainEncoded := base64.URLEncoding.EncodeToString([]byte("Hello plain"))
Expand Down
Loading