diff --git a/cmd/api/api.go b/cmd/api/api.go index 1c8697650..2964d80af 100644 --- a/cmd/api/api.go +++ b/cmd/api/api.go @@ -239,12 +239,13 @@ func apiRun(opts *APIOptions) error { return output.MarkRaw(client.WrapDoAPIError(err)) } err = client.HandleResponse(resp, client.ResponseOptions{ - OutputPath: opts.Output, - Format: format, - JqExpr: opts.JqExpr, - Out: out, - ErrOut: f.IOStreams.ErrOut, - FileIO: f.ResolveFileIO(opts.Ctx), + OutputPath: opts.Output, + Format: format, + JqExpr: opts.JqExpr, + Out: out, + ErrOut: f.IOStreams.ErrOut, + FileIO: f.ResolveFileIO(opts.Ctx), + CommandPath: opts.Cmd.CommandPath(), }) // MarkRaw tells root error handler to skip enrichPermissionError, // preserving the original API error detail (log_id, troubleshooter, etc.). diff --git a/cmd/service/service.go b/cmd/service/service.go index cd1f7c364..4b8cbeab5 100644 --- a/cmd/service/service.go +++ b/cmd/service/service.go @@ -272,13 +272,14 @@ func serviceMethodRun(opts *ServiceMethodOptions) error { return output.ErrNetwork("API call failed: %s", err) } return client.HandleResponse(resp, client.ResponseOptions{ - OutputPath: opts.Output, - Format: format, - JqExpr: opts.JqExpr, - Out: out, - ErrOut: f.IOStreams.ErrOut, - FileIO: f.ResolveFileIO(opts.Ctx), - CheckError: checkErr, + OutputPath: opts.Output, + Format: format, + JqExpr: opts.JqExpr, + Out: out, + ErrOut: f.IOStreams.ErrOut, + FileIO: f.ResolveFileIO(opts.Ctx), + CommandPath: opts.Cmd.CommandPath(), + CheckError: checkErr, }) } diff --git a/extension/contentsafety/registry.go b/extension/contentsafety/registry.go new file mode 100644 index 000000000..af6df0f6a --- /dev/null +++ b/extension/contentsafety/registry.go @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package contentsafety + +import "sync" + +var ( + mu sync.Mutex + provider Provider +) + +// Register installs a content-safety Provider. Later registrations +// override earlier ones (last-write-wins). +// Typically called from init() via blank import. +func Register(p Provider) { + mu.Lock() + defer mu.Unlock() + provider = p +} + +// GetProvider returns the currently registered Provider. +// Returns nil if no provider has been registered. +func GetProvider() Provider { + mu.Lock() + defer mu.Unlock() + return provider +} diff --git a/extension/contentsafety/types.go b/extension/contentsafety/types.go new file mode 100644 index 000000000..5304f3234 --- /dev/null +++ b/extension/contentsafety/types.go @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package contentsafety + +import ( + "context" + "io" +) + +// Provider scans parsed response data for content-safety issues. +// Implementations must be safe for concurrent use. +type Provider interface { + Name() string + Scan(ctx context.Context, req ScanRequest) (*Alert, error) +} + +// ScanRequest carries the data to scan. +type ScanRequest struct { + Path string // normalized command path (e.g. "im.messages_search") + Data any // parsed response data (generic JSON shape) + ErrOut io.Writer // stderr for provider-level notices (e.g. lazy-config creation) +} + +// Alert holds the result of a content-safety scan that detected issues. +type Alert struct { + Provider string `json:"provider"` + MatchedRules []string `json:"matched_rules"` +} diff --git a/extension/contentsafety/types_test.go b/extension/contentsafety/types_test.go new file mode 100644 index 000000000..5e9f72a24 --- /dev/null +++ b/extension/contentsafety/types_test.go @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package contentsafety + +import ( + "context" + "io" + "testing" +) + +func TestAlertFields(t *testing.T) { + a := &Alert{ + Provider: "regex", + MatchedRules: []string{"rule_a", "rule_b"}, + } + if a.Provider != "regex" { + t.Errorf("Provider = %q, want %q", a.Provider, "regex") + } + if len(a.MatchedRules) != 2 { + t.Errorf("MatchedRules length = %d, want 2", len(a.MatchedRules)) + } +} + +type stubProvider struct{} + +func (s *stubProvider) Name() string { return "stub" } +func (s *stubProvider) Scan(_ context.Context, _ ScanRequest) (*Alert, error) { + return &Alert{Provider: "stub", MatchedRules: []string{"test"}}, nil +} + +func TestProviderInterface(t *testing.T) { + var p Provider = &stubProvider{} + if p.Name() != "stub" { + t.Errorf("Name() = %q, want %q", p.Name(), "stub") + } + alert, err := p.Scan(context.Background(), ScanRequest{Path: "test", Data: nil, ErrOut: io.Discard}) + if err != nil { + t.Fatalf("Scan() error = %v", err) + } + if alert.Provider != "stub" { + t.Errorf("alert.Provider = %q, want %q", alert.Provider, "stub") + } +} + +func TestRegistryLastWriteWins(t *testing.T) { + mu.Lock() + old := provider + provider = nil + mu.Unlock() + defer func() { + mu.Lock() + provider = old + mu.Unlock() + }() + + if GetProvider() != nil { + t.Fatal("expected nil provider initially") + } + p1 := &stubProvider{} + Register(p1) + if GetProvider() != p1 { + t.Fatal("expected p1 after first Register") + } + p2 := &stubProvider{} + Register(p2) + if GetProvider() != p2 { + t.Fatal("expected p2 after second Register (last-write-wins)") + } +} diff --git a/internal/client/response.go b/internal/client/response.go index 4025a7a78..aec73cb04 100644 --- a/internal/client/response.go +++ b/internal/client/response.go @@ -23,12 +23,13 @@ import ( // ResponseOptions configures how HandleResponse routes a raw API response. type ResponseOptions struct { - OutputPath string // --output flag; "" = auto-detect - Format output.Format // output format for JSON responses - JqExpr string // if set, apply jq filter instead of Format - Out io.Writer // stdout - ErrOut io.Writer // stderr - FileIO fileio.FileIO // file transfer abstraction; required when saving files (--output or binary response) + OutputPath string // --output flag; "" = auto-detect + Format output.Format // output format for JSON responses + JqExpr string // if set, apply jq filter instead of Format + Out io.Writer // stdout + ErrOut io.Writer // stderr + FileIO fileio.FileIO // file transfer abstraction; required when saving files (--output or binary response) + CommandPath string // raw cobra CommandPath() for content safety scanning // CheckError is called on parsed JSON results. Nil defaults to CheckLarkResponse. CheckError func(interface{}) error } @@ -60,9 +61,20 @@ func HandleResponse(resp *larkcore.ApiResp, opts ResponseOptions) error { if apiErr := check(result); apiErr != nil { return apiErr } + // Content safety scanning + scanResult := output.ScanForSafety(opts.CommandPath, result, opts.ErrOut) + if scanResult.Blocked { + return scanResult.BlockErr + } if opts.OutputPath != "" { + if scanResult.Alert != nil { + output.WriteAlertWarning(opts.ErrOut, scanResult.Alert) + } return saveAndPrint(opts.FileIO, resp, opts.OutputPath, opts.Out) } + if scanResult.Alert != nil { + output.WriteAlertWarning(opts.ErrOut, scanResult.Alert) + } if opts.JqExpr != "" { return output.JqFilter(opts.Out, result, opts.JqExpr) } diff --git a/internal/cmdutil/factory_default.go b/internal/cmdutil/factory_default.go index ffd97c449..ec858d35d 100644 --- a/internal/cmdutil/factory_default.go +++ b/internal/cmdutil/factory_default.go @@ -21,6 +21,7 @@ import ( "github.com/larksuite/cli/internal/credential" "github.com/larksuite/cli/internal/keychain" "github.com/larksuite/cli/internal/registry" + _ "github.com/larksuite/cli/internal/security/contentsafety" // register content safety provider "github.com/larksuite/cli/internal/util" _ "github.com/larksuite/cli/internal/vfs/localfileio" // register default FileIO provider ) diff --git a/internal/envvars/envvars.go b/internal/envvars/envvars.go index ecb629fd0..41560ec9d 100644 --- a/internal/envvars/envvars.go +++ b/internal/envvars/envvars.go @@ -15,4 +15,7 @@ const ( // Sidecar proxy (auth proxy mode) CliAuthProxy = "LARKSUITE_CLI_AUTH_PROXY" // sidecar HTTP address, e.g. "http://127.0.0.1:16384" CliProxyKey = "LARKSUITE_CLI_PROXY_KEY" // HMAC signing key shared with sidecar + + // Content safety scanning mode + CliContentSafetyMode = "LARKSUITE_CLI_CONTENT_SAFETY_MODE" ) diff --git a/internal/output/emit.go b/internal/output/emit.go new file mode 100644 index 000000000..dfc4598b1 --- /dev/null +++ b/internal/output/emit.go @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package output + +import ( + "errors" + "fmt" + "io" + "strings" + + extcs "github.com/larksuite/cli/extension/contentsafety" +) + +// ScanResult holds the output of ScanForSafety. +type ScanResult struct { + Alert *extcs.Alert + Blocked bool + BlockErr error +} + +// ScanForSafety runs content-safety scanning on the given data. +// cmdPath is the raw cobra CommandPath(). +// When MODE=off, no provider registered, or the command is not allowlisted, +// returns a zero ScanResult. +func ScanForSafety(cmdPath string, data any, errOut io.Writer) ScanResult { + alert, csErr := runContentSafety(cmdPath, data, errOut) + if errors.Is(csErr, errBlocked) { + return ScanResult{ + Alert: alert, + Blocked: true, + BlockErr: wrapBlockError(alert), + } + } + return ScanResult{Alert: alert} +} + +// wrapBlockError creates an ExitError for content-safety block. +func wrapBlockError(alert *extcs.Alert) error { + rules := "" + if alert != nil { + rules = strings.Join(alert.MatchedRules, ", ") + } + return &ExitError{ + Code: ExitContentSafety, + Detail: &ErrDetail{ + Type: "content_safety_blocked", + Message: fmt.Sprintf("content safety violation detected (rules: %s)", rules), + }, + } +} + +// WriteAlertWarning writes a human-readable content-safety warning to w. +// Used by non-JSON output paths (pretty, table, csv) in warn mode. +func WriteAlertWarning(w io.Writer, alert *extcs.Alert) { + if alert == nil { + return + } + fmt.Fprintf(w, "warning: content safety alert from %s (rules: %s)\n", + alert.Provider, strings.Join(alert.MatchedRules, ", ")) +} diff --git a/internal/output/emit_core.go b/internal/output/emit_core.go new file mode 100644 index 000000000..2c5360845 --- /dev/null +++ b/internal/output/emit_core.go @@ -0,0 +1,132 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package output + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + "strings" + "time" + + extcs "github.com/larksuite/cli/extension/contentsafety" + "github.com/larksuite/cli/internal/envvars" +) + +type mode uint8 + +const ( + modeOff mode = iota + modeWarn + modeBlock +) + +// scanTimeout caps the content-safety scan so it cannot dominate CLI latency. +// 100 ms is generous for a regex walk of a typical API response (KB-scale JSON); +// larger responses hit maxDepth/maxStringBytes well before this fires. +const scanTimeout = 100 * time.Millisecond + +// modeFromEnv reads LARKSUITE_CLI_CONTENT_SAFETY_MODE. +func modeFromEnv(errOut io.Writer) mode { + raw := strings.TrimSpace(os.Getenv(envvars.CliContentSafetyMode)) + if raw == "" { + return modeOff + } + switch strings.ToLower(raw) { + case "off": + return modeOff + case "warn": + return modeWarn + case "block": + return modeBlock + default: + fmt.Fprintf(errOut, + "warning: unknown %s value %q, falling back to off\n", + envvars.CliContentSafetyMode, raw) + return modeOff + } +} + +// normalizeCommandPath converts cobra CommandPath() to dotted form. +// "lark-cli im +messages-search" -> "im.messages_search" +func normalizeCommandPath(cobraPath string) string { + segs := strings.Fields(cobraPath) + if len(segs) <= 1 { + return "" + } + segs = segs[1:] + for i, s := range segs { + s = strings.TrimPrefix(s, "+") + s = strings.ReplaceAll(s, "-", "_") + segs[i] = s + } + return strings.Join(segs, ".") +} + +var errBlocked = fmt.Errorf("content safety blocked") + +// runContentSafety orchestrates the scan: mode check -> provider -> scan with timeout + panic recovery. +func runContentSafety(cobraPath string, data any, errOut io.Writer) (*extcs.Alert, error) { + m := modeFromEnv(errOut) + if m == modeOff { + return nil, nil + } + + p := extcs.GetProvider() + if p == nil { + return nil, nil + } + + cmdPath := normalizeCommandPath(cobraPath) + if cmdPath == "" { + return nil, nil + } + + type result struct { + alert *extcs.Alert + err error + } + ch := make(chan result, 1) + ctx, cancel := context.WithTimeout(context.Background(), scanTimeout) + defer cancel() + + // Give the goroutine its own writer so it cannot race on errOut after timeout. + // On success, we copy any provider notices to the real errOut. + // On timeout, the buffer is owned by the goroutine until it finishes; no shared access. + scanErrBuf := &bytes.Buffer{} + go func() { + defer func() { + if r := recover(); r != nil { + ch <- result{nil, fmt.Errorf("content safety panic: %v", r)} + } + }() + a, e := p.Scan(ctx, extcs.ScanRequest{Path: cmdPath, Data: data, ErrOut: scanErrBuf}) + ch <- result{a, e} + }() + + var res result + select { + case res = <-ch: + if scanErrBuf.Len() > 0 { + _, _ = io.Copy(errOut, scanErrBuf) + } + case <-ctx.Done(): + return nil, nil // timeout, fail-open; scanErrBuf stays with the goroutine + } + + if res.err != nil { + fmt.Fprintf(errOut, "warning: content safety scan error: %v\n", res.err) + return nil, nil // fail-open + } + if res.alert == nil { + return nil, nil + } + + if m == modeBlock { + return res.alert, errBlocked + } + return res.alert, nil +} diff --git a/internal/output/emit_core_test.go b/internal/output/emit_core_test.go new file mode 100644 index 000000000..5c2107b92 --- /dev/null +++ b/internal/output/emit_core_test.go @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package output + +import ( + "bytes" + "testing" +) + +func TestModeFromEnv(t *testing.T) { + tests := []struct { + name string + envVal string + want mode + wantWarn bool + }{ + {"empty", "", modeOff, false}, + {"off", "off", modeOff, false}, + {"OFF", "OFF", modeOff, false}, + {"warn", "warn", modeWarn, false}, + {"WARN", "WARN", modeWarn, false}, + {"block", "block", modeBlock, false}, + {"unknown", "banana", modeOff, true}, + {"whitespace", " warn ", modeWarn, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONTENT_SAFETY_MODE", tt.envVal) + var buf bytes.Buffer + got := modeFromEnv(&buf) + if got != tt.want { + t.Errorf("modeFromEnv() = %d, want %d", got, tt.want) + } + if tt.wantWarn && buf.Len() == 0 { + t.Error("expected stderr warning") + } + if !tt.wantWarn && buf.Len() > 0 { + t.Errorf("unexpected stderr: %s", buf.String()) + } + }) + } +} + +func TestNormalizeCommandPath(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"lark-cli im +messages-search", "im.messages_search"}, + {"lark-cli drive upload +file", "drive.upload.file"}, + {"lark-cli api GET /path", "api.GET./path"}, + {"lark-cli", ""}, + {"", ""}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := normalizeCommandPath(tt.input) + if got != tt.want { + t.Errorf("normalizeCommandPath(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/output/emit_test.go b/internal/output/emit_test.go new file mode 100644 index 000000000..a25c1e620 --- /dev/null +++ b/internal/output/emit_test.go @@ -0,0 +1,149 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package output + +import ( + "bytes" + "context" + "errors" + "strings" + "testing" + "time" + + extcs "github.com/larksuite/cli/extension/contentsafety" +) + +// mockProvider is a test provider that returns a configurable alert. +type mockProvider struct { + name string + alert *extcs.Alert + err error +} + +func (m *mockProvider) Name() string { return m.name } +func (m *mockProvider) Scan(_ context.Context, _ extcs.ScanRequest) (*extcs.Alert, error) { + return m.alert, m.err +} + +func TestScanForSafety_ModeOff(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONTENT_SAFETY_MODE", "off") + var buf bytes.Buffer + result := ScanForSafety("lark-cli im +messages-search", map[string]any{"text": "inject"}, &buf) + if result.Alert != nil || result.Blocked { + t.Error("mode=off should produce zero ScanResult") + } +} + +func TestScanForSafety_ModeWarn_WithAlert(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONTENT_SAFETY_MODE", "warn") + alert := &extcs.Alert{Provider: "mock", MatchedRules: []string{"r1"}} + mp := &mockProvider{name: "mock", alert: alert} + + // Register mock provider (save and restore) + extcs.Register(mp) + defer extcs.Register(nil) + + var buf bytes.Buffer + result := ScanForSafety("lark-cli im +test", map[string]any{}, &buf) + if result.Alert == nil { + t.Fatal("expected non-nil alert in warn mode") + } + if result.Blocked { + t.Error("warn mode should not block") + } + if result.BlockErr != nil { + t.Error("warn mode should not have BlockErr") + } +} + +func TestScanForSafety_ModeBlock_WithAlert(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONTENT_SAFETY_MODE", "block") + alert := &extcs.Alert{Provider: "mock", MatchedRules: []string{"r1"}} + mp := &mockProvider{name: "mock", alert: alert} + extcs.Register(mp) + defer extcs.Register(nil) + + var buf bytes.Buffer + result := ScanForSafety("lark-cli im +test", map[string]any{}, &buf) + if !result.Blocked { + t.Error("block mode with alert should set Blocked=true") + } + if result.BlockErr == nil { + t.Error("block mode with alert should have BlockErr") + } + var exitErr *ExitError + if !errors.As(result.BlockErr, &exitErr) { + t.Fatalf("BlockErr should be *ExitError, got %T", result.BlockErr) + } + if exitErr.Code != ExitContentSafety { + t.Errorf("exit code = %d, want %d", exitErr.Code, ExitContentSafety) + } +} + +func TestScanForSafety_NoProvider(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONTENT_SAFETY_MODE", "warn") + extcs.Register(nil) + + var buf bytes.Buffer + result := ScanForSafety("lark-cli im +test", map[string]any{}, &buf) + if result.Alert != nil || result.Blocked { + t.Error("no provider should produce zero ScanResult") + } +} + +func TestScanForSafety_ScanError_FailOpen(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONTENT_SAFETY_MODE", "block") + mp := &mockProvider{name: "mock", err: errors.New("scan broke")} + extcs.Register(mp) + defer extcs.Register(nil) + + var buf bytes.Buffer + result := ScanForSafety("lark-cli im +test", map[string]any{}, &buf) + if result.Blocked { + t.Error("scan error should fail-open, not block") + } + if !strings.Contains(buf.String(), "scan error") { + t.Errorf("expected warning on stderr, got: %s", buf.String()) + } +} + +func TestScanForSafety_SlowProvider_Timeout_FailOpen(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONTENT_SAFETY_MODE", "block") + + slow := &slowProvider{} + extcs.Register(slow) + defer extcs.Register(nil) + + var buf bytes.Buffer + result := ScanForSafety("lark-cli im +test", map[string]any{}, &buf) + if result.Blocked { + t.Error("slow provider should fail-open on timeout, not block") + } + if result.Alert != nil { + t.Error("slow provider should return nil alert on timeout") + } +} + +// slowProvider blocks for longer than scanTimeout to trigger the timeout path. +type slowProvider struct{} + +func (s *slowProvider) Name() string { return "slow" } +func (s *slowProvider) Scan(ctx context.Context, _ extcs.ScanRequest) (*extcs.Alert, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(200 * time.Millisecond): + return &extcs.Alert{Provider: "slow", MatchedRules: []string{"never"}}, nil + } +} + +func TestWriteAlertWarning(t *testing.T) { + alert := &extcs.Alert{Provider: "regex", MatchedRules: []string{"r1", "r2"}} + var buf bytes.Buffer + WriteAlertWarning(&buf, alert) + got := buf.String() + if !strings.Contains(got, "r1") || !strings.Contains(got, "r2") { + t.Errorf("warning should contain rule IDs, got: %s", got) + } +} diff --git a/internal/output/envelope.go b/internal/output/envelope.go index e76b6d5c0..0109f643e 100644 --- a/internal/output/envelope.go +++ b/internal/output/envelope.go @@ -5,11 +5,12 @@ package output // Envelope is the standard success response wrapper. type Envelope struct { - OK bool `json:"ok"` - Identity string `json:"identity,omitempty"` - Data interface{} `json:"data,omitempty"` - Meta *Meta `json:"meta,omitempty"` - Notice map[string]interface{} `json:"_notice,omitempty"` + OK bool `json:"ok"` + Identity string `json:"identity,omitempty"` + Data interface{} `json:"data,omitempty"` + Meta *Meta `json:"meta,omitempty"` + ContentSafetyAlert interface{} `json:"_content_safety_alert,omitempty"` + Notice map[string]interface{} `json:"_notice,omitempty"` } // ErrorEnvelope is the standard error response wrapper. diff --git a/internal/output/exitcode.go b/internal/output/exitcode.go index 47628afda..266ae7ce8 100644 --- a/internal/output/exitcode.go +++ b/internal/output/exitcode.go @@ -7,10 +7,11 @@ package output // are communicated via the JSON error envelope's "type" field, // not via exit codes. const ( - ExitOK = 0 // 成功 - ExitAPI = 1 // API / 通用错误(含 permission、not_found、conflict、rate_limit) - ExitValidation = 2 // 参数校验失败 - ExitAuth = 3 // 认证失败(token 无效 / 过期) - ExitNetwork = 4 // 网络错误(连接超时、DNS 解析失败等) - ExitInternal = 5 // 内部错误(不应发生) + ExitOK = 0 // 成功 + ExitAPI = 1 // API / 通用错误(含 permission、not_found、conflict、rate_limit) + ExitValidation = 2 // 参数校验失败 + ExitAuth = 3 // 认证失败(token 无效 / 过期) + ExitNetwork = 4 // 网络错误(连接超时、DNS 解析失败等) + ExitInternal = 5 // 内部错误(不应发生) + ExitContentSafety = 6 // content safety violation (block mode) ) diff --git a/internal/security/contentsafety/config.go b/internal/security/contentsafety/config.go new file mode 100644 index 000000000..88bbb9e2d --- /dev/null +++ b/internal/security/contentsafety/config.go @@ -0,0 +1,109 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package contentsafety + +import ( + "encoding/json" + "fmt" + "io" + "io/fs" + "path/filepath" + "regexp" + "strings" + + "github.com/larksuite/cli/internal/vfs" +) + +const configFileName = "content-safety.json" + +type Config struct { + Allowlist []string + Rules []rule +} + +type rawConfig struct { + Allowlist []string `json:"allowlist"` + Rules []rawRule `json:"rules"` +} + +type rawRule struct { + ID string `json:"id"` + Pattern string `json:"pattern"` +} + +func LoadConfig(configDir string) (*Config, error) { + path := filepath.Join(configDir, configFileName) + data, err := vfs.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read content-safety config: %w", err) + } + var raw rawConfig + if err := json.Unmarshal(data, &raw); err != nil { + return nil, fmt.Errorf("parse content-safety config: %w", err) + } + rules := make([]rule, 0, len(raw.Rules)) + for _, r := range raw.Rules { + compiled, err := regexp.Compile(r.Pattern) + if err != nil { + return nil, fmt.Errorf("compile rule %q pattern: %w", r.ID, err) + } + rules = append(rules, rule{ID: r.ID, Pattern: compiled}) + } + return &Config{Allowlist: raw.Allowlist, Rules: rules}, nil +} + +func EnsureDefaultConfig(configDir string, errOut io.Writer) error { + path := filepath.Join(configDir, configFileName) + if _, err := vfs.Stat(path); err == nil { + return nil + } + if err := vfs.MkdirAll(configDir, 0700); err != nil { + return fmt.Errorf("create config dir: %w", err) + } + data, err := json.MarshalIndent(defaultRawConfig(), "", " ") + if err != nil { + return fmt.Errorf("marshal default config: %w", err) + } + if err := vfs.WriteFile(path, append(data, '\n'), fs.FileMode(0600)); err != nil { + return err + } + fmt.Fprintf(errOut, "notice: created default content-safety config at %s\n", path) + return nil +} + +func defaultRawConfig() rawConfig { + return rawConfig{ + Allowlist: []string{"all"}, + Rules: []rawRule{ + { + ID: "instruction_override", + Pattern: `(?i)ignore\s+(all\s+|any\s+|the\s+)?(previous|prior|above|earlier)\s+(instructions?|prompts?|directives?)`, + }, + { + ID: "role_injection", + Pattern: `(?i)<\s*/?\s*(system|assistant|tool|user|developer)\s*>`, + }, + { + ID: "system_prompt_leak", + Pattern: `(?i)\b(reveal|print|show|output|display|repeat)\s+(your|the|all)\s+(system\s+|initial\s+|original\s+)?(prompt|instructions?|rules?)`, + }, + { + ID: "delimiter_smuggle", + Pattern: `<\|im_(start|end|sep)\|>|<\|endoftext\|>|###\s*(system|assistant|user)\s*:`, + }, + }, + } +} + +func IsAllowlisted(cmdPath string, allowlist []string) bool { + for _, entry := range allowlist { + if strings.EqualFold(entry, "all") { + return true + } + if cmdPath == entry || strings.HasPrefix(cmdPath, entry+".") { + return true + } + } + return false +} diff --git a/internal/security/contentsafety/config_test.go b/internal/security/contentsafety/config_test.go new file mode 100644 index 000000000..44d93fce0 --- /dev/null +++ b/internal/security/contentsafety/config_test.go @@ -0,0 +1,124 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package contentsafety + +import ( + "io" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestLoadConfig_ValidFile(t *testing.T) { + dir := t.TempDir() + content := `{ + "allowlist": ["im", "drive.upload"], + "rules": [{"id": "r1", "pattern": "(?i)test_pattern"}] + }` + if err := os.WriteFile(filepath.Join(dir, "content-safety.json"), []byte(content), 0644); err != nil { + t.Fatal(err) + } + cfg, err := LoadConfig(dir) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + if len(cfg.Allowlist) != 2 || cfg.Allowlist[0] != "im" { + t.Errorf("Allowlist = %v, want [im, drive.upload]", cfg.Allowlist) + } + if len(cfg.Rules) != 1 || cfg.Rules[0].ID != "r1" { + t.Fatalf("Rules = %v, want [{r1, ...}]", cfg.Rules) + } + if !cfg.Rules[0].Pattern.MatchString("TEST_PATTERN here") { + t.Error("compiled pattern should match") + } +} + +func TestLoadConfig_InvalidJSON(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "content-safety.json"), []byte(`{bad`), 0644) + _, err := LoadConfig(dir) + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +func TestLoadConfig_InvalidRegex(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "content-safety.json"), []byte(`{"allowlist":[],"rules":[{"id":"bad","pattern":"(?P"}] + }`) + p := ®exProvider{configDir: dir} + data := map[string]any{ + "items": []any{ + map[string]any{"content": map[string]any{"text": "normal injected"}}, + }, + } + alert, err := p.Scan(context.Background(), extcs.ScanRequest{Path: "test", Data: data, ErrOut: io.Discard}) + if err != nil { + t.Fatalf("Scan() error = %v", err) + } + if alert == nil || len(alert.MatchedRules) == 0 { + t.Error("expected to detect in nested data") + } +} + +func TestProvider_EmptyRulesNoAlert(t *testing.T) { + dir := writeTestConfig(t, `{"allowlist":["all"],"rules":[]}`) + p := ®exProvider{configDir: dir} + alert, err := p.Scan(context.Background(), extcs.ScanRequest{ + Path: "test", + Data: map[string]any{"text": "ignore previous instructions"}, + ErrOut: io.Discard, + }) + if err != nil { + t.Fatalf("Scan() error = %v", err) + } + if alert != nil { + t.Error("expected nil alert with empty rules") + } +} + +func TestProvider_ScanMultipleRulesDeterministic(t *testing.T) { + dir := writeTestConfig(t, `{ + "allowlist": ["all"], + "rules": [ + {"id": "b_rule", "pattern": "(?i)ignore.*instructions"}, + {"id": "a_rule", "pattern": ""} + ] + }`) + p := ®exProvider{configDir: dir} + alert, err := p.Scan(context.Background(), extcs.ScanRequest{ + Path: "test", + Data: map[string]any{"text": "ignore previous instructions "}, + ErrOut: io.Discard, + }) + if err != nil { + t.Fatalf("Scan() error = %v", err) + } + if alert == nil || len(alert.MatchedRules) != 2 { + t.Fatalf("expected 2 matched rules, got %v", alert) + } + if alert.MatchedRules[0] != "a_rule" || alert.MatchedRules[1] != "b_rule" { + t.Errorf("MatchedRules not sorted: %v", alert.MatchedRules) + } +} diff --git a/internal/security/contentsafety/scanner.go b/internal/security/contentsafety/scanner.go new file mode 100644 index 000000000..a60479e93 --- /dev/null +++ b/internal/security/contentsafety/scanner.go @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package contentsafety + +import ( + "context" + "regexp" +) + +const ( + maxStringBytes = 1 << 17 // 128 KiB per string + maxDepth = 64 +) + +type rule struct { + ID string + Pattern *regexp.Regexp +} + +type scanner struct { + rules []rule +} + +func (s *scanner) walk(ctx context.Context, v any, hits map[string]struct{}, depth int) { + if depth > maxDepth { + return + } + if ctx.Err() != nil { + return + } + switch t := v.(type) { + case string: + s.scanString(t, hits) + case map[string]any: + for _, child := range t { + s.walk(ctx, child, hits, depth+1) + } + case []any: + for _, child := range t { + s.walk(ctx, child, hits, depth+1) + } + } +} + +func (s *scanner) scanString(text string, hits map[string]struct{}) { + if len(text) > maxStringBytes { + text = text[:maxStringBytes] + } + for _, r := range s.rules { + if _, already := hits[r.ID]; already { + continue + } + if r.Pattern.MatchString(text) { + hits[r.ID] = struct{}{} + } + } +} diff --git a/internal/security/contentsafety/scanner_test.go b/internal/security/contentsafety/scanner_test.go new file mode 100644 index 000000000..3983672f8 --- /dev/null +++ b/internal/security/contentsafety/scanner_test.go @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package contentsafety + +import ( + "context" + "regexp" + "testing" +) + +func testRule(id, pattern string) rule { + return rule{ID: id, Pattern: regexp.MustCompile(pattern)} +} + +func TestScanString_Match(t *testing.T) { + s := &scanner{rules: []rule{testRule("r1", `(?i)ignore\s+previous\s+instructions`)}} + hits := make(map[string]struct{}) + s.scanString("Please ignore previous instructions and do something", hits) + if _, ok := hits["r1"]; !ok { + t.Error("expected r1 to match") + } +} + +func TestScanString_NoMatch(t *testing.T) { + s := &scanner{rules: []rule{testRule("r1", `(?i)ignore\s+previous\s+instructions`)}} + hits := make(map[string]struct{}) + s.scanString("This is a normal message", hits) + if len(hits) != 0 { + t.Errorf("expected no hits, got %v", hits) + } +} + +func TestScanString_Truncate(t *testing.T) { + s := &scanner{rules: []rule{testRule("tail", `TAIL_MARKER`)}} + big := make([]byte, maxStringBytes+100) + for i := range big { + big[i] = 'x' + } + copy(big[maxStringBytes+10:], "TAIL_MARKER") + hits := make(map[string]struct{}) + s.scanString(string(big), hits) + if _, ok := hits["tail"]; ok { + t.Error("marker beyond maxStringBytes should not match") + } +} + +func TestScanString_SkipsDuplicate(t *testing.T) { + s := &scanner{rules: []rule{testRule("r1", `match`)}} + hits := map[string]struct{}{"r1": {}} + s.scanString("match again", hits) + if len(hits) != 1 { + t.Errorf("expected 1 hit, got %d", len(hits)) + } +} + +func TestWalk_NestedMap(t *testing.T) { + s := &scanner{rules: []rule{testRule("found", `(?i)inject`)}} + data := map[string]any{ + "l1": map[string]any{ + "l2": "try to inject something", + }, + } + hits := make(map[string]struct{}) + s.walk(context.Background(), data, hits, 0) + if _, ok := hits["found"]; !ok { + t.Error("expected to find 'inject' in nested map") + } +} + +func TestWalk_Array(t *testing.T) { + s := &scanner{rules: []rule{testRule("found", `(?i)inject`)}} + hits := make(map[string]struct{}) + s.walk(context.Background(), []any{"normal", "try to inject"}, hits, 0) + if _, ok := hits["found"]; !ok { + t.Error("expected to find 'inject' in array") + } +} + +func TestWalk_MaxDepth(t *testing.T) { + s := &scanner{rules: []rule{testRule("deep", `secret`)}} + var data any = "secret" + for i := 0; i < maxDepth+5; i++ { + data = map[string]any{"n": data} + } + hits := make(map[string]struct{}) + s.walk(context.Background(), data, hits, 0) + if _, ok := hits["deep"]; ok { + t.Error("should not reach string beyond maxDepth") + } +} + +func TestWalk_ContextCancel(t *testing.T) { + s := &scanner{rules: []rule{testRule("found", `target`)}} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + hits := make(map[string]struct{}) + s.walk(ctx, map[string]any{"key": "target"}, hits, 0) + if _, ok := hits["found"]; ok { + t.Error("should not match after context cancel") + } +} diff --git a/shortcuts/common/runner.go b/shortcuts/common/runner.go index e81db01c7..b132d6e6f 100644 --- a/shortcuts/common/runner.go +++ b/shortcuts/common/runner.go @@ -482,7 +482,17 @@ func (ctx *RuntimeContext) ValidatePath(path string) error { // Out prints a success JSON envelope to stdout. func (ctx *RuntimeContext) Out(data interface{}, meta *output.Meta) { + // Content safety scanning + scanResult := output.ScanForSafety(ctx.Cmd.CommandPath(), data, ctx.IO().ErrOut) + if scanResult.Blocked { + ctx.outputErrOnce.Do(func() { ctx.outputErr = scanResult.BlockErr }) + return + } + env := output.Envelope{OK: true, Identity: string(ctx.As()), Data: data, Meta: meta, Notice: output.GetNotice()} + if scanResult.Alert != nil { + env.ContentSafetyAlert = scanResult.Alert + } if ctx.JqExpr != "" { if err := output.JqFilter(ctx.IO().Out, env, ctx.JqExpr); err != nil { fmt.Fprintf(ctx.IO().ErrOut, "error: %v\n", err) @@ -497,23 +507,41 @@ func (ctx *RuntimeContext) Out(data interface{}, meta *output.Meta) { // OutFormat prints output based on --format flag. // "json" (default) outputs JSON envelope; "pretty" calls prettyFn; others delegate to FormatValue. // When JqExpr is set, routes through Out() regardless of format. +// For json/"" and jq paths, Out() handles content safety scanning. +// For pretty/table/csv/ndjson, scanning is done here and the alert is written to stderr. func (ctx *RuntimeContext) OutFormat(data interface{}, meta *output.Meta, prettyFn func(w io.Writer)) { if ctx.JqExpr != "" { - ctx.Out(data, meta) + ctx.Out(data, meta) // Out() handles scanning return } switch ctx.Format { + case "json", "": + ctx.Out(data, meta) // Out() handles scanning case "pretty": + scanResult := output.ScanForSafety(ctx.Cmd.CommandPath(), data, ctx.IO().ErrOut) + if scanResult.Blocked { + ctx.outputErrOnce.Do(func() { ctx.outputErr = scanResult.BlockErr }) + return + } + if scanResult.Alert != nil { + output.WriteAlertWarning(ctx.IO().ErrOut, scanResult.Alert) + } if prettyFn != nil { prettyFn(ctx.IO().Out) } else { ctx.Out(data, meta) } - case "json", "": - ctx.Out(data, meta) default: // table, csv, ndjson — pass data directly; FormatValue handles both // plain arrays and maps with array fields (e.g. {"members":[…]}) + scanResult := output.ScanForSafety(ctx.Cmd.CommandPath(), data, ctx.IO().ErrOut) + if scanResult.Blocked { + ctx.outputErrOnce.Do(func() { ctx.outputErr = scanResult.BlockErr }) + return + } + if scanResult.Alert != nil { + output.WriteAlertWarning(ctx.IO().ErrOut, scanResult.Alert) + } format, formatOK := output.ParseFormat(ctx.Format) if !formatOK { fmt.Fprintf(ctx.IO().ErrOut, "warning: unknown format %q, falling back to json\n", ctx.Format) diff --git a/shortcuts/common/runner_contentsafety_test.go b/shortcuts/common/runner_contentsafety_test.go new file mode 100644 index 000000000..09d012696 --- /dev/null +++ b/shortcuts/common/runner_contentsafety_test.go @@ -0,0 +1,98 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package common + +import ( + "bytes" + "context" + "encoding/json" + "testing" + + "github.com/spf13/cobra" + + extcs "github.com/larksuite/cli/extension/contentsafety" + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/output" +) + +type csTestProvider struct { + alert *extcs.Alert +} + +func (p *csTestProvider) Name() string { return "test" } +func (p *csTestProvider) Scan(_ context.Context, _ extcs.ScanRequest) (*extcs.Alert, error) { + return p.alert, nil +} + +func newCSTestContext(t *testing.T) (*RuntimeContext, *bytes.Buffer, *bytes.Buffer) { + t.Helper() + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + parentCmd := &cobra.Command{Use: "lark-cli"} + cmd := &cobra.Command{Use: "test"} + parentCmd.AddCommand(cmd) + rctx := &RuntimeContext{ + ctx: context.Background(), + Config: &core.CliConfig{Brand: core.BrandFeishu}, + Cmd: cmd, + resolvedAs: core.AsBot, + Factory: &cmdutil.Factory{ + IOStreams: &cmdutil.IOStreams{Out: stdout, ErrOut: stderr}, + }, + } + return rctx, stdout, stderr +} + +func TestOut_ContentSafetyWarn(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONTENT_SAFETY_MODE", "warn") + + alert := &extcs.Alert{Provider: "test", MatchedRules: []string{"r1"}} + extcs.Register(&csTestProvider{alert: alert}) + defer extcs.Register(nil) + + rctx, stdout, _ := newCSTestContext(t) + rctx.Out(map[string]any{"msg": "hello"}, nil) + + var env output.Envelope + if err := json.Unmarshal(stdout.Bytes(), &env); err != nil { + t.Fatalf("unmarshal envelope: %v", err) + } + if env.ContentSafetyAlert == nil { + t.Error("expected _content_safety_alert in envelope") + } +} + +func TestOut_ContentSafetyBlock(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONTENT_SAFETY_MODE", "block") + + alert := &extcs.Alert{Provider: "test", MatchedRules: []string{"r1"}} + extcs.Register(&csTestProvider{alert: alert}) + defer extcs.Register(nil) + + rctx, stdout, _ := newCSTestContext(t) + rctx.Out(map[string]any{"msg": "hello"}, nil) + + if stdout.Len() > 0 { + t.Error("block mode should not write data to stdout") + } + if rctx.outputErr == nil { + t.Error("block mode should set outputErr") + } +} + +func TestOut_ContentSafetyOff(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONTENT_SAFETY_MODE", "off") + + rctx, stdout, _ := newCSTestContext(t) + rctx.Out(map[string]any{"msg": "hello"}, nil) + + var env output.Envelope + if err := json.Unmarshal(stdout.Bytes(), &env); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if env.ContentSafetyAlert != nil { + t.Error("mode=off should not produce alert") + } +}