From 8470c0955a73b9705a8d9b7faddcdb9d864fcd47 Mon Sep 17 00:00:00 2001 From: tuxedomm <273098272+tuxedomm@users.noreply.github.com> Date: Wed, 8 Apr 2026 13:02:56 +0800 Subject: [PATCH 1/3] refactor: migrate common/client/im to FileIO and add localfileio tests - runner resolveInputFlags: replace validate.SafeInputPath + vfs.ReadFile with FileIO.Open + io.ReadAll - SaveResponse: delegate to FileIO.Save + ResolvePath - cmd/api, cmd/service: pass FileIO to ResponseOptions - im: replace validate.SafeLocalFlagPath with RuntimeContext.ValidatePath, migrate download/upload to FileIO.Save/Open/Stat - Add path_test.go and atomicwrite_test.go for localfileio - Add validate_media_test.go for im media flag validation - Adapt test mocks to fileio.FileInfo interface Change-Id: I1cacaf36a07af6b011e275680ac20192422406e7 --- cmd/api/api.go | 1 + cmd/service/service.go | 1 + extension/fileio/errors.go | 43 +++ internal/client/response.go | 47 +-- internal/client/response_test.go | 60 +++- internal/cmdutil/factory_default_test.go | 34 ++ internal/vfs/localfileio/atomicwrite_test.go | 146 +++++++++ internal/vfs/localfileio/localfileio.go | 14 +- internal/vfs/localfileio/localfileio_test.go | 306 ++++++++++++++++++ internal/vfs/localfileio/path_test.go | 245 ++++++++++++++ shortcuts/common/runner.go | 41 ++- shortcuts/common/runner_input_test.go | 1 + shortcuts/common/runner_jq_test.go | 46 ++- shortcuts/im/coverage_additional_test.go | 39 ++- shortcuts/im/helpers.go | 48 +-- shortcuts/im/helpers_network_test.go | 102 +++--- shortcuts/im/helpers_test.go | 8 +- shortcuts/im/im_messages_reply.go | 60 +--- .../im/im_messages_resources_download.go | 84 ++--- shortcuts/im/im_messages_send.go | 75 ++--- shortcuts/im/validate_media_test.go | 51 +++ 21 files changed, 1167 insertions(+), 285 deletions(-) create mode 100644 extension/fileio/errors.go create mode 100644 internal/vfs/localfileio/atomicwrite_test.go create mode 100644 internal/vfs/localfileio/localfileio_test.go create mode 100644 internal/vfs/localfileio/path_test.go create mode 100644 shortcuts/im/validate_media_test.go diff --git a/cmd/api/api.go b/cmd/api/api.go index 084cb059..696eda9b 100644 --- a/cmd/api/api.go +++ b/cmd/api/api.go @@ -207,6 +207,7 @@ func apiRun(opts *APIOptions) error { JqExpr: opts.JqExpr, Out: out, ErrOut: f.IOStreams.ErrOut, + FileIO: f.ResolveFileIO(opts.Ctx), }) // MarkRaw tells root error handler to skip enrichPermissionError, // preserving the original API error detail (log_id, troubleshooter, etc.). diff --git a/cmd/service/service.go b/cmd/service/service.go index 85c62cc3..61759fa6 100644 --- a/cmd/service/service.go +++ b/cmd/service/service.go @@ -250,6 +250,7 @@ func serviceMethodRun(opts *ServiceMethodOptions) error { JqExpr: opts.JqExpr, Out: out, ErrOut: f.IOStreams.ErrOut, + FileIO: f.ResolveFileIO(opts.Ctx), CheckError: checkErr, }) } diff --git a/extension/fileio/errors.go b/extension/fileio/errors.go new file mode 100644 index 00000000..a68bdf7e --- /dev/null +++ b/extension/fileio/errors.go @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package fileio + +import "errors" + +// Sentinel errors for FileIO operations. Callers can use errors.Is to +// distinguish error categories and wrap with their own messages. + +// ErrPathValidation indicates the path failed security validation +// (traversal, absolute, control chars, symlink escape, etc.). +var ErrPathValidation = errors.New("path validation failed") + +// ErrMkdir indicates parent directory creation failed. +var ErrMkdir = errors.New("directory creation failed") + +// PathValidationError wraps a path validation error with ErrPathValidation. +// Both the sentinel (ErrPathValidation) and the original error are +// reachable via errors.Is / errors.As. +type PathValidationError struct { + Err error +} + +func (e *PathValidationError) Error() string { return e.Err.Error() } + +// Unwrap returns both the sentinel and the original error so that +// errors.Is(err, ErrPathValidation) and errors.Is(err, os.ErrPermission) +// (or any OS error in the chain) both work. +func (e *PathValidationError) Unwrap() []error { + return []error{ErrPathValidation, e.Err} +} + +// MkdirError wraps a directory creation error with ErrMkdir. +type MkdirError struct { + Err error +} + +func (e *MkdirError) Error() string { return e.Err.Error() } + +func (e *MkdirError) Unwrap() []error { + return []error{ErrMkdir, e.Err} +} diff --git a/internal/client/response.go b/internal/client/response.go index 10695614..9c4d385a 100644 --- a/internal/client/response.go +++ b/internal/client/response.go @@ -6,18 +6,17 @@ package client import ( "bytes" "encoding/json" + "errors" "fmt" "io" "mime" - "path/filepath" "strings" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" + "github.com/larksuite/cli/extension/fileio" "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/util" - "github.com/larksuite/cli/internal/validate" - "github.com/larksuite/cli/internal/vfs" ) // ── Response routing ── @@ -29,6 +28,7 @@ type ResponseOptions struct { JqExpr string // if set, apply jq filter instead of Format Out io.Writer // stdout ErrOut io.Writer // stderr + FileIO fileio.FileIO // file transfer abstraction; nil falls back to direct os calls // CheckError is called on parsed JSON results. Nil defaults to CheckLarkResponse. CheckError func(interface{}) error } @@ -61,7 +61,7 @@ func HandleResponse(resp *larkcore.ApiResp, opts ResponseOptions) error { return apiErr } if opts.OutputPath != "" { - return saveAndPrint(resp, opts.OutputPath, opts.Out) + return saveAndPrint(opts.FileIO, resp, opts.OutputPath, opts.Out) } if opts.JqExpr != "" { return output.JqFilter(opts.Out, result, opts.JqExpr) @@ -75,11 +75,11 @@ func HandleResponse(resp *larkcore.ApiResp, opts ResponseOptions) error { return output.ErrValidation("--jq requires a JSON response (got Content-Type: %s)", ct) } if opts.OutputPath != "" { - return saveAndPrint(resp, opts.OutputPath, opts.Out) + return saveAndPrint(opts.FileIO, resp, opts.OutputPath, opts.Out) } // No --output: auto-save with derived filename. - meta, err := SaveResponse(resp, ResolveFilename(resp)) + meta, err := SaveResponse(opts.FileIO, resp, ResolveFilename(resp)) if err != nil { return output.Errorf(output.ExitInternal, "file_error", "%s", err) } @@ -88,8 +88,8 @@ func HandleResponse(resp *larkcore.ApiResp, opts ResponseOptions) error { return nil } -func saveAndPrint(resp *larkcore.ApiResp, path string, w io.Writer) error { - meta, err := SaveResponse(resp, path) +func saveAndPrint(fio fileio.FileIO, resp *larkcore.ApiResp, path string, w io.Writer) error { + meta, err := SaveResponse(fio, resp, path) if err != nil { return output.Errorf(output.ExitInternal, "file_error", "%s", err) } @@ -119,23 +119,30 @@ func ParseJSONResponse(resp *larkcore.ApiResp) (interface{}, error) { // ── File saving ── // SaveResponse writes an API response body to the given outputPath and returns metadata. -func SaveResponse(resp *larkcore.ApiResp, outputPath string) (map[string]interface{}, error) { - safePath, err := validate.SafeOutputPath(outputPath) +// It delegates to FileIO.Save for path validation and atomic write; fio must not be nil. +func SaveResponse(fio fileio.FileIO, resp *larkcore.ApiResp, outputPath string) (map[string]interface{}, error) { + result, err := fio.Save(outputPath, fileio.SaveOptions{ + ContentType: resp.Header.Get("Content-Type"), + ContentLength: int64(len(resp.RawBody)), + }, bytes.NewReader(resp.RawBody)) if err != nil { - return nil, fmt.Errorf("unsafe output path: %s", err) - } - - if err := vfs.MkdirAll(filepath.Dir(safePath), 0700); err != nil { - return nil, fmt.Errorf("create directory: %s", err) + switch { + case errors.Is(err, fileio.ErrPathValidation): + return nil, fmt.Errorf("unsafe output path: %w", err) + case errors.Is(err, fileio.ErrMkdir): + return nil, fmt.Errorf("create directory: %w", err) + default: + return nil, fmt.Errorf("cannot write file: %w", err) + } } - if err := validate.AtomicWrite(safePath, resp.RawBody, 0644); err != nil { - return nil, fmt.Errorf("cannot write file: %s", err) + resolvedPath, err := fio.ResolvePath(outputPath) + if err != nil || resolvedPath == "" { + resolvedPath = outputPath } - return map[string]interface{}{ - "saved_path": safePath, - "size_bytes": len(resp.RawBody), + "saved_path": resolvedPath, + "size_bytes": result.Size(), "content_type": resp.Header.Get("Content-Type"), }, nil } diff --git a/internal/client/response_test.go b/internal/client/response_test.go index 0de09f97..edc0fa98 100644 --- a/internal/client/response_test.go +++ b/internal/client/response_test.go @@ -15,6 +15,7 @@ import ( larkcore "github.com/larksuite/oapi-sdk-go/v3/core" "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/internal/vfs/localfileio" ) func newApiResp(body []byte, headers map[string]string) *larkcore.ApiResp { @@ -150,11 +151,11 @@ func TestSaveResponse(t *testing.T) { body := []byte("hello binary data") resp := newApiResp(body, map[string]string{"Content-Type": "application/octet-stream"}) - meta, err := SaveResponse(resp, "test_output.bin") + meta, err := SaveResponse(&localfileio.LocalFileIO{}, resp, "test_output.bin") if err != nil { t.Fatalf("SaveResponse failed: %v", err) } - if meta["size_bytes"] != len(body) { + if meta["size_bytes"] != int64(len(body)) { t.Errorf("expected size_bytes=%d, got %v", len(body), meta["size_bytes"]) } @@ -176,7 +177,7 @@ func TestSaveResponse_CreatesDir(t *testing.T) { resp := newApiResp([]byte("data"), map[string]string{"Content-Type": "application/octet-stream"}) - meta, err := SaveResponse(resp, filepath.Join("sub", "deep", "out.bin")) + meta, err := SaveResponse(&localfileio.LocalFileIO{}, resp, filepath.Join("sub", "deep", "out.bin")) if err != nil { t.Fatalf("SaveResponse with nested dir failed: %v", err) } @@ -195,6 +196,7 @@ func TestHandleResponse_JSON(t *testing.T) { err := HandleResponse(resp, ResponseOptions{ Out: &out, ErrOut: &errOut, + FileIO: &localfileio.LocalFileIO{}, }) if err != nil { t.Fatalf("HandleResponse failed: %v", err) @@ -213,6 +215,7 @@ func TestHandleResponse_JSONWithError(t *testing.T) { err := HandleResponse(resp, ResponseOptions{ Out: &out, ErrOut: &errOut, + FileIO: &localfileio.LocalFileIO{}, }) if err == nil { t.Error("expected error for non-zero code") @@ -232,6 +235,7 @@ func TestHandleResponse_BinaryAutoSave(t *testing.T) { err := HandleResponse(resp, ResponseOptions{ Out: &out, ErrOut: &errOut, + FileIO: &localfileio.LocalFileIO{}, }) if err != nil { t.Fatalf("HandleResponse binary failed: %v", err) @@ -255,6 +259,7 @@ func TestHandleResponse_BinaryWithOutput(t *testing.T) { OutputPath: "out.png", Out: &out, ErrOut: &errOut, + FileIO: &localfileio.LocalFileIO{}, }) if err != nil { t.Fatalf("HandleResponse with output path failed: %v", err) @@ -269,7 +274,7 @@ func TestHandleResponse_NonJSONError_404(t *testing.T) { resp := newApiRespWithStatus(404, []byte("404 page not found"), map[string]string{"Content-Type": "text/plain"}) var out, errOut bytes.Buffer - err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut}) + err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut, FileIO: &localfileio.LocalFileIO{}}) if err == nil { t.Fatal("expected error for 404 text/plain") } @@ -287,7 +292,7 @@ func TestHandleResponse_NonJSONError_502(t *testing.T) { resp := newApiRespWithStatus(502, []byte("Bad Gateway"), map[string]string{"Content-Type": "text/html"}) var out, errOut bytes.Buffer - err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut}) + err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut, FileIO: &localfileio.LocalFileIO{}}) if err == nil { t.Fatal("expected error for 502 text/html") } @@ -310,7 +315,7 @@ func TestHandleResponse_200TextPlain_SavesFile(t *testing.T) { resp := newApiRespWithStatus(200, []byte("plain text file content"), map[string]string{"Content-Type": "text/plain"}) var out, errOut bytes.Buffer - err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut}) + err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut, FileIO: &localfileio.LocalFileIO{}}) if err != nil { t.Fatalf("expected no error for 200 text/plain, got: %v", err) } @@ -336,12 +341,53 @@ func TestHandleResponse_BinaryWithJq_RejectsNonJSON(t *testing.T) { } } +func TestSaveResponse_RejectsPathTraversal(t *testing.T) { + dir := t.TempDir() + origWd, _ := os.Getwd() + os.Chdir(dir) + defer os.Chdir(origWd) + + resp := newApiResp([]byte("data"), map[string]string{"Content-Type": "application/octet-stream"}) + _, err := SaveResponse(&localfileio.LocalFileIO{}, resp, "../../evil.txt") + if err == nil { + t.Fatal("expected error for path traversal") + } + if !strings.Contains(err.Error(), "unsafe output path") { + t.Errorf("expected 'unsafe output path' wrapper, got: %v", err) + } +} + +func TestSaveResponse_RejectsAbsolutePath(t *testing.T) { + resp := newApiResp([]byte("data"), map[string]string{"Content-Type": "application/octet-stream"}) + _, err := SaveResponse(&localfileio.LocalFileIO{}, resp, "/tmp/evil.txt") + if err == nil { + t.Fatal("expected error for absolute path") + } +} + +func TestSaveResponse_MetadataContainsAbsolutePath(t *testing.T) { + dir := t.TempDir() + origWd, _ := os.Getwd() + os.Chdir(dir) + defer os.Chdir(origWd) + + resp := newApiResp([]byte("x"), map[string]string{"Content-Type": "text/plain"}) + meta, err := SaveResponse(&localfileio.LocalFileIO{}, resp, "rel.txt") + if err != nil { + t.Fatalf("SaveResponse failed: %v", err) + } + savedPath, _ := meta["saved_path"].(string) + if !filepath.IsAbs(savedPath) { + t.Errorf("saved_path should be absolute, got %q", savedPath) + } +} + func TestHandleResponse_403JSON_CheckLarkResponse(t *testing.T) { body := []byte(`{"code":99991400,"msg":"invalid token"}`) resp := newApiRespWithStatus(403, body, map[string]string{"Content-Type": "application/json"}) var out, errOut bytes.Buffer - err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut}) + err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut, FileIO: &localfileio.LocalFileIO{}}) if err == nil { t.Fatal("expected error for 403 JSON with non-zero code") } diff --git a/internal/cmdutil/factory_default_test.go b/internal/cmdutil/factory_default_test.go index 9d01e82f..f30b622e 100644 --- a/internal/cmdutil/factory_default_test.go +++ b/internal/cmdutil/factory_default_test.go @@ -11,13 +11,26 @@ import ( "testing" _ "github.com/larksuite/cli/extension/credential/env" + "github.com/larksuite/cli/extension/fileio" exttransport "github.com/larksuite/cli/extension/transport" internalauth "github.com/larksuite/cli/internal/auth" "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/credential" "github.com/larksuite/cli/internal/envvars" + "github.com/larksuite/cli/internal/vfs/localfileio" ) +type countingFileIOProvider struct { + resolveCalls int +} + +func (p *countingFileIOProvider) Name() string { return "counting" } + +func (p *countingFileIOProvider) ResolveFileIO(context.Context) fileio.FileIO { + p.resolveCalls++ + return &localfileio.LocalFileIO{} +} + func TestNewDefault_InvocationProfileUsedByStrictModeAndConfig(t *testing.T) { t.Setenv(envvars.CliAppID, "") t.Setenv(envvars.CliAppSecret, "") @@ -198,6 +211,27 @@ func TestNewDefault_ConfigUsesRuntimePlaceholderForTokenOnlyEnvAccount(t *testin } } +func TestNewDefault_FileIOProviderDoesNotResolveDuringInitialization(t *testing.T) { + prev := fileio.GetProvider() + provider := &countingFileIOProvider{} + fileio.Register(provider) + t.Cleanup(func() { fileio.Register(prev) }) + + f := NewDefault(InvocationContext{}) + if f.FileIOProvider != provider { + t.Fatalf("NewDefault() provider = %T, want %T", f.FileIOProvider, provider) + } + if provider.resolveCalls != 0 { + t.Fatalf("ResolveFileIO() calls after NewDefault() = %d, want 0", provider.resolveCalls) + } + + if got := f.ResolveFileIO(context.Background()); got == nil { + t.Fatal("ResolveFileIO() = nil, want non-nil") + } + if provider.resolveCalls != 1 { + t.Fatalf("ResolveFileIO() calls after explicit resolve = %d, want 1", provider.resolveCalls) + } +} type stubTransportProvider struct { interceptor exttransport.Interceptor } diff --git a/internal/vfs/localfileio/atomicwrite_test.go b/internal/vfs/localfileio/atomicwrite_test.go new file mode 100644 index 00000000..d8dbbb75 --- /dev/null +++ b/internal/vfs/localfileio/atomicwrite_test.go @@ -0,0 +1,146 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package localfileio + +import ( + "os" + "path/filepath" + "runtime" + "sync" + "testing" +) + +func TestAtomicWrite_WritesContentAndPermissionCorrectly(t *testing.T) { + // GIVEN: a target path in a temp directory + dir := t.TempDir() + path := filepath.Join(dir, "test.json") + data := []byte(`{"key":"value"}`) + + // WHEN: AtomicWrite writes data with 0644 permission + if err := AtomicWrite(path, data, 0644); err != nil { + t.Fatalf("AtomicWrite failed: %v", err) + } + + // THEN: file content matches exactly + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile failed: %v", err) + } + if string(got) != string(data) { + t.Errorf("content = %q, want %q", got, data) + } +} + +func TestAtomicWrite_SetsRestrictivePermission(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("permission test not reliable on Windows") + } + + // GIVEN: a target path + dir := t.TempDir() + path := filepath.Join(dir, "secret.json") + + // WHEN: AtomicWrite writes with 0600 permission + if err := AtomicWrite(path, []byte("secret"), 0600); err != nil { + t.Fatalf("AtomicWrite failed: %v", err) + } + + // THEN: file permission is exactly 0600 (owner read-write only) + info, _ := os.Stat(path) + if perm := info.Mode().Perm(); perm != 0600 { + t.Errorf("permission = %04o, want 0600", perm) + } +} + +func TestAtomicWrite_OverwritesExistingFile(t *testing.T) { + // GIVEN: an existing file with old content + dir := t.TempDir() + path := filepath.Join(dir, "test.json") + AtomicWrite(path, []byte("old"), 0644) + + // WHEN: AtomicWrite overwrites with new content + if err := AtomicWrite(path, []byte("new"), 0644); err != nil { + t.Fatalf("second write failed: %v", err) + } + + // THEN: file contains new content + got, _ := os.ReadFile(path) + if string(got) != "new" { + t.Errorf("content = %q, want %q", got, "new") + } +} + +func TestAtomicWrite_LeavesNoResidualTempFileOnError(t *testing.T) { + // GIVEN: a target path in a non-existent nested directory + path := filepath.Join(t.TempDir(), "nonexistent", "subdir", "file.txt") + + // WHEN: AtomicWrite fails (parent directory doesn't exist) + err := AtomicWrite(path, []byte("data"), 0644) + + // THEN: the write fails + if err == nil { + t.Fatal("expected error writing to nonexistent dir") + } + + // THEN: no .tmp files are left behind + parentDir := filepath.Dir(filepath.Dir(path)) + entries, _ := os.ReadDir(parentDir) + for _, e := range entries { + if filepath.Ext(e.Name()) == ".tmp" { + t.Errorf("residual temp file found: %s", e.Name()) + } + } +} + +func TestAtomicWrite_PreservesOriginalFileOnFailure(t *testing.T) { + // GIVEN: an existing file with known content + dir := t.TempDir() + original := []byte("original content") + path := filepath.Join(dir, "file.json") + if err := AtomicWrite(path, original, 0644); err != nil { + t.Fatal(err) + } + + // WHEN: AtomicWrite targets a non-existent directory (guaranteed to fail even as root) + badPath := filepath.Join(dir, "no", "such", "dir", "file.json") + err := AtomicWrite(badPath, []byte("new"), 0644) + + // THEN: write fails + if err == nil { + t.Fatal("expected error writing to non-existent dir") + } + + // THEN: the original file at the valid path is untouched + got, _ := os.ReadFile(path) + if string(got) != string(original) { + t.Errorf("original file corrupted: got %q, want %q", got, original) + } +} + +func TestAtomicWrite_HandlesCorrectlyUnderConcurrentWrites(t *testing.T) { + // GIVEN: a target file that will be written by 20 concurrent goroutines + dir := t.TempDir() + path := filepath.Join(dir, "concurrent.json") + + // WHEN: 20 goroutines write simultaneously + var wg sync.WaitGroup + for i := range 20 { + wg.Add(1) + go func(n int) { + defer wg.Done() + data := []byte(`{"n":` + string(rune('0'+n%10)) + `}`) + AtomicWrite(path, data, 0644) + }(i) + } + wg.Wait() + + // THEN: file exists and is valid (not corrupted by interleaved writes) + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile failed: %v", err) + } + if len(got) == 0 { + t.Error("file is empty after concurrent writes") + } +} diff --git a/internal/vfs/localfileio/localfileio.go b/internal/vfs/localfileio/localfileio.go index 5c712c23..b4318a3d 100644 --- a/internal/vfs/localfileio/localfileio.go +++ b/internal/vfs/localfileio/localfileio.go @@ -34,7 +34,7 @@ type LocalFileIO struct{} func (l *LocalFileIO) Open(name string) (fileio.File, error) { safePath, err := SafeInputPath(name) if err != nil { - return nil, err + return nil, &fileio.PathValidationError{Err: err} } return vfs.Open(safePath) } @@ -43,7 +43,7 @@ func (l *LocalFileIO) Open(name string) (fileio.File, error) { func (l *LocalFileIO) Stat(name string) (fileio.FileInfo, error) { safePath, err := SafeInputPath(name) if err != nil { - return nil, err + return nil, &fileio.PathValidationError{Err: err} } return vfs.Stat(safePath) } @@ -55,7 +55,11 @@ func (r *saveResult) Size() int64 { return r.size } // ResolvePath returns the validated absolute path for the given output path. func (l *LocalFileIO) ResolvePath(path string) (string, error) { - return SafeOutputPath(path) + resolved, err := SafeOutputPath(path) + if err != nil { + return "", &fileio.PathValidationError{Err: err} + } + return resolved, nil } // Save writes body to path atomically after validating the output path. @@ -64,10 +68,10 @@ func (l *LocalFileIO) ResolvePath(path string) (string, error) { func (l *LocalFileIO) Save(path string, _ fileio.SaveOptions, body io.Reader) (fileio.SaveResult, error) { safePath, err := SafeOutputPath(path) if err != nil { - return nil, err + return nil, &fileio.PathValidationError{Err: err} } if err := vfs.MkdirAll(filepath.Dir(safePath), 0700); err != nil { - return nil, err + return nil, &fileio.MkdirError{Err: err} } n, err := AtomicWriteFromReader(safePath, body, 0600) if err != nil { diff --git a/internal/vfs/localfileio/localfileio_test.go b/internal/vfs/localfileio/localfileio_test.go new file mode 100644 index 00000000..3cb0efd3 --- /dev/null +++ b/internal/vfs/localfileio/localfileio_test.go @@ -0,0 +1,306 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package localfileio + +import ( + "io" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/larksuite/cli/extension/fileio" +) + +// testChdir temporarily changes the working directory for a test. +// Not compatible with t.Parallel(). +func testChdir(t *testing.T, dir string) { + t.Helper() + orig, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + if err := os.Chdir(dir); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { os.Chdir(orig) }) +} + +// ── Provider ── + +func TestProvider_Name(t *testing.T) { + p := &Provider{} + if got := p.Name(); got != "local" { + t.Errorf("Provider.Name() = %q, want %q", got, "local") + } +} + +func TestProvider_ResolveFileIO(t *testing.T) { + p := &Provider{} + fio := p.ResolveFileIO(nil) + if fio == nil { + t.Fatal("Provider.ResolveFileIO returned nil") + } + if _, ok := fio.(*LocalFileIO); !ok { + t.Errorf("expected *LocalFileIO, got %T", fio) + } +} + +// ── Open ── + +func TestLocalFileIO_Open_ValidFile(t *testing.T) { + dir := t.TempDir() + testChdir(t, dir) + + content := []byte("hello world") + os.WriteFile("test.txt", content, 0644) + + fio := &LocalFileIO{} + f, err := fio.Open("test.txt") + if err != nil { + t.Fatalf("Open failed: %v", err) + } + defer f.Close() + + got, err := io.ReadAll(f) + if err != nil { + t.Fatalf("ReadAll failed: %v", err) + } + if string(got) != string(content) { + t.Errorf("content = %q, want %q", got, content) + } +} + +func TestLocalFileIO_Open_RejectsTraversal(t *testing.T) { + dir := t.TempDir() + testChdir(t, dir) + + fio := &LocalFileIO{} + _, err := fio.Open("../../etc/passwd") + if err == nil { + t.Error("expected error for path traversal") + } +} + +func TestLocalFileIO_Open_RejectsAbsolutePath(t *testing.T) { + fio := &LocalFileIO{} + _, err := fio.Open("/etc/passwd") + if err == nil { + t.Error("expected error for absolute path") + } + if err != nil && !strings.Contains(err.Error(), "relative path") { + t.Errorf("error should mention relative path, got: %v", err) + } +} + +func TestLocalFileIO_Open_NonexistentFile(t *testing.T) { + dir := t.TempDir() + testChdir(t, dir) + + fio := &LocalFileIO{} + _, err := fio.Open("nonexistent.txt") + if err == nil { + t.Error("expected error for nonexistent file") + } +} + +// ── Stat ── + +func TestLocalFileIO_Stat_ValidFile(t *testing.T) { + dir := t.TempDir() + testChdir(t, dir) + + os.WriteFile("stat.txt", []byte("12345"), 0644) + + fio := &LocalFileIO{} + info, err := fio.Stat("stat.txt") + if err != nil { + t.Fatalf("Stat failed: %v", err) + } + if info.Size() != 5 { + t.Errorf("Size() = %d, want 5", info.Size()) + } + if info.IsDir() { + t.Error("expected IsDir() = false") + } +} + +func TestLocalFileIO_Stat_RejectsTraversal(t *testing.T) { + dir := t.TempDir() + testChdir(t, dir) + + fio := &LocalFileIO{} + _, err := fio.Stat("../../etc/passwd") + if err == nil { + t.Error("expected error for path traversal") + } + if err != nil && os.IsNotExist(err) { + t.Error("traversal should not be os.IsNotExist, should be a validation error") + } +} + +func TestLocalFileIO_Stat_NonexistentReturnsIsNotExist(t *testing.T) { + dir := t.TempDir() + testChdir(t, dir) + + fio := &LocalFileIO{} + _, err := fio.Stat("nope.txt") + if err == nil { + t.Error("expected error for nonexistent file") + } + if !os.IsNotExist(err) { + t.Errorf("expected os.IsNotExist, got: %v", err) + } +} + +// ── Save ── + +func TestLocalFileIO_Save_WritesContent(t *testing.T) { + dir := t.TempDir() + testChdir(t, dir) + + fio := &LocalFileIO{} + body := strings.NewReader("saved content") + result, err := fio.Save("output.bin", fileio.SaveOptions{}, body) + if err != nil { + t.Fatalf("Save failed: %v", err) + } + if result.Size() != int64(len("saved content")) { + t.Errorf("Size() = %d, want %d", result.Size(), len("saved content")) + } + + got, _ := os.ReadFile(filepath.Join(dir, "output.bin")) + if string(got) != "saved content" { + t.Errorf("file content = %q, want %q", got, "saved content") + } +} + +func TestLocalFileIO_Save_CreatesParentDirs(t *testing.T) { + dir := t.TempDir() + testChdir(t, dir) + + fio := &LocalFileIO{} + body := strings.NewReader("nested") + _, err := fio.Save(filepath.Join("a", "b", "c.txt"), fileio.SaveOptions{}, body) + if err != nil { + t.Fatalf("Save with nested dir failed: %v", err) + } + + got, _ := os.ReadFile(filepath.Join(dir, "a", "b", "c.txt")) + if string(got) != "nested" { + t.Errorf("file content = %q, want %q", got, "nested") + } +} + +func TestLocalFileIO_Save_RejectsTraversal(t *testing.T) { + dir := t.TempDir() + testChdir(t, dir) + + fio := &LocalFileIO{} + _, err := fio.Save("../../evil.txt", fileio.SaveOptions{}, strings.NewReader("bad")) + if err == nil { + t.Error("expected error for path traversal in Save") + } +} + +func TestLocalFileIO_Save_RejectsAbsolutePath(t *testing.T) { + fio := &LocalFileIO{} + _, err := fio.Save("/tmp/evil.txt", fileio.SaveOptions{}, strings.NewReader("bad")) + if err == nil { + t.Error("expected error for absolute path in Save") + } +} + +// ── ResolvePath ── + +func TestLocalFileIO_ResolvePath_ReturnsAbsolute(t *testing.T) { + dir := t.TempDir() + testChdir(t, dir) + + fio := &LocalFileIO{} + resolved, err := fio.ResolvePath("file.txt") + if err != nil { + t.Fatalf("ResolvePath failed: %v", err) + } + if !filepath.IsAbs(resolved) { + t.Errorf("expected absolute path, got %q", resolved) + } + if filepath.Base(resolved) != "file.txt" { + t.Errorf("expected base name file.txt, got %q", filepath.Base(resolved)) + } +} + +func TestLocalFileIO_ResolvePath_RejectsTraversal(t *testing.T) { + dir := t.TempDir() + testChdir(t, dir) + + fio := &LocalFileIO{} + _, err := fio.ResolvePath("../../etc/passwd") + if err == nil { + t.Error("expected error for path traversal in ResolvePath") + } +} + +func TestLocalFileIO_ResolvePath_RejectsAbsolute(t *testing.T) { + fio := &LocalFileIO{} + _, err := fio.ResolvePath("/etc/passwd") + if err == nil { + t.Error("expected error for absolute path in ResolvePath") + } +} + +// ── Error message consistency ── + +func TestLocalFileIO_ErrorMessages_ContainCorrectFlagName(t *testing.T) { + dir := t.TempDir() + testChdir(t, dir) + + fio := &LocalFileIO{} + + // Open/Stat use SafeInputPath → errors should mention "--file" + _, err := fio.Open("/absolute/path") + if err == nil || !strings.Contains(err.Error(), "--file") { + t.Errorf("Open absolute path error should mention --file, got: %v", err) + } + + _, err = fio.Stat("/absolute/path") + if err == nil || !strings.Contains(err.Error(), "--file") { + t.Errorf("Stat absolute path error should mention --file, got: %v", err) + } + + // Save/ResolvePath use SafeOutputPath → errors should mention "--output" + _, err = fio.Save("/absolute/path", fileio.SaveOptions{}, strings.NewReader("")) + if err == nil || !strings.Contains(err.Error(), "--output") { + t.Errorf("Save absolute path error should mention --output, got: %v", err) + } + + _, err = fio.ResolvePath("/absolute/path") + if err == nil || !strings.Contains(err.Error(), "--output") { + t.Errorf("ResolvePath absolute path error should mention --output, got: %v", err) + } +} + +// ── Control character / Unicode rejection ── + +func TestLocalFileIO_RejectsControlCharsInPath(t *testing.T) { + dir := t.TempDir() + testChdir(t, dir) + + fio := &LocalFileIO{} + paths := []string{ + "file\x00name.txt", // null byte + "file\x1fname.txt", // control char + "file\u200Bname.txt", // zero-width space + "file\u202Ename.txt", // bidi override + } + + for _, p := range paths { + if _, err := fio.Open(p); err == nil { + t.Errorf("Open(%q) should reject control/dangerous chars", p) + } + if _, err := fio.Save(p, fileio.SaveOptions{}, strings.NewReader("")); err == nil { + t.Errorf("Save(%q) should reject control/dangerous chars", p) + } + } +} diff --git a/internal/vfs/localfileio/path_test.go b/internal/vfs/localfileio/path_test.go new file mode 100644 index 00000000..946d1be1 --- /dev/null +++ b/internal/vfs/localfileio/path_test.go @@ -0,0 +1,245 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package localfileio + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestSafeOutputPath_RejectsPathTraversalAndDangerousInput(t *testing.T) { + for _, tt := range []struct { + name string + input string + wantErr bool + }{ + // ── GIVEN: normal relative paths → THEN: allowed ── + {"normal file", "report.xlsx", false}, + {"subdir file", "output/report.xlsx", false}, + {"current dir explicit", "./file.txt", false}, + {"nested subdir", "a/b/c/file.txt", false}, + {"dot in name", "my.report.v2.xlsx", false}, + {"space in name", "my file.txt", false}, + {"unicode normal", "报告.xlsx", false}, + {"dot-dot resolves to cwd", "subdir/..", false}, + + // ── GIVEN: path traversal via .. → THEN: rejected ── + {"dot-dot escape", "../../.ssh/authorized_keys", true}, + {"dot-dot mid path", "subdir/../../etc/passwd", true}, + {"triple dot-dot", "../../../etc/shadow", true}, + + // ── GIVEN: absolute paths → THEN: rejected ── + {"absolute path unix", "/etc/passwd", true}, + {"absolute path root", "/tmp/evil", true}, + + // ── GIVEN: control characters in path → THEN: rejected ── + {"null byte", "file\x00.txt", true}, + {"carriage return", "file\r.txt", true}, + {"bell char", "file\x07.txt", true}, + + // ── GIVEN: dangerous Unicode in path → THEN: rejected ── + {"bidi RLO", "file\u202Ename.txt", true}, + {"zero width space", "file\u200Bname.txt", true}, + {"BOM char", "file\uFEFFname.txt", true}, + {"line separator", "file\u2028name.txt", true}, + {"bidi LRI", "file\u2066name.txt", true}, + + // ── GIVEN: looks dangerous but is actually safe → THEN: allowed ── + {"literal percent 2e", "%2e%2e/etc/passwd", false}, + {"tilde path", "~/file.txt", false}, + } { + t.Run(tt.name, func(t *testing.T) { + // WHEN: SafeOutputPath validates the path + _, err := SafeOutputPath(tt.input) + + // THEN: error matches expectation + if (err != nil) != tt.wantErr { + t.Errorf("SafeOutputPath(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestSafeOutputPath_ReturnsCanonicalAbsolutePath(t *testing.T) { + // GIVEN: a clean temp directory as CWD + dir := t.TempDir() + dir, _ = filepath.EvalSymlinks(dir) + origDir, _ := os.Getwd() + defer os.Chdir(origDir) + os.Chdir(dir) + + // WHEN: SafeOutputPath validates a relative path + got, err := SafeOutputPath("output/file.txt") + + // THEN: returns the canonical absolute path for subsequent I/O + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + want := filepath.Join(dir, "output", "file.txt") + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func TestSafeOutputPath_RejectsSymlinkEscapingCWD(t *testing.T) { + // GIVEN: a symlink in CWD pointing to /etc (outside CWD) + dir := t.TempDir() + dir, _ = filepath.EvalSymlinks(dir) + origDir, _ := os.Getwd() + defer os.Chdir(origDir) + os.Chdir(dir) + os.Symlink("/etc", filepath.Join(dir, "link-to-etc")) + + // WHEN: SafeOutputPath validates a path through the symlink + _, err := SafeOutputPath("link-to-etc/passwd") + + // THEN: rejected because the resolved path is outside CWD + if err == nil { + t.Error("expected error for symlink escaping CWD, got nil") + } +} + +func TestSafeOutputPath_AllowsSymlinkWithinCWD(t *testing.T) { + // GIVEN: a symlink in CWD pointing to a subdirectory within CWD + dir := t.TempDir() + dir, _ = filepath.EvalSymlinks(dir) + origDir, _ := os.Getwd() + defer os.Chdir(origDir) + os.Chdir(dir) + os.MkdirAll(filepath.Join(dir, "real"), 0755) + os.Symlink(filepath.Join(dir, "real"), filepath.Join(dir, "link")) + + // WHEN: SafeOutputPath validates a path through the internal symlink + got, err := SafeOutputPath("link/file.txt") + + // THEN: allowed, resolved to the real path within CWD + if err != nil { + t.Fatalf("symlink within CWD should be allowed: %v", err) + } + want := filepath.Join(dir, "real", "file.txt") + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func TestSafeOutputPath_ResolvesAncestorSymlinkWhenParentMissing(t *testing.T) { + // GIVEN: CWD contains a symlink "escape" → /etc, and the target path + // goes through "escape/sub/file.txt" where "sub" does not exist. + // The old code failed to resolve the symlink because the immediate + // parent ("escape/sub") didn't exist, leaving resolved un-anchored. + dir := t.TempDir() + dir, _ = filepath.EvalSymlinks(dir) + origDir, _ := os.Getwd() + defer os.Chdir(origDir) + os.Chdir(dir) + os.Symlink("/etc", filepath.Join(dir, "escape")) + + // WHEN: SafeOutputPath validates a path through the symlink with missing intermediate dirs + _, err := SafeOutputPath("escape/nonexistent/file.txt") + + // THEN: rejected — the resolved path is under /etc, outside CWD + if err == nil { + t.Error("expected error for symlink escaping CWD via non-existent parent, got nil") + } +} + +func TestSafeOutputPath_DeepNonExistentPathStaysInCWD(t *testing.T) { + // GIVEN: a deeply nested non-existent path with no symlinks + dir := t.TempDir() + dir, _ = filepath.EvalSymlinks(dir) + origDir, _ := os.Getwd() + defer os.Chdir(origDir) + os.Chdir(dir) + + // WHEN: SafeOutputPath validates "a/b/c/d/file.txt" (none of a/b/c/d exist) + got, err := SafeOutputPath("a/b/c/d/file.txt") + + // THEN: allowed, resolved to canonical path under CWD + if err != nil { + t.Fatalf("deep non-existent path within CWD should be allowed: %v", err) + } + want := filepath.Join(dir, "a", "b", "c", "d", "file.txt") + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func TestSafeUploadPath_AllowsTempFileAbsolutePath(t *testing.T) { + // GIVEN: a real temp file (absolute path under os.TempDir()) + f, err := os.CreateTemp("", "upload-test-*.bin") + if err != nil { + t.Fatalf("CreateTemp: %v", err) + } + tmpPath := f.Name() + f.Close() + t.Cleanup(func() { os.Remove(tmpPath) }) + + // WHEN: SafeUploadPath validates the absolute temp path + _, err = SafeInputPath(tmpPath) + + // THEN: absolute paths are rejected even in temp dir + if err == nil { + t.Fatal("expected error for absolute temp path, got nil") + } +} + +func TestSafeUploadPath_RejectsNonTempAbsolutePath(t *testing.T) { + // GIVEN: an absolute path outside the temp directory + // WHEN / THEN: SafeUploadPath rejects it + _, err := SafeInputPath("/etc/passwd") + if err == nil { + t.Error("expected error for absolute non-temp path, got nil") + } +} + +func TestSafeUploadPath_AcceptsRelativePath(t *testing.T) { + // GIVEN: a clean temp CWD with a real file + dir := t.TempDir() + dir, _ = filepath.EvalSymlinks(dir) + orig, _ := os.Getwd() + defer os.Chdir(orig) + os.Chdir(dir) + + os.WriteFile(filepath.Join(dir, "upload.bin"), []byte("data"), 0600) + + // WHEN: SafeUploadPath validates a relative path to an existing file + got, err := SafeInputPath("upload.bin") + + // THEN: accepted and returned as absolute canonical path + if err != nil { + t.Fatalf("SafeUploadPath(relative) error = %v", err) + } + want := filepath.Join(dir, "upload.bin") + if got != want { + t.Errorf("SafeUploadPath(relative) = %q, want %q", got, want) + } +} + +func Test_SafeInputPath_ErrorMessageContainsCorrectFlagName(t *testing.T) { + // GIVEN: an absolute path + + // WHEN: SafeInputPath rejects it + _, err := SafeInputPath("/etc/passwd") + + // THEN: error message mentions --file (not --output) + if err == nil { + t.Fatal("expected error for absolute path") + } + if !strings.Contains(err.Error(), "--file") { + t.Errorf("error should mention --file, got: %s", err.Error()) + } + + // WHEN: SafeOutputPath rejects it + _, err = SafeOutputPath("/etc/passwd") + + // THEN: error message mentions --output (not --file) + if err == nil { + t.Fatal("expected error for absolute path") + } + if !strings.Contains(err.Error(), "--output") { + t.Errorf("error should mention --output, got: %s", err.Error()) + } +} diff --git a/shortcuts/common/runner.go b/shortcuts/common/runner.go index 6eef2b45..b431aaca 100644 --- a/shortcuts/common/runner.go +++ b/shortcuts/common/runner.go @@ -25,8 +25,6 @@ import ( "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/credential" "github.com/larksuite/cli/internal/output" - "github.com/larksuite/cli/internal/validate" - "github.com/larksuite/cli/internal/vfs" "github.com/spf13/cobra" ) @@ -335,6 +333,35 @@ func (ctx *RuntimeContext) ResolveSavePath(path string) (string, error) { return resolved, nil } +// WrapSaveError matches a FileIO.Save error against known categories and wraps +// it with the caller-provided message prefix, preserving backward-compatible +// error text per shortcut. +func WrapSaveError(err error, pathMsg, mkdirMsg, writeMsg string) error { + if err == nil { + return nil + } + switch { + case errors.Is(err, fileio.ErrPathValidation): + return fmt.Errorf("%s: %w", pathMsg, err) + case errors.Is(err, fileio.ErrMkdir): + return fmt.Errorf("%s: %w", mkdirMsg, err) + default: + return fmt.Errorf("%s: %w", writeMsg, err) + } +} + +// WrapOpenError matches a FileIO.Open/Stat error and wraps it with the +// caller-provided message prefix. +func WrapOpenError(err error, pathMsg, readMsg string) error { + if err == nil { + return nil + } + if errors.Is(err, fileio.ErrPathValidation) { + return fmt.Errorf("%s: %w", pathMsg, err) + } + return fmt.Errorf("%s: %w", readMsg, err) +} + // ValidatePath checks that path is a valid relative input path within the // working directory by delegating to FileIO.Stat. Returns nil if the path is // valid or does not exist yet; returns an error only for illegal paths @@ -633,11 +660,15 @@ func resolveInputFlags(rctx *RuntimeContext, flags []Flag) error { if path == "" { return FlagErrorf("--%s: file path cannot be empty after @", fl.Name) } - safePath, err := validate.SafeInputPath(path) + f, err := rctx.FileIO().Open(path) if err != nil { - return FlagErrorf("--%s: invalid file path %q: %v", fl.Name, path, err) + if errors.Is(err, fileio.ErrPathValidation) { + return FlagErrorf("--%s: invalid file path %q: %v", fl.Name, path, err) + } + return FlagErrorf("--%s: cannot read file %q: %v", fl.Name, path, err) } - data, err := vfs.ReadFile(safePath) + data, err := io.ReadAll(f) + f.Close() if err != nil { return FlagErrorf("--%s: cannot read file %q: %v", fl.Name, path, err) } diff --git a/shortcuts/common/runner_input_test.go b/shortcuts/common/runner_input_test.go index 25aa806b..058e2870 100644 --- a/shortcuts/common/runner_input_test.go +++ b/shortcuts/common/runner_input_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/larksuite/cli/internal/cmdutil" + _ "github.com/larksuite/cli/internal/vfs/localfileio" "github.com/spf13/cobra" ) diff --git a/shortcuts/common/runner_jq_test.go b/shortcuts/common/runner_jq_test.go index cce144f4..17af83dc 100644 --- a/shortcuts/common/runner_jq_test.go +++ b/shortcuts/common/runner_jq_test.go @@ -13,6 +13,7 @@ import ( lark "github.com/larksuite/oapi-sdk-go/v3" "github.com/spf13/cobra" + "github.com/larksuite/cli/extension/fileio" "github.com/larksuite/cli/internal/cmdutil" "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/output" @@ -102,6 +103,48 @@ func TestRuntimeContext_Out_WithJq_InvalidExpr_WritesStderr(t *testing.T) { } } +type testResolvedFileIO struct{} + +func (testResolvedFileIO) Open(string) (fileio.File, error) { return nil, nil } +func (testResolvedFileIO) Stat(string) (fileio.FileInfo, error) { return nil, nil } +func (testResolvedFileIO) ResolvePath(path string) (string, error) { return path, nil } +func (testResolvedFileIO) Save(string, fileio.SaveOptions, io.Reader) (fileio.SaveResult, error) { + return nil, nil +} + +type capturingFileIOProvider struct { + gotCtx context.Context + fileIO fileio.FileIO +} + +func (p *capturingFileIOProvider) Name() string { return "capture" } + +func (p *capturingFileIOProvider) ResolveFileIO(ctx context.Context) fileio.FileIO { + p.gotCtx = ctx + return p.fileIO +} + +func TestRuntimeContext_FileIO_UsesExecutionContext(t *testing.T) { + execCtx := context.WithValue(context.Background(), "key", "value") + resolved := testResolvedFileIO{} + provider := &capturingFileIOProvider{fileIO: resolved} + + rctx := &RuntimeContext{ + ctx: execCtx, + Factory: &cmdutil.Factory{ + FileIOProvider: provider, + }, + } + + got := rctx.FileIO() + if got != resolved { + t.Fatalf("FileIO() returned %T, want %T", got, resolved) + } + if provider.gotCtx != execCtx { + t.Fatal("ResolveFileIO() did not receive the runtime execution context") + } +} + func newTestShortcutCmd(s *Shortcut) *cobra.Command { cmd := &cobra.Command{Use: "test-shortcut"} cmd.SetContext(context.Background()) @@ -119,7 +162,8 @@ func newTestFactory() *cmdutil.Factory { LarkClient: func() (*lark.Client, error) { return lark.NewClient("test", "test"), nil }, - IOStreams: &cmdutil.IOStreams{Out: &bytes.Buffer{}, ErrOut: &bytes.Buffer{}}, + IOStreams: &cmdutil.IOStreams{Out: &bytes.Buffer{}, ErrOut: &bytes.Buffer{}}, + FileIOProvider: fileio.GetProvider(), } } diff --git a/shortcuts/im/coverage_additional_test.go b/shortcuts/im/coverage_additional_test.go index ac7f82dc..3500021a 100644 --- a/shortcuts/im/coverage_additional_test.go +++ b/shortcuts/im/coverage_additional_test.go @@ -16,6 +16,8 @@ import ( larkcore "github.com/larksuite/oapi-sdk-go/v3/core" "github.com/spf13/cobra" + + "github.com/larksuite/cli/internal/cmdutil" ) func TestSanitizeURLForDisplay(t *testing.T) { @@ -404,39 +406,36 @@ func TestBuildSearchChatBodyAdditionalBranches(t *testing.T) { func TestParseMediaDurationSuccess(t *testing.T) { t.Run("mp4", func(t *testing.T) { - f, err := os.CreateTemp("", "im-duration-*.mp4") - if err != nil { - t.Fatalf("CreateTemp() error = %v", err) + cmdutil.TestChdir(t, t.TempDir()) + fname := "im-duration-test.mp4" + if err := os.WriteFile(fname, wrapInMoov(buildMvhdBox(0, 1000, 5000)), 0644); err != nil { + t.Fatalf("WriteFile() error = %v", err) } - defer os.Remove(f.Name()) - defer f.Close() - - if _, err := f.Write(wrapInMoov(buildMvhdBox(0, 1000, 5000))); err != nil { - t.Fatalf("Write() error = %v", err) - } - if got := parseMediaDuration(f.Name(), "mp4"); got != "5000" { + rt := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("unexpected") + })) + if got := parseMediaDuration(rt, fname, "mp4"); got != "5000" { t.Fatalf("parseMediaDuration(mp4) = %q, want %q", got, "5000") } }) t.Run("opus", func(t *testing.T) { - f, err := os.CreateTemp("", "im-duration-*.ogg") - if err != nil { - t.Fatalf("CreateTemp() error = %v", err) - } - defer os.Remove(f.Name()) - defer f.Close() - + cmdutil.TestChdir(t, t.TempDir()) page := make([]byte, 27) copy(page[0:4], "OggS") page[5] = 4 page[6] = 0x00 page[7] = 0x53 page[8] = 0x07 - if _, err := f.Write(page); err != nil { - t.Fatalf("Write() error = %v", err) + + fname := "im-duration-test.ogg" + if err := os.WriteFile(fname, page, 0644); err != nil { + t.Fatalf("WriteFile() error = %v", err) } - if got := parseMediaDuration(f.Name(), "opus"); got != "10000" { + rt := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("unexpected") + })) + if got := parseMediaDuration(rt, fname, "opus"); got != "10000" { t.Fatalf("parseMediaDuration(opus) = %q, want %q", got, "10000") } }) diff --git a/shortcuts/im/helpers.go b/shortcuts/im/helpers.go index 57f354a1..50fed673 100644 --- a/shortcuts/im/helpers.go +++ b/shortcuts/im/helpers.go @@ -13,16 +13,15 @@ import ( "math" "net/http" "net/url" - "os" "path" "path/filepath" "regexp" "strconv" "strings" + "github.com/larksuite/cli/extension/fileio" "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/validate" - "github.com/larksuite/cli/internal/vfs" "github.com/larksuite/cli/shortcuts/common" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" "github.com/spf13/cobra" @@ -327,21 +326,16 @@ func resolveURLMedia(ctx context.Context, runtime *common.RuntimeContext, s medi func resolveLocalMedia(ctx context.Context, runtime *common.RuntimeContext, s mediaSpec) (string, error) { fmt.Fprintf(runtime.IO().ErrOut, "uploading %s: %s\n", s.mediaType, filepath.Base(s.value)) - safePath, err := validate.SafeInputPath(s.value) - if err != nil { - return "", err - } - if s.kind == mediaKindImage { - return uploadImageToIM(ctx, runtime, safePath, "message") + return uploadImageToIM(ctx, runtime, s.value, "message") } - ft := detectIMFileType(safePath) + ft := detectIMFileType(s.value) dur := "" if s.withDuration { - dur = parseMediaDuration(safePath, ft) + dur = parseMediaDuration(runtime, s.value, ft) } - return uploadFileToIM(ctx, runtime, safePath, ft, dur) + return uploadFileToIM(ctx, runtime, s.value, ft, dur) } // resolveVideoContent handles the video case which needs both a file_key and @@ -556,18 +550,16 @@ func findMP4Box(data []byte, start, end int, boxType string) (int, int) { // for audio/video uploads. Only reads the minimal portion of the file needed // for parsing (tail for OGG, box headers + moov for MP4). // Returns "" if parsing fails or the file type is not audio/video. -func parseMediaDuration(filePath, fileType string) string { +func parseMediaDuration(runtime *common.RuntimeContext, filePath, fileType string) string { if fileType != "opus" && fileType != "mp4" { return "" } - f, err := vfs.Open(filePath) - if err != nil { + info, err := runtime.FileIO().Stat(filePath) + if err != nil || info.Size() == 0 { return "" } - defer f.Close() - - info, err := f.Stat() - if err != nil || info.Size() == 0 { + f, err := runtime.FileIO().Open(filePath) + if err != nil { return "" } @@ -698,7 +690,7 @@ func readMp4DurationBytes(data []byte) int64 { } // readOggDuration reads the tail of an OGG file (up to 64 KB) and parses duration. -func readOggDuration(f *os.File, fileSize int64) int64 { +func readOggDuration(f fileio.File, fileSize int64) int64 { const maxTail = 65536 readSize := fileSize if readSize > maxTail { @@ -713,7 +705,7 @@ func readOggDuration(f *os.File, fileSize int64) int64 { // readMp4Duration walks top-level MP4 boxes via file seeks to find moov, // then reads only the moov content to locate mvhd and extract the duration. -func readMp4Duration(f *os.File, fileSize int64) int64 { +func readMp4Duration(f fileio.File, fileSize int64) int64 { hdr := make([]byte, 16) var offset int64 for offset+8 <= fileSize { @@ -1005,14 +997,11 @@ const maxImageUploadSize = 5 * 1024 * 1024 // 5MB — Lark API limit for images const maxFileUploadSize = 100 * 1024 * 1024 // 100MB — Lark API limit for files func uploadImageToIM(ctx context.Context, runtime *common.RuntimeContext, filePath, imageType string) (string, error) { - // filePath is already validated by the caller (resolveLocalMedia). - safePath := filePath - - if info, err := vfs.Stat(safePath); err == nil && info.Size() > maxImageUploadSize { + if info, err := runtime.FileIO().Stat(filePath); err == nil && info.Size() > maxImageUploadSize { return "", fmt.Errorf("image size %s exceeds limit (max 5MB)", common.FormatSize(info.Size())) } - f, err := vfs.Open(safePath) + f, err := runtime.FileIO().Open(filePath) if err != nil { return "", err } @@ -1045,14 +1034,11 @@ func uploadImageToIM(ctx context.Context, runtime *common.RuntimeContext, filePa } func uploadFileToIM(ctx context.Context, runtime *common.RuntimeContext, filePath, fileType, duration string) (string, error) { - // filePath is already validated by the caller (resolveLocalMedia). - safePath := filePath - - if info, err := vfs.Stat(safePath); err == nil && info.Size() > maxFileUploadSize { + if info, err := runtime.FileIO().Stat(filePath); err == nil && info.Size() > maxFileUploadSize { return "", fmt.Errorf("file size %s exceeds limit (max 100MB)", common.FormatSize(info.Size())) } - f, err := vfs.Open(safePath) + f, err := runtime.FileIO().Open(filePath) if err != nil { return "", err } @@ -1060,7 +1046,7 @@ func uploadFileToIM(ctx context.Context, runtime *common.RuntimeContext, filePat fd := larkcore.NewFormdata() fd.AddField("file_type", fileType) - fd.AddField("file_name", filepath.Base(safePath)) + fd.AddField("file_name", filepath.Base(filePath)) if duration != "" { fd.AddField("duration", duration) } diff --git a/shortcuts/im/helpers_network_test.go b/shortcuts/im/helpers_network_test.go index 9e914fdd..d7b6c374 100644 --- a/shortcuts/im/helpers_network_test.go +++ b/shortcuts/im/helpers_network_test.go @@ -20,6 +20,7 @@ import ( lark "github.com/larksuite/oapi-sdk-go/v3" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" + "github.com/larksuite/cli/extension/fileio" "github.com/larksuite/cli/internal/cmdutil" "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/credential" @@ -52,9 +53,10 @@ func shortcutRawResponse(status int, body []byte, headers http.Header) *http.Res headers = make(http.Header) } return &http.Response{ - StatusCode: status, - Header: headers, - Body: io.NopCloser(bytes.NewReader(body)), + StatusCode: status, + Header: headers, + Body: io.NopCloser(bytes.NewReader(body)), + ContentLength: int64(len(body)), } } @@ -88,10 +90,11 @@ func newBotShortcutRuntime(t *testing.T, rt http.RoundTripper) *common.RuntimeCo runtime := &common.RuntimeContext{ Config: cfg, Factory: &cmdutil.Factory{ - Config: func() (*core.CliConfig, error) { return cfg, nil }, - HttpClient: func() (*http.Client, error) { return httpClient, nil }, - LarkClient: func() (*lark.Client, error) { return sdk, nil }, - Credential: testCred, + Config: func() (*core.CliConfig, error) { return cfg, nil }, + HttpClient: func() (*http.Client, error) { return httpClient, nil }, + LarkClient: func() (*lark.Client, error) { return sdk, nil }, + Credential: testCred, + FileIOProvider: fileio.GetProvider(), IOStreams: &cmdutil.IOStreams{ Out: &bytes.Buffer{}, ErrOut: &bytes.Buffer{}, @@ -241,7 +244,9 @@ func TestDownloadIMResourceToPathSuccess(t *testing.T) { } })) - target := filepath.Join(t.TempDir(), "nested", "resource.bin") + cmdutil.TestChdir(t, t.TempDir()) + + target := filepath.Join("nested", "resource.bin") _, size, err := downloadIMResourceToPath(context.Background(), runtime, "om_123", "file_123", "file", target) if err != nil { t.Fatalf("downloadIMResourceToPath() error = %v", err) @@ -280,7 +285,9 @@ func TestDownloadIMResourceToPathHTTPErrorBody(t *testing.T) { } })) - _, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_403", "file_403", "file", filepath.Join(t.TempDir(), "out.bin")) + cmdutil.TestChdir(t, t.TempDir()) + + _, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_403", "file_403", "file", "out.bin") if err == nil || !strings.Contains(err.Error(), "HTTP 403: denied") { t.Fatalf("downloadIMResourceToPath() error = %v", err) } @@ -305,28 +312,14 @@ func TestUploadImageToIMSuccess(t *testing.T) { } })) - wd, err := os.Getwd() - if err != nil { - t.Fatalf("Getwd() error = %v", err) - } - tmpDir := t.TempDir() - if err := os.Chdir(tmpDir); err != nil { - t.Fatalf("Chdir() error = %v", err) - } - t.Cleanup(func() { - _ = os.Chdir(wd) - }) + cmdutil.TestChdir(t, t.TempDir()) path := "demo.png" if err := os.WriteFile(path, []byte("png"), 0600); err != nil { t.Fatalf("WriteFile() error = %v", err) } - absPath, err := filepath.Abs(path) - if err != nil { - t.Fatalf("Abs() error = %v", err) - } - got, err := uploadImageToIM(context.Background(), runtime, absPath, "message") + got, err := uploadImageToIM(context.Background(), runtime, path, "message") if err != nil { t.Fatalf("uploadImageToIM() error = %v", err) } @@ -357,28 +350,14 @@ func TestUploadFileToIMSuccess(t *testing.T) { } })) - wd, err := os.Getwd() - if err != nil { - t.Fatalf("Getwd() error = %v", err) - } - tmpDir := t.TempDir() - if err := os.Chdir(tmpDir); err != nil { - t.Fatalf("Chdir() error = %v", err) - } - t.Cleanup(func() { - _ = os.Chdir(wd) - }) + cmdutil.TestChdir(t, t.TempDir()) path := "demo.txt" if err := os.WriteFile(path, []byte("demo"), 0600); err != nil { t.Fatalf("WriteFile() error = %v", err) } - absPath, err := filepath.Abs(path) - if err != nil { - t.Fatalf("Abs() error = %v", err) - } - got, err := uploadFileToIM(context.Background(), runtime, absPath, "stream", "1200") + got, err := uploadFileToIM(context.Background(), runtime, path, "stream", "1200") if err != nil { t.Fatalf("uploadFileToIM() error = %v", err) } @@ -394,7 +373,8 @@ func TestUploadFileToIMSuccess(t *testing.T) { } func TestUploadImageToIMSizeLimit(t *testing.T) { - path := filepath.Join(t.TempDir(), "too-large.png") + cmdutil.TestChdir(t, t.TempDir()) + path := "too-large.png" f, err := os.Create(path) if err != nil { t.Fatalf("Create() error = %v", err) @@ -404,14 +384,18 @@ func TestUploadImageToIMSizeLimit(t *testing.T) { } f.Close() - _, err = uploadImageToIM(context.Background(), nil, path, "message") + rt := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("unexpected") + })) + _, err = uploadImageToIM(context.Background(), rt, path, "message") if err == nil || !strings.Contains(err.Error(), "exceeds limit") { t.Fatalf("uploadImageToIM() error = %v", err) } } func TestUploadFileToIMSizeLimit(t *testing.T) { - path := filepath.Join(t.TempDir(), "too-large.bin") + cmdutil.TestChdir(t, t.TempDir()) + path := "too-large.bin" f, err := os.Create(path) if err != nil { t.Fatalf("Create() error = %v", err) @@ -421,7 +405,10 @@ func TestUploadFileToIMSizeLimit(t *testing.T) { } f.Close() - _, err = uploadFileToIM(context.Background(), nil, path, "stream", "") + rt := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("unexpected") + })) + _, err = uploadFileToIM(context.Background(), rt, path, "stream", "") if err == nil || !strings.Contains(err.Error(), "exceeds limit") { t.Fatalf("uploadFileToIM() error = %v", err) } @@ -430,6 +417,7 @@ func TestUploadFileToIMSizeLimit(t *testing.T) { func TestResolveMediaContentWrapsUploadError(t *testing.T) { runtime := &common.RuntimeContext{ Factory: &cmdutil.Factory{ + FileIOProvider: fileio.GetProvider(), IOStreams: &cmdutil.IOStreams{ Out: &bytes.Buffer{}, ErrOut: &bytes.Buffer{}, @@ -437,7 +425,9 @@ func TestResolveMediaContentWrapsUploadError(t *testing.T) { }, } - missing := filepath.Join(t.TempDir(), "missing.png") + cmdutil.TestChdir(t, t.TempDir()) + + missing := "missing.png" _, _, err := resolveMediaContent(context.Background(), runtime, "", missing, "", "", "", "") if err == nil || !strings.Contains(err.Error(), "image upload failed") { t.Fatalf("resolveMediaContent() error = %v", err) @@ -457,15 +447,7 @@ func TestResolveLocalMediaImage(t *testing.T) { return nil, fmt.Errorf("unexpected request: %s", req.URL.String()) })) - wd, err := os.Getwd() - if err != nil { - t.Fatalf("Getwd() error = %v", err) - } - tmpDir := t.TempDir() - if err := os.Chdir(tmpDir); err != nil { - t.Fatalf("Chdir() error = %v", err) - } - t.Cleanup(func() { _ = os.Chdir(wd) }) + cmdutil.TestChdir(t, t.TempDir()) if err := os.WriteFile("test.png", []byte("png-data"), 0600); err != nil { t.Fatalf("WriteFile() error = %v", err) @@ -496,15 +478,7 @@ func TestResolveLocalMediaFile(t *testing.T) { return nil, fmt.Errorf("unexpected request: %s", req.URL.String()) })) - wd, err := os.Getwd() - if err != nil { - t.Fatalf("Getwd() error = %v", err) - } - tmpDir := t.TempDir() - if err := os.Chdir(tmpDir); err != nil { - t.Fatalf("Chdir() error = %v", err) - } - t.Cleanup(func() { _ = os.Chdir(wd) }) + cmdutil.TestChdir(t, t.TempDir()) if err := os.WriteFile("test.txt", []byte("file-data"), 0600); err != nil { t.Fatalf("WriteFile() error = %v", err) diff --git a/shortcuts/im/helpers_test.go b/shortcuts/im/helpers_test.go index f1813e9e..c64785e5 100644 --- a/shortcuts/im/helpers_test.go +++ b/shortcuts/im/helpers_test.go @@ -7,6 +7,7 @@ import ( "context" "encoding/binary" "errors" + "fmt" "net/http" "reflect" "strings" @@ -263,10 +264,13 @@ func TestParseMp4Duration(t *testing.T) { } func TestParseMediaDuration(t *testing.T) { - if got := parseMediaDuration("test.pdf", "pdf"); got != "" { + rt := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("unexpected") + })) + if got := parseMediaDuration(rt, "test.pdf", "pdf"); got != "" { t.Fatalf("parseMediaDuration(pdf) = %q, want empty", got) } - if got := parseMediaDuration("nonexistent.opus", "opus"); got != "" { + if got := parseMediaDuration(rt, "nonexistent.opus", "opus"); got != "" { t.Fatalf("parseMediaDuration(missing) = %q, want empty", got) } } diff --git a/shortcuts/im/im_messages_reply.go b/shortcuts/im/im_messages_reply.go index f7b73cc0..806ee739 100644 --- a/shortcuts/im/im_messages_reply.go +++ b/shortcuts/im/im_messages_reply.go @@ -91,29 +91,13 @@ var ImMessagesReply = common.Shortcut{ videoCoverKey := runtime.Str("video-cover") audioKey := runtime.Str("audio") - if !isMediaKey(imageKey) { - if _, err := validate.SafeLocalFlagPath("--image", imageKey); err != nil { - return output.ErrValidation("%v", err) - } - } - if !isMediaKey(fileKey) { - if _, err := validate.SafeLocalFlagPath("--file", fileKey); err != nil { - return output.ErrValidation("%v", err) - } - } - if !isMediaKey(videoKey) { - if _, err := validate.SafeLocalFlagPath("--video", videoKey); err != nil { - return output.ErrValidation("%v", err) - } - } - if !isMediaKey(videoCoverKey) { - if _, err := validate.SafeLocalFlagPath("--video-cover", videoCoverKey); err != nil { - return output.ErrValidation("%v", err) - } - } - if !isMediaKey(audioKey) { - if _, err := validate.SafeLocalFlagPath("--audio", audioKey); err != nil { - return output.ErrValidation("%v", err) + fio := runtime.FileIO() + for _, mf := range []struct{ flag, val string }{ + {"--image", imageKey}, {"--file", fileKey}, {"--video", videoKey}, + {"--video-cover", videoCoverKey}, {"--audio", audioKey}, + } { + if err := validateMediaFlagPath(fio, mf.flag, mf.val); err != nil { + return err } } @@ -149,29 +133,13 @@ var ImMessagesReply = common.Shortcut{ audioVal := runtime.Str("audio") replyInThread := runtime.Bool("reply-in-thread") idempotencyKey := runtime.Str("idempotency-key") - if !isMediaKey(imageVal) { - if _, err := validate.SafeLocalFlagPath("--image", imageVal); err != nil { - return output.ErrValidation("%v", err) - } - } - if !isMediaKey(fileVal) { - if _, err := validate.SafeLocalFlagPath("--file", fileVal); err != nil { - return output.ErrValidation("%v", err) - } - } - if !isMediaKey(videoVal) { - if _, err := validate.SafeLocalFlagPath("--video", videoVal); err != nil { - return output.ErrValidation("%v", err) - } - } - if !isMediaKey(videoCoverVal) { - if _, err := validate.SafeLocalFlagPath("--video-cover", videoCoverVal); err != nil { - return output.ErrValidation("%v", err) - } - } - if !isMediaKey(audioVal) { - if _, err := validate.SafeLocalFlagPath("--audio", audioVal); err != nil { - return output.ErrValidation("%v", err) + fio := runtime.FileIO() + for _, mf := range []struct{ flag, val string }{ + {"--image", imageVal}, {"--file", fileVal}, {"--video", videoVal}, + {"--video-cover", videoCoverVal}, {"--audio", audioVal}, + } { + if err := validateMediaFlagPath(fio, mf.flag, mf.val); err != nil { + return err } } diff --git a/shortcuts/im/im_messages_resources_download.go b/shortcuts/im/im_messages_resources_download.go index 0f29392a..819af6df 100644 --- a/shortcuts/im/im_messages_resources_download.go +++ b/shortcuts/im/im_messages_resources_download.go @@ -6,15 +6,15 @@ package im import ( "context" "fmt" + "io" "net/http" "path/filepath" "strings" "time" + "github.com/larksuite/cli/extension/fileio" "github.com/larksuite/cli/internal/client" "github.com/larksuite/cli/internal/output" - "github.com/larksuite/cli/internal/validate" - "github.com/larksuite/cli/internal/vfs" "github.com/larksuite/cli/shortcuts/common" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" ) @@ -54,7 +54,7 @@ var ImMessagesResourcesDownload = common.Shortcut{ if err != nil { return output.ErrValidation("%s", err) } - if _, err := validate.SafeOutputPath(relPath); err != nil { + if _, err := runtime.ResolveSavePath(relPath); err != nil { return output.ErrValidation("unsafe output path: %s", err) } return nil @@ -67,12 +67,8 @@ var ImMessagesResourcesDownload = common.Shortcut{ if err != nil { return output.ErrValidation("invalid output path: %s", err) } - safePath, err := validate.SafeOutputPath(relPath) - if err != nil { - return output.ErrValidation("unsafe output path: %s", err) - } - finalPath, sizeBytes, err := downloadIMResourceToPath(ctx, runtime, messageId, fileKey, fileType, safePath) + finalPath, sizeBytes, err := downloadIMResourceToPath(ctx, runtime, messageId, fileKey, fileType, relPath) if err != nil { return err } @@ -109,33 +105,33 @@ func normalizeDownloadOutputPath(fileKey, outputPath string) (string, error) { const defaultIMResourceDownloadTimeout = 120 * time.Second var imMimeToExt = map[string]string{ - "image/png": ".png", - "image/jpeg": ".jpg", - "image/gif": ".gif", - "image/webp": ".webp", - "image/svg+xml": ".svg", - "application/pdf": ".pdf", - "video/mp4": ".mp4", - "video/3gpp": ".3gp", - "video/x-msvideo": ".avi", - "audio/mpeg": ".mp3", - "audio/ogg": ".ogg", - "audio/wav": ".wav", - "text/plain": ".txt", - "text/html": ".html", - "text/css": ".css", - "text/csv": ".csv", - "application/zip": ".zip", + "image/png": ".png", + "image/jpeg": ".jpg", + "image/gif": ".gif", + "image/webp": ".webp", + "image/svg+xml": ".svg", + "application/pdf": ".pdf", + "video/mp4": ".mp4", + "video/3gpp": ".3gp", + "video/x-msvideo": ".avi", + "audio/mpeg": ".mp3", + "audio/ogg": ".ogg", + "audio/wav": ".wav", + "text/plain": ".txt", + "text/html": ".html", + "text/css": ".css", + "text/csv": ".csv", + "application/zip": ".zip", "application/x-zip-compressed": ".zip", "application/x-rar-compressed": ".rar", - "application/json": ".json", - "application/xml": ".xml", - "application/octet-stream": ".bin", - "application/msword": ".doc", + "application/json": ".json", + "application/xml": ".xml", + "application/octet-stream": ".bin", + "application/msword": ".doc", "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx", - "application/vnd.ms-excel": ".xls", - "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx", - "application/vnd.ms-powerpoint": ".ppt", + "application/vnd.ms-excel": ".xls", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx", + "application/vnd.ms-powerpoint": ".ppt", "application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx", } @@ -156,8 +152,12 @@ func downloadIMResourceToPath(ctx context.Context, runtime *common.RuntimeContex } defer downloadResp.Body.Close() - if err := vfs.MkdirAll(filepath.Dir(safePath), 0700); err != nil { - return "", 0, output.Errorf(output.ExitInternal, "api_error", "cannot create parent directory: %s", err) + if downloadResp.StatusCode >= 400 { + body, _ := io.ReadAll(io.LimitReader(downloadResp.Body, 4096)) + if len(body) > 0 { + return "", 0, output.ErrNetwork("download failed: HTTP %d: %s", downloadResp.StatusCode, strings.TrimSpace(string(body))) + } + return "", 0, output.ErrNetwork("download failed: HTTP %d", downloadResp.StatusCode) } // Auto-detect extension from Content-Type if missing @@ -171,9 +171,19 @@ func downloadIMResourceToPath(ctx context.Context, runtime *common.RuntimeContex } } - sizeBytes, err := validate.AtomicWriteFromReader(finalPath, downloadResp.Body, 0600) + result, err := runtime.FileIO().Save(finalPath, fileio.SaveOptions{ + ContentType: downloadResp.Header.Get("Content-Type"), + ContentLength: downloadResp.ContentLength, + }, downloadResp.Body) if err != nil { - return "", 0, output.Errorf(output.ExitInternal, "api_error", "cannot create file: %s", err) + return "", 0, output.Errorf(output.ExitInternal, "api_error", "%s", + common.WrapSaveError(err, "unsafe output path", "cannot create parent directory", "cannot create file")) + } + savedPath, resolveErr := runtime.ResolveSavePath(finalPath) + if resolveErr != nil { + // Save succeeded — file is on disk. Fall back to the relative path + // rather than returning an error for a successfully written file. + savedPath = finalPath } - return finalPath, sizeBytes, nil + return savedPath, result.Size(), nil } diff --git a/shortcuts/im/im_messages_send.go b/shortcuts/im/im_messages_send.go index 116b7b9b..efaa5485 100644 --- a/shortcuts/im/im_messages_send.go +++ b/shortcuts/im/im_messages_send.go @@ -7,10 +7,11 @@ import ( "context" "encoding/json" "net/http" + "os" "strings" + "github.com/larksuite/cli/extension/fileio" "github.com/larksuite/cli/internal/output" - "github.com/larksuite/cli/internal/validate" "github.com/larksuite/cli/shortcuts/common" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" ) @@ -98,29 +99,13 @@ var ImMessagesSend = common.Shortcut{ videoCoverKey := runtime.Str("video-cover") audioKey := runtime.Str("audio") - if !isMediaKey(imageKey) { - if _, err := validate.SafeLocalFlagPath("--image", imageKey); err != nil { - return output.ErrValidation("%v", err) - } - } - if !isMediaKey(fileKey) { - if _, err := validate.SafeLocalFlagPath("--file", fileKey); err != nil { - return output.ErrValidation("%v", err) - } - } - if !isMediaKey(videoKey) { - if _, err := validate.SafeLocalFlagPath("--video", videoKey); err != nil { - return output.ErrValidation("%v", err) - } - } - if !isMediaKey(videoCoverKey) { - if _, err := validate.SafeLocalFlagPath("--video-cover", videoCoverKey); err != nil { - return output.ErrValidation("%v", err) - } - } - if !isMediaKey(audioKey) { - if _, err := validate.SafeLocalFlagPath("--audio", audioKey); err != nil { - return output.ErrValidation("%v", err) + fio := runtime.FileIO() + for _, mf := range []struct{ flag, val string }{ + {"--image", imageKey}, {"--file", fileKey}, {"--video", videoKey}, + {"--video-cover", videoCoverKey}, {"--audio", audioKey}, + } { + if err := validateMediaFlagPath(fio, mf.flag, mf.val); err != nil { + return err } } @@ -165,29 +150,13 @@ var ImMessagesSend = common.Shortcut{ videoVal := runtime.Str("video") videoCoverVal := runtime.Str("video-cover") audioVal := runtime.Str("audio") - if !isMediaKey(imageVal) { - if _, err := validate.SafeLocalFlagPath("--image", imageVal); err != nil { - return output.ErrValidation("%v", err) - } - } - if !isMediaKey(fileVal) { - if _, err := validate.SafeLocalFlagPath("--file", fileVal); err != nil { - return output.ErrValidation("%v", err) - } - } - if !isMediaKey(videoVal) { - if _, err := validate.SafeLocalFlagPath("--video", videoVal); err != nil { - return output.ErrValidation("%v", err) - } - } - if !isMediaKey(videoCoverVal) { - if _, err := validate.SafeLocalFlagPath("--video-cover", videoCoverVal); err != nil { - return output.ErrValidation("%v", err) - } - } - if !isMediaKey(audioVal) { - if _, err := validate.SafeLocalFlagPath("--audio", audioVal); err != nil { - return output.ErrValidation("%v", err) + fio := runtime.FileIO() + for _, mf := range []struct{ flag, val string }{ + {"--image", imageVal}, {"--file", fileVal}, {"--video", videoVal}, + {"--video-cover", videoCoverVal}, {"--audio", audioVal}, + } { + if err := validateMediaFlagPath(fio, mf.flag, mf.val); err != nil { + return err } } // Resolve content type @@ -239,3 +208,15 @@ var ImMessagesSend = common.Shortcut{ func isMediaKey(value string) bool { return strings.HasPrefix(value, "img_") || strings.HasPrefix(value, "file_") } + +// validateMediaFlagPath validates a media flag value as a local file path via FileIO. +// Empty values, URLs, and media keys are skipped (not local files). +func validateMediaFlagPath(fio fileio.FileIO, flagName, value string) error { + if value == "" || strings.HasPrefix(value, "http://") || strings.HasPrefix(value, "https://") || isMediaKey(value) { + return nil + } + if _, err := fio.Stat(value); err != nil && !os.IsNotExist(err) { + return output.ErrValidation("%s: %v", flagName, err) + } + return nil +} diff --git a/shortcuts/im/validate_media_test.go b/shortcuts/im/validate_media_test.go new file mode 100644 index 00000000..b6c63efd --- /dev/null +++ b/shortcuts/im/validate_media_test.go @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package im + +import ( + "os" + "path/filepath" + "testing" + + "github.com/larksuite/cli/internal/vfs/localfileio" +) + +func TestValidateMediaFlagPath(t *testing.T) { + dir := t.TempDir() + orig, _ := os.Getwd() + defer os.Chdir(orig) + os.Chdir(dir) + os.WriteFile(filepath.Join(dir, "photo.jpg"), []byte("img"), 0644) + + fio := &localfileio.LocalFileIO{} + + tests := []struct { + name string + flag string + value string + wantErr bool + }{ + {"empty value skipped", "--image", "", false}, + {"http URL skipped", "--image", "http://example.com/a.jpg", false}, + {"https URL skipped", "--file", "https://example.com/b.mp4", false}, + {"media key skipped", "--image", "img_abc123", false}, + {"file key skipped", "--file", "file_abc123", false}, + {"valid local file", "--image", "photo.jpg", false}, + {"nonexistent file allowed", "--file", "missing.txt", false}, + {"path traversal rejected", "--image", "../../etc/passwd", true}, + {"absolute path rejected", "--file", "/etc/passwd", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateMediaFlagPath(fio, tt.flag, tt.value) + if tt.wantErr && err == nil { + t.Fatalf("expected error for %s=%q, got nil", tt.flag, tt.value) + } + if !tt.wantErr && err != nil { + t.Fatalf("unexpected error for %s=%q: %v", tt.flag, tt.value, err) + } + }) + } +} From bd7220a68514a134377e278b9c6e48db0e097e55 Mon Sep 17 00:00:00 2001 From: tuxedomm <273098272+tuxedomm@users.noreply.github.com> Date: Wed, 8 Apr 2026 16:20:08 +0800 Subject: [PATCH 2/3] refactor: replace fileio sentinel errors with typed errors for mkdir/write - Remove ErrMkdir sentinel, replace with MkdirError type - Add WriteError type to wrap file write failures - Error types use transparent Error() (no prefix) to avoid breakchange - SaveResponse keeps switch for backward-compatible error messages - WrapSaveError uses errors.As for type-based matching Change-Id: Iaeba30b82cf5ebaa1868d6efbb471968094d405c --- extension/fileio/errors.go | 31 +++++++++++-------------- internal/client/response.go | 12 ++++++---- internal/vfs/localfileio/localfileio.go | 2 +- shortcuts/common/runner.go | 6 ++++- 4 files changed, 28 insertions(+), 23 deletions(-) diff --git a/extension/fileio/errors.go b/extension/fileio/errors.go index a68bdf7e..d02aec91 100644 --- a/extension/fileio/errors.go +++ b/extension/fileio/errors.go @@ -5,39 +5,36 @@ package fileio import "errors" -// Sentinel errors for FileIO operations. Callers can use errors.Is to -// distinguish error categories and wrap with their own messages. - // ErrPathValidation indicates the path failed security validation // (traversal, absolute, control chars, symlink escape, etc.). var ErrPathValidation = errors.New("path validation failed") -// ErrMkdir indicates parent directory creation failed. -var ErrMkdir = errors.New("directory creation failed") - -// PathValidationError wraps a path validation error with ErrPathValidation. -// Both the sentinel (ErrPathValidation) and the original error are -// reachable via errors.Is / errors.As. +// PathValidationError wraps a path validation error. +// errors.Is(err, ErrPathValidation) returns true. +// errors.Is(err, ) also works via the chain. type PathValidationError struct { - Err error + Err error // original error } func (e *PathValidationError) Error() string { return e.Err.Error() } - -// Unwrap returns both the sentinel and the original error so that -// errors.Is(err, ErrPathValidation) and errors.Is(err, os.ErrPermission) -// (or any OS error in the chain) both work. func (e *PathValidationError) Unwrap() []error { return []error{ErrPathValidation, e.Err} } -// MkdirError wraps a directory creation error with ErrMkdir. +// MkdirError indicates parent directory creation failed. +// Use errors.As(err, &fileio.MkdirError{}) to match. type MkdirError struct { Err error } func (e *MkdirError) Error() string { return e.Err.Error() } +func (e *MkdirError) Unwrap() error { return e.Err } -func (e *MkdirError) Unwrap() []error { - return []error{ErrMkdir, e.Err} +// WriteError indicates file write failed. +// Use errors.As(err, &fileio.WriteError{}) to match. +type WriteError struct { + Err error } + +func (e *WriteError) Error() string { return e.Err.Error() } +func (e *WriteError) Unwrap() error { return e.Err } diff --git a/internal/client/response.go b/internal/client/response.go index 9c4d385a..2a9b84a8 100644 --- a/internal/client/response.go +++ b/internal/client/response.go @@ -126,13 +126,17 @@ func SaveResponse(fio fileio.FileIO, resp *larkcore.ApiResp, outputPath string) ContentLength: int64(len(resp.RawBody)), }, bytes.NewReader(resp.RawBody)) if err != nil { + var me *fileio.MkdirError + var we *fileio.WriteError switch { case errors.Is(err, fileio.ErrPathValidation): - return nil, fmt.Errorf("unsafe output path: %w", err) - case errors.Is(err, fileio.ErrMkdir): - return nil, fmt.Errorf("create directory: %w", err) + return nil, fmt.Errorf("unsafe output path: %s", err) + case errors.As(err, &me): + return nil, fmt.Errorf("create directory: %s", err) + case errors.As(err, &we): + return nil, fmt.Errorf("cannot write file: %s", err) default: - return nil, fmt.Errorf("cannot write file: %w", err) + return nil, fmt.Errorf("cannot write file: %s", err) } } diff --git a/internal/vfs/localfileio/localfileio.go b/internal/vfs/localfileio/localfileio.go index b4318a3d..9fd60bc1 100644 --- a/internal/vfs/localfileio/localfileio.go +++ b/internal/vfs/localfileio/localfileio.go @@ -75,7 +75,7 @@ func (l *LocalFileIO) Save(path string, _ fileio.SaveOptions, body io.Reader) (f } n, err := AtomicWriteFromReader(safePath, body, 0600) if err != nil { - return nil, err + return nil, &fileio.WriteError{Err: err} } return &saveResult{size: n}, nil } diff --git a/shortcuts/common/runner.go b/shortcuts/common/runner.go index b431aaca..99b9b6ac 100644 --- a/shortcuts/common/runner.go +++ b/shortcuts/common/runner.go @@ -340,11 +340,15 @@ func WrapSaveError(err error, pathMsg, mkdirMsg, writeMsg string) error { if err == nil { return nil } + var me *fileio.MkdirError + var we *fileio.WriteError switch { case errors.Is(err, fileio.ErrPathValidation): return fmt.Errorf("%s: %w", pathMsg, err) - case errors.Is(err, fileio.ErrMkdir): + case errors.As(err, &me): return fmt.Errorf("%s: %w", mkdirMsg, err) + case errors.As(err, &we): + return fmt.Errorf("%s: %w", writeMsg, err) default: return fmt.Errorf("%s: %w", writeMsg, err) } From 8636b7aa12dc4fc363a8ee5f07062317b4306140 Mon Sep 17 00:00:00 2001 From: tuxedomm <273098272+tuxedomm@users.noreply.github.com> Date: Wed, 8 Apr 2026 16:24:27 +0800 Subject: [PATCH 3/3] style: fix gofmt formatting in test files Change-Id: Ibe275d3a67555e8cbfcf285c016d57d51e27b142 --- internal/client/response.go | 2 +- internal/cmdutil/factory_default_test.go | 1 + internal/vfs/localfileio/localfileio_test.go | 8 ++++---- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/internal/client/response.go b/internal/client/response.go index 2a9b84a8..8ac750a9 100644 --- a/internal/client/response.go +++ b/internal/client/response.go @@ -28,7 +28,7 @@ type ResponseOptions struct { JqExpr string // if set, apply jq filter instead of Format Out io.Writer // stdout ErrOut io.Writer // stderr - FileIO fileio.FileIO // file transfer abstraction; nil falls back to direct os calls + FileIO fileio.FileIO // file transfer abstraction; required when saving files (--output or binary response) // CheckError is called on parsed JSON results. Nil defaults to CheckLarkResponse. CheckError func(interface{}) error } diff --git a/internal/cmdutil/factory_default_test.go b/internal/cmdutil/factory_default_test.go index f30b622e..fe91ead7 100644 --- a/internal/cmdutil/factory_default_test.go +++ b/internal/cmdutil/factory_default_test.go @@ -232,6 +232,7 @@ func TestNewDefault_FileIOProviderDoesNotResolveDuringInitialization(t *testing. t.Fatalf("ResolveFileIO() calls after explicit resolve = %d, want 1", provider.resolveCalls) } } + type stubTransportProvider struct { interceptor exttransport.Interceptor } diff --git a/internal/vfs/localfileio/localfileio_test.go b/internal/vfs/localfileio/localfileio_test.go index 3cb0efd3..9581165a 100644 --- a/internal/vfs/localfileio/localfileio_test.go +++ b/internal/vfs/localfileio/localfileio_test.go @@ -289,10 +289,10 @@ func TestLocalFileIO_RejectsControlCharsInPath(t *testing.T) { fio := &LocalFileIO{} paths := []string{ - "file\x00name.txt", // null byte - "file\x1fname.txt", // control char - "file\u200Bname.txt", // zero-width space - "file\u202Ename.txt", // bidi override + "file\x00name.txt", // null byte + "file\x1fname.txt", // control char + "file\u200Bname.txt", // zero-width space + "file\u202Ename.txt", // bidi override } for _, p := range paths {