diff --git a/shortcuts/doc/docs_create.go b/shortcuts/doc/docs_create.go index c88c6eaf6..9a38a766d 100644 --- a/shortcuts/doc/docs_create.go +++ b/shortcuts/doc/docs_create.go @@ -5,7 +5,10 @@ package doc import ( "context" + "fmt" + "os" + "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/shortcuts/common" ) @@ -15,13 +18,14 @@ var DocsCreate = common.Shortcut{ Description: "Create a Lark document", Risk: "write", AuthTypes: []string{"user", "bot"}, - Scopes: []string{"docx:document:create"}, + Scopes: []string{"docx:document:create", "docs:document.media:upload", "docx:document:write_only", "docx:document:readonly"}, Flags: []common.Flag{ {Name: "title", Desc: "document title"}, {Name: "markdown", Desc: "Markdown content (Lark-flavored)", Required: true}, {Name: "folder-token", Desc: "parent folder token"}, {Name: "wiki-node", Desc: "wiki node token"}, {Name: "wiki-space", Desc: "wiki space ID (use my_library for personal library)"}, + {Name: "base-dir", Desc: "base directory for resolving local image paths (default: current working directory)"}, }, Validate: func(ctx context.Context, runtime *common.RuntimeContext) error { count := 0 @@ -37,6 +41,16 @@ var DocsCreate = common.Shortcut{ if count > 1 { return common.FlagErrorf("--folder-token, --wiki-node, and --wiki-space are mutually exclusive") } + + if dir := runtime.Str("base-dir"); dir != "" { + info, err := os.Stat(dir) + if err != nil { + return output.ErrValidation("--base-dir %q does not exist: %v", dir, err) + } + if !info.IsDir() { + return output.ErrValidation("--base-dir %q is not a directory", dir) + } + } return nil }, DryRun: func(ctx context.Context, runtime *common.RuntimeContext) *common.DryRunAPI { @@ -55,15 +69,23 @@ var DocsCreate = common.Shortcut{ if v := runtime.Str("wiki-space"); v != "" { args["wiki_space"] = v } - return common.NewDryRunAPI(). + + d := common.NewDryRunAPI(). POST(common.MCPEndpoint(runtime.Config.Brand)). Desc("MCP tool: create-doc"). Body(map[string]interface{}{"method": "tools/call", "params": map[string]interface{}{"name": "create-doc", "arguments": args}}). Set("mcp_tool", "create-doc").Set("args", args) + + if hasLocalImages(runtime.Str("markdown")) { + d.Desc("Two-phase create: create-doc + upload local images + update-doc (overwrite)") + } + + return d }, Execute: func(ctx context.Context, runtime *common.RuntimeContext) error { + markdown := runtime.Str("markdown") args := map[string]interface{}{ - "markdown": runtime.Str("markdown"), + "markdown": markdown, } if v := runtime.Str("title"); v != "" { args["title"] = v @@ -78,11 +100,37 @@ var DocsCreate = common.Shortcut{ args["wiki_space"] = v } + // If markdown contains local image paths, use two-phase creation + if hasLocalImages(markdown) { + baseDir := runtime.Str("base-dir") + if baseDir == "" { + var err error + baseDir, err = os.Getwd() + if err != nil { + return output.ErrValidation("cannot determine working directory: %v", err) + } + } + + result, err := processMarkdownImages(ctx, runtime, markdown, baseDir, args) + if err != nil { + return err + } + runtime.Out(result, nil) + return nil + } + result, err := common.CallMCPTool(runtime, "create-doc", args) if err != nil { return err } + // Post-process: auto-resize table column widths + if docID := common.GetString(result, "doc_id"); docID != "" { + if warn := autoResizeTableColumns(runtime, docID); warn != "" { + fmt.Fprintf(runtime.IO().ErrOut, "warning: %s\n", warn) + } + } + runtime.Out(result, nil) return nil }, diff --git a/shortcuts/doc/docs_create_images.go b/shortcuts/doc/docs_create_images.go new file mode 100644 index 000000000..1b1d08cb7 --- /dev/null +++ b/shortcuts/doc/docs_create_images.go @@ -0,0 +1,259 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package doc + +import ( + "context" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" + + "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/shortcuts/common" +) + +var imageRefRegex = regexp.MustCompile(`!\[[^\]]*\]\(([^)\s]+)\)`) + +var allowedImageExts = map[string]bool{ + ".jpg": true, ".jpeg": true, ".png": true, + ".gif": true, ".bmp": true, ".webp": true, +} + +type imageRef struct { + fullMatch string + path string +} + +// parseImageRefs extracts all markdown image references from the given text. +func parseImageRefs(markdown string) []imageRef { + matches := imageRefRegex.FindAllStringSubmatch(markdown, -1) + var refs []imageRef + for _, m := range matches { + refs = append(refs, imageRef{ + fullMatch: m[0], + path: m[1], + }) + } + return refs +} + +// isLocalPath returns true if the path is not an HTTP(S) URL. +func isLocalPath(p string) bool { + return !strings.HasPrefix(p, "http://") && !strings.HasPrefix(p, "https://") +} + +// hasLocalImages checks whether the markdown contains any local image references. +func hasLocalImages(markdown string) bool { + for _, ref := range parseImageRefs(markdown) { + if isLocalPath(ref.path) { + return true + } + } + return false +} + +// safeImagePath resolves an image path relative to baseDir and validates it. +// It rejects absolute paths, prevents traversal outside baseDir, resolves +// symlinks, and checks the file exists. +func safeImagePath(imgPath, baseDir string) (string, error) { + if filepath.IsAbs(imgPath) { + return "", fmt.Errorf("absolute image path not allowed: %s", imgPath) + } + if err := validate.RejectControlChars(imgPath, "image path"); err != nil { + return "", err + } + + cleaned := filepath.Clean(imgPath) + resolved := filepath.Join(baseDir, cleaned) + + // Resolve symlinks for the actual path + real, err := filepath.EvalSymlinks(resolved) + if err != nil { + return "", fmt.Errorf("cannot resolve %s: %w", imgPath, err) + } + + // Ensure the resolved path stays under baseDir + absBase, err := filepath.Abs(baseDir) + if err != nil { + return "", fmt.Errorf("cannot resolve base directory: %w", err) + } + realBase, err := filepath.EvalSymlinks(absBase) + if err != nil { + return "", fmt.Errorf("cannot resolve base directory: %w", err) + } + + rel, err := filepath.Rel(realBase, real) + if err != nil || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || rel == ".." { + return "", fmt.Errorf("image path %q resolves outside base directory", imgPath) + } + + return real, nil +} + +// validateImageFile checks that the file has an allowed extension and is within size limits. +func validateImageFile(path string) (os.FileInfo, error) { + ext := strings.ToLower(filepath.Ext(path)) + if !allowedImageExts[ext] { + return nil, fmt.Errorf("unsupported image format %q (allowed: jpg, jpeg, png, gif, bmp, webp)", ext) + } + + stat, err := os.Stat(path) + if err != nil { + return nil, err + } + if stat.Size() > maxFileSize { + return nil, fmt.Errorf("file %.1fMB exceeds 20MB limit", float64(stat.Size())/1024/1024) + } + return stat, nil +} + +// processMarkdownImages implements two-phase document creation: +// 1. Create document with title only +// 2. Upload local images and replace paths with file tokens +// 3. Update document with processed markdown (overwrite mode) +func processMarkdownImages(ctx context.Context, runtime *common.RuntimeContext, markdown, baseDir string, createArgs map[string]interface{}) (map[string]interface{}, error) { + // Phase 1: Create document with minimal content + titleArgs := make(map[string]interface{}) + for k, v := range createArgs { + if k != "markdown" { + titleArgs[k] = v + } + } + titleArgs["markdown"] = " " + + result, err := common.CallMCPTool(runtime, "create-doc", titleArgs) + if err != nil { + return nil, fmt.Errorf("failed to create document: %w", err) + } + + documentID := extractDocumentID(result) + if documentID == "" { + return nil, fmt.Errorf("create-doc did not return document_id") + } + + fmt.Fprintf(runtime.IO().ErrOut, "Document created: %s, uploading local images...\n", common.MaskToken(documentID)) + + // Upload images and collect replacements + processedMarkdown, uploadCount, err := uploadAndReplaceImages(ctx, runtime, markdown, baseDir, documentID) + if err != nil { + return result, fmt.Errorf("image upload failed: %w", err) + } + + if uploadCount == 0 { + // No images were uploaded, just update with original markdown + processedMarkdown = markdown + } + + // Phase 2: Update document with processed markdown + updateArgs := map[string]interface{}{ + "doc_id": documentID, + "mode": "overwrite", + "markdown": processedMarkdown, + } + + _, err = common.CallMCPTool(runtime, "update-doc", updateArgs) + if err != nil { + return result, fmt.Errorf("failed to update document content: %w", err) + } + + fmt.Fprintf(runtime.IO().ErrOut, "Document content updated with %d uploaded image(s)\n", uploadCount) + return result, nil +} + +// uploadAndReplaceImages uploads local images and returns the markdown with paths replaced. +func uploadAndReplaceImages(ctx context.Context, runtime *common.RuntimeContext, markdown, baseDir, documentID string) (string, int, error) { + refs := parseImageRefs(markdown) + replacements := make(map[string]string) // path -> file_token (dedup) + + // Get document root block + rootData, err := runtime.CallAPI("GET", + fmt.Sprintf("/open-apis/docx/v1/documents/%s/blocks/%s", + validate.EncodePathSegment(documentID), validate.EncodePathSegment(documentID)), + nil, nil) + if err != nil { + return markdown, 0, fmt.Errorf("failed to get document root: %w", err) + } + + parentBlockID, insertIndex, err := extractAppendTarget(rootData, documentID) + if err != nil { + return markdown, 0, err + } + + for _, ref := range refs { + if !isLocalPath(ref.path) { + continue + } + + // Skip duplicates + if _, ok := replacements[ref.path]; ok { + continue + } + + resolved, err := safeImagePath(ref.path, baseDir) + if err != nil { + fmt.Fprintf(runtime.IO().ErrOut, "Warning: skipping image %s: %v\n", ref.path, err) + continue + } + + if _, err := validateImageFile(resolved); err != nil { + fmt.Fprintf(runtime.IO().ErrOut, "Warning: skipping image %s: %v\n", ref.path, err) + continue + } + + // Create empty image block as upload target + createData, err := runtime.CallAPI("POST", + fmt.Sprintf("/open-apis/docx/v1/documents/%s/blocks/%s/children", + validate.EncodePathSegment(documentID), validate.EncodePathSegment(parentBlockID)), + nil, buildCreateBlockData("image", insertIndex)) + if err != nil { + fmt.Fprintf(runtime.IO().ErrOut, "Warning: failed to create block for %s: %v\n", ref.path, err) + continue + } + + _, uploadParentNode, _ := extractCreatedBlockTargets(createData, "image") + if uploadParentNode == "" { + fmt.Fprintf(runtime.IO().ErrOut, "Warning: failed to create block for %s\n", ref.path) + continue + } + insertIndex++ + + // Upload file + fileName := filepath.Base(resolved) + fileToken, err := uploadMediaFile(ctx, runtime, resolved, fileName, "image", uploadParentNode, documentID) + if err != nil { + fmt.Fprintf(runtime.IO().ErrOut, "Warning: failed to upload %s: %v\n", ref.path, err) + continue + } + + fmt.Fprintf(runtime.IO().ErrOut, "Uploaded: %s -> %s\n", ref.path, fileToken) + replacements[ref.path] = fileToken + } + + // Replace paths in markdown + processed := markdown + for oldPath, fileToken := range replacements { + processed = strings.ReplaceAll(processed, "]("+oldPath+")", "]("+fileToken+")") + } + + return processed, len(replacements), nil +} + +// extractDocumentID tries to get document_id from a create-doc MCP result. +func extractDocumentID(result map[string]interface{}) string { + // MCP create-doc returns "doc_id" + if id := common.GetString(result, "doc_id"); id != "" { + return id + } + if id := common.GetString(result, "document_id"); id != "" { + return id + } + if id := common.GetString(result, "doc_url"); id != "" { + if ref, err := parseDocumentRef(id); err == nil { + return ref.Token + } + } + return "" +} diff --git a/shortcuts/doc/docs_create_images_test.go b/shortcuts/doc/docs_create_images_test.go new file mode 100644 index 000000000..40e1894b7 --- /dev/null +++ b/shortcuts/doc/docs_create_images_test.go @@ -0,0 +1,262 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package doc + +import ( + "os" + "path/filepath" + "testing" +) + +func TestParseImageRefs(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + markdown string + want []imageRef + }{ + { + name: "no images", + markdown: "hello world", + want: nil, + }, + { + name: "single local image", + markdown: "text ![alt](images/photo.png) more", + want: []imageRef{{fullMatch: "![alt](images/photo.png)", path: "images/photo.png"}}, + }, + { + name: "single URL image", + markdown: "![logo](https://example.com/logo.png)", + want: []imageRef{{fullMatch: "![logo](https://example.com/logo.png)", path: "https://example.com/logo.png"}}, + }, + { + name: "multiple images", + markdown: "![a](a.png) text ![b](https://x.com/b.jpg) ![c](./c.gif)", + want: []imageRef{ + {fullMatch: "![a](a.png)", path: "a.png"}, + {fullMatch: "![b](https://x.com/b.jpg)", path: "https://x.com/b.jpg"}, + {fullMatch: "![c](./c.gif)", path: "./c.gif"}, + }, + }, + { + name: "empty alt text", + markdown: "![](image.png)", + want: []imageRef{{fullMatch: "![](image.png)", path: "image.png"}}, + }, + { + name: "path with subdirectory", + markdown: "![screenshot](case_images/shot1.jpg)", + want: []imageRef{{fullMatch: "![screenshot](case_images/shot1.jpg)", path: "case_images/shot1.jpg"}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := parseImageRefs(tt.markdown) + if len(got) != len(tt.want) { + t.Fatalf("parseImageRefs() returned %d refs, want %d", len(got), len(tt.want)) + } + for i := range got { + if got[i].fullMatch != tt.want[i].fullMatch { + t.Errorf("[%d] fullMatch = %q, want %q", i, got[i].fullMatch, tt.want[i].fullMatch) + } + if got[i].path != tt.want[i].path { + t.Errorf("[%d] path = %q, want %q", i, got[i].path, tt.want[i].path) + } + } + }) + } +} + +func TestIsLocalPath(t *testing.T) { + t.Parallel() + + tests := []struct { + path string + want bool + }{ + {"images/photo.png", true}, + {"./photo.png", true}, + {"photo.png", true}, + {"https://example.com/photo.png", false}, + {"http://example.com/photo.png", false}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + t.Parallel() + if got := isLocalPath(tt.path); got != tt.want { + t.Errorf("isLocalPath(%q) = %v, want %v", tt.path, got, tt.want) + } + }) + } +} + +func TestHasLocalImages(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + markdown string + want bool + }{ + {"no images", "just text", false}, + {"only URL images", "![a](https://example.com/a.png)", false}, + {"local image", "![a](photo.png)", true}, + {"mixed", "![a](https://x.com/a.png) ![b](local.jpg)", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := hasLocalImages(tt.markdown); got != tt.want { + t.Errorf("hasLocalImages() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSafeImagePath(t *testing.T) { + t.Parallel() + + // Create a temp directory with a test image + tmpDir := t.TempDir() + subDir := filepath.Join(tmpDir, "images") + if err := os.MkdirAll(subDir, 0o755); err != nil { + t.Fatal(err) + } + testFile := filepath.Join(subDir, "test.png") + if err := os.WriteFile(testFile, []byte("fake-png"), 0o644); err != nil { + t.Fatal(err) + } + + t.Run("valid relative path", func(t *testing.T) { + t.Parallel() + got, err := safeImagePath("images/test.png", tmpDir) + if err != nil { + t.Fatalf("safeImagePath() error: %v", err) + } + // Resolve symlinks for comparison (macOS /var -> /private/var) + wantResolved, _ := filepath.EvalSymlinks(testFile) + if got != wantResolved { + t.Errorf("safeImagePath() = %q, want %q", got, wantResolved) + } + }) + + t.Run("rejects absolute path", func(t *testing.T) { + t.Parallel() + _, err := safeImagePath("/etc/passwd", tmpDir) + if err == nil { + t.Fatal("safeImagePath() expected error for absolute path") + } + }) + + t.Run("rejects traversal outside base", func(t *testing.T) { + t.Parallel() + _, err := safeImagePath("../../etc/passwd", tmpDir) + if err == nil { + t.Fatal("safeImagePath() expected error for path traversal") + } + }) + + t.Run("rejects non-existent file", func(t *testing.T) { + t.Parallel() + _, err := safeImagePath("images/nonexistent.png", tmpDir) + if err == nil { + t.Fatal("safeImagePath() expected error for non-existent file") + } + }) +} + +func TestValidateImageFile(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + + t.Run("valid png file", func(t *testing.T) { + t.Parallel() + f := filepath.Join(tmpDir, "test.png") + if err := os.WriteFile(f, []byte("data"), 0o644); err != nil { + t.Fatal(err) + } + _, err := validateImageFile(f) + if err != nil { + t.Fatalf("validateImageFile() unexpected error: %v", err) + } + }) + + t.Run("rejects unsupported extension", func(t *testing.T) { + t.Parallel() + f := filepath.Join(tmpDir, "test.svg") + if err := os.WriteFile(f, []byte("data"), 0o644); err != nil { + t.Fatal(err) + } + _, err := validateImageFile(f) + if err == nil { + t.Fatal("validateImageFile() expected error for unsupported format") + } + }) + + t.Run("valid jpg file", func(t *testing.T) { + t.Parallel() + f := filepath.Join(tmpDir, "test.jpg") + if err := os.WriteFile(f, []byte("data"), 0o644); err != nil { + t.Fatal(err) + } + _, err := validateImageFile(f) + if err != nil { + t.Fatalf("validateImageFile() unexpected error: %v", err) + } + }) + + t.Run("valid webp file", func(t *testing.T) { + t.Parallel() + f := filepath.Join(tmpDir, "test.webp") + if err := os.WriteFile(f, []byte("data"), 0o644); err != nil { + t.Fatal(err) + } + _, err := validateImageFile(f) + if err != nil { + t.Fatalf("validateImageFile() unexpected error: %v", err) + } + }) +} + +func TestExtractDocumentID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + result map[string]interface{} + want string + }{ + { + name: "direct document_id", + result: map[string]interface{}{"document_id": "doc123"}, + want: "doc123", + }, + { + name: "from URL", + result: map[string]interface{}{"doc_url": "https://example.com/docx/abc456"}, + want: "abc456", + }, + { + name: "empty result", + result: map[string]interface{}{}, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := extractDocumentID(tt.result); got != tt.want { + t.Errorf("extractDocumentID() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/shortcuts/doc/docs_update.go b/shortcuts/doc/docs_update.go index 5c64b7cc7..8049bf582 100644 --- a/shortcuts/doc/docs_update.go +++ b/shortcuts/doc/docs_update.go @@ -5,6 +5,7 @@ package doc import ( "context" + "fmt" "strings" "github.com/larksuite/cli/shortcuts/common" @@ -112,6 +113,17 @@ var DocsUpdate = common.Shortcut{ } normalizeDocsUpdateResult(result, runtime.Str("markdown")) + + // Post-process: auto-resize table columns for modes that create tables + mode := runtime.Str("mode") + if mode == "overwrite" || mode == "append" { + if docID := common.GetString(result, "doc_id"); docID != "" { + if warn := autoResizeTableColumns(runtime, docID); warn != "" { + fmt.Fprintf(runtime.IO().ErrOut, "warning: %s\n", warn) + } + } + } + runtime.Out(result, nil) return nil }, diff --git a/shortcuts/doc/table_auto_width.go b/shortcuts/doc/table_auto_width.go new file mode 100644 index 000000000..948425911 --- /dev/null +++ b/shortcuts/doc/table_auto_width.go @@ -0,0 +1,293 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package doc + +import ( + "fmt" + "unicode/utf8" + + "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/shortcuts/common" +) + +const ( + blockTypeTable = 31 + blockTypeText = 2 + + minColumnWidth = 80 + maxColumnWidth = 400 + docContainerWidth = 700 + charUnitWidth = 8 // approximate pixel width per character unit +) + +// autoResizeTableColumns fetches all blocks from a document, finds table blocks, +// calculates optimal column widths based on cell content, and updates via API. +// Errors are non-fatal: returns a warning message or empty string on success. +func autoResizeTableColumns(runtime *common.RuntimeContext, documentID string) string { + blocks, err := fetchAllBlocks(runtime, documentID) + if err != nil { + return fmt.Sprintf("table auto-width skipped: %v", err) + } + + blockMap := make(map[string]map[string]interface{}, len(blocks)) + for _, b := range blocks { + if m, ok := b.(map[string]interface{}); ok { + if id, _ := m["block_id"].(string); id != "" { + blockMap[id] = m + } + } + } + + var warnings []string + for _, b := range blocks { + m, ok := b.(map[string]interface{}) + if !ok { + continue + } + blockType, _ := m["block_type"].(float64) + if int(blockType) != blockTypeTable { + continue + } + blockID, _ := m["block_id"].(string) + if blockID == "" { + continue + } + if warn := resizeOneTable(runtime, documentID, blockID, m, blockMap); warn != "" { + warnings = append(warnings, warn) + } + } + + if len(warnings) > 0 { + return fmt.Sprintf("table auto-width partial: %v", warnings) + } + return "" +} + +// fetchAllBlocks retrieves all document blocks with pagination. +func fetchAllBlocks(runtime *common.RuntimeContext, documentID string) ([]interface{}, error) { + var allItems []interface{} + var pageToken string + for { + params := map[string]interface{}{ + "page_size": 500, + } + if pageToken != "" { + params["page_token"] = pageToken + } + data, err := runtime.CallAPI("GET", + fmt.Sprintf("/open-apis/docx/v1/documents/%s/blocks", validate.EncodePathSegment(documentID)), + params, nil) + if err != nil { + return nil, err + } + + items := common.GetSlice(data, "items") + allItems = append(allItems, items...) + + if !common.GetBool(data, "has_more") { + break + } + nextToken := common.GetString(data, "page_token") + if nextToken == "" { + break + } + pageToken = nextToken + } + return allItems, nil +} + +// resizeOneTable calculates and applies optimal column widths for a single table. +func resizeOneTable(runtime *common.RuntimeContext, documentID, blockID string, tableBlock map[string]interface{}, blockMap map[string]map[string]interface{}) string { + table := common.GetMap(tableBlock, "table") + if table == nil { + return "" + } + prop := common.GetMap(table, "property") + if prop == nil { + return "" + } + + colSize := int(common.GetFloat(prop, "column_size")) + rowSize := int(common.GetFloat(prop, "row_size")) + if colSize == 0 || rowSize == 0 { + return "" + } + + // Get cell block IDs - they are ordered row by row, left to right + children, _ := tableBlock["children"].([]interface{}) + if len(children) == 0 { + return "" + } + + // Calculate max content width for each column + colMaxWidths := make([]int, colSize) + for i, childID := range children { + col := i % colSize + cellID, _ := childID.(string) + if cellID == "" { + continue + } + w := cellContentWidth(cellID, blockMap) + if w > colMaxWidths[col] { + colMaxWidths[col] = w + } + } + + // Convert character widths to pixel widths with constraints + columnWidths := computePixelWidths(colMaxWidths, colSize) + + // Check if widths actually differ from equal distribution + equalWidth := docContainerWidth / colSize + allEqual := true + for _, w := range columnWidths { + if w != equalWidth { + allEqual = false + break + } + } + if allEqual { + return "" + } + + // Update table column widths via batch_update + requests := []interface{}{ + map[string]interface{}{ + "block_id": blockID, + "update_table_property": map[string]interface{}{ + "column_width": columnWidths, + }, + }, + } + _, err := runtime.CallAPI("PATCH", + fmt.Sprintf("/open-apis/docx/v1/documents/%s/blocks/batch_update", validate.EncodePathSegment(documentID)), + nil, map[string]interface{}{"requests": requests}) + if err != nil { + return fmt.Sprintf("failed to update table %s: %v", blockID, err) + } + return "" +} + +// cellContentWidth returns the max text width (in character units) of a cell's content. +func cellContentWidth(cellID string, blockMap map[string]map[string]interface{}) int { + cellBlock, ok := blockMap[cellID] + if !ok { + return 0 + } + children, _ := cellBlock["children"].([]interface{}) + maxWidth := 0 + for _, childID := range children { + id, _ := childID.(string) + if id == "" { + continue + } + child, ok := blockMap[id] + if !ok { + continue + } + w := blockTextWidth(child) + if w > maxWidth { + maxWidth = w + } + } + return maxWidth +} + +// blockTextWidth calculates the display width of text in a block. +// Chinese/fullwidth characters count as 2 units, ASCII as 1. +func blockTextWidth(block map[string]interface{}) int { + blockType, _ := block["block_type"].(float64) + if int(blockType) != blockTypeText { + return 0 + } + text := common.GetMap(block, "text") + if text == nil { + return 0 + } + elements, _ := text["elements"].([]interface{}) + totalWidth := 0 + for _, elem := range elements { + e, ok := elem.(map[string]interface{}) + if !ok { + continue + } + textRun := common.GetMap(e, "text_run") + if textRun == nil { + continue + } + content, _ := textRun["content"].(string) + totalWidth += stringDisplayWidth(content) + } + return totalWidth +} + +// stringDisplayWidth calculates display width: CJK/fullwidth = 2, others = 1. +func stringDisplayWidth(s string) int { + width := 0 + for i := 0; i < len(s); { + r, size := utf8.DecodeRuneInString(s[i:]) + if r == utf8.RuneError && size <= 1 { + width++ + i++ + continue + } + if isWideChar(r) { + width += 2 + } else { + width++ + } + i += size + } + return width +} + +// isWideChar returns true for CJK and fullwidth characters. +func isWideChar(r rune) bool { + return (r >= 0x1100 && r <= 0x115F) || // Hangul Jamo + (r >= 0x2E80 && r <= 0x303E) || // CJK Radicals, Kangxi, Ideographic + (r >= 0x3040 && r <= 0x33BF) || // Hiragana, Katakana, Bopomofo, CJK Compatibility + (r >= 0x3400 && r <= 0x4DBF) || // CJK Extension A + (r >= 0x4E00 && r <= 0xA4CF) || // CJK Unified, Yi + (r >= 0xA960 && r <= 0xA97C) || // Hangul Jamo Extended-A + (r >= 0xAC00 && r <= 0xD7FF) || // Hangul Syllables, Hangul Jamo Extended-B + (r >= 0xF900 && r <= 0xFAFF) || // CJK Compatibility Ideographs + (r >= 0xFE30 && r <= 0xFE6F) || // CJK Compatibility Forms, Small Form Variants + (r >= 0xFF01 && r <= 0xFF60) || // Fullwidth Forms + (r >= 0xFFE0 && r <= 0xFFE6) || // Fullwidth Signs + (r >= 0x20000 && r <= 0x2FA1F) // CJK Extension B-F, Compatibility Supplement +} + +// computePixelWidths converts character-unit widths to pixel widths +// with min/max constraints and total width normalization. +func computePixelWidths(charWidths []int, colSize int) []int { + pixelWidths := make([]int, colSize) + for i, cw := range charWidths { + pw := cw * charUnitWidth + if pw < minColumnWidth { + pw = minColumnWidth + } + if pw > maxColumnWidth { + pw = maxColumnWidth + } + pixelWidths[i] = pw + } + + // Normalize to fit within container width + total := 0 + for _, w := range pixelWidths { + total += w + } + if total > docContainerWidth && total > 0 { + scale := float64(docContainerWidth) / float64(total) + newTotal := 0 + for i := range pixelWidths { + pixelWidths[i] = int(float64(pixelWidths[i]) * scale) + if pixelWidths[i] < minColumnWidth { + pixelWidths[i] = minColumnWidth + } + newTotal += pixelWidths[i] + } + } + + return pixelWidths +} diff --git a/shortcuts/doc/table_auto_width_test.go b/shortcuts/doc/table_auto_width_test.go new file mode 100644 index 000000000..f50493898 --- /dev/null +++ b/shortcuts/doc/table_auto_width_test.go @@ -0,0 +1,169 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package doc + +import ( + "testing" +) + +func TestStringDisplayWidth(t *testing.T) { + tests := []struct { + input string + want int + }{ + {"hello", 5}, + {"abc123", 6}, + {"你好", 4}, + {"Hello世界", 9}, + {"", 0}, + {"a", 1}, + {"中", 2}, + {"abc你好def", 10}, + } + for _, tt := range tests { + got := stringDisplayWidth(tt.input) + if got != tt.want { + t.Errorf("stringDisplayWidth(%q) = %d, want %d", tt.input, got, tt.want) + } + } +} + +func TestIsWideChar(t *testing.T) { + tests := []struct { + r rune + want bool + }{ + {'a', false}, + {'1', false}, + {' ', false}, + {'中', true}, + {'日', true}, + {'あ', true}, + {'ア', true}, + {'한', true}, + {'A', true}, // fullwidth A + } + for _, tt := range tests { + got := isWideChar(tt.r) + if got != tt.want { + t.Errorf("isWideChar(%q) = %v, want %v", tt.r, got, tt.want) + } + } +} + +func TestComputePixelWidths(t *testing.T) { + t.Run("applies minimum width", func(t *testing.T) { + widths := computePixelWidths([]int{1, 2, 3}, 3) + for i, w := range widths { + if w < minColumnWidth { + t.Errorf("column %d width %d < min %d", i, w, minColumnWidth) + } + } + }) + + t.Run("applies maximum width", func(t *testing.T) { + widths := computePixelWidths([]int{100, 200}, 2) + for i, w := range widths { + if w > maxColumnWidth { + t.Errorf("column %d width %d > max %d", i, w, maxColumnWidth) + } + } + }) + + t.Run("normalizes to container width", func(t *testing.T) { + // 3 columns each needing 400px = 1200px total, should be scaled down + widths := computePixelWidths([]int{50, 50, 50}, 3) + total := 0 + for _, w := range widths { + total += w + } + if total > docContainerWidth+colPaddingTolerance(3) { + t.Errorf("total width %d exceeds container %d", total, docContainerWidth) + } + }) + + t.Run("small content gets minimum", func(t *testing.T) { + widths := computePixelWidths([]int{0, 0}, 2) + for i, w := range widths { + if w != minColumnWidth { + t.Errorf("column %d width %d, want min %d", i, w, minColumnWidth) + } + } + }) +} + +func colPaddingTolerance(cols int) int { + // Allow some tolerance for rounding when minimum widths are enforced + return cols * minColumnWidth +} + +func TestBlockTextWidth(t *testing.T) { + block := map[string]interface{}{ + "block_type": float64(blockTypeText), + "text": map[string]interface{}{ + "elements": []interface{}{ + map[string]interface{}{ + "text_run": map[string]interface{}{ + "content": "Hello世界", + }, + }, + }, + }, + } + got := blockTextWidth(block) + if got != 9 { // "Hello" = 5, "世界" = 4 + t.Errorf("blockTextWidth = %d, want 9", got) + } +} + +func TestBlockTextWidthNonText(t *testing.T) { + block := map[string]interface{}{ + "block_type": float64(27), // image block + } + got := blockTextWidth(block) + if got != 0 { + t.Errorf("blockTextWidth for non-text = %d, want 0", got) + } +} + +func TestCellContentWidth(t *testing.T) { + blockMap := map[string]map[string]interface{}{ + "cell1": { + "block_id": "cell1", + "block_type": float64(34), + "children": []interface{}{"text1", "text2"}, + }, + "text1": { + "block_id": "text1", + "block_type": float64(blockTypeText), + "text": map[string]interface{}{ + "elements": []interface{}{ + map[string]interface{}{ + "text_run": map[string]interface{}{ + "content": "short", + }, + }, + }, + }, + }, + "text2": { + "block_id": "text2", + "block_type": float64(blockTypeText), + "text": map[string]interface{}{ + "elements": []interface{}{ + map[string]interface{}{ + "text_run": map[string]interface{}{ + "content": "a longer text line", + }, + }, + }, + }, + }, + } + + got := cellContentWidth("cell1", blockMap) + if got != 18 { // "a longer text line" = 18 + t.Errorf("cellContentWidth = %d, want 18", got) + } +}