Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions extension/fileio/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package fileio
import (
"context"
"io"
"io/fs"
)

// Provider creates FileIO instances.
Expand Down Expand Up @@ -46,6 +47,7 @@ type FileIO interface {
type FileInfo interface {
Size() int64
IsDir() bool
Mode() fs.FileMode
}

// File is the interface returned by FileIO.Open.
Expand Down
31 changes: 0 additions & 31 deletions shortcuts/common/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ package common

import (
"fmt"
"os"
"path/filepath"
"strconv"
"testing"
"time"
Expand Down Expand Up @@ -57,32 +55,3 @@ func TestParseTimeEndHint(t *testing.T) {
t.Errorf("ParseTime(2026-03-15, end) = %v, want 23:59:59", parsed)
}
}

func TestEnsureWritableFile(t *testing.T) {
t.Run("allows missing target", func(t *testing.T) {
path := filepath.Join(t.TempDir(), "missing.txt")
if err := EnsureWritableFile(path, false); err != nil {
t.Fatalf("EnsureWritableFile() unexpected error: %v", err)
}
})

t.Run("rejects existing target without overwrite", func(t *testing.T) {
path := filepath.Join(t.TempDir(), "exists.txt")
if err := os.WriteFile(path, []byte("data"), 0644); err != nil {
t.Fatalf("WriteFile() error: %v", err)
}
if err := EnsureWritableFile(path, false); err == nil {
t.Fatalf("expected overwrite protection error, got nil")
}
})

t.Run("allows existing target with overwrite", func(t *testing.T) {
path := filepath.Join(t.TempDir(), "exists.txt")
if err := os.WriteFile(path, []byte("data"), 0644); err != nil {
t.Fatalf("WriteFile() error: %v", err)
}
if err := EnsureWritableFile(path, true); err != nil {
t.Fatalf("EnsureWritableFile() unexpected error: %v", err)
}
})
}
18 changes: 4 additions & 14 deletions shortcuts/common/drive_media_upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ import (
"io"
"net/http"

"github.com/larksuite/cli/internal/validate"
"github.com/larksuite/cli/internal/vfs"
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"

"github.com/larksuite/cli/internal/output"
Expand Down Expand Up @@ -51,13 +49,9 @@ type DriveMediaMultipartUploadConfig struct {
}

func UploadDriveMediaAll(runtime *RuntimeContext, cfg DriveMediaUploadAllConfig) (string, error) {
safeFilePath, err := validate.SafeInputPath(cfg.FilePath)
f, err := runtime.FileIO().Open(cfg.FilePath)
if err != nil {
return "", output.ErrValidation("invalid file path: %s", err)
}
f, err := vfs.Open(safeFilePath)
if err != nil {
return "", output.ErrValidation("cannot read file: %s", err)
return "", WrapInputStatError(err)
}
defer f.Close()

Expand Down Expand Up @@ -173,13 +167,9 @@ func ExtractDriveMediaUploadFileToken(data map[string]interface{}, action string
}

func uploadDriveMediaMultipartParts(runtime *RuntimeContext, filePath string, fileSize int64, session DriveMediaMultipartUploadSession) error {
safeFilePath, err := validate.SafeInputPath(filePath)
if err != nil {
return output.ErrValidation("invalid file path: %s", err)
}
f, err := vfs.Open(safeFilePath)
f, err := runtime.FileIO().Open(filePath)
if err != nil {
return output.ErrValidation("cannot read file: %s", err)
return WrapInputStatError(err)
}
defer f.Close()

Expand Down
18 changes: 0 additions & 18 deletions shortcuts/common/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,9 @@ package common

import (
"encoding/json"
"errors"
"io"
"mime/multipart"
"net/textproto"
"os"

"github.com/larksuite/cli/internal/output"
"github.com/larksuite/cli/internal/vfs"
)

// MultipartWriter wraps multipart.Writer for file uploads.
Expand All @@ -37,16 +32,3 @@ func (mw *MultipartWriter) CreateFormFile(fieldname, filename string) (io.Writer
func ParseJSON(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}

// EnsureWritableFile refuses to overwrite an existing file unless overwrite is true.
func EnsureWritableFile(path string, overwrite bool) error {
if overwrite {
return nil
}
if _, err := vfs.Stat(path); err == nil {
return output.ErrValidation("output file already exists: %s (use --overwrite to replace)", path)
} else if !errors.Is(err, os.ErrNotExist) {
return output.Errorf(output.ExitInternal, "io", "cannot access output path %s: %v", path, err)
}
return nil
}
20 changes: 20 additions & 0 deletions shortcuts/common/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,26 @@ func WrapOpenError(err error, pathMsg, readMsg string) error {
return fmt.Errorf("%s: %w", readMsg, err)
}

// WrapInputStatError wraps a FileIO.Stat/Open error for input file validation,
// returning output.ErrValidation with the appropriate message:
// - Path validation failures → "unsafe file path: ..."
// - Other errors → readMsg prefix (default "cannot read file")
//
// Pass an optional readMsg to override the non-path-validation message prefix.
func WrapInputStatError(err error, readMsg ...string) error {
if err == nil {
return nil
}
if errors.Is(err, fileio.ErrPathValidation) {
return output.ErrValidation("unsafe file path: %s", err)
}
msg := "cannot read file"
if len(readMsg) > 0 && readMsg[0] != "" {
msg = readMsg[0]
}
return output.ErrValidation("%s: %s", msg, err)
}

// WrapSaveErrorByCategory maps a FileIO.Save error to structured output errors,
// using standardized messages and the given error category (e.g. "api_error", "io").
// Path validation errors always use ErrValidation (exit code 2).
Expand Down
43 changes: 6 additions & 37 deletions shortcuts/common/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@ package common

import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"

"github.com/larksuite/cli/extension/fileio"
"github.com/larksuite/cli/internal/output"
"github.com/larksuite/cli/internal/vfs"
)

// FlagErrorf returns a validation error with flag context (exit code 2).
Expand Down Expand Up @@ -88,40 +86,11 @@ func ParseIntBounded(rt *RuntimeContext, name string, min, max int) int {
// ValidateSafeOutputDir ensures outputDir is a relative path that resolves
// within the current working directory, preventing path traversal attacks
// (including symlink-based escape).
func ValidateSafeOutputDir(outputDir string) error {
if filepath.IsAbs(outputDir) {
return fmt.Errorf("--output-dir must be a relative path, got: %q", outputDir)
}
cwd, err := vfs.Getwd()
if err != nil {
return fmt.Errorf("cannot determine working directory: %w", err)
}
canonicalCwd, err := filepath.EvalSymlinks(cwd)
if err != nil {
canonicalCwd = cwd
}
abs := filepath.Clean(filepath.Join(cwd, outputDir))

// Resolve symlinks in abs to prevent symlink-escape attacks (e.g. an
// attacker-controlled symlink inside CWD pointing outside).
canonicalAbs, err := filepath.EvalSymlinks(abs)
if err != nil {
if !os.IsNotExist(err) {
return fmt.Errorf("--output-dir %q: %w", outputDir, err)
}
// Path does not exist yet. If os.Lstat succeeds the entry is a dangling
// symlink — reject it to prevent future escapes once the target is created.
if _, lstErr := vfs.Lstat(abs); lstErr == nil {
return fmt.Errorf("--output-dir %q is a symlink with a non-existent target", outputDir)
}
// The path itself doesn't exist; the string-level check is sufficient.
canonicalAbs = abs
}

if !strings.HasPrefix(canonicalAbs, canonicalCwd+string(filepath.Separator)) {
return fmt.Errorf("--output-dir %q resolves outside the working directory", outputDir)
}
return nil
// It delegates all validation to FileIO.ResolvePath which already performs
// cwd-boundary checks, symlink resolution, and control-character rejection.
func ValidateSafeOutputDir(fio fileio.FileIO, outputDir string) error {
_, err := fio.ResolvePath(outputDir)
return err
}

// RejectDangerousChars returns an error if value contains ASCII control
Expand Down
9 changes: 5 additions & 4 deletions shortcuts/common/validate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"path/filepath"
"testing"

"github.com/larksuite/cli/internal/vfs/localfileio"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -199,7 +200,7 @@ func TestValidateSafeOutputDir_RejectsSymlinkEscape(t *testing.T) {
t.Fatalf("Symlink: %v", err)
}

if err := ValidateSafeOutputDir("evil_out"); err == nil {
if err := ValidateSafeOutputDir(&localfileio.LocalFileIO{}, "evil_out"); err == nil {
t.Fatal("expected error for symlink pointing outside CWD, got nil")
}
}
Expand All @@ -214,7 +215,7 @@ func TestValidateSafeOutputDir_RejectsDanglingSymlink(t *testing.T) {
t.Fatalf("Symlink: %v", err)
}

if err := ValidateSafeOutputDir("dangling"); err == nil {
if err := ValidateSafeOutputDir(&localfileio.LocalFileIO{}, "dangling"); err == nil {
t.Fatal("expected error for dangling symlink, got nil")
}
}
Expand All @@ -230,7 +231,7 @@ func TestValidateSafeOutputDir_AllowsNormalSubdir(t *testing.T) {
t.Fatalf("Mkdir: %v", err)
}

if err := ValidateSafeOutputDir("output"); err != nil {
if err := ValidateSafeOutputDir(&localfileio.LocalFileIO{}, "output"); err != nil {
t.Fatalf("expected no error for real subdir, got: %v", err)
}
}
Expand All @@ -241,7 +242,7 @@ func TestValidateSafeOutputDir_AllowsNonExistentPath(t *testing.T) {
workDir := t.TempDir()
chdirForTest(t, workDir)

if err := ValidateSafeOutputDir("new_output_dir"); err != nil {
if err := ValidateSafeOutputDir(&localfileio.LocalFileIO{}, "new_output_dir"); err != nil {
t.Fatalf("expected no error for non-existent path, got: %v", err)
}
}
38 changes: 23 additions & 15 deletions shortcuts/doc/doc_media_download.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import (

larkcore "github.com/larksuite/oapi-sdk-go/v3/core"

"github.com/larksuite/cli/extension/fileio"
"github.com/larksuite/cli/internal/output"
"github.com/larksuite/cli/internal/validate"
"github.com/larksuite/cli/internal/vfs"
"github.com/larksuite/cli/shortcuts/common"
)

Expand Down Expand Up @@ -66,8 +66,7 @@ var DocMediaDownload = common.Shortcut{
if err := validate.ResourceName(token, "--token"); err != nil {
return output.ErrValidation("%s", err)
}
// Early path validation before API call (final validation after auto-extension below)
if _, err := validate.SafeOutputPath(outputPath); err != nil {
if _, err := runtime.ResolveSavePath(outputPath); err != nil {
return output.ErrValidation("unsafe output path: %s", err)
}

Expand Down Expand Up @@ -105,26 +104,35 @@ var DocMediaDownload = common.Shortcut{
}
}

safePath, err := validate.SafeOutputPath(finalPath)
if err != nil {
return output.ErrValidation("unsafe output path: %s", err)
}
if err := common.EnsureWritableFile(safePath, overwrite); err != nil {
return err
// Validate final path after extension append
if finalPath != outputPath {
if _, err := runtime.ResolveSavePath(finalPath); err != nil {
return output.ErrValidation("unsafe output path: %s", err)
}
}

if err := vfs.MkdirAll(filepath.Dir(safePath), 0700); err != nil {
return output.Errorf(output.ExitInternal, "io", "cannot create parent directory: %v", err)
// Overwrite check on final path (after extension detection)
if !overwrite {
if _, statErr := runtime.FileIO().Stat(finalPath); statErr == nil {
return output.ErrValidation("output file already exists: %s (use --overwrite to replace)", finalPath)
}
}

sizeBytes, err := validate.AtomicWriteFromReader(safePath, resp.Body, 0600)
result, err := runtime.FileIO().Save(finalPath, fileio.SaveOptions{
ContentType: resp.Header.Get("Content-Type"),
ContentLength: resp.ContentLength,
}, resp.Body)
if err != nil {
return output.Errorf(output.ExitInternal, "io", "cannot create file: %v", err)
return common.WrapSaveErrorByCategory(err, "io")
}

savedPath, _ := runtime.ResolveSavePath(finalPath)
if savedPath == "" {
savedPath = finalPath
}
runtime.Out(map[string]interface{}{
"saved_path": safePath,
"size_bytes": sizeBytes,
"saved_path": savedPath,
"size_bytes": result.Size(),
"content_type": resp.Header.Get("Content-Type"),
}, nil)
return nil
Expand Down
17 changes: 6 additions & 11 deletions shortcuts/doc/doc_media_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import (
"fmt"
"path/filepath"

"github.com/larksuite/cli/extension/fileio"
"github.com/larksuite/cli/internal/output"
"github.com/larksuite/cli/internal/validate"
"github.com/larksuite/cli/internal/vfs"
"github.com/larksuite/cli/shortcuts/common"
)

Expand Down Expand Up @@ -79,7 +79,7 @@ var DocMediaInsert = common.Shortcut{
POST("/open-apis/docx/v1/documents/:document_id/blocks/:document_id/children").
Desc(fmt.Sprintf("[%d] Create empty block at document end", stepBase+1)).
Body(createBlockData)
appendDocMediaInsertUploadDryRun(d, filePath, parentType, stepBase+2)
appendDocMediaInsertUploadDryRun(d, runtime.FileIO(), filePath, parentType, stepBase+2)
d.PATCH("/open-apis/docx/v1/documents/:document_id/blocks/batch_update").
Desc(fmt.Sprintf("[%d] Bind uploaded file token to the new block", stepBase+3)).
Body(batchUpdateData)
Expand All @@ -93,20 +93,15 @@ var DocMediaInsert = common.Shortcut{
alignStr := runtime.Str("align")
caption := runtime.Str("caption")

safeFilePath, pathErr := validate.SafeInputPath(filePath)
if pathErr != nil {
return output.ErrValidation("unsafe file path: %s", pathErr)
}

documentID, err := resolveDocxDocumentID(runtime, docInput)
if err != nil {
return err
}

// Validate file
stat, err := vfs.Stat(safeFilePath)
stat, err := runtime.FileIO().Stat(filePath)
if err != nil {
return output.ErrValidation("file not found: %s", filePath)
return common.WrapInputStatError(err, "file not found")
}
if !stat.Mode().IsRegular() {
return output.ErrValidation("file must be a regular file: %s", filePath)
Expand Down Expand Up @@ -347,12 +342,12 @@ func extractCreatedBlockTargets(createData map[string]interface{}, mediaType str
return blockID, uploadParentNode, replaceBlockID
}

func appendDocMediaInsertUploadDryRun(d *common.DryRunAPI, filePath, parentType string, step int) {
func appendDocMediaInsertUploadDryRun(d *common.DryRunAPI, fio fileio.FileIO, filePath, parentType string, step int) {
// The upload step runs only after the empty placeholder block is created, so
// dry-run can refer to that future block ID only symbolically. For large
// files, keep multipart internals as substeps of the single user-facing
// "upload file" step.
if docMediaShouldUseMultipart(filePath) {
if docMediaShouldUseMultipart(fio, filePath) {
d.POST("/open-apis/drive/v1/medias/upload_prepare").
Desc(fmt.Sprintf("[%da] Initialize multipart upload", step)).
Body(map[string]interface{}{
Expand Down
Loading
Loading