diff --git a/cmd/prefilter-manual/main.go b/cmd/prefilter-manual/main.go deleted file mode 100644 index 6f0d1a8..0000000 --- a/cmd/prefilter-manual/main.go +++ /dev/null @@ -1,59 +0,0 @@ -package main - -import ( - "context" - "flag" - "os" - "path/filepath" - "strings" - - "github.com/tis24dev/proxsave/internal/backup" - "github.com/tis24dev/proxsave/internal/logging" - "github.com/tis24dev/proxsave/internal/types" -) - -func parseLogLevel(raw string) types.LogLevel { - switch strings.ToLower(strings.TrimSpace(raw)) { - case "debug": - return types.LogLevelDebug - case "info", "": - return types.LogLevelInfo - case "warning", "warn": - return types.LogLevelWarning - case "error": - return types.LogLevelError - default: - return types.LogLevelInfo - } -} - -func main() { - var ( - root string - maxSize int64 - levelLabel string - ) - - flag.StringVar(&root, "root", "/tmp/test_prefilter", "Root directory to run prefilter on") - flag.Int64Var(&maxSize, "max-size", 8*1024*1024, "Max file size (bytes) to prefilter") - flag.StringVar(&levelLabel, "log-level", "info", "Log level: debug|info|warn|error") - flag.Parse() - - root = filepath.Clean(strings.TrimSpace(root)) - if root == "" || root == "." { - root = string(os.PathSeparator) - } - - logger := logging.New(parseLogLevel(levelLabel), false) - logger.SetOutput(os.Stdout) - - cfg := backup.OptimizationConfig{ - EnablePrefilter: true, - PrefilterMaxFileSizeBytes: maxSize, - } - - if err := backup.ApplyOptimizations(context.Background(), logger, root, cfg); err != nil { - logger.Error("Prefilter failed: %v", err) - os.Exit(1) - } -} diff --git a/cmd/proxsave/install.go b/cmd/proxsave/install.go index d95e881..b781773 100644 --- a/cmd/proxsave/install.go +++ b/cmd/proxsave/install.go @@ -9,6 +9,7 @@ import ( "os" "os/exec" "path/filepath" + "sort" "strings" "github.com/tis24dev/proxsave/internal/config" @@ -88,6 +89,14 @@ func runInstall(ctx context.Context, configPath string, bootstrap *logging.Boots return err } + // Optional post-install audit: run a dry-run and offer to disable unused collectors. + if !skipConfigWizard { + logging.DebugStepBootstrap(bootstrap, "install workflow (cli)", "post-install audit") + if err := runPostInstallAuditCLI(ctx, reader, execInfo.ExecPath, configPath, bootstrap); err != nil { + return err + } + } + logging.DebugStepBootstrap(bootstrap, "install workflow (cli)", "finalizing symlinks and cron") runPostInstallSymlinksAndCron(ctx, baseDir, execInfo, bootstrap) @@ -108,6 +117,122 @@ func runInstall(ctx context.Context, configPath string, bootstrap *logging.Boots return nil } +func runPostInstallAuditCLI(ctx context.Context, reader *bufio.Reader, execPath, configPath string, bootstrap *logging.BootstrapLogger) error { + fmt.Println("\n--- Post-install check (optional) ---") + run, err := promptYesNo(ctx, reader, "Run a dry-run to detect unused components and reduce warnings? [Y/n]: ", true) + if err != nil { + return wrapInstallError(err) + } + if !run { + if bootstrap != nil { + bootstrap.Info("Post-install audit: skipped by user") + } + return nil + } + + if bootstrap != nil { + bootstrap.Info("Post-install audit: running dry-run (this may take a minute)") + } + + suggestions, err := wizard.CollectPostInstallDisableSuggestions(ctx, execPath, configPath) + if err != nil { + fmt.Printf("WARNING: Post-install check failed (non-blocking): %v\n", err) + if bootstrap != nil { + bootstrap.Warning("Post-install audit failed (non-blocking): %v", err) + } + return nil + } + if len(suggestions) == 0 { + fmt.Println("No unused components detected. No changes required.") + if bootstrap != nil { + bootstrap.Info("Post-install audit: no unused components detected") + } + return nil + } + + fmt.Printf("Detected %d unused/optional component(s) that may cause WARNINGs.\n", len(suggestions)) + if bootstrap != nil { + keys := make([]string, 0, len(suggestions)) + for _, s := range suggestions { + keys = append(keys, s.Key) + } + bootstrap.Info("Post-install audit: suggested disables (%d): %s", len(keys), strings.Join(keys, ", ")) + } + for _, s := range suggestions { + reason := "" + if len(s.Messages) > 0 { + reason = strings.TrimSpace(s.Messages[0]) + } + if reason != "" { + fmt.Printf(" - %s: %s\n", s.Key, reason) + } else { + fmt.Printf(" - %s\n", s.Key) + } + } + fmt.Println() + + disableAny, err := promptYesNo(ctx, reader, "Disable any of the suggested components now (set KEY=false)? [y/N]: ", false) + if err != nil { + return wrapInstallError(err) + } + if !disableAny { + fmt.Println("No changes applied. You can disable unused components later by editing backup.env.") + if bootstrap != nil { + bootstrap.Info("Post-install audit: no disables applied") + } + return nil + } + + keys := make([]string, 0, len(suggestions)) + for _, s := range suggestions { + disable, err := promptYesNo(ctx, reader, fmt.Sprintf("Disable %s? [y/N]: ", s.Key), false) + if err != nil { + return wrapInstallError(err) + } + if disable { + keys = append(keys, s.Key) + } + } + if len(keys) == 0 { + fmt.Println("No changes selected. Nothing was modified.") + if bootstrap != nil { + bootstrap.Info("Post-install audit: no disables selected") + } + return nil + } + + contentBytes, err := os.ReadFile(configPath) + if err != nil { + fmt.Printf("ERROR: Unable to update configuration (read failed): %v\n", err) + if bootstrap != nil { + bootstrap.Warning("Post-install audit: unable to update configuration (read failed): %v", err) + } + return nil + } + content := string(contentBytes) + + sort.Strings(keys) + for _, key := range keys { + content = setEnvValue(content, key, "false") + } + + tmpAuditPath := configPath + ".tmp.audit" + defer cleanupTempConfig(tmpAuditPath) + if err := writeConfigFile(configPath, tmpAuditPath, content); err != nil { + fmt.Printf("ERROR: Unable to update configuration (write failed): %v\n", err) + if bootstrap != nil { + bootstrap.Warning("Post-install audit: unable to update configuration (write failed): %v", err) + } + return nil + } + + fmt.Printf("✓ Updated %s: disabled %d component(s): %s\n", configPath, len(keys), strings.Join(keys, ", ")) + if bootstrap != nil { + bootstrap.Info("Post-install audit: disabled (%d): %s", len(keys), strings.Join(keys, ", ")) + } + return nil +} + func runNewInstall(ctx context.Context, configPath string, bootstrap *logging.BootstrapLogger, useCLI bool) (err error) { done := logging.DebugStartBootstrap(bootstrap, "new-install workflow", "config=%s", configPath) defer func() { done(err) }() diff --git a/cmd/proxsave/install_tui.go b/cmd/proxsave/install_tui.go index 7d6d8ee..9f4c1cd 100644 --- a/cmd/proxsave/install_tui.go +++ b/cmd/proxsave/install_tui.go @@ -93,7 +93,7 @@ func runInstallTUI(ctx context.Context, configPath string, bootstrap *logging.Bo if !skipConfigWizard { // Run the wizard logging.DebugStepBootstrap(bootstrap, "install workflow (tui)", "running install wizard") - wizardData, err = wizard.RunInstallWizard(ctx, configPath, baseDir, buildSig) + wizardData, err = wizard.RunInstallWizard(ctx, configPath, baseDir, buildSig, baseTemplate) if err != nil { if errors.Is(err, wizard.ErrInstallCancelled) { return wrapInstallError(errInteractiveAborted) @@ -179,6 +179,35 @@ func runInstallTUI(ctx context.Context, configPath string, bootstrap *logging.Bo bootstrap.Info("IMPORTANT: Keep your passphrase/private key offline and secure!") } + // Optional post-install audit: run a dry-run and offer to disable unused collectors + // based on actionable warning hints like "set BACKUP_*=false to disable". + auditRes, auditErr := wizard.RunPostInstallAuditWizard(ctx, execInfo.ExecPath, configPath, buildSig) + if bootstrap != nil { + if auditErr != nil { + bootstrap.Warning("Post-install check failed (non-blocking): %v", auditErr) + } else { + switch { + case !auditRes.Ran: + bootstrap.Info("Post-install audit: skipped by user") + case auditRes.CollectErr != nil: + bootstrap.Warning("Post-install audit failed (non-blocking): %v", auditRes.CollectErr) + case len(auditRes.Suggestions) == 0: + bootstrap.Info("Post-install audit: no unused components detected") + default: + keys := make([]string, 0, len(auditRes.Suggestions)) + for _, s := range auditRes.Suggestions { + keys = append(keys, s.Key) + } + bootstrap.Info("Post-install audit: suggested disables (%d): %s", len(keys), strings.Join(keys, ", ")) + if len(auditRes.AppliedKeys) > 0 { + bootstrap.Info("Post-install audit: disabled (%d): %s", len(auditRes.AppliedKeys), strings.Join(auditRes.AppliedKeys, ", ")) + } else { + bootstrap.Info("Post-install audit: no disables applied") + } + } + } + } + // Clean up legacy bash-based symlinks if bootstrap != nil { bootstrap.Info("Cleaning up legacy bash-based symlinks (if present)") diff --git a/cmd/proxsave/main.go b/cmd/proxsave/main.go index 1d51c69..185e7b6 100644 --- a/cmd/proxsave/main.go +++ b/cmd/proxsave/main.go @@ -1302,13 +1302,6 @@ func run() int { fmt.Println() - if !cfg.EnableGoBackup && !args.Support { - logging.Warning("ENABLE_GO_BACKUP=false is ignored; the Go backup pipeline is always used.") - } else { - logging.Debug("Go backup pipeline enabled") - } - fmt.Println() - // Storage info logging.Info("Storage configuration:") logging.Info(" Primary: %s", formatStorageLabel(cfg.BackupPath, localFS)) diff --git a/docs/CLI_REFERENCE.md b/docs/CLI_REFERENCE.md index d69cdc1..a5c2602 100644 --- a/docs/CLI_REFERENCE.md +++ b/docs/CLI_REFERENCE.md @@ -130,6 +130,10 @@ Some interactive commands support two interface modes: **Use `--cli` when**: TUI rendering issues occur or advanced debugging is needed. +**Existing configuration**: +- If the configuration file already exists, the **TUI wizard** prompts you to **Overwrite**, **Edit existing** (uses the current file as base and pre-fills the wizard fields), or **Keep & exit**. +- In **CLI mode** (`--cli`), you will be prompted to overwrite; choosing "No" keeps the file and skips the configuration wizard. + **Wizard workflow**: 1. Generates/updates the configuration file (`configs/backup.env` by default) 2. Optionally configures secondary storage @@ -137,7 +141,11 @@ Some interactive commands support two interface modes: 4. Optionally enables firewall rules collection (`BACKUP_FIREWALL_RULES=false` by default) 5. Optionally sets up notifications (Telegram, Email; Email defaults to `EMAIL_DELIVERY_METHOD=relay`) 6. Optionally configures encryption (AGE setup) -7. Finalizes installation (symlinks, cron migration, permission checks) +7. (TUI) Optionally selects a cron time (HH:MM) for the `proxsave` cron entry +8. Optionally runs a post-install dry-run audit and offers to disable unused collectors (TUI: checklist; CLI: per-key prompts; actionable hints like `set BACKUP_*=false to disable`) +9. Finalizes installation (symlinks, cron migration, permission checks) + +**Install log**: The installer writes a session log under `/tmp/proxsave/install-*.log` (includes post-install audit suggestions and any accepted disables). ### Configuration Upgrade diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index fddead4..87e655b 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -46,9 +46,6 @@ Complete reference for all 200+ configuration variables in `configs/backup.env`. # Enable/disable backup system BACKUP_ENABLED=true # true | false -# Enable Go pipeline (vs legacy Bash) -ENABLE_GO_BACKUP=true # true | false - # Colored output in terminal USE_COLOR=true # true | false @@ -922,8 +919,6 @@ METRICS_ENABLED=false # true | false METRICS_PATH=${BASE_DIR}/metrics # Empty = /var/lib/prometheus/node-exporter ``` -> ℹ️ Metrics export is available only for the Go pipeline (`ENABLE_GO_BACKUP=true`). - **Output**: Creates `proxmox_backup.prom` in `METRICS_PATH` with: - Backup duration and start/end timestamps - Archive size and raw bytes collected diff --git a/docs/EXAMPLES.md b/docs/EXAMPLES.md index 0a53278..3f8a6eb 100644 --- a/docs/EXAMPLES.md +++ b/docs/EXAMPLES.md @@ -866,7 +866,6 @@ CLOUD_LOG_PATH= # configs/backup.env SYSTEM_ROOT_PREFIX=/mnt/snapshot-root # points to the alternate root BACKUP_ENABLED=true -ENABLE_GO_BACKUP=true # /etc, /var, /root, /home are resolved under the prefix ``` diff --git a/docs/INSTALL.md b/docs/INSTALL.md index fc45a23..eae383d 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -204,14 +204,21 @@ The installation wizard creates your configuration file interactively: ./build/proxsave --new-install ``` +If the configuration file already exists, the **TUI wizard** will ask whether to: +- **Overwrite** (start from the embedded template) +- **Edit existing** (use the current file as base and pre-fill the wizard fields) +- **Keep & exit** (leave the file untouched and exit) + **Wizard prompts:** 1. **Configuration file path**: Default `configs/backup.env` (accepts absolute or relative paths within repo) 2. **Secondary storage**: Optional path for backup/log copies -3. **Cloud storage**: Optional rclone remote configuration +3. **Cloud storage (rclone)**: Optional rclone configuration (supports `CLOUD_REMOTE` as a remote name (recommended) or legacy `remote:path`; `CLOUD_LOG_PATH` supports path-only (recommended) or `otherremote:/path`) 4. **Firewall rules**: Optional firewall rules collection toggle (`BACKUP_FIREWALL_RULES=false` by default; supports iptables/nftables) 5. **Notifications**: Enable Telegram (centralized) and Email notifications (wizard defaults to `EMAIL_DELIVERY_METHOD=relay`; you can switch to `sendmail` or `pmf` later) 6. **Encryption**: AGE encryption setup (runs sub-wizard immediately if enabled) +7. **Cron schedule**: Choose cron time (HH:MM) for the `proxsave` cron entry (TUI mode only) +8. **Post-install check (optional)**: Runs `proxsave --dry-run` and shows actionable warnings like `set BACKUP_*=false to disable`, allowing you to disable unused collectors and reduce WARNING noise **Features:** @@ -219,6 +226,8 @@ The installation wizard creates your configuration file interactively: - Template comment preservation - Creates all necessary directories with proper permissions (0700) - Immediate AGE key generation if encryption is enabled +- Optional post-install audit to disable unused collectors (keeps changes explicit; nothing is disabled silently) +- Install session log under `/tmp/proxsave/install-*.log` (includes post-install audit suggestions and any accepted disables) After completion, edit `configs/backup.env` manually for advanced options. diff --git a/go.mod b/go.mod index 717ff99..4a36bb8 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( ) require ( - filippo.io/edwards25519 v1.1.0 // indirect + filippo.io/edwards25519 v1.1.1 // indirect filippo.io/hpke v0.4.0 // indirect github.com/gdamore/encoding v1.0.1 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect diff --git a/go.sum b/go.sum index ab99f6d..3ce3e36 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ c2sp.org/CCTV/age v0.0.0-20251208015420-e9274a7bdbfd h1:ZLsPO6WdZ5zatV4UfVpr7oAw c2sp.org/CCTV/age v0.0.0-20251208015420-e9274a7bdbfd/go.mod h1:SrHC2C7r5GkDk8R+NFVzYy/sdj0Ypg9htaPXQq5Cqeo= filippo.io/age v1.3.1 h1:hbzdQOJkuaMEpRCLSN1/C5DX74RPcNCk6oqhKMXmZi0= filippo.io/age v1.3.1/go.mod h1:EZorDTYUxt836i3zdori5IJX/v2Lj6kWFU0cfh6C0D4= -filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= -filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw= +filippo.io/edwards25519 v1.1.1/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= filippo.io/hpke v0.4.0 h1:p575VVQ6ted4pL+it6M00V/f2qTZITO0zgmdKCkd5+A= filippo.io/hpke v0.4.0/go.mod h1:EmAN849/P3qdeK+PCMkDpDm83vRHM5cDipBJ8xbQLVY= github.com/gdamore/encoding v1.0.1 h1:YzKZckdBL6jVt2Gc+5p82qhrGiqMdG/eNs6Wy0u3Uhw= diff --git a/internal/config/config.go b/internal/config/config.go index 1bea4f2..bb35388 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -33,7 +33,6 @@ type Config struct { DebugLevel types.LogLevel UseColor bool ColorizeStepLogs bool - EnableGoBackup bool ProfilingEnabled bool BaseDir string DryRun bool @@ -423,7 +422,6 @@ func (c *Config) parseOptimizationSettings() { } func (c *Config) parseSecuritySettings() { - c.EnableGoBackup = c.getBoolWithFallback([]string{"ENABLE_GO_BACKUP", "ENABLE_GO_PIPELINE"}, true) c.DisableNetworkPreflight = c.getBool("DISABLE_NETWORK_PREFLIGHT", false) // Base directory diff --git a/internal/config/config_test.go b/internal/config/config_test.go index eb26f81..3fb7ab0 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -427,38 +427,11 @@ func TestConfigDefaults(t *testing.T) { t.Errorf("Default LocalRetentionDays = %d; want 7", cfg.LocalRetentionDays) } - if !cfg.EnableGoBackup { - t.Error("Expected default EnableGoBackup to be true") - } - if cfg.BaseDir != "/defaults/base" { t.Errorf("Default BaseDir = %q; want %q", cfg.BaseDir, "/defaults/base") } } -func TestEnableGoBackupFlag(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "go_pipeline.env") - - content := `ENABLE_GO_BACKUP=false -` - if err := os.WriteFile(configPath, []byte(content), 0644); err != nil { - t.Fatalf("Failed to create test config: %v", err) - } - - cleanup := setBaseDirEnv(t, "/flag/base") - defer cleanup() - - cfg, err := LoadConfig(configPath) - if err != nil { - t.Fatalf("LoadConfig() error = %v", err) - } - - if cfg.EnableGoBackup { - t.Error("Expected EnableGoBackup to be false when explicitly disabled") - } -} - func TestLoadConfigBaseDirFromConfig(t *testing.T) { tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, "base_dir.env") diff --git a/internal/config/templates/backup.env b/internal/config/templates/backup.env index 9c3ec39..2dbf810 100644 --- a/internal/config/templates/backup.env +++ b/internal/config/templates/backup.env @@ -7,7 +7,6 @@ # General settings # ---------------------------------------------------------------------- BACKUP_ENABLED=true -ENABLE_GO_BACKUP=true PROFILING_ENABLED=true # Enable CPU/heap profiling (pprof) for Go pipeline USE_COLOR=true COLORIZE_STEP_LOGS=true # Highlight "Step N/8" lines (requires USE_COLOR=true) diff --git a/internal/identity/identity.go b/internal/identity/identity.go index 0e43740..87ae582 100644 --- a/internal/identity/identity.go +++ b/internal/identity/identity.go @@ -170,11 +170,6 @@ func collectMACCandidates(logger *logging.Logger) ([]macCandidate, []string) { return candidates, macs } -func collectMACAddresses() []string { - _, macs := collectMACCandidates(nil) - return macs -} - func selectPreferredMAC(candidates []macCandidate) (string, string) { var best *macCandidate for i := range candidates { @@ -469,10 +464,6 @@ func buildSystemData(macs []string, logger *logging.Logger) string { return builder.String() } -func encodeProtectedServerID(serverID, primaryMAC string, logger *logging.Logger) (string, error) { - return encodeProtectedServerIDWithMACs(serverID, []string{primaryMAC}, primaryMAC, logger) -} - func encodeProtectedServerIDWithMACs(serverID string, macs []string, primaryMAC string, logger *logging.Logger) (string, error) { logDebug(logger, "Identity: encodeProtectedServerID: start (serverID=%s primaryMAC=%s)", serverID, primaryMAC) keyField := buildIdentityKeyField(macs, primaryMAC, logger) @@ -574,16 +565,6 @@ func decodeProtectedServerID(fileContent, primaryMAC string, logger *logging.Log return serverID, matchedByMAC, nil } -func generateSystemKey(primaryMAC string, logger *logging.Logger) string { - machineID := readMachineID(logger) - hostnamePart := readHostnamePart(logger) - - macPart := strings.ReplaceAll(primaryMAC, ":", "") - key := computeSystemKey(machineID, hostnamePart, macPart) - logDebug(logger, "Identity: generateSystemKey: systemKey=%s", key) - return key -} - func buildIdentityKeyField(macs []string, primaryMAC string, logger *logging.Logger) string { machineID := readMachineID(logger) hostnamePart := readHostnamePart(logger) diff --git a/internal/identity/identity_test.go b/internal/identity/identity_test.go index d7a6354..8271696 100644 --- a/internal/identity/identity_test.go +++ b/internal/identity/identity_test.go @@ -15,11 +15,15 @@ import ( "github.com/tis24dev/proxsave/internal/types" ) +func encodeProtectedServerIDForTest(serverID, primaryMAC string, logger *logging.Logger) (string, error) { + return encodeProtectedServerIDWithMACs(serverID, []string{primaryMAC}, primaryMAC, logger) +} + func TestEncodeDecodeProtectedServerIDRoundTrip(t *testing.T) { const serverID = "1234567890123456" const mac = "aa:bb:cc:dd:ee:ff" - content, err := encodeProtectedServerID(serverID, mac, nil) + content, err := encodeProtectedServerIDForTest(serverID, mac, nil) if err != nil { t.Fatalf("encodeProtectedServerID() error = %v", err) } @@ -57,7 +61,7 @@ func TestDecodeProtectedServerIDAcceptsDifferentMACOnSameHost(t *testing.T) { } const serverID = "1111222233334444" - content, err := encodeProtectedServerID(serverID, "aa:bb:cc:dd:ee:ff", nil) + content, err := encodeProtectedServerIDForTest(serverID, "aa:bb:cc:dd:ee:ff", nil) if err != nil { t.Fatalf("encodeProtectedServerID() error = %v", err) } @@ -95,7 +99,7 @@ func TestDecodeProtectedServerIDRejectsDifferentHost(t *testing.T) { } const serverID = "1111222233334444" - content, err := encodeProtectedServerID(serverID, "aa:bb:cc:dd:ee:ff", nil) + content, err := encodeProtectedServerIDForTest(serverID, "aa:bb:cc:dd:ee:ff", nil) if err != nil { t.Fatalf("encodeProtectedServerID() error = %v", err) } @@ -147,7 +151,7 @@ func TestDecodeProtectedServerIDDetectsCorruptedData(t *testing.T) { const serverID = "5555666677778888" const mac = "aa:aa:aa:aa:aa:aa" - content, err := encodeProtectedServerID(serverID, mac, nil) + content, err := encodeProtectedServerIDForTest(serverID, mac, nil) if err != nil { t.Fatalf("encodeProtectedServerID() error = %v", err) } @@ -204,14 +208,14 @@ func TestDetectUsesExistingIdentityFile(t *testing.T) { } identityPath := filepath.Join(identityDir, identityFileName) - macs := collectMACAddresses() + _, macs := collectMACCandidates(nil) if len(macs) == 0 { t.Skip("no non-loopback MACs available on this system") } primary := macs[0] const serverID = "1234567890123456" - content, err := encodeProtectedServerID(serverID, primary, nil) + content, err := encodeProtectedServerIDForTest(serverID, primary, nil) if err != nil { t.Fatalf("encodeProtectedServerID() error = %v", err) } @@ -354,21 +358,8 @@ func TestFallbackServerIDFormat(t *testing.T) { } } -func TestGenerateSystemKeyStableAndLength(t *testing.T) { - const mac = "aa:bb:cc:dd:ee:ff" - k1 := generateSystemKey(mac, nil) - k2 := generateSystemKey(mac, nil) - - if len(k1) != 16 { - t.Fatalf("generateSystemKey length = %d, want 16", len(k1)) - } - if k1 != k2 { - t.Fatalf("generateSystemKey should be stable, got %q and %q", k1, k2) - } -} - func TestCollectMACAddressesSortedAndUnique(t *testing.T) { - macs := collectMACAddresses() + _, macs := collectMACCandidates(nil) for i := 0; i < len(macs); i++ { if macs[i] == "" { t.Fatalf("unexpected empty MAC at index %d", i) @@ -413,7 +404,7 @@ func TestDecodeProtectedServerIDInvalidPayloadFormat(t *testing.T) { func TestDecodeProtectedServerIDInvalidServerIDFormat(t *testing.T) { const mac = "aa:bb:cc:dd:ee:ff" - content, err := encodeProtectedServerID("AAAAAAAAAAAAAAAA", mac, nil) + content, err := encodeProtectedServerIDForTest("AAAAAAAAAAAAAAAA", mac, nil) if err != nil { t.Fatalf("encodeProtectedServerID() error = %v", err) } @@ -446,7 +437,7 @@ func TestLoadServerIDWithEmptyMACSlice(t *testing.T) { path := filepath.Join(dir, "identity.conf") const serverID = "1234567890123456" - content, err := encodeProtectedServerID(serverID, "", nil) + content, err := encodeProtectedServerIDForTest(serverID, "", nil) if err != nil { t.Fatalf("encodeProtectedServerID() error = %v", err) } @@ -487,7 +478,10 @@ func TestLoadServerIDFailsAllMACs(t *testing.T) { } func encodeProtectedServerIDLegacy(serverID, primaryMAC string) (string, error) { - systemKey := generateSystemKey(primaryMAC, nil) + machineID := readMachineID(nil) + hostnamePart := readHostnamePart(nil) + macPart := strings.ReplaceAll(primaryMAC, ":", "") + systemKey := computeSystemKey(machineID, hostnamePart, macPart) timestamp := time.Unix(1700000000, 0).Unix() data := fmt.Sprintf("%s:%d:%s", serverID, timestamp, systemKey[:systemKeyPrefixLength]) checksum := sha256.Sum256([]byte(data)) @@ -1588,7 +1582,7 @@ func TestDecodeProtectedServerIDWithEmptyMAC(t *testing.T) { } const serverID = "1234567890123456" - content, err := encodeProtectedServerID(serverID, "aa:bb:cc:dd:ee:ff", nil) + content, err := encodeProtectedServerIDForTest(serverID, "aa:bb:cc:dd:ee:ff", nil) if err != nil { t.Fatalf("encodeProtectedServerID() error = %v", err) } diff --git a/internal/orchestrator/bundle_test.go b/internal/orchestrator/bundle_test.go index 9462b4b..280060e 100644 --- a/internal/orchestrator/bundle_test.go +++ b/internal/orchestrator/bundle_test.go @@ -121,41 +121,6 @@ func TestCreateBundle_CreatesValidTarArchive(t *testing.T) { } } -func TestLegacyCreateBundleWrapper_DelegatesToMethod(t *testing.T) { - logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) - tempDir := t.TempDir() - archive := filepath.Join(tempDir, "backup.tar") - - // Minimal associated files required by createBundle - required := map[string]string{ - "": "archive-content", - ".sha256": "checksum", - ".metadata": "metadata-json", - } - for suffix, content := range required { - if err := os.WriteFile(archive+suffix, []byte(content), 0o640); err != nil { - t.Fatalf("write %s: %v", suffix, err) - } - } - - ctx := context.Background() - - // Call legacy wrapper - bundlePath, err := createBundle(ctx, logger, archive) - if err != nil { - t.Fatalf("legacy createBundle returned error: %v", err) - } - - expectedPath := archive + ".bundle.tar" - if bundlePath != expectedPath { - t.Fatalf("bundle path = %s, want %s", bundlePath, expectedPath) - } - - if _, err := os.Stat(bundlePath); err != nil { - t.Fatalf("expected bundle file to exist, got %v", err) - } -} - func TestRemoveAssociatedFiles_RemovesAll(t *testing.T) { logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) tempDir := t.TempDir() diff --git a/internal/orchestrator/decrypt_test.go b/internal/orchestrator/decrypt_test.go index ae17bd2..84fcee1 100644 --- a/internal/orchestrator/decrypt_test.go +++ b/internal/orchestrator/decrypt_test.go @@ -4250,22 +4250,6 @@ exit 0 os.Setenv("PATH", tmp+":"+origPath) defer os.Setenv("PATH", origPath) - // Create a filesystem wrapper that allows download but fails MkdirAll for tempRoot - type fakeMkdirAllFailOnTempRoot struct { - osFS - } - fake := &struct { - osFS - mkdirCalls int - }{} - - // Use osFS with a hook to fail on the second MkdirAll (tempRoot creation) - type osFSWithMkdirHook struct { - osFS - mkdirCalls int - } - hookFS := &osFSWithMkdirHook{} - orig := restoreFS // Use regular osFS - the download will work, then MkdirAll for /tmp/proxsave should succeed // but we can trigger error by making /tmp/proxsave unwritable after download @@ -4289,27 +4273,6 @@ exit 0 } // If download succeeds and extraction succeeds, that's fine - we've tested the path _ = err - _ = fake - _ = hookFS -} - -// fakeChecksumFailFS wraps osFS to make the plain archive unreadable after extraction -// This triggers GenerateChecksum error (lines 670-673) -type fakeChecksumFailFS struct { - osFS - extractDone bool -} - -func (f *fakeChecksumFailFS) OpenFile(path string, flag int, perm os.FileMode) (*os.File, error) { - file, err := os.OpenFile(path, flag, perm) - if err != nil { - return nil, err - } - // After extracting, make the archive unreadable for checksum - if f.extractDone && strings.Contains(path, "proxmox-decrypt") && strings.HasSuffix(path, ".tar.xz") { - os.Chmod(path, 0o000) - } - return file, nil } // fakeStatThenRemoveFS removes the file after stat succeeds diff --git a/internal/orchestrator/decrypt_workflow_ui.go b/internal/orchestrator/decrypt_workflow_ui.go index a57d45a..2ae37d7 100644 --- a/internal/orchestrator/decrypt_workflow_ui.go +++ b/internal/orchestrator/decrypt_workflow_ui.go @@ -361,7 +361,8 @@ func runDecryptWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l } logger.Info("Creating decrypted bundle...") - bundlePath, err := createBundle(ctx, logger, tempArchivePath) + o := &Orchestrator{logger: logger, fs: osFS{}} + bundlePath, err := o.createBundle(ctx, tempArchivePath) if err != nil { return err } diff --git a/internal/orchestrator/mount_guard.go b/internal/orchestrator/mount_guard.go index 037811d..f4ee262 100644 --- a/internal/orchestrator/mount_guard.go +++ b/internal/orchestrator/mount_guard.go @@ -18,6 +18,18 @@ import ( const mountGuardBaseDir = "/var/lib/proxsave/guards" const mountGuardMountAttemptTimeout = 10 * time.Second +var ( + mountGuardGeteuid = os.Geteuid + mountGuardReadFile = os.ReadFile + mountGuardMkdirAll = os.MkdirAll + mountGuardReadDir = os.ReadDir + mountGuardSysMount = syscall.Mount + mountGuardSysUnmount = syscall.Unmount + mountGuardFstabMountpointsSet = fstabMountpointsSet + mountGuardIsPathOnRootFilesystem = isPathOnRootFilesystem + mountGuardParsePBSDatastoreCfg = parsePBSDatastoreCfgBlocks +) + func maybeApplyPBSDatastoreMountGuards(ctx context.Context, logger *logging.Logger, plan *RestorePlan, stageRoot, destRoot string, dryRun bool) error { if plan == nil || plan.SystemType != SystemTypePBS || !plan.HasCategoryID("datastore_pbs") { return nil @@ -44,7 +56,7 @@ func maybeApplyPBSDatastoreMountGuards(ctx context.Context, logger *logging.Logg } return nil } - if os.Geteuid() != 0 { + if mountGuardGeteuid() != 0 { if logger != nil { logger.Warning("Skipping PBS mount guards: requires root privileges") } @@ -64,7 +76,7 @@ func maybeApplyPBSDatastoreMountGuards(ctx context.Context, logger *logging.Logg } normalized, _ := normalizePBSDatastoreCfgContent(string(data)) - blocks, err := parsePBSDatastoreCfgBlocks(normalized) + blocks, err := mountGuardParsePBSDatastoreCfg(normalized) if err != nil { return err } @@ -75,7 +87,7 @@ func maybeApplyPBSDatastoreMountGuards(ctx context.Context, logger *logging.Logg var fstabMounts map[string]struct{} var mountpointCandidates []string currentFstab := filepath.Join(destRoot, "etc", "fstab") - if mounts, err := fstabMountpointsSet(currentFstab); err != nil { + if mounts, err := mountGuardFstabMountpointsSet(currentFstab); err != nil { if logger != nil { logger.Warning("PBS mount guard: unable to parse current fstab %s: %v (continuing without fstab cross-check)", currentFstab, err) } @@ -123,14 +135,14 @@ func maybeApplyPBSDatastoreMountGuards(ctx context.Context, logger *logging.Logg } } - if err := os.MkdirAll(guardTarget, 0o755); err != nil { + if err := mountGuardMkdirAll(guardTarget, 0o755); err != nil { if logger != nil { logger.Warning("PBS mount guard: unable to create mountpoint directory %s: %v", guardTarget, err) } continue } - onRootFS, _, devErr := isPathOnRootFilesystem(guardTarget) + onRootFS, _, devErr := mountGuardIsPathOnRootFilesystem(guardTarget) if devErr != nil { if logger != nil { logger.Warning("PBS mount guard: unable to determine filesystem device for %s: %v", guardTarget, devErr) @@ -158,7 +170,7 @@ func maybeApplyPBSDatastoreMountGuards(ctx context.Context, logger *logging.Logg out, attemptErr := restoreCmd.Run(mountCtx, "mount", guardTarget) cancel() if attemptErr == nil { - onRootFSNow, _, devErrNow := isPathOnRootFilesystem(guardTarget) + onRootFSNow, _, devErrNow := mountGuardIsPathOnRootFilesystem(guardTarget) if devErrNow == nil && !onRootFSNow { if logger != nil { logger.Info("PBS mount guard: mountpoint %s is now mounted (mount attempt succeeded)", guardTarget) @@ -209,7 +221,7 @@ func maybeApplyPBSDatastoreMountGuards(ctx context.Context, logger *logging.Logg protected[guardTarget] = struct{}{} if logger != nil { - if entries, err := os.ReadDir(guardTarget); err == nil && len(entries) > 0 { + if entries, err := mountGuardReadDir(guardTarget); err == nil && len(entries) > 0 { logger.Warning("PBS mount guard: guard mount point %s is not empty (entries=%d)", guardTarget, len(entries)) } logger.Warning("PBS mount guard: %s resolves to root filesystem (mount missing?) — bind-mounted a read-only guard to prevent writes until storage is available", guardTarget) @@ -241,22 +253,22 @@ func guardMountPoint(ctx context.Context, guardTarget string) error { } guardDir := guardDirForTarget(target) - if err := os.MkdirAll(guardDir, 0o755); err != nil { + if err := mountGuardMkdirAll(guardDir, 0o755); err != nil { return fmt.Errorf("mkdir guard dir: %w", err) } - if err := os.MkdirAll(target, 0o755); err != nil { + if err := mountGuardMkdirAll(target, 0o755); err != nil { return fmt.Errorf("mkdir target: %w", err) } // Bind mount guard directory over the mountpoint to avoid writes to the underlying rootfs path. - if err := syscall.Mount(guardDir, target, "", syscall.MS_BIND, ""); err != nil { + if err := mountGuardSysMount(guardDir, target, "", syscall.MS_BIND, ""); err != nil { return fmt.Errorf("bind mount guard: %w", err) } // Make the bind mount read-only to ensure PBS cannot write backup data to the guard directory. remountFlags := uintptr(syscall.MS_BIND | syscall.MS_REMOUNT | syscall.MS_RDONLY | syscall.MS_NODEV | syscall.MS_NOSUID | syscall.MS_NOEXEC) - if err := syscall.Mount("", target, "", remountFlags, ""); err != nil { - _ = syscall.Unmount(target, 0) + if err := mountGuardSysMount("", target, "", remountFlags, ""); err != nil { + _ = mountGuardSysUnmount(target, 0) return fmt.Errorf("remount guard read-only: %w", err) } @@ -274,7 +286,7 @@ func guardDirForTarget(target string) string { } func isMounted(path string) (bool, error) { - data, err := os.ReadFile("/proc/self/mountinfo") + data, err := mountGuardReadFile("/proc/self/mountinfo") if err == nil { return isMountedFromMountinfo(string(data), path), nil } @@ -315,7 +327,7 @@ func isMountedFromMountinfo(mountinfo, path string) bool { } func isMountedFromProcMounts(path string) (bool, error) { - data, err := os.ReadFile("/proc/mounts") + data, err := mountGuardReadFile("/proc/mounts") if err != nil { return false, err } @@ -408,9 +420,6 @@ func pbsMountGuardRootForDatastorePath(path string) string { case strings.HasPrefix(p, "/run/media/"): rest := strings.TrimPrefix(p, "/run/media/") parts := splitPath(rest) - if len(parts) == 0 { - return "" - } if len(parts) == 1 { return filepath.Join("/run/media", parts[0]) } diff --git a/internal/orchestrator/mount_guard_more_test.go b/internal/orchestrator/mount_guard_more_test.go new file mode 100644 index 0000000..4110908 --- /dev/null +++ b/internal/orchestrator/mount_guard_more_test.go @@ -0,0 +1,896 @@ +package orchestrator + +import ( + "context" + "crypto/sha256" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "syscall" + "testing" + "time" +) + +func TestGuardDirForTarget(t *testing.T) { + t.Parallel() + + target := "/mnt/datastore" + sum := sha256.Sum256([]byte(target)) + id := fmt.Sprintf("%x", sum[:8]) + want := filepath.Join(mountGuardBaseDir, fmt.Sprintf("%s-%s", filepath.Base(target), id)) + if got := guardDirForTarget(target); got != want { + t.Fatalf("guardDirForTarget(%q)=%q want %q", target, got, want) + } + + rootTarget := "/" + sum = sha256.Sum256([]byte(rootTarget)) + id = fmt.Sprintf("%x", sum[:8]) + want = filepath.Join(mountGuardBaseDir, fmt.Sprintf("%s-%s", "guard", id)) + if got := guardDirForTarget(rootTarget); got != want { + t.Fatalf("guardDirForTarget(%q)=%q want %q", rootTarget, got, want) + } +} + +func TestIsMountedFromMountinfo(t *testing.T) { + t.Parallel() + + mountinfo := strings.Join([]string{ + "36 25 0:32 / / rw,relatime - ext4 /dev/sda1 rw", + `37 36 0:33 / /mnt/pbs\040datastore rw,relatime - ext4 /dev/sdb1 rw`, + "bad line", + "", + }, "\n") + + if got := isMountedFromMountinfo(mountinfo, "/"); !got { + t.Fatalf("expected / to be mounted") + } + if got := isMountedFromMountinfo(mountinfo, "/mnt/pbs datastore"); !got { + t.Fatalf("expected escaped mountpoint to match") + } + if got := isMountedFromMountinfo(mountinfo, "/not-mounted"); got { + t.Fatalf("expected /not-mounted to be unmounted") + } + if got := isMountedFromMountinfo(mountinfo, ""); got { + t.Fatalf("expected empty path to be unmounted") + } +} + +func TestFstabMountpointsSet(t *testing.T) { + tmp := filepath.Join(t.TempDir(), "fstab") + content := strings.Join([]string{ + "# comment", + "UUID=abc / ext4 defaults 0 1", + "/dev/sdb1 /mnt/data/ ext4 defaults 0 2", + "/dev/sdc1 /mnt/data2 ext4 defaults 0 2 # inline comment", + "/dev/sdd1 . ext4 defaults 0 0", + "invalidline", + "", + }, "\n") + if err := os.WriteFile(tmp, []byte(content), 0o600); err != nil { + t.Fatalf("write temp fstab: %v", err) + } + + mps, err := fstabMountpointsSet(tmp) + if err != nil { + t.Fatalf("fstabMountpointsSet error: %v", err) + } + + for _, mp := range []string{"/", "/mnt/data", "/mnt/data2"} { + if _, ok := mps[mp]; !ok { + t.Fatalf("expected mountpoint %s to be present", mp) + } + } + if _, ok := mps["."]; ok { + t.Fatalf("expected dot mountpoint to be skipped") + } +} + +func TestFstabMountpointsSet_Error(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + restoreFS = NewFakeFS() + + if _, err := fstabMountpointsSet("/does-not-exist"); err == nil { + t.Fatalf("expected error") + } +} + +func TestSplitPathAndMountRootWithPrefix(t *testing.T) { + t.Parallel() + + if got := splitPath("a//b/ /c/"); strings.Join(got, ",") != "a,b,c" { + t.Fatalf("splitPath unexpected: %#v", got) + } + if got := mountRootWithPrefix("/mnt/datastore/Data1", "/mnt/"); got != "/mnt/datastore" { + t.Fatalf("mountRootWithPrefix got %q want %q", got, "/mnt/datastore") + } + if got := mountRootWithPrefix("/mnt/", "/mnt/"); got != "" { + t.Fatalf("mountRootWithPrefix(/mnt/)=%q want empty", got) + } +} + +func TestSortByLengthDesc(t *testing.T) { + t.Parallel() + + items := []string{"a", "abc", "ab"} + sortByLengthDesc(items) + if len(items) != 3 { + t.Fatalf("unexpected len: %d", len(items)) + } + if !(len(items[0]) >= len(items[1]) && len(items[1]) >= len(items[2])) { + t.Fatalf("expected non-increasing lengths, got %#v", items) + } +} + +func TestFirstFstabMountpointMatch(t *testing.T) { + t.Parallel() + + mountpoints := []string{"/mnt/storage/pbs", "/mnt/storage", "/"} + if got := firstFstabMountpointMatch("/mnt/storage/pbs/ds1/data", mountpoints); got != "/mnt/storage/pbs" { + t.Fatalf("firstFstabMountpointMatch got %q want %q", got, "/mnt/storage/pbs") + } + if got := firstFstabMountpointMatch(" ", mountpoints); got != "" { + t.Fatalf("firstFstabMountpointMatch empty got %q want empty", got) + } +} + +func TestIsMounted_Variants(t *testing.T) { + origReadFile := mountGuardReadFile + t.Cleanup(func() { mountGuardReadFile = origReadFile }) + + t.Run("prefers mountinfo", func(t *testing.T) { + mountGuardReadFile = func(path string) ([]byte, error) { + if path != "/proc/self/mountinfo" { + t.Fatalf("unexpected read path: %s", path) + } + return []byte("1 2 3:4 / /mnt/target rw - ext4 /dev/sda1 rw\n"), nil + } + mounted, err := isMounted("/mnt/target") + if err != nil { + t.Fatalf("isMounted error: %v", err) + } + if !mounted { + t.Fatalf("expected mounted") + } + }) + + t.Run("falls back to proc mounts", func(t *testing.T) { + mountGuardReadFile = func(path string) ([]byte, error) { + switch path { + case "/proc/self/mountinfo": + return nil, os.ErrNotExist + case "/proc/mounts": + return []byte("/dev/sda1 /mnt/target ext4 rw 0 0\n"), nil + default: + t.Fatalf("unexpected read path: %s", path) + return nil, nil + } + } + mounted, err := isMounted("/mnt/target") + if err != nil { + t.Fatalf("isMounted error: %v", err) + } + if !mounted { + t.Fatalf("expected mounted") + } + }) + + t.Run("reports mounts error when mountinfo missing", func(t *testing.T) { + wantErr := errors.New("mounts read failed") + mountGuardReadFile = func(path string) ([]byte, error) { + switch path { + case "/proc/self/mountinfo": + return nil, os.ErrNotExist + case "/proc/mounts": + return nil, wantErr + default: + t.Fatalf("unexpected read path: %s", path) + return nil, nil + } + } + _, err := isMounted("/mnt/target") + if !errors.Is(err, wantErr) { + t.Fatalf("expected mounts error, got %v", err) + } + }) + + t.Run("includes both errors when mountinfo read fails", func(t *testing.T) { + mountErr := errors.New("mountinfo boom") + mountsErr := errors.New("mounts boom") + mountGuardReadFile = func(path string) ([]byte, error) { + switch path { + case "/proc/self/mountinfo": + return nil, mountErr + case "/proc/mounts": + return nil, mountsErr + default: + t.Fatalf("unexpected read path: %s", path) + return nil, nil + } + } + _, err := isMounted("/mnt/target") + if err == nil || !strings.Contains(err.Error(), "mountinfo boom") || !strings.Contains(err.Error(), "mounts boom") { + t.Fatalf("expected combined error, got %v", err) + } + }) +} + +func TestIsMountedFromProcMounts_Parsing(t *testing.T) { + origReadFile := mountGuardReadFile + t.Cleanup(func() { mountGuardReadFile = origReadFile }) + + mountGuardReadFile = func(path string) ([]byte, error) { + if path != "/proc/mounts" { + t.Fatalf("unexpected read path: %s", path) + } + return []byte(strings.Join([]string{ + "", + "invalid", + "/dev/sda1 /mnt/other ext4 rw 0 0", + "", + }, "\n")), nil + } + + mounted, err := isMountedFromProcMounts("/mnt/target") + if err != nil { + t.Fatalf("isMountedFromProcMounts error: %v", err) + } + if mounted { + t.Fatalf("expected unmounted") + } + + mounted, err = isMountedFromProcMounts(" ") + if err != nil { + t.Fatalf("isMountedFromProcMounts empty target error: %v", err) + } + if mounted { + t.Fatalf("expected empty target to be unmounted") + } +} + +func TestGuardMountPoint(t *testing.T) { + origReadFile := mountGuardReadFile + origMkdirAll := mountGuardMkdirAll + origMount := mountGuardSysMount + origUnmount := mountGuardSysUnmount + t.Cleanup(func() { + mountGuardReadFile = origReadFile + mountGuardMkdirAll = origMkdirAll + mountGuardSysMount = origMount + mountGuardSysUnmount = origUnmount + }) + + t.Run("rejects invalid target", func(t *testing.T) { + if err := guardMountPoint(context.Background(), "/"); err == nil { + t.Fatalf("expected error") + } + }) + + t.Run("nil context uses background", func(t *testing.T) { + mountGuardReadFile = func(string) ([]byte, error) { + return []byte("1 2 3:4 / / rw - ext4 /dev/sda1 rw\n"), nil + } + mountGuardMkdirAll = func(string, os.FileMode) error { return nil } + mountGuardSysMount = func(string, string, string, uintptr, string) error { return nil } + mountGuardSysUnmount = func(string, int) error { + t.Fatalf("unexpected unmount call") + return nil + } + + if err := guardMountPoint(nil, "/mnt/nilctx"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("mount status check error", func(t *testing.T) { + mountGuardReadFile = func(string) ([]byte, error) { return nil, errors.New("read failed") } + if err := guardMountPoint(context.Background(), "/mnt/statuserr"); err == nil || !strings.Contains(err.Error(), "check mount status") { + t.Fatalf("expected status check error, got %v", err) + } + }) + + t.Run("mkdir guard dir failure", func(t *testing.T) { + target := "/mnt/mkdir-guard-dir-fail" + guardDir := guardDirForTarget(target) + + mountGuardReadFile = func(string) ([]byte, error) { + return []byte("1 2 3:4 / / rw - ext4 /dev/sda1 rw\n"), nil + } + mountGuardMkdirAll = func(path string, _ os.FileMode) error { + if filepath.Clean(path) == filepath.Clean(guardDir) { + return errors.New("mkdir guard dir failed") + } + return nil + } + mountGuardSysMount = func(string, string, string, uintptr, string) error { + t.Fatalf("unexpected mount call") + return nil + } + + if err := guardMountPoint(context.Background(), target); err == nil || !strings.Contains(err.Error(), "mkdir guard dir") { + t.Fatalf("expected mkdir guard dir error, got %v", err) + } + }) + + t.Run("mkdir target failure", func(t *testing.T) { + target := "/mnt/mkdir-target-fail" + guardDir := guardDirForTarget(target) + + mountGuardReadFile = func(string) ([]byte, error) { + return []byte("1 2 3:4 / / rw - ext4 /dev/sda1 rw\n"), nil + } + mountGuardMkdirAll = func(path string, _ os.FileMode) error { + switch filepath.Clean(path) { + case filepath.Clean(guardDir): + return nil + case filepath.Clean(target): + return errors.New("mkdir target failed") + default: + return nil + } + } + mountGuardSysMount = func(string, string, string, uintptr, string) error { + t.Fatalf("unexpected mount call") + return nil + } + + if err := guardMountPoint(context.Background(), target); err == nil || !strings.Contains(err.Error(), "mkdir target") { + t.Fatalf("expected mkdir target error, got %v", err) + } + }) + + t.Run("returns nil when already mounted", func(t *testing.T) { + mountGuardReadFile = func(path string) ([]byte, error) { + if path != "/proc/self/mountinfo" { + return nil, os.ErrNotExist + } + return []byte("1 2 3:4 / /mnt/already rw - ext4 /dev/sda1 rw\n"), nil + } + mountGuardMkdirAll = func(string, os.FileMode) error { + t.Fatalf("unexpected mkdir call") + return nil + } + mountGuardSysMount = func(string, string, string, uintptr, string) error { + t.Fatalf("unexpected mount call") + return nil + } + + if err := guardMountPoint(context.Background(), "/mnt/already"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("bind mount failure", func(t *testing.T) { + mountGuardReadFile = func(path string) ([]byte, error) { + return []byte("1 2 3:4 / / rw - ext4 /dev/sda1 rw\n"), nil + } + mountGuardMkdirAll = func(string, os.FileMode) error { return nil } + mountGuardSysMount = func(_, _, _ string, flags uintptr, _ string) error { + if flags == syscall.MS_BIND { + return syscall.EPERM + } + return nil + } + mountGuardSysUnmount = func(string, int) error { + t.Fatalf("unexpected unmount call") + return nil + } + + if err := guardMountPoint(context.Background(), "/mnt/failbind"); err == nil || !strings.Contains(err.Error(), "bind mount guard") { + t.Fatalf("expected bind mount error, got %v", err) + } + }) + + t.Run("remount failure unmounts", func(t *testing.T) { + mountGuardReadFile = func(path string) ([]byte, error) { + return []byte("1 2 3:4 / / rw - ext4 /dev/sda1 rw\n"), nil + } + mountGuardMkdirAll = func(string, os.FileMode) error { return nil } + + mountCalls := 0 + mountGuardSysMount = func(_, _, _ string, _ uintptr, _ string) error { + mountCalls++ + if mountCalls == 2 { + return syscall.EPERM + } + return nil + } + + unmountCalls := 0 + mountGuardSysUnmount = func(target string, flags int) error { + unmountCalls++ + if target != "/mnt/failremount" || flags != 0 { + t.Fatalf("unexpected unmount args: target=%s flags=%d", target, flags) + } + return nil + } + + if err := guardMountPoint(context.Background(), "/mnt/failremount"); err == nil || !strings.Contains(err.Error(), "remount guard read-only") { + t.Fatalf("expected remount error, got %v", err) + } + if unmountCalls != 1 { + t.Fatalf("expected 1 unmount call, got %d", unmountCalls) + } + }) + + t.Run("context canceled", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := guardMountPoint(ctx, "/mnt/ctx"); !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled, got %v", err) + } + }) +} + +type mountGuardCommandCall struct { + name string + args []string +} + +type mountGuardCommandRunner struct { + run func(ctx context.Context, name string, args ...string) ([]byte, error) + calls []mountGuardCommandCall +} + +func (f *mountGuardCommandRunner) Run(ctx context.Context, name string, args ...string) ([]byte, error) { + f.calls = append(f.calls, mountGuardCommandCall{name: name, args: append([]string{}, args...)}) + if f.run != nil { + return f.run(ctx, name, args...) + } + return nil, nil +} + +func TestMaybeApplyPBSDatastoreMountGuards_EarlyReturns(t *testing.T) { + logger := newTestLogger() + ctx := context.Background() + plan := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "datastore_pbs"}}} + + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, nil, "/stage", "/", false); err != nil { + t.Fatalf("nil plan error: %v", err) + } + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, &RestorePlan{SystemType: SystemTypePVE}, "/stage", "/", false); err != nil { + t.Fatalf("wrong system type error: %v", err) + } + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, &RestorePlan{SystemType: SystemTypePBS}, "/stage", "/", false); err != nil { + t.Fatalf("missing category error: %v", err) + } + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, "", "/", false); err != nil { + t.Fatalf("empty stageRoot error: %v", err) + } + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, "/stage", "/not-root", false); err != nil { + t.Fatalf("destRoot not root error: %v", err) + } + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, "/stage", "/", true); err != nil { + t.Fatalf("dryRun error: %v", err) + } + + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + restoreFS = NewFakeFS() + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, "/stage", "/", false); err != nil { + t.Fatalf("non-real restoreFS error: %v", err) + } + + origGeteuid := mountGuardGeteuid + t.Cleanup(func() { mountGuardGeteuid = origGeteuid }) + mountGuardGeteuid = func() int { return 1 } + restoreFS = origFS + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, "/stage", "/", false); err != nil { + t.Fatalf("non-root user error: %v", err) + } +} + +func TestMaybeApplyPBSDatastoreMountGuards_StagedDatastoreCfgHandling(t *testing.T) { + logger := newTestLogger() + ctx := context.Background() + plan := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "datastore_pbs"}}} + + origGeteuid := mountGuardGeteuid + t.Cleanup(func() { mountGuardGeteuid = origGeteuid }) + mountGuardGeteuid = func() int { return 0 } + + stageRoot := t.TempDir() + + // Missing file => no-op. + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, stageRoot, "/", false); err != nil { + t.Fatalf("missing staged file error: %v", err) + } + + // Non-file error should propagate. + stagePath := filepath.Join(stageRoot, "etc/proxmox-backup/datastore.cfg") + if err := os.MkdirAll(stagePath, 0o755); err != nil { + t.Fatalf("mkdir staged path: %v", err) + } + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, stageRoot, "/", false); err == nil || !strings.Contains(err.Error(), "read staged datastore.cfg") { + t.Fatalf("expected read staged error, got %v", err) + } + + // Empty file => no-op. + if err := os.RemoveAll(filepath.Dir(stagePath)); err != nil { + t.Fatalf("cleanup staged dir: %v", err) + } + if err := os.MkdirAll(filepath.Dir(stagePath), 0o755); err != nil { + t.Fatalf("mkdir staged dir: %v", err) + } + if err := os.WriteFile(stagePath, []byte(" \n\t"), 0o600); err != nil { + t.Fatalf("write staged file: %v", err) + } + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, stageRoot, "/", false); err != nil { + t.Fatalf("empty staged content error: %v", err) + } +} + +func TestMaybeApplyPBSDatastoreMountGuards_NoBlocks(t *testing.T) { + logger := newTestLogger() + ctx := context.Background() + plan := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "datastore_pbs"}}} + + origGeteuid := mountGuardGeteuid + t.Cleanup(func() { mountGuardGeteuid = origGeteuid }) + mountGuardGeteuid = func() int { return 0 } + + stageRoot := t.TempDir() + stagePath := filepath.Join(stageRoot, "etc/proxmox-backup/datastore.cfg") + if err := os.MkdirAll(filepath.Dir(stagePath), 0o755); err != nil { + t.Fatalf("mkdir staged dir: %v", err) + } + if err := os.WriteFile(stagePath, []byte("# comment only\n"), 0o600); err != nil { + t.Fatalf("write datastore.cfg: %v", err) + } + + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, stageRoot, "/", false); err != nil { + t.Fatalf("maybeApplyPBSDatastoreMountGuards error: %v", err) + } +} + +func TestMaybeApplyPBSDatastoreMountGuards_FstabParseErrorContinues(t *testing.T) { + logger := newTestLogger() + ctx := context.Background() + plan := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "datastore_pbs"}}} + + origGeteuid := mountGuardGeteuid + origFstab := mountGuardFstabMountpointsSet + origMkdirAll := mountGuardMkdirAll + origRootFS := mountGuardIsPathOnRootFilesystem + t.Cleanup(func() { + mountGuardGeteuid = origGeteuid + mountGuardFstabMountpointsSet = origFstab + mountGuardMkdirAll = origMkdirAll + mountGuardIsPathOnRootFilesystem = origRootFS + }) + + mountGuardGeteuid = func() int { return 0 } + mountGuardFstabMountpointsSet = func(string) (map[string]struct{}, error) { + return nil, errors.New("fstab parse failed") + } + mountGuardMkdirAll = func(string, os.FileMode) error { return nil } + mountGuardIsPathOnRootFilesystem = func(path string) (bool, string, error) { + return false, filepath.Clean(path), nil + } + + stageRoot := t.TempDir() + stagePath := filepath.Join(stageRoot, "etc/proxmox-backup/datastore.cfg") + if err := os.MkdirAll(filepath.Dir(stagePath), 0o755); err != nil { + t.Fatalf("mkdir staged dir: %v", err) + } + if err := os.WriteFile(stagePath, []byte("datastore: ds\n path /mnt/test/store\n"), 0o600); err != nil { + t.Fatalf("write datastore.cfg: %v", err) + } + + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, stageRoot, "/", false); err != nil { + t.Fatalf("maybeApplyPBSDatastoreMountGuards error: %v", err) + } +} + +func TestMaybeApplyPBSDatastoreMountGuards_ParseBlocksError(t *testing.T) { + logger := newTestLogger() + ctx := context.Background() + plan := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "datastore_pbs"}}} + + origGeteuid := mountGuardGeteuid + origParse := mountGuardParsePBSDatastoreCfg + t.Cleanup(func() { + mountGuardGeteuid = origGeteuid + mountGuardParsePBSDatastoreCfg = origParse + }) + + mountGuardGeteuid = func() int { return 0 } + wantErr := errors.New("parse blocks failed") + mountGuardParsePBSDatastoreCfg = func(string) ([]pbsDatastoreBlock, error) { + return nil, wantErr + } + + stageRoot := t.TempDir() + stagePath := filepath.Join(stageRoot, "etc/proxmox-backup/datastore.cfg") + if err := os.MkdirAll(filepath.Dir(stagePath), 0o755); err != nil { + t.Fatalf("mkdir staged dir: %v", err) + } + if err := os.WriteFile(stagePath, []byte("datastore: ds\n path /mnt/test/store\n"), 0o600); err != nil { + t.Fatalf("write datastore.cfg: %v", err) + } + + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, stageRoot, "/", false); !errors.Is(err, wantErr) { + t.Fatalf("expected parse error %v, got %v", wantErr, err) + } +} + +func TestMaybeApplyPBSDatastoreMountGuards_FullFlow(t *testing.T) { + logger := newTestLogger() + ctx := context.Background() + plan := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "datastore_pbs"}}} + + origGeteuid := mountGuardGeteuid + origReadFile := mountGuardReadFile + origMkdirAll := mountGuardMkdirAll + origReadDir := mountGuardReadDir + origMount := mountGuardSysMount + origUnmount := mountGuardSysUnmount + origFstab := mountGuardFstabMountpointsSet + origRootFS := mountGuardIsPathOnRootFilesystem + origCmd := restoreCmd + t.Cleanup(func() { + mountGuardGeteuid = origGeteuid + mountGuardReadFile = origReadFile + mountGuardMkdirAll = origMkdirAll + mountGuardReadDir = origReadDir + mountGuardSysMount = origMount + mountGuardSysUnmount = origUnmount + mountGuardFstabMountpointsSet = origFstab + mountGuardIsPathOnRootFilesystem = origRootFS + restoreCmd = origCmd + }) + + mountGuardGeteuid = func() int { return 0 } + + stageRoot := t.TempDir() + stagePath := filepath.Join(stageRoot, "etc/proxmox-backup/datastore.cfg") + if err := os.MkdirAll(filepath.Dir(stagePath), 0o755); err != nil { + t.Fatalf("mkdir staged dir: %v", err) + } + cfg := strings.Join([]string{ + "datastore: ds-chattrsuccess", + " path /mnt/chattrsuccess/store", + "datastore: ds-invalid", + " path /", + "datastore: ds-nomountstyle", + " path /srv/pbs", + "datastore: ds-storage", + " path /mnt/storage/pbs/ds1/data", + "datastore: ds-media-skip-fstab", + " path /media/USB/PBS", + "datastore: ds-mkdirerr", + " path /mnt/mkdirerr/store", + "datastore: ds-deverr", + " path /mnt/deverr/store", + "datastore: ds-notroot", + " path /mnt/notroot/store", + "datastore: ds-mounted", + " path /mnt/mounted/store", + "datastore: ds-mountok", + " path /mnt/mountok/store", + "datastore: ds-mountok2", + " path /mnt/mountok2/store", + "datastore: ds-chattrfail", + " path /mnt/chattrfail/store", + "datastore: ds-guardok", + " path /mnt/guardok/store", + "datastore: ds-guarddup", + " path /mnt/guardok/other", + "", + }, "\n") + if err := os.WriteFile(stagePath, []byte(cfg), 0o600); err != nil { + t.Fatalf("write datastore.cfg: %v", err) + } + + mountGuardFstabMountpointsSet = func(string) (map[string]struct{}, error) { + return map[string]struct{}{ + "/": {}, + "/srv": {}, + "/mnt/storage": {}, + "/mnt/storage/pbs": {}, + "/mnt/mkdirerr": {}, + "/mnt/deverr": {}, + "/mnt/notroot": {}, + "/mnt/mounted": {}, + "/mnt/mountok": {}, + "/mnt/mountok2": {}, + "/mnt/chattrfail": {}, + "/mnt/chattrsuccess": {}, + "/mnt/guardok": {}, + }, nil + } + + mountGuardMkdirAll = func(path string, _ os.FileMode) error { + if filepath.Clean(path) == "/mnt/mkdirerr" { + return errors.New("mkdir denied") + } + return nil + } + + rootCalls := make(map[string]int) + mountGuardIsPathOnRootFilesystem = func(path string) (bool, string, error) { + path = filepath.Clean(path) + rootCalls[path]++ + switch path { + case "/mnt/deverr": + return false, path, errors.New("stat failed") + case "/mnt/notroot": + return false, path, nil + case "/mnt/mountok": + if rootCalls[path] == 1 { + return true, path, nil + } + return false, path, nil + default: + return true, path, nil + } + } + + mountedTargets := map[string]struct{}{ + "/mnt/mounted": {}, + } + mountinfoReads := 0 + mountsReads := 0 + buildMountinfo := func() string { + var b strings.Builder + for mp := range mountedTargets { + b.WriteString(fmt.Sprintf("1 2 3:4 / %s rw - ext4 /dev/sda1 rw\n", mp)) + } + return b.String() + } + buildProcMounts := func() string { + var b strings.Builder + for mp := range mountedTargets { + b.WriteString(fmt.Sprintf("/dev/sda1 %s ext4 rw 0 0\n", mp)) + } + return b.String() + } + mountGuardReadFile = func(path string) ([]byte, error) { + switch path { + case "/proc/self/mountinfo": + mountinfoReads++ + if mountinfoReads == 1 { + return nil, errors.New("mountinfo read failed") + } + return []byte(buildMountinfo()), nil + case "/proc/mounts": + mountsReads++ + if mountsReads == 1 { + return nil, errors.New("mounts read failed") + } + return []byte(buildProcMounts()), nil + default: + return nil, fmt.Errorf("unexpected read: %s", path) + } + } + + mountGuardReadDir = func(path string) ([]os.DirEntry, error) { + if filepath.Clean(path) == "/mnt/guardok" { + return []os.DirEntry{&fakeDirEntry{name: "nonempty"}}, nil + } + return nil, os.ErrNotExist + } + + mountGuardSysMount = func(_, target, _ string, _ uintptr, _ string) error { + switch filepath.Clean(target) { + case "/mnt/chattrfail", "/mnt/chattrsuccess": + return syscall.EPERM + default: + return nil + } + } + mountGuardSysUnmount = func(string, int) error { return nil } + + cmd := &mountGuardCommandRunner{} + cmd.run = func(_ context.Context, name string, args ...string) ([]byte, error) { + switch name { + case "mount": + if len(args) != 1 { + return nil, fmt.Errorf("unexpected mount args: %v", args) + } + target := filepath.Clean(args[0]) + switch target { + case "/mnt/mountok", "/mnt/mountok2": + if target == "/mnt/mountok2" { + mountedTargets[target] = struct{}{} + } + return nil, nil + case "/mnt/chattrfail": + return []byte(" \n\t"), errors.New("mount failed") + default: + return []byte("mount: failed"), errors.New("mount failed") + } + case "chattr": + if len(args) != 2 || args[0] != "+i" { + return nil, fmt.Errorf("unexpected chattr args: %v", args) + } + target := filepath.Clean(args[1]) + if target == "/mnt/chattrfail" { + return nil, errors.New("chattr failed") + } + return nil, nil + default: + return nil, fmt.Errorf("unexpected command: %s", name) + } + } + restoreCmd = cmd + + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, stageRoot, "/", false); err != nil { + t.Fatalf("maybeApplyPBSDatastoreMountGuards error: %v", err) + } + + // Ensure the longest fstab mountpoint match wins (/mnt/storage/pbs instead of /mnt/storage). + foundStoragePBS := false + for _, c := range cmd.calls { + if c.name == "mount" && len(c.args) == 1 && filepath.Clean(c.args[0]) == "/mnt/storage/pbs" { + foundStoragePBS = true + break + } + } + if !foundStoragePBS { + t.Fatalf("expected mount attempt for /mnt/storage/pbs, calls=%#v", cmd.calls) + } +} + +func TestMaybeApplyPBSDatastoreMountGuards_MountAttemptTimeout(t *testing.T) { + logger := newTestLogger() + baseCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second)) + t.Cleanup(cancel) + plan := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "datastore_pbs"}}} + + origGeteuid := mountGuardGeteuid + origReadFile := mountGuardReadFile + origMkdirAll := mountGuardMkdirAll + origMount := mountGuardSysMount + origUnmount := mountGuardSysUnmount + origFstab := mountGuardFstabMountpointsSet + origRootFS := mountGuardIsPathOnRootFilesystem + origCmd := restoreCmd + t.Cleanup(func() { + mountGuardGeteuid = origGeteuid + mountGuardReadFile = origReadFile + mountGuardMkdirAll = origMkdirAll + mountGuardSysMount = origMount + mountGuardSysUnmount = origUnmount + mountGuardFstabMountpointsSet = origFstab + mountGuardIsPathOnRootFilesystem = origRootFS + restoreCmd = origCmd + }) + + mountGuardGeteuid = func() int { return 0 } + mountGuardReadFile = func(string) ([]byte, error) { return []byte(""), nil } + mountGuardMkdirAll = func(string, os.FileMode) error { return nil } + mountGuardSysMount = func(string, string, string, uintptr, string) error { return nil } + mountGuardSysUnmount = func(string, int) error { return nil } + mountGuardIsPathOnRootFilesystem = func(path string) (bool, string, error) { return true, filepath.Clean(path), nil } + mountGuardFstabMountpointsSet = func(string) (map[string]struct{}, error) { + return map[string]struct{}{"/mnt/timeout": {}}, nil + } + + stageRoot := t.TempDir() + stagePath := filepath.Join(stageRoot, "etc/proxmox-backup/datastore.cfg") + if err := os.MkdirAll(filepath.Dir(stagePath), 0o755); err != nil { + t.Fatalf("mkdir staged dir: %v", err) + } + if err := os.WriteFile(stagePath, []byte("datastore: ds\n path /mnt/timeout/store\n"), 0o600); err != nil { + t.Fatalf("write datastore.cfg: %v", err) + } + + restoreCmd = &mountGuardCommandRunner{ + run: func(ctx context.Context, name string, args ...string) ([]byte, error) { + if name == "mount" { + return nil, ctx.Err() + } + if name == "chattr" { + return nil, ctx.Err() + } + return nil, fmt.Errorf("unexpected command: %s", name) + }, + } + + if err := maybeApplyPBSDatastoreMountGuards(baseCtx, logger, plan, stageRoot, "/", false); err != nil { + t.Fatalf("maybeApplyPBSDatastoreMountGuards error: %v", err) + } +} diff --git a/internal/orchestrator/nic_mapping.go b/internal/orchestrator/nic_mapping.go index f4fa1d1..f2a0372c 100644 --- a/internal/orchestrator/nic_mapping.go +++ b/internal/orchestrator/nic_mapping.go @@ -20,6 +20,7 @@ import ( const maxArchiveInventoryBytes = 10 << 20 // 10 MiB var nicRepairSequence uint64 +var sysClassNetPath = "/sys/class/net" type archivedNetworkInventory struct { GeneratedAt string `json:"generated_at,omitempty"` @@ -371,7 +372,7 @@ func readArchiveEntry(ctx context.Context, archivePath string, candidates []stri } func collectCurrentNetworkInventory(ctx context.Context) (*archivedNetworkInventory, error) { - sysNet := "/sys/class/net" + sysNet := sysClassNetPath entries, err := os.ReadDir(sysNet) if err != nil { return nil, err diff --git a/internal/orchestrator/nic_mapping_additional_test.go b/internal/orchestrator/nic_mapping_additional_test.go new file mode 100644 index 0000000..c4cb681 --- /dev/null +++ b/internal/orchestrator/nic_mapping_additional_test.go @@ -0,0 +1,1262 @@ +package orchestrator + +import ( + "archive/tar" + "bytes" + "context" + "encoding/json" + "errors" + "os" + "path/filepath" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/types" +) + +type tarEntry struct { + Name string + Typeflag byte + Mode int64 + Data []byte +} + +type mkdirAllFailFS struct { + FS + failPath string + err error +} + +func (f mkdirAllFailFS) MkdirAll(path string, perm os.FileMode) error { + if filepath.Clean(path) == filepath.Clean(f.failPath) { + return f.err + } + return f.FS.MkdirAll(path, perm) +} + +func writeTarToFakeFS(t *testing.T, fs *FakeFS, archivePath string, entries []tarEntry) { + t.Helper() + + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + for _, entry := range entries { + hdr := &tar.Header{ + Name: entry.Name, + Typeflag: entry.Typeflag, + Mode: entry.Mode, + Size: int64(len(entry.Data)), + } + if entry.Typeflag == tar.TypeDir { + hdr.Size = 0 + } + if err := tw.WriteHeader(hdr); err != nil { + t.Fatalf("WriteHeader %s: %v", entry.Name, err) + } + if hdr.Size > 0 { + if _, err := tw.Write(entry.Data); err != nil { + t.Fatalf("Write %s: %v", entry.Name, err) + } + } + } + if err := tw.Close(); err != nil { + t.Fatalf("tar close: %v", err) + } + if err := fs.WriteFile(archivePath, buf.Bytes(), 0o644); err != nil { + t.Fatalf("write archive: %v", err) + } +} + +func TestNICMappingResult_RenameMapAndDetails(t *testing.T) { + r := nicMappingResult{ + Entries: []nicMappingEntry{ + {OldName: "eno1", NewName: "enp3s0"}, + {OldName: "", NewName: "enp4s0"}, + {OldName: "ens2", NewName: ""}, + {OldName: "eno1", NewName: "enp3s1"}, + }, + } + m := r.RenameMap() + if len(m) != 1 || m["eno1"] != "enp3s1" { + t.Fatalf("RenameMap=%v; want {eno1:enp3s1}", m) + } + + if got := (nicMappingResult{}).Details(); got != "NIC mapping: none" { + t.Fatalf("empty Details=%q", got) + } + + details := nicMappingResult{ + Entries: []nicMappingEntry{ + {OldName: "b", NewName: "B", Method: nicMatchMAC, Identifier: "m2"}, + {OldName: "a", NewName: "A", Method: nicMatchPermanentMAC, Identifier: "m1"}, + }, + }.Details() + want := strings.Join([]string{ + "NIC mapping (backup -> current):", + "- a -> A (permanent_mac=m1)", + "- b -> B (mac=m2)", + }, "\n") + if details != want { + t.Fatalf("Details=%q; want %q", details, want) + } +} + +func TestNICNameConflict_Details(t *testing.T) { + c := nicNameConflict{ + Mapping: nicMappingEntry{ + OldName: "eno1", + NewName: "eth0", + Method: nicMatchMAC, + Identifier: "aa:bb", + }, + Existing: archivedNetworkInterface{ + Name: "eno1", + PermanentMAC: "AA:BB:CC:DD:EE:FF", + MAC: "mac:11:22:33:44:55:66 ", + PCIPath: "/pci/0000:00:1f.6", + }, + } + got := c.Details() + if !strings.Contains(got, "permMAC=aa:bb:cc:dd:ee:ff") || !strings.Contains(got, "mac=11:22:33:44:55:66") || !strings.Contains(got, "pci=/pci/0000:00:1f.6") { + t.Fatalf("Details=%q; want identifiers included", got) + } + if !strings.Contains(got, "but current eno1 exists") { + t.Fatalf("Details=%q; want conflict message", got) + } + + none := nicNameConflict{ + Mapping: nicMappingEntry{OldName: "eno1", NewName: "eth0", Method: nicMatchPCIPath, Identifier: "pci0"}, + Existing: archivedNetworkInterface{ + Name: "eno1", + }, + }.Details() + if !strings.Contains(none, "no identifiers") { + t.Fatalf("Details=%q; want no identifiers", none) + } +} + +func TestNICRepairPlan_HasWork(t *testing.T) { + if (nicRepairPlan{}).HasWork() { + t.Fatalf("expected HasWork=false") + } + if !(nicRepairPlan{SafeMappings: []nicMappingEntry{{OldName: "a", NewName: "b"}}}.HasWork()) { + t.Fatalf("expected HasWork=true with safe mappings") + } + if !(nicRepairPlan{Conflicts: []nicNameConflict{{Mapping: nicMappingEntry{OldName: "a", NewName: "b"}}}}.HasWork()) { + t.Fatalf("expected HasWork=true with conflicts") + } +} + +func TestNICRepairResult_SummaryAndDetails(t *testing.T) { + r := nicRepairResult{SkippedReason: "test"} + if got := r.Summary(); got != "NIC name repair skipped: test" { + t.Fatalf("Summary=%q", got) + } + if r.Applied() { + t.Fatalf("Applied=true; want false") + } + + r = nicRepairResult{} + if got := r.Summary(); got != "NIC name repair: no changes needed" { + t.Fatalf("Summary=%q", got) + } + + r = nicRepairResult{ + ChangedFiles: []string{"/etc/network/interfaces"}, + BackupDir: "/tmp/proxsave/nic_repair_test", + AppliedNICMap: []nicMappingEntry{ + {OldName: "eno1", NewName: "eth0", Method: nicMatchMAC, Identifier: "aa"}, + }, + } + if got := r.Summary(); got != "NIC name repair applied: 1 file(s) updated" { + t.Fatalf("Summary=%q", got) + } + if !r.Applied() { + t.Fatalf("Applied=false; want true") + } + details := r.Details() + if !strings.Contains(details, "Backup of pre-repair files: /tmp/proxsave/nic_repair_test") || !strings.Contains(details, "Updated files:\n- /etc/network/interfaces") { + t.Fatalf("Details=%q; want backup and updated files", details) + } + if !strings.Contains(details, "NIC mapping (backup -> current):") || !strings.Contains(details, "- eno1 -> eth0 (mac=aa)") { + t.Fatalf("Details=%q; want mapping details", details) + } +} + +func TestReadArchiveEntry_ErrorsAndNotFound(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + archivePath := "/backup.tar" + writeTarToFakeFS(t, fakeFS, archivePath, []tarEntry{ + {Name: "./other.txt", Typeflag: tar.TypeReg, Mode: 0o644, Data: []byte("ok")}, + }) + + _, _, err := readArchiveEntry(context.Background(), archivePath, []string{"./missing.txt"}, 16) + if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("err=%v; want os.ErrNotExist", err) + } + + _, _, err = readArchiveEntry(context.Background(), "/does-not-exist.tar", []string{"./x"}, 16) + if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("err=%v; want open os.ErrNotExist", err) + } + + cctx, cancel := context.WithCancel(context.Background()) + cancel() + _, _, err = readArchiveEntry(cctx, archivePath, []string{"./other.txt"}, 16) + if !errors.Is(err, context.Canceled) { + t.Fatalf("err=%v; want context.Canceled", err) + } + + if err := fakeFS.WriteFile("/backup.zip", []byte("not a tar"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + _, _, err = readArchiveEntry(context.Background(), "/backup.zip", []string{"./other.txt"}, 16) + if err == nil || !strings.Contains(err.Error(), "unsupported archive format") { + t.Fatalf("err=%v; want unsupported archive format", err) + } + + writeTarToFakeFS(t, fakeFS, "/nonregular.tar", []tarEntry{ + {Name: "./entry", Typeflag: tar.TypeDir, Mode: 0o755}, + }) + _, _, err = readArchiveEntry(context.Background(), "/nonregular.tar", []string{"./entry"}, 16) + if err == nil || !strings.Contains(err.Error(), "not a regular file") { + t.Fatalf("err=%v; want not a regular file", err) + } + + writeTarToFakeFS(t, fakeFS, "/toolarge.tar", []tarEntry{ + {Name: "./big", Typeflag: tar.TypeReg, Mode: 0o644, Data: []byte("0123456789")}, + }) + _, _, err = readArchiveEntry(context.Background(), "/toolarge.tar", []string{"./big"}, 4) + if err == nil || !strings.Contains(err.Error(), "too large") { + t.Fatalf("err=%v; want too large", err) + } +} + +func TestLoadBackupNetworkInventoryFromArchive_SuccessAndBadJSON(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + type invJSON struct { + Interfaces []archivedNetworkInterface `json:"interfaces"` + } + payload, err := json.Marshal(invJSON{ + Interfaces: []archivedNetworkInterface{{Name: "eth0", MAC: "aa:bb"}}, + }) + if err != nil { + t.Fatalf("marshal: %v", err) + } + writeTarToFakeFS(t, fakeFS, "/inv.tar", []tarEntry{ + {Name: "./commands/network_inventory.json", Typeflag: tar.TypeReg, Mode: 0o644, Data: payload}, + }) + + inv, used, err := loadBackupNetworkInventoryFromArchive(context.Background(), "/inv.tar") + if err != nil { + t.Fatalf("load: %v", err) + } + if used != "./commands/network_inventory.json" { + t.Fatalf("used=%q", used) + } + if inv == nil || len(inv.Interfaces) != 1 || inv.Interfaces[0].Name != "eth0" { + t.Fatalf("inv=%+v", inv) + } + + writeTarToFakeFS(t, fakeFS, "/bad.tar", []tarEntry{ + {Name: "./commands/network_inventory.json", Typeflag: tar.TypeReg, Mode: 0o644, Data: []byte("{")}, + }) + _, _, err = loadBackupNetworkInventoryFromArchive(context.Background(), "/bad.tar") + if err == nil || !strings.Contains(err.Error(), "parse network inventory json") { + t.Fatalf("err=%v; want parse network inventory json", err) + } +} + +func TestParseAndReadPermanentMAC(t *testing.T) { + output := "some header\nPermanent Address: AA:BB:CC:DD:EE:FF \n" + if got := parsePermanentMAC(output); got != "aa:bb:cc:dd:ee:ff" { + t.Fatalf("parsePermanentMAC=%q", got) + } + if got := parsePermanentMAC("nope"); got != "" { + t.Fatalf("parsePermanentMAC=%q; want empty", got) + } + + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + cmd := &FakeCommandRunner{ + Outputs: map[string][]byte{ + "ethtool -P eth0": []byte(output), + }, + } + restoreCmd = cmd + + got, err := readPermanentMAC(context.Background(), "eth0") + if err != nil { + t.Fatalf("readPermanentMAC: %v", err) + } + if got != "aa:bb:cc:dd:ee:ff" { + t.Fatalf("readPermanentMAC=%q", got) + } + + cmd.Errors = map[string]error{"ethtool -P eth0": errors.New("boom")} + _, err = readPermanentMAC(context.Background(), "eth0") + if err == nil { + t.Fatalf("expected error") + } +} + +func TestReadUdevProperties(t *testing.T) { + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + + cmd := &FakeCommandRunner{ + Outputs: map[string][]byte{ + "udevadm info -q property -p /sys/class/net/eth0": []byte(strings.Join([]string{ + "ID_SERIAL=abc", + "ID_PATH= pci-0000:00:1f.6 ", + "BADLINE", + "FOO=", + "=bar", + "", + }, "\n")), + }, + } + restoreCmd = cmd + + props, err := readUdevProperties(context.Background(), "/sys/class/net/eth0") + if err != nil { + t.Fatalf("readUdevProperties: %v", err) + } + if props["ID_SERIAL"] != "abc" || props["ID_PATH"] != "pci-0000:00:1f.6" { + t.Fatalf("props=%v", props) + } + if _, ok := props["FOO"]; ok { + t.Fatalf("expected empty-value key to be skipped: %v", props) + } + + cmd.Errors = map[string]error{"udevadm info -q property -p /sys/class/net/eth0": errors.New("boom")} + _, err = readUdevProperties(context.Background(), "/sys/class/net/eth0") + if err == nil { + t.Fatalf("expected error") + } +} + +func TestReadTrimmedLine(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "line") + if err := os.WriteFile(path, []byte(" HELLO WORLD \n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + if got := readTrimmedLine(path, 0); got != "HELLO WORLD" { + t.Fatalf("readTrimmedLine=%q", got) + } + if got := readTrimmedLine(path, 5); got != "HELLO" { + t.Fatalf("readTrimmedLine=%q", got) + } + if got := readTrimmedLine(filepath.Join(dir, "missing"), 5); got != "" { + t.Fatalf("readTrimmedLine=%q; want empty", got) + } +} + +func TestCollectCurrentNetworkInventory_WithFakeSysfs(t *testing.T) { + origSys := sysClassNetPath + t.Cleanup(func() { sysClassNetPath = origSys }) + t.Setenv("PATH", "") + + root := t.TempDir() + netDir := filepath.Join(root, "sys/class/net") + if err := os.MkdirAll(netDir, 0o755); err != nil { + t.Fatalf("mkdir netDir: %v", err) + } + sysClassNetPath = netDir + if err := os.MkdirAll(filepath.Join(netDir, " "), 0o755); err != nil { + t.Fatalf("mkdir blank-name entry: %v", err) + } + + writePhysical := func(name, mac, pciDevice, driver string) { + t.Helper() + ifaceDir := filepath.Join(netDir, name) + if err := os.MkdirAll(ifaceDir, 0o755); err != nil { + t.Fatalf("mkdir ifaceDir: %v", err) + } + if err := os.WriteFile(filepath.Join(ifaceDir, "address"), []byte(mac+"\n"), 0o644); err != nil { + t.Fatalf("write address: %v", err) + } + + devDir := filepath.Join(root, "devices", pciDevice) + driverDir := filepath.Join(root, "drivers", driver) + if err := os.MkdirAll(devDir, 0o755); err != nil { + t.Fatalf("mkdir devDir: %v", err) + } + if err := os.MkdirAll(driverDir, 0o755); err != nil { + t.Fatalf("mkdir driverDir: %v", err) + } + if err := os.Symlink(devDir, filepath.Join(ifaceDir, "device")); err != nil { + t.Fatalf("symlink device: %v", err) + } + if err := os.Symlink(driverDir, filepath.Join(devDir, "driver")); err != nil { + t.Fatalf("symlink driver: %v", err) + } + } + writePhysical("eth0", "MAC:AA:BB:CC:DD:EE:01", "pci0000:00/0000:00:1f.6", "e1000") + writePhysical("eno1", "aa:bb:cc:dd:ee:02", "pci0000:00/0000:00:1c.0", "igb") + + ifaceLo := filepath.Join(netDir, "lo") + if err := os.MkdirAll(ifaceLo, 0o755); err != nil { + t.Fatalf("mkdir lo: %v", err) + } + if err := os.WriteFile(filepath.Join(ifaceLo, "address"), []byte("00:00:00:00:00:00\n"), 0o644); err != nil { + t.Fatalf("write lo address: %v", err) + } + + virtualTarget := filepath.Join(root, "devices/virtual/net/vmbr0") + if err := os.MkdirAll(virtualTarget, 0o755); err != nil { + t.Fatalf("mkdir virtualTarget: %v", err) + } + if err := os.WriteFile(filepath.Join(virtualTarget, "address"), []byte("aa:aa:aa:aa:aa:aa\n"), 0o644); err != nil { + t.Fatalf("write vmbr0 address: %v", err) + } + if err := os.Symlink(virtualTarget, filepath.Join(netDir, "vmbr0")); err != nil { + t.Fatalf("symlink vmbr0: %v", err) + } + + inv, err := collectCurrentNetworkInventory(context.Background()) + if err != nil { + t.Fatalf("collect: %v", err) + } + if inv == nil || len(inv.Interfaces) != 4 { + t.Fatalf("inv=%+v", inv) + } + if inv.Interfaces[0].Name != "eno1" || inv.Interfaces[1].Name != "eth0" || inv.Interfaces[2].Name != "lo" || inv.Interfaces[3].Name != "vmbr0" { + t.Fatalf("sorted names=%v", []string{inv.Interfaces[0].Name, inv.Interfaces[1].Name, inv.Interfaces[2].Name, inv.Interfaces[3].Name}) + } + var gotEth0, gotVmbr0 archivedNetworkInterface + for _, iface := range inv.Interfaces { + switch iface.Name { + case "eth0": + gotEth0 = iface + case "vmbr0": + gotVmbr0 = iface + } + } + if gotEth0.MAC != "aa:bb:cc:dd:ee:01" { + t.Fatalf("eth0 MAC=%q", gotEth0.MAC) + } + if gotEth0.Driver != "e1000" { + t.Fatalf("eth0 Driver=%q", gotEth0.Driver) + } + if !strings.Contains(gotEth0.PCIPath, "devices/pci0000:00/0000:00:1f.6") { + t.Fatalf("eth0 PCIPath=%q", gotEth0.PCIPath) + } + if !gotVmbr0.IsVirtual { + t.Fatalf("vmbr0 IsVirtual=false") + } +} + +func TestPlanAndApplyNICNameRepair_WithFakeInventory(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSys := sysClassNetPath + origSeq := atomic.LoadUint64(&nicRepairSequence) + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSys + atomic.StoreUint64(&nicRepairSequence, origSeq) + }) + t.Setenv("PATH", "") + + // Fake current inventory via fake sysfs. + root := t.TempDir() + netDir := filepath.Join(root, "sys/class/net") + if err := os.MkdirAll(netDir, 0o755); err != nil { + t.Fatalf("mkdir netDir: %v", err) + } + sysClassNetPath = netDir + mustWriteFile := func(path, content string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + } + mustWriteFile(filepath.Join(netDir, "eth0/address"), "aa:bb:cc:dd:ee:01\n") + mustWriteFile(filepath.Join(netDir, "eno1/address"), "aa:bb:cc:dd:ee:02\n") + mustWriteFile(filepath.Join(netDir, "lo/address"), "00:00:00:00:00:00\n") + + // Fake backup archive. + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2025, 1, 1, 1, 2, 3, 0, time.UTC)} + + backupInv := archivedNetworkInventory{ + GeneratedAt: time.Now().Format(time.RFC3339), + Hostname: "backup-host", + Interfaces: []archivedNetworkInterface{ + {Name: "eno1", MAC: "aa:bb:cc:dd:ee:01"}, // maps to current eth0 -> conflict (current eno1 exists) + {Name: "ens20", MAC: "aa:bb:cc:dd:ee:02"}, // maps to current eno1 -> safe + }, + } + payload, err := json.Marshal(backupInv) + if err != nil { + t.Fatalf("marshal: %v", err) + } + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + {Name: "./commands/network_inventory.json", Typeflag: tar.TypeReg, Mode: 0o644, Data: payload}, + }) + + plan, err := planNICNameRepair(context.Background(), "/backup.tar") + if err != nil { + t.Fatalf("plan: %v", err) + } + if plan == nil { + t.Fatalf("plan=nil") + } + if plan.SkippedReason != "" { + t.Fatalf("SkippedReason=%q", plan.SkippedReason) + } + if plan.Mapping.BackupSourcePath != "./commands/network_inventory.json" { + t.Fatalf("BackupSourcePath=%q", plan.Mapping.BackupSourcePath) + } + if len(plan.SafeMappings) != 1 || plan.SafeMappings[0].OldName != "ens20" || plan.SafeMappings[0].NewName != "eno1" { + t.Fatalf("SafeMappings=%+v", plan.SafeMappings) + } + if len(plan.Conflicts) != 1 || plan.Conflicts[0].Mapping.OldName != "eno1" || plan.Conflicts[0].Mapping.NewName != "eth0" { + t.Fatalf("Conflicts=%+v", plan.Conflicts) + } + + logger := logging.New(types.LogLevelDebug, false) + logger.SetOutput(bytes.NewBuffer(nil)) + + // Prepare config to exercise both safe mapping and conflict mapping. + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte(strings.Join([]string{ + "auto ens20", + "iface ens20 inet manual", + "auto eno1", + "iface eno1 inet manual", + "", + }, "\n")), 0o644); err != nil { + t.Fatalf("write interfaces: %v", err) + } + + // includeConflicts=false: applies only safe mapping (ens2 -> eno1). + res, err := applyNICNameRepair(logger, plan, false) + if err != nil { + t.Fatalf("apply: %v", err) + } + if res == nil || res.SkippedReason != "" { + t.Fatalf("result=%+v", res) + } + if len(res.ChangedFiles) != 1 || res.ChangedFiles[0] != "/etc/network/interfaces" { + t.Fatalf("ChangedFiles=%v", res.ChangedFiles) + } + data, err := fakeFS.ReadFile("/etc/network/interfaces") + if err != nil { + t.Fatalf("read: %v", err) + } + if strings.Contains(string(data), "ens20") { + t.Fatalf("expected ens20 to be replaced:\n%s", string(data)) + } + if !strings.Contains(string(data), "auto eno1") { + t.Fatalf("expected eno1 to remain:\n%s", string(data)) + } + + // includeConflicts=true: also applies eno1 -> eth0. + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte(strings.Join([]string{ + "auto ens20", + "iface ens20 inet manual", + "auto eno1", + "iface eno1 inet manual", + "", + }, "\n")), 0o644); err != nil { + t.Fatalf("rewrite interfaces: %v", err) + } + res, err = applyNICNameRepair(logger, plan, true) + if err != nil { + t.Fatalf("apply conflicts: %v", err) + } + data, err = fakeFS.ReadFile("/etc/network/interfaces") + if err != nil { + t.Fatalf("read: %v", err) + } + if strings.Contains(string(data), "ens20") || strings.Contains(string(data), "auto eno1\n") { + t.Fatalf("expected ens20 and eno1 to be replaced:\n%s", string(data)) + } + if !strings.Contains(string(data), "auto eth0") { + t.Fatalf("expected eth0:\n%s", string(data)) + } +} + +func TestPlanNICNameRepair_SkipAndErrorBranches(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + plan, err := planNICNameRepair(context.Background(), " ") + if err != nil || plan == nil || plan.SkippedReason == "" { + t.Fatalf("plan=%+v err=%v", plan, err) + } + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + writeTarToFakeFS(t, fakeFS, "/missing_inv.tar", []tarEntry{ + {Name: "./unrelated", Typeflag: tar.TypeReg, Mode: 0o644, Data: []byte("x")}, + }) + plan, err = planNICNameRepair(context.Background(), "/missing_inv.tar") + if err != nil || plan == nil || !strings.Contains(plan.SkippedReason, "backup does not include network inventory") { + t.Fatalf("plan=%+v err=%v", plan, err) + } + + if err := fakeFS.WriteFile("/bad.zip", []byte("x"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + _, err = planNICNameRepair(context.Background(), "/bad.zip") + if err == nil || !strings.Contains(err.Error(), "unsupported archive format") { + t.Fatalf("err=%v; want unsupported archive format", err) + } +} + +func TestApplyNICNameRepair_SkipBranches(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + logger.SetOutput(bytes.NewBuffer(nil)) + + res, err := applyNICNameRepair(logger, nil, false) + if err != nil || res == nil || res.SkippedReason == "" { + t.Fatalf("res=%+v err=%v", res, err) + } + + res, err = applyNICNameRepair(logger, &nicRepairPlan{SkippedReason: "nope"}, false) + if err != nil || res == nil || !strings.Contains(res.SkippedReason, "nope") { + t.Fatalf("res=%+v err=%v", res, err) + } + + res, err = applyNICNameRepair(logger, &nicRepairPlan{ + Conflicts: []nicNameConflict{{Mapping: nicMappingEntry{OldName: "a", NewName: "b"}}}, + }, false) + if err != nil || res == nil || !strings.Contains(res.SkippedReason, "conflicting NIC mappings") { + t.Fatalf("res=%+v err=%v", res, err) + } + + res, err = applyNICNameRepair(logger, &nicRepairPlan{ + SafeMappings: []nicMappingEntry{ + {OldName: "eno1", NewName: "eno1"}, + }, + Conflicts: []nicNameConflict{ + {Mapping: nicMappingEntry{OldName: "eno2", NewName: "eth0"}}, + }, + }, false) + if err != nil || res == nil || res.SkippedReason != "conflicting NIC mappings detected; skipped by user" { + t.Fatalf("res=%+v err=%v", res, err) + } +} + +func TestMappingHelpersAndEdgeCases(t *testing.T) { + if !computeNICMapping(nil, nil).IsEmpty() { + t.Fatalf("expected empty mapping for nil inventories") + } + + if !hasStableUdevIdentifiers(map[string]string{"ID_SERIAL": " abc "}) { + t.Fatalf("expected stable udev identifiers") + } + if hasStableUdevIdentifiers(map[string]string{"ID_SERIAL": " "}) { + t.Fatalf("expected false for blank udev values") + } + + if shouldAddMapping("a", "a", map[string]struct{}{}) { + t.Fatalf("expected false for old==new") + } + if !shouldAddMapping("a", "b", nil) { + t.Fatalf("expected true when usedCurrent nil") + } + if shouldAddMapping("a", "b", map[string]struct{}{"b": {}}) { + t.Fatalf("expected false when usedCurrent already contains new") + } + + if isCandidatePhysicalNIC(archivedNetworkInterface{Name: "lo", MAC: "aa"}) { + t.Fatalf("expected lo to be non-candidate") + } + if isCandidatePhysicalNIC(archivedNetworkInterface{Name: "eth0", IsVirtual: true, MAC: "aa"}) { + t.Fatalf("expected virtual to be non-candidate") + } + if isCandidatePhysicalNIC(archivedNetworkInterface{Name: "eth0"}) { + t.Fatalf("expected no-identifiers to be non-candidate") + } + if !isCandidatePhysicalNIC(archivedNetworkInterface{Name: "eth0", UdevProps: map[string]string{"ID_PATH": "pci-1"}}) { + t.Fatalf("expected udev identifiers to make candidate") + } + + if out, changed := applyInterfaceRenameMap("", map[string]string{"a": "b"}); out != "" || changed { + t.Fatalf("applyInterfaceRenameMap unexpected result: out=%q changed=%v", out, changed) + } + if out, changed := applyInterfaceRenameMap("auto a\n", map[string]string{}); out != "auto a\n" || changed { + t.Fatalf("applyInterfaceRenameMap unexpected result: out=%q changed=%v", out, changed) + } + + if out, changed := replaceInterfaceToken("", "a", "b"); out != "" || changed { + t.Fatalf("replaceInterfaceToken unexpected: out=%q changed=%v", out, changed) + } + if _, changed := replaceInterfaceToken("auto a\n", "a", "a"); changed { + t.Fatalf("replaceInterfaceToken should not change when old==new") + } + + cases := map[byte]bool{ + 'a': true, + 'Z': true, + '0': true, + '_': true, + '-': true, + '.': false, + ' ': false, + } + for ch, want := range cases { + if got := isIfaceNameChar(ch); got != want { + t.Fatalf("isIfaceNameChar(%q)=%v want %v", string(ch), got, want) + } + } + + backup := &archivedNetworkInventory{ + Interfaces: []archivedNetworkInterface{ + {Name: "eno1", MAC: "aa:aa", PCIPath: "/pci/1"}, + }, + } + current := &archivedNetworkInventory{ + Interfaces: []archivedNetworkInterface{ + {Name: "eth0", MAC: "aa:aa", PCIPath: "/pci/1"}, + {Name: "eth1", MAC: "aa:aa", PCIPath: "/pci/2"}, + }, + } + got := computeNICMapping(backup, current) + if got.IsEmpty() || got.Entries[0].Method != nicMatchPCIPath { + t.Fatalf("got=%+v; want pci_path match due to MAC dupes", got) + } + + backup = &archivedNetworkInventory{ + Interfaces: []archivedNetworkInterface{ + {Name: "eno1", MAC: "bb:bb"}, + {Name: "ens2", MAC: "bb:bb"}, + }, + } + current = &archivedNetworkInventory{ + Interfaces: []archivedNetworkInterface{ + {Name: "eth0", MAC: "bb:bb"}, + }, + } + got = computeNICMapping(backup, current) + if len(got.Entries) != 1 { + t.Fatalf("Entries=%+v; want 1 due to usedCurrent", got.Entries) + } +} + +func TestPlanNICNameRepair_NoMappingAndCurrentInventoryError(t *testing.T) { + origFS := restoreFS + origSys := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + sysClassNetPath = origSys + }) + t.Setenv("PATH", "") + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + t.Run("no mapping", func(t *testing.T) { + root := t.TempDir() + netDir := filepath.Join(root, "sys/class/net") + if err := os.MkdirAll(netDir, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + sysClassNetPath = netDir + if err := os.MkdirAll(filepath.Join(netDir, "eth0"), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(netDir, "eth0/address"), []byte("aa:bb:cc:dd:ee:01\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + backupInv := archivedNetworkInventory{ + Interfaces: []archivedNetworkInterface{ + {Name: "eth0", MAC: "aa:bb:cc:dd:ee:01"}, + }, + } + payload, err := json.Marshal(backupInv) + if err != nil { + t.Fatalf("marshal: %v", err) + } + writeTarToFakeFS(t, fakeFS, "/nomap.tar", []tarEntry{ + {Name: "./commands/network_inventory.json", Typeflag: tar.TypeReg, Mode: 0o644, Data: payload}, + }) + + plan, err := planNICNameRepair(context.Background(), "/nomap.tar") + if err != nil { + t.Fatalf("plan: %v", err) + } + if plan == nil || !strings.Contains(plan.SkippedReason, "no NIC rename mapping found") { + t.Fatalf("plan=%+v", plan) + } + }) + + t.Run("current inventory error", func(t *testing.T) { + sysClassNetPath = filepath.Join(t.TempDir(), "does-not-exist") + + backupInv := archivedNetworkInventory{ + Interfaces: []archivedNetworkInterface{ + {Name: "eno1", MAC: "aa:bb"}, + }, + } + payload, err := json.Marshal(backupInv) + if err != nil { + t.Fatalf("marshal: %v", err) + } + writeTarToFakeFS(t, fakeFS, "/inv.tar", []tarEntry{ + {Name: "./commands/network_inventory.json", Typeflag: tar.TypeReg, Mode: 0o644, Data: payload}, + }) + + _, err = planNICNameRepair(context.Background(), "/inv.tar") + if err == nil || !strings.Contains(err.Error(), "collect current network inventory") { + t.Fatalf("err=%v; want collect current network inventory", err) + } + }) +} + +func TestApplyNICNameRepair_NoMatchesAndNoRenamesSelected(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)} + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto lo\niface lo inet loopback\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + res, err := applyNICNameRepair(nil, &nicRepairPlan{ + SafeMappings: []nicMappingEntry{ + {OldName: "eno1", NewName: "eth0", Method: nicMatchMAC, Identifier: "aa"}, + }, + }, false) + if err != nil { + t.Fatalf("apply: %v", err) + } + if res == nil || !strings.Contains(res.SkippedReason, "no matching interface names found") { + t.Fatalf("res=%+v", res) + } + + res, err = applyNICNameRepair(nil, &nicRepairPlan{ + SafeMappings: []nicMappingEntry{ + {OldName: "eno1", NewName: "eno1"}, + }, + }, false) + if err != nil { + t.Fatalf("apply: %v", err) + } + if res == nil || res.SkippedReason != "no NIC renames selected" { + t.Fatalf("res=%+v", res) + } +} + +func TestApplyNICNameRepair_PropagatesRewriteErrors(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)} + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + fakeFS.MkdirAllErr = errors.New("disk full") + + res, err := applyNICNameRepair(nil, &nicRepairPlan{ + SafeMappings: []nicMappingEntry{{OldName: "eno1", NewName: "eth0"}}, + }, false) + if err == nil || res != nil { + t.Fatalf("res=%+v err=%v; want nil result and error", res, err) + } +} + +func TestReadArchiveEntry_CorruptTar(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.WriteFile("/corrupt.tar", []byte("not a tar"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + _, _, err := readArchiveEntry(context.Background(), "/corrupt.tar", []string{"./x"}, 16) + if err == nil || errors.Is(err, os.ErrNotExist) { + t.Fatalf("err=%v; want tar parse error", err) + } +} + +func TestReadArchiveEntry_TruncatedEntryReadError(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + data := []byte("0123456789") + if err := tw.WriteHeader(&tar.Header{Name: "./x", Mode: 0o644, Size: int64(len(data))}); err != nil { + t.Fatalf("WriteHeader: %v", err) + } + if _, err := tw.Write(data); err != nil { + t.Fatalf("Write: %v", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + full := buf.Bytes() + if len(full) <= 512+5 { + t.Fatalf("unexpected tar size: %d", len(full)) + } + truncated := full[:512+5] + if err := fakeFS.WriteFile("/trunc.tar", truncated, 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + _, _, err := readArchiveEntry(context.Background(), "/trunc.tar", []string{"./x"}, 16) + if err == nil || errors.Is(err, os.ErrNotExist) { + t.Fatalf("err=%v; want read error", err) + } +} + +func TestCollectCurrentNetworkInventory_UdevAndPermanentMAC(t *testing.T) { + origSys := sysClassNetPath + origCmd := restoreCmd + t.Cleanup(func() { + sysClassNetPath = origSys + restoreCmd = origCmd + }) + + // Make commandAvailable() succeed. + binDir := t.TempDir() + for _, name := range []string{"udevadm", "ethtool"} { + path := filepath.Join(binDir, name) + if err := os.WriteFile(path, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { + t.Fatalf("write %s: %v", name, err) + } + } + t.Setenv("PATH", binDir) + + root := t.TempDir() + netDir := filepath.Join(root, "sys/class/net") + if err := os.MkdirAll(netDir, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + sysClassNetPath = netDir + + // eth0 directory. + if err := os.MkdirAll(filepath.Join(netDir, "eth0"), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(netDir, "eth0/address"), []byte("aa:bb:cc:dd:ee:01\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + // eth1 symlink (non-virtual). + eth1Target := filepath.Join(root, "devices/pci0000:00/eth1") + if err := os.MkdirAll(eth1Target, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(eth1Target, "address"), []byte("aa:bb:cc:dd:ee:02\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + if err := os.Symlink(eth1Target, filepath.Join(netDir, "eth1")); err != nil { + t.Fatalf("symlink: %v", err) + } + + cmd := &FakeCommandRunner{ + Outputs: map[string][]byte{ + "udevadm info -q property -p " + filepath.Join(netDir, "eth0"): []byte("ID_SERIAL=abc\n"), + "ethtool -P eth0": []byte("permanent address: AA:BB:CC:DD:EE:FF\n"), + "udevadm info -q property -p " + filepath.Join(netDir, "eth1"): []byte("ID_SERIAL=ignored\n"), + "ethtool -P eth1": []byte("permanent address: 11:22:33:44:55:66\n"), + }, + Errors: map[string]error{ + "udevadm info -q property -p " + filepath.Join(netDir, "eth1"): errors.New("boom"), + "ethtool -P eth1": errors.New("boom"), + }, + } + restoreCmd = cmd + + inv, err := collectCurrentNetworkInventory(context.Background()) + if err != nil { + t.Fatalf("collect: %v", err) + } + var eth0, eth1 archivedNetworkInterface + for _, iface := range inv.Interfaces { + switch iface.Name { + case "eth0": + eth0 = iface + case "eth1": + eth1 = iface + } + } + if eth0.PermanentMAC != "aa:bb:cc:dd:ee:ff" { + t.Fatalf("eth0 PermanentMAC=%q", eth0.PermanentMAC) + } + if eth0.UdevProps == nil || eth0.UdevProps["ID_SERIAL"] != "abc" { + t.Fatalf("eth0 UdevProps=%v", eth0.UdevProps) + } + if eth1.IsVirtual { + t.Fatalf("eth1 IsVirtual=true; want false") + } + if eth1.PermanentMAC != "" || eth1.UdevProps != nil { + t.Fatalf("eth1 should ignore cmd errors: %+v", eth1) + } +} + +func TestRewriteIfupdownConfigFiles_EdgeAndErrorPaths(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + restoreTime = &FakeTime{Current: time.Date(2025, 1, 1, 1, 2, 3, 0, time.UTC)} + + t.Run("empty rename map", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + paths, dir, err := rewriteIfupdownConfigFiles(nil, map[string]string{}) + if err != nil || len(paths) != 0 || dir != "" { + t.Fatalf("paths=%v dir=%q err=%v", paths, dir, err) + } + }) + + t.Run("interfaces.d missing but update succeeds", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + paths, _, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err != nil || len(paths) != 1 { + t.Fatalf("paths=%v err=%v", paths, err) + } + }) + + t.Run("interfaces.d entries skipped", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + if err := fakeFS.MkdirAll("/etc/network/interfaces.d/subdir", 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := fakeFS.WriteFile("/etc/network/interfaces.d/ ", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + if err := fakeFS.WriteFile("/etc/network/interfaces.d/extra", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + paths, _, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err != nil || len(paths) != 2 { + t.Fatalf("paths=%v err=%v", paths, err) + } + }) + + t.Run("stat failure skips", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + base := statFailFS{FS: fakeFS, failPath: "/etc/network/interfaces", err: errors.New("boom")} + restoreFS = base + + paths, dir, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err != nil || len(paths) != 0 || dir != "" { + t.Fatalf("paths=%v dir=%q err=%v", paths, dir, err) + } + }) + + t.Run("not regular file skipped", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.MkdirAll("/etc/network/interfaces", 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + paths, dir, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err != nil || len(paths) != 0 || dir != "" { + t.Fatalf("paths=%v dir=%q err=%v", paths, dir, err) + } + }) + + t.Run("read failure skipped", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + restoreFS = readFileFailFS{FS: fakeFS, failPath: "/etc/network/interfaces", err: errors.New("boom")} + + paths, dir, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err != nil || len(paths) != 0 || dir != "" { + t.Fatalf("paths=%v dir=%q err=%v", paths, dir, err) + } + }) + + t.Run("base dir create fails", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + fakeFS.MkdirAllErr = errors.New("disk full") + + _, _, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err == nil || !strings.Contains(err.Error(), "create nic repair base directory") { + t.Fatalf("err=%v; want create nic repair base directory", err) + } + }) + + t.Run("write updated fails", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + restoreFS = writeFileFailFS{FS: fakeFS, failPath: "/etc/network/interfaces", err: errors.New("disk full")} + + _, _, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err == nil || !strings.Contains(err.Error(), "write updated file") { + t.Fatalf("err=%v; want write updated file", err) + } + }) +} + +func TestRewriteIfupdownConfigFiles_BackupStageErrors(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSeq := atomic.LoadUint64(&nicRepairSequence) + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + atomic.StoreUint64(&nicRepairSequence, origSeq) + }) + + restoreTime = &FakeTime{Current: time.Date(2025, 1, 1, 1, 2, 3, 0, time.UTC)} + expectedBackupDir := "/tmp/proxsave/nic_repair_20250101_010203_1" + expectedBackupPath := filepath.Join(expectedBackupDir, "etc/network/interfaces") + expectedBackupPathDir := filepath.Dir(expectedBackupPath) + + t.Run("create backup dir fails", func(t *testing.T) { + atomic.StoreUint64(&nicRepairSequence, 0) + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + restoreFS = mkdirAllFailFS{FS: fakeFS, failPath: expectedBackupDir, err: errors.New("boom")} + + _, _, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err == nil || !strings.Contains(err.Error(), "create nic repair backup directory") { + t.Fatalf("err=%v; want create nic repair backup directory", err) + } + }) + + t.Run("create backup path dir fails", func(t *testing.T) { + atomic.StoreUint64(&nicRepairSequence, 0) + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + restoreFS = mkdirAllFailFS{FS: fakeFS, failPath: expectedBackupPathDir, err: errors.New("boom")} + + _, _, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err == nil || !strings.Contains(err.Error(), "create backup directory for") { + t.Fatalf("err=%v; want create backup directory for", err) + } + }) + + t.Run("write backup file fails", func(t *testing.T) { + atomic.StoreUint64(&nicRepairSequence, 0) + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + restoreFS = writeFileFailFS{FS: fakeFS, failPath: expectedBackupPath, err: errors.New("disk full")} + + _, _, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err == nil || !strings.Contains(err.Error(), "write backup file") { + t.Fatalf("err=%v; want write backup file", err) + } + }) +} + +func TestMapToEntriesAndTokenBoundary(t *testing.T) { + if got := mapToEntries(map[string]string{}); got != nil { + t.Fatalf("mapToEntries=%v; want nil", got) + } + got := mapToEntries(map[string]string{"b": "B", "a": "A"}) + if len(got) != 2 || got[0].OldName != "a" || got[1].OldName != "b" { + t.Fatalf("entries=%+v", got) + } + + if isTokenBoundary("abc", -1, "a") { + t.Fatalf("expected false for negative idx") + } + if isTokenBoundary("abc", 2, "zz") { + t.Fatalf("expected false for token overflow") + } + if isTokenBoundary("xeno1", 1, "eno1") { + t.Fatalf("expected false for iface-char prefix") + } + if isTokenBoundary("eno10", 0, "eno1") { + t.Fatalf("expected false for iface-char suffix") + } + if !isTokenBoundary("eno1", 0, "eno1") { + t.Fatalf("expected true for token covering full string") + } + + if out, changed := applyInterfaceRenameMap("auto a\n", map[string]string{"a": "a"}); out != "auto a\n" || changed { + t.Fatalf("applyInterfaceRenameMap unexpected: out=%q changed=%v", out, changed) + } +} diff --git a/internal/orchestrator/orchestrator.go b/internal/orchestrator/orchestrator.go index c629435..4eb71f7 100644 --- a/internal/orchestrator/orchestrator.go +++ b/internal/orchestrator/orchestrator.go @@ -1192,12 +1192,6 @@ func (o *Orchestrator) removeAssociatedFiles(archivePath string) error { return nil } -// Legacy compatibility wrapper for callers that used the package-level createBundle function. -func createBundle(ctx context.Context, logger *logging.Logger, archivePath string) (string, error) { - o := &Orchestrator{logger: logger, fs: osFS{}, clock: realTimeProvider{}} - return o.createBundle(ctx, archivePath) -} - // encryptArchive was replaced by streaming encryption inside the archiver. // SaveStatsReport writes a JSON report with backup statistics to the log directory. diff --git a/internal/orchestrator/pbs_api_apply.go b/internal/orchestrator/pbs_api_apply.go index 830bcc4..9851866 100644 --- a/internal/orchestrator/pbs_api_apply.go +++ b/internal/orchestrator/pbs_api_apply.go @@ -12,6 +12,8 @@ import ( "github.com/tis24dev/proxsave/internal/logging" ) +var pbsAPIApplyGeteuid = os.Geteuid + func normalizeProxmoxCfgKey(key string) string { key = strings.ToLower(strings.TrimSpace(key)) key = strings.ReplaceAll(key, "_", "-") @@ -176,7 +178,7 @@ func ensurePBSServicesForAPI(ctx context.Context, logger *logging.Logger) error if !isRealRestoreFS(restoreFS) { return fmt.Errorf("non-system filesystem in use") } - if os.Geteuid() != 0 { + if pbsAPIApplyGeteuid() != 0 { return fmt.Errorf("requires root privileges") } diff --git a/internal/orchestrator/pbs_api_apply_test.go b/internal/orchestrator/pbs_api_apply_test.go new file mode 100644 index 0000000..f9b3b20 --- /dev/null +++ b/internal/orchestrator/pbs_api_apply_test.go @@ -0,0 +1,1276 @@ +package orchestrator + +import ( + "context" + "errors" + "os" + "reflect" + "strings" + "testing" + + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/types" +) + +func setupPBSAPIApplyTestDeps(t *testing.T) (stageRoot string, fs *FakeFS, runner *fakeCommandRunner) { + t.Helper() + + origCmd := restoreCmd + origFS := restoreFS + t.Cleanup(func() { + restoreCmd = origCmd + restoreFS = origFS + }) + + fs = NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fs.Root) }) + restoreFS = fs + + runner = &fakeCommandRunner{} + restoreCmd = runner + + return "/stage", fs, runner +} + +func writeStageFile(t *testing.T, fs *FakeFS, stageRoot, relPath, content string, perm os.FileMode) { + t.Helper() + if err := fs.WriteFile(stageRoot+"/"+relPath, []byte(content), perm); err != nil { + t.Fatalf("write staged %s: %v", relPath, err) + } +} + +func TestNormalizeProxmoxCfgKey(t *testing.T) { + tests := []struct { + in string + want string + }{ + {in: " Foo_Bar ", want: "foo-bar"}, + {in: "dns1", want: "dns1"}, + {in: "", want: ""}, + {in: " ", want: ""}, + } + for _, tt := range tests { + if got := normalizeProxmoxCfgKey(tt.in); got != tt.want { + t.Fatalf("normalizeProxmoxCfgKey(%q)=%q want %q", tt.in, got, tt.want) + } + } +} + +func TestBuildProxmoxManagerFlags_SkipsAndNormalizes(t *testing.T) { + entries := []proxmoxNotificationEntry{ + {Key: "HOST", Value: " pbs.example "}, + {Key: "digest", Value: "abc"}, + {Key: "name", Value: "ignored"}, + {Key: "foo_bar", Value: "baz"}, + {Key: "skip_me", Value: "nope"}, + {Key: "", Value: "x"}, + } + got := buildProxmoxManagerFlags(entries, "SKIP_ME") + want := []string{"--host", "pbs.example", "--foo-bar", "baz"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("flags=%v want %v", got, want) + } +} + +func TestBuildProxmoxManagerFlags_EmptyReturnsNil(t *testing.T) { + if got := buildProxmoxManagerFlags(nil); got != nil { + t.Fatalf("expected nil, got %v", got) + } + if got := buildProxmoxManagerFlags([]proxmoxNotificationEntry{}); got != nil { + t.Fatalf("expected nil, got %v", got) + } +} + +func TestPopEntryValue_FirstMatchOnly(t *testing.T) { + entries := []proxmoxNotificationEntry{ + {Key: "Foo_Bar", Value: " first "}, + {Key: "foo-bar", Value: "second"}, + {Key: "x", Value: "y"}, + } + value, remaining, ok := popEntryValue(entries, "foo-bar") + if !ok || value != "first" { + t.Fatalf("ok=%v value=%q want ok=true value=first", ok, value) + } + wantRemaining := []proxmoxNotificationEntry{ + {Key: "foo-bar", Value: "second"}, + {Key: "x", Value: "y"}, + } + if !reflect.DeepEqual(remaining, wantRemaining) { + t.Fatalf("remaining=%v want %v", remaining, wantRemaining) + } +} + +func TestPopEntryValue_NoKeysOrEntries(t *testing.T) { + value, remaining, ok := popEntryValue(nil, "k") + if ok || value != "" || remaining != nil { + t.Fatalf("ok=%v value=%q remaining=%v want ok=false value=\"\" remaining=nil", ok, value, remaining) + } + + entries := []proxmoxNotificationEntry{{Key: "k", Value: "v"}} + value, remaining, ok = popEntryValue(entries) + if ok || value != "" || !reflect.DeepEqual(remaining, entries) { + t.Fatalf("ok=%v value=%q remaining=%v want ok=false value=\"\" remaining=%v", ok, value, remaining, entries) + } +} + +func TestRunPBSManagerRedacted_RedactsFlagsAndIndexes(t *testing.T) { + orig := restoreCmd + t.Cleanup(func() { restoreCmd = orig }) + + runner := &fakeCommandRunner{ + outputs: map[string][]byte{}, + errs: map[string]error{}, + } + restoreCmd = runner + + args := []string{"remote", "create", "r1", "--password", "secret123", "--token", "token-value-123"} + key := "proxmox-backup-manager " + strings.Join(args, " ") + runner.outputs[key] = []byte("boom") + runner.errs[key] = errors.New("exit 1") + + out, err := runPBSManagerRedacted(context.Background(), args, []string{"--password"}, []int{6}) + if string(out) != "boom" { + t.Fatalf("out=%q want %q", string(out), "boom") + } + if err == nil { + t.Fatalf("expected error") + } + msg := err.Error() + if strings.Contains(msg, "secret123") || strings.Contains(msg, "token-value-123") { + t.Fatalf("expected secrets to be redacted, got: %s", msg) + } + if !strings.Contains(msg, "") { + t.Fatalf("expected in error, got: %s", msg) + } +} + +func TestUnwrapPBSJSONData(t *testing.T) { + if got := unwrapPBSJSONData([]byte(" ")); got != nil { + t.Fatalf("expected nil for empty input, got %q", string(got)) + } + + got := string(unwrapPBSJSONData([]byte(" not-json "))) + if got != "not-json" { + t.Fatalf("got %q want %q", got, "not-json") + } + + got = string(unwrapPBSJSONData([]byte(`{"data":[{"id":"a"}]}`))) + if got != `[{"id":"a"}]` { + t.Fatalf("got %q want %q", got, `[{"id":"a"}]`) + } + + got = string(unwrapPBSJSONData([]byte(`{"foo":"bar"}`))) + if got != `{"foo":"bar"}` { + t.Fatalf("got %q want %q", got, `{"foo":"bar"}`) + } +} + +func TestParsePBSListIDs(t *testing.T) { + if _, err := parsePBSListIDs([]byte(`[{"id":"a"}]`)); err == nil { + t.Fatalf("expected error for missing candidate keys") + } + + ids, err := parsePBSListIDs([]byte(" "), "id") + if err != nil || ids != nil { + t.Fatalf("ids=%v err=%v want nil,nil", ids, err) + } + + ids, err = parsePBSListIDs([]byte(`{"data":[{"id":"b"},{"id":"a"},{"id":"a"}]}`), "id") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !reflect.DeepEqual(ids, []string{"a", "b"}) { + t.Fatalf("ids=%v want %v", ids, []string{"a", "b"}) + } + + _, err = parsePBSListIDs([]byte(`{"data":[{"id":123,"name":""}]}`), "id", "name") + if err == nil || !strings.Contains(err.Error(), "failed to parse PBS list row 0") { + t.Fatalf("expected row parse error, got: %v", err) + } + + ids, err = parsePBSListIDs([]byte(`[{"name":"x"}]`), "id", "name") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !reflect.DeepEqual(ids, []string{"x"}) { + t.Fatalf("ids=%v want %v", ids, []string{"x"}) + } + + ids, err = parsePBSListIDs([]byte(`{"data":[{"id":"a"}]}`), " id ", " ", "name") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !reflect.DeepEqual(ids, []string{"a"}) { + t.Fatalf("ids=%v want %v", ids, []string{"a"}) + } + + if _, err := parsePBSListIDs([]byte(`{"data":{}}`), "id"); err == nil { + t.Fatalf("expected unmarshal error for non-array data") + } +} + +func TestPBSAPIApply_ReadStageFileOptionalErrors(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + + type applyFunc struct { + name string + relPath string + apply func(context.Context, *logging.Logger, string, bool) error + } + + for _, tt := range []applyFunc{ + {name: "remote", relPath: "etc/proxmox-backup/remote.cfg", apply: applyPBSRemoteCfgViaAPI}, + {name: "s3", relPath: "etc/proxmox-backup/s3.cfg", apply: applyPBSS3CfgViaAPI}, + {name: "datastore", relPath: "etc/proxmox-backup/datastore.cfg", apply: applyPBSDatastoreCfgViaAPI}, + {name: "sync", relPath: "etc/proxmox-backup/sync.cfg", apply: applyPBSSyncCfgViaAPI}, + {name: "verification", relPath: "etc/proxmox-backup/verification.cfg", apply: applyPBSVerificationCfgViaAPI}, + {name: "prune", relPath: "etc/proxmox-backup/prune.cfg", apply: applyPBSPruneCfgViaAPI}, + {name: "traffic-control", relPath: "etc/proxmox-backup/traffic-control.cfg", apply: applyPBSTrafficControlCfgViaAPI}, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + stageRoot, fs, _ := setupPBSAPIApplyTestDeps(t) + if err := fs.AddDir(stageRoot + "/" + tt.relPath); err != nil { + t.Fatalf("create staged dir: %v", err) + } + err := tt.apply(context.Background(), logger, stageRoot, false) + if err == nil || !strings.Contains(err.Error(), "read staged") { + t.Fatalf("expected read staged error, got: %v", err) + } + }) + } + + t.Run("node", func(t *testing.T) { + stageRoot, fs, _ := setupPBSAPIApplyTestDeps(t) + if err := fs.AddDir(stageRoot + "/etc/proxmox-backup/node.cfg"); err != nil { + t.Fatalf("create staged dir: %v", err) + } + err := applyPBSNodeCfgViaAPI(context.Background(), stageRoot) + if err == nil || !strings.Contains(err.Error(), "read staged") { + t.Fatalf("expected read staged error, got: %v", err) + } + }) +} + +func TestPBSAPIApply_NoFileBranches(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + + type applyFunc struct { + name string + apply func(context.Context, *logging.Logger, string, bool) error + } + + for _, tt := range []applyFunc{ + {name: "datastore", apply: applyPBSDatastoreCfgViaAPI}, + {name: "sync", apply: applyPBSSyncCfgViaAPI}, + {name: "verification", apply: applyPBSVerificationCfgViaAPI}, + {name: "prune", apply: applyPBSPruneCfgViaAPI}, + {name: "traffic-control", apply: applyPBSTrafficControlCfgViaAPI}, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + stageRoot, _, runner := setupPBSAPIApplyTestDeps(t) + if err := tt.apply(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(runner.calls) != 0 { + t.Fatalf("expected no calls, got %v", runner.calls) + } + }) + } +} + +func TestPBSAPIApply_StrictListCommandErrors(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + + tests := []struct { + name string + relPath string + content string + listCmd string + apply func(context.Context, *logging.Logger, string, bool) error + }{ + { + name: "remote", + relPath: "etc/proxmox-backup/remote.cfg", + content: "remote: r1\n host pbs.example\n", + listCmd: "proxmox-backup-manager remote list --output-format=json", + apply: applyPBSRemoteCfgViaAPI, + }, + { + name: "s3", + relPath: "etc/proxmox-backup/s3.cfg", + content: "s3: e1\n endpoint https://s3.example\n", + listCmd: "proxmox-backup-manager s3 endpoint list --output-format=json", + apply: applyPBSS3CfgViaAPI, + }, + { + name: "sync", + relPath: "etc/proxmox-backup/sync.cfg", + content: "sync-job: job1\n remote r1\n store ds1\n", + listCmd: "proxmox-backup-manager sync-job list --output-format=json", + apply: applyPBSSyncCfgViaAPI, + }, + { + name: "verification", + relPath: "etc/proxmox-backup/verification.cfg", + content: "verify-job: v1\n store ds1\n", + listCmd: "proxmox-backup-manager verify-job list --output-format=json", + apply: applyPBSVerificationCfgViaAPI, + }, + { + name: "prune", + relPath: "etc/proxmox-backup/prune.cfg", + content: "prune-job: p1\n store ds1\n", + listCmd: "proxmox-backup-manager prune-job list --output-format=json", + apply: applyPBSPruneCfgViaAPI, + }, + { + name: "traffic-control", + relPath: "etc/proxmox-backup/traffic-control.cfg", + content: "traffic-control: tc1\n rate-in 1000\n", + listCmd: "proxmox-backup-manager traffic-control list --output-format=json", + apply: applyPBSTrafficControlCfgViaAPI, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + writeStageFile(t, fs, stageRoot, tt.relPath, tt.content, 0o640) + runner.errs = map[string]error{tt.listCmd: errors.New("boom")} + if err := tt.apply(context.Background(), logger, stageRoot, true); err == nil { + t.Fatalf("expected error") + } + }) + } +} + +func TestPBSAPIApply_StrictListParseErrors(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + + tests := []struct { + name string + relPath string + content string + listCmd string + wantErrPrefix string + apply func(context.Context, *logging.Logger, string, bool) error + }{ + { + name: "remote", + relPath: "etc/proxmox-backup/remote.cfg", + content: "remote: r1\n host pbs.example\n", + listCmd: "proxmox-backup-manager remote list --output-format=json", + wantErrPrefix: "parse remote list:", + apply: applyPBSRemoteCfgViaAPI, + }, + { + name: "s3", + relPath: "etc/proxmox-backup/s3.cfg", + content: "s3: e1\n endpoint https://s3.example\n", + listCmd: "proxmox-backup-manager s3 endpoint list --output-format=json", + wantErrPrefix: "parse s3 endpoint list:", + apply: applyPBSS3CfgViaAPI, + }, + { + name: "sync", + relPath: "etc/proxmox-backup/sync.cfg", + content: "sync-job: job1\n remote r1\n store ds1\n", + listCmd: "proxmox-backup-manager sync-job list --output-format=json", + wantErrPrefix: "parse sync-job list:", + apply: applyPBSSyncCfgViaAPI, + }, + { + name: "verification", + relPath: "etc/proxmox-backup/verification.cfg", + content: "verify-job: v1\n store ds1\n", + listCmd: "proxmox-backup-manager verify-job list --output-format=json", + wantErrPrefix: "parse verify-job list:", + apply: applyPBSVerificationCfgViaAPI, + }, + { + name: "prune", + relPath: "etc/proxmox-backup/prune.cfg", + content: "prune-job: p1\n store ds1\n", + listCmd: "proxmox-backup-manager prune-job list --output-format=json", + wantErrPrefix: "parse prune-job list:", + apply: applyPBSPruneCfgViaAPI, + }, + { + name: "traffic-control", + relPath: "etc/proxmox-backup/traffic-control.cfg", + content: "traffic-control: tc1\n rate-in 1000\n", + listCmd: "proxmox-backup-manager traffic-control list --output-format=json", + wantErrPrefix: "parse traffic-control list:", + apply: applyPBSTrafficControlCfgViaAPI, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + writeStageFile(t, fs, stageRoot, tt.relPath, tt.content, 0o640) + runner.outputs = map[string][]byte{tt.listCmd: []byte("not-json")} + if err := tt.apply(context.Background(), logger, stageRoot, true); err == nil || !strings.Contains(err.Error(), tt.wantErrPrefix) { + t.Fatalf("expected %q error, got: %v", tt.wantErrPrefix, err) + } + }) + } +} + +func TestPBSAPIApply_CreateUpdateBothFailErrors(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + + tests := []struct { + name string + relPath string + content string + create string + update string + apply func(context.Context, *logging.Logger, string, bool) error + }{ + { + name: "sync", + relPath: "etc/proxmox-backup/sync.cfg", + content: "sync-job: job1\n remote r1\n store ds1\n", + create: "proxmox-backup-manager sync-job create job1 --remote r1 --store ds1", + update: "proxmox-backup-manager sync-job update job1 --remote r1 --store ds1", + apply: applyPBSSyncCfgViaAPI, + }, + { + name: "verification", + relPath: "etc/proxmox-backup/verification.cfg", + content: "verify-job: v1\n store ds1\n", + create: "proxmox-backup-manager verify-job create v1 --store ds1", + update: "proxmox-backup-manager verify-job update v1 --store ds1", + apply: applyPBSVerificationCfgViaAPI, + }, + { + name: "prune", + relPath: "etc/proxmox-backup/prune.cfg", + content: "prune-job: p1\n store ds1\n keep-last 3\n", + create: "proxmox-backup-manager prune-job create p1 --store ds1 --keep-last 3", + update: "proxmox-backup-manager prune-job update p1 --store ds1 --keep-last 3", + apply: applyPBSPruneCfgViaAPI, + }, + { + name: "traffic-control", + relPath: "etc/proxmox-backup/traffic-control.cfg", + content: "traffic-control: tc1\n rate-in 1000\n", + create: "proxmox-backup-manager traffic-control create tc1 --rate-in 1000", + update: "proxmox-backup-manager traffic-control update tc1 --rate-in 1000", + apply: applyPBSTrafficControlCfgViaAPI, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + writeStageFile(t, fs, stageRoot, tt.relPath, tt.content, 0o640) + runner.errs = map[string]error{ + tt.create: errors.New("create failed"), + tt.update: errors.New("update failed"), + } + err := tt.apply(context.Background(), logger, stageRoot, false) + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "create failed") || !strings.Contains(err.Error(), "update failed") { + t.Fatalf("expected create/update errors, got: %s", err.Error()) + } + }) + } +} + +func TestEnsurePBSServicesForAPI(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + + t.Run("non-system-fs", func(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + restoreFS = NewFakeFS() + if err := ensurePBSServicesForAPI(context.Background(), logger); err == nil || !strings.Contains(err.Error(), "non-system filesystem") { + t.Fatalf("expected non-system filesystem error, got: %v", err) + } + }) + + t.Run("non-root", func(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origGeteuid := pbsAPIApplyGeteuid + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + pbsAPIApplyGeteuid = origGeteuid + }) + restoreFS = osFS{} + restoreCmd = &fakeCommandRunner{} + pbsAPIApplyGeteuid = func() int { return 1000 } + + if err := ensurePBSServicesForAPI(context.Background(), logger); err == nil || !strings.Contains(err.Error(), "requires root privileges") { + t.Fatalf("expected root privileges error, got: %v", err) + } + }) + + t.Run("pbs-manager-missing", func(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origGeteuid := pbsAPIApplyGeteuid + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + pbsAPIApplyGeteuid = origGeteuid + }) + restoreFS = osFS{} + + runner := &fakeCommandRunner{ + errs: map[string]error{ + "proxmox-backup-manager version": errors.New("not found"), + }, + } + restoreCmd = runner + pbsAPIApplyGeteuid = func() int { return 0 } + + if err := ensurePBSServicesForAPI(context.Background(), logger); err == nil || !strings.Contains(err.Error(), "proxmox-backup-manager not available") { + t.Fatalf("expected pbs manager missing error, got: %v", err) + } + }) + + t.Run("start-services-fails", func(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origGeteuid := pbsAPIApplyGeteuid + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + pbsAPIApplyGeteuid = origGeteuid + }) + restoreFS = osFS{} + + runner := &fakeCommandRunner{ + errs: map[string]error{ + "which systemctl": errors.New("no systemctl"), + }, + } + restoreCmd = runner + pbsAPIApplyGeteuid = func() int { return 0 } + + if err := ensurePBSServicesForAPI(context.Background(), logger); err == nil || !strings.Contains(err.Error(), "systemctl not available") { + t.Fatalf("expected systemctl not available error, got: %v", err) + } + }) + + t.Run("ok", func(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origGeteuid := pbsAPIApplyGeteuid + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + pbsAPIApplyGeteuid = origGeteuid + }) + restoreFS = osFS{} + restoreCmd = &fakeCommandRunner{} + pbsAPIApplyGeteuid = func() int { return 0 } + + if err := ensurePBSServicesForAPI(context.Background(), nil); err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) +} + +func TestApplyPBSRemoteCfgViaAPI_StrictCleanupAndCreate(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/remote.cfg", + "remote: r1\n"+ + " HOST pbs1.example\n"+ + " password secret1\n"+ + " foo_bar baz\n"+ + " digest abc\n"+ + " name ignored\n"+ + "\n"+ + "remote: r2\n"+ + " host pbs2.example\n"+ + " username admin\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager remote list --output-format=json": []byte(`{"data":[{"id":"r1"},{"id":"old"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager remote remove old": errors.New("cannot remove old"), + } + + if err := applyPBSRemoteCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSRemoteCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager remote list --output-format=json", + "proxmox-backup-manager remote remove old", + "proxmox-backup-manager remote create r1 --host pbs1.example --password secret1 --foo-bar baz", + "proxmox-backup-manager remote create r2 --host pbs2.example --username admin", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSRemoteCfgViaAPI_NoFileNoCalls(t *testing.T) { + stageRoot, _, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + if err := applyPBSRemoteCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSRemoteCfgViaAPI error: %v", err) + } + if len(runner.calls) != 0 { + t.Fatalf("expected no calls, got %v", runner.calls) + } +} + +func TestApplyPBSRemoteCfgViaAPI_CreateFailsThenUpdate(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/remote.cfg", + "remote: r1\n"+ + " host pbs.example\n"+ + " password secret1\n", + 0o640, + ) + + runner.errs = map[string]error{ + "proxmox-backup-manager remote create r1 --host pbs.example --password secret1": errors.New("already exists"), + } + + if err := applyPBSRemoteCfgViaAPI(context.Background(), logger, stageRoot, false); err != nil { + t.Fatalf("applyPBSRemoteCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager remote create r1 --host pbs.example --password secret1", + "proxmox-backup-manager remote update r1 --host pbs.example --password secret1", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSRemoteCfgViaAPI_RedactsPasswordOnFailure(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/remote.cfg", + "remote: r1\n"+ + " host pbs.example\n"+ + " password secret1\n", + 0o640, + ) + + runner.errs = map[string]error{ + "proxmox-backup-manager remote create r1 --host pbs.example --password secret1": errors.New("create failed"), + "proxmox-backup-manager remote update r1 --host pbs.example --password secret1": errors.New("update failed"), + } + + err := applyPBSRemoteCfgViaAPI(context.Background(), logger, stageRoot, false) + if err == nil { + t.Fatalf("expected error") + } + if strings.Contains(err.Error(), "secret1") { + t.Fatalf("expected password to be redacted, got: %s", err.Error()) + } + if !strings.Contains(err.Error(), "") { + t.Fatalf("expected in error, got: %s", err.Error()) + } +} + +func TestApplyPBSS3CfgViaAPI_CreateUpdateAndStrictCleanup(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/s3.cfg", + "s3: e1\n"+ + " endpoint https://s3.example\n"+ + " access-key access1\n"+ + " secret-key secret1\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager s3 endpoint list --output-format=json": []byte(`{"data":[{"id":"e1"},{"id":"old"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager s3 endpoint remove old": errors.New("cannot remove old"), + "proxmox-backup-manager s3 endpoint create e1 --endpoint https://s3.example --access-key access1 --secret-key secret1": errors.New("already exists"), + } + + if err := applyPBSS3CfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSS3CfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager s3 endpoint list --output-format=json", + "proxmox-backup-manager s3 endpoint remove old", + "proxmox-backup-manager s3 endpoint create e1 --endpoint https://s3.example --access-key access1 --secret-key secret1", + "proxmox-backup-manager s3 endpoint update e1 --endpoint https://s3.example --access-key access1 --secret-key secret1", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSS3CfgViaAPI_NoFileNoCalls(t *testing.T) { + stageRoot, _, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + if err := applyPBSS3CfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSS3CfgViaAPI error: %v", err) + } + if len(runner.calls) != 0 { + t.Fatalf("expected no calls, got %v", runner.calls) + } +} + +func TestApplyPBSS3CfgViaAPI_RedactsKeysOnFailure(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/s3.cfg", + "s3: e1\n"+ + " endpoint https://s3.example\n"+ + " access-key access1\n"+ + " secret-key secret1\n", + 0o640, + ) + + runner.errs = map[string]error{ + "proxmox-backup-manager s3 endpoint create e1 --endpoint https://s3.example --access-key access1 --secret-key secret1": errors.New("create failed"), + "proxmox-backup-manager s3 endpoint update e1 --endpoint https://s3.example --access-key access1 --secret-key secret1": errors.New("update failed"), + } + + err := applyPBSS3CfgViaAPI(context.Background(), logger, stageRoot, false) + if err == nil { + t.Fatalf("expected error") + } + if strings.Contains(err.Error(), "access1") || strings.Contains(err.Error(), "secret1") { + t.Fatalf("expected keys to be redacted, got: %s", err.Error()) + } + if !strings.Contains(err.Error(), "") { + t.Fatalf("expected in error, got: %s", err.Error()) + } +} + +func TestApplyPBSDatastoreCfgViaAPI_StrictFullFlow(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/datastore.cfg", + "datastore: ds1\n"+ + " path /new1\n"+ + " comment c1\n"+ + "\n"+ + "datastore: ds2\n"+ + " comment missing-path\n"+ + "\n"+ + "datastore: ds3\n"+ + " path /p3\n"+ + " comment c3\n"+ + "\n"+ + "datastore: ds4\n"+ + " path /p4\n"+ + " comment c4\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager datastore list --output-format=json": []byte(`{"data":[{"name":"ds1","path":"/old1"},{"name":"ds3","path":"/p3"},{"name":"ds-old","path":"/old"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager datastore create ds4 /p4 --comment c4": errors.New("already exists"), + } + + if err := applyPBSDatastoreCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSDatastoreCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager datastore list --output-format=json", + "proxmox-backup-manager datastore remove ds-old", + "proxmox-backup-manager datastore remove ds1", + "proxmox-backup-manager datastore create ds1 /new1 --comment c1", + "proxmox-backup-manager datastore update ds3 --comment c3", + "proxmox-backup-manager datastore create ds4 /p4 --comment c4", + "proxmox-backup-manager datastore update ds4 --comment c4", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSDatastoreCfgViaAPI_CurrentPathsFallbacksAndStrictRemoveWarn(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/datastore.cfg", + "datastore: ds1\n"+ + " path /p1\n"+ + " comment c1\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager datastore list --output-format=json": []byte( + `{"data":[` + + `{"store":"store1","path":"/ps"},` + + `{"id":"id1","path":"/pi"},` + + `{"path":"/no-id"},` + + `{"name":"ds1","path":"/p1"}` + + `]}`, + ), + } + runner.errs = map[string]error{ + "proxmox-backup-manager datastore remove id1": errors.New("cannot remove id1"), + } + + if err := applyPBSDatastoreCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSDatastoreCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager datastore list --output-format=json", + "proxmox-backup-manager datastore remove id1", + "proxmox-backup-manager datastore remove store1", + "proxmox-backup-manager datastore update ds1 --comment c1", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSDatastoreCfgViaAPI_StrictPathMismatchRemoveFails(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/datastore.cfg", + "datastore: ds1\n"+ + " path /new\n"+ + " comment c1\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager datastore list --output-format=json": []byte(`{"data":[{"name":"ds1","path":"/old"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager datastore remove ds1": errors.New("cannot remove ds1"), + } + + err := applyPBSDatastoreCfgViaAPI(context.Background(), logger, stageRoot, true) + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "path mismatch") || !strings.Contains(err.Error(), "remove failed") { + t.Fatalf("unexpected error: %s", err.Error()) + } +} + +func TestApplyPBSDatastoreCfgViaAPI_StrictPathMismatchRecreateFails(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/datastore.cfg", + "datastore: ds1\n"+ + " path /new\n"+ + " comment c1\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager datastore list --output-format=json": []byte(`{"data":[{"name":"ds1","path":"/old"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager datastore create ds1 /new --comment c1": errors.New("create failed"), + } + + err := applyPBSDatastoreCfgViaAPI(context.Background(), logger, stageRoot, true) + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "recreate after path mismatch failed") { + t.Fatalf("unexpected error: %s", err.Error()) + } +} + +func TestApplyPBSDatastoreCfgViaAPI_UpdateFails(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/datastore.cfg", + "datastore: ds1\n"+ + " path /p1\n"+ + " comment c1\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager datastore list --output-format=json": []byte(`{"data":[{"name":"ds1","path":"/p1"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager datastore update ds1 --comment c1": errors.New("update failed"), + } + + err := applyPBSDatastoreCfgViaAPI(context.Background(), logger, stageRoot, false) + if err == nil || !strings.Contains(err.Error(), "update failed") { + t.Fatalf("expected update failed error, got: %v", err) + } +} + +func TestApplyPBSDatastoreCfgViaAPI_CreateAndUpdateFail(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/datastore.cfg", + "datastore: ds1\n"+ + " path /p1\n"+ + " comment c1\n", + 0o640, + ) + + runner.errs = map[string]error{ + "proxmox-backup-manager datastore create ds1 /p1 --comment c1": errors.New("create failed"), + "proxmox-backup-manager datastore update ds1 --comment c1": errors.New("update failed"), + } + + err := applyPBSDatastoreCfgViaAPI(context.Background(), logger, stageRoot, false) + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "create failed") || !strings.Contains(err.Error(), "update failed") { + t.Fatalf("expected create/update errors, got: %s", err.Error()) + } +} + +func TestApplyPBSDatastoreCfgViaAPI_NonStrictPathMismatchKeepsPath(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/datastore.cfg", + "datastore: ds1\n"+ + " path /new1\n"+ + " comment c1\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager datastore list --output-format=json": []byte(`{"data":[{"name":"ds1","path":"/old1"}]}`), + } + + if err := applyPBSDatastoreCfgViaAPI(context.Background(), logger, stageRoot, false); err != nil { + t.Fatalf("applyPBSDatastoreCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager datastore list --output-format=json", + "proxmox-backup-manager datastore update ds1 --comment c1", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSSyncCfgViaAPI_Create(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/sync.cfg", + "sync-job: job1\n"+ + " remote r1\n"+ + " store ds1\n", + 0o640, + ) + + if err := applyPBSSyncCfgViaAPI(context.Background(), logger, stageRoot, false); err != nil { + t.Fatalf("applyPBSSyncCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager sync-job create job1 --remote r1 --store ds1", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSSyncCfgViaAPI_StrictCleanupAndFallbackUpdate(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/sync.cfg", + "sync-job: job1\n"+ + " remote r1\n"+ + " store ds1\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager sync-job list --output-format=json": []byte(`{"data":[{"id":"job1"},{"id":"old"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager sync-job remove old": errors.New("cannot remove old"), + "proxmox-backup-manager sync-job create job1 --remote r1 --store ds1": errors.New("already exists"), + } + + if err := applyPBSSyncCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSSyncCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager sync-job list --output-format=json", + "proxmox-backup-manager sync-job remove old", + "proxmox-backup-manager sync-job create job1 --remote r1 --store ds1", + "proxmox-backup-manager sync-job update job1 --remote r1 --store ds1", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSVerificationCfgViaAPI_Create(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/verification.cfg", + "verify-job: v1\n"+ + " store ds1\n", + 0o640, + ) + + if err := applyPBSVerificationCfgViaAPI(context.Background(), logger, stageRoot, false); err != nil { + t.Fatalf("applyPBSVerificationCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager verify-job create v1 --store ds1", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSVerificationCfgViaAPI_StrictCleanupAndFallbackUpdate(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/verification.cfg", + "verify-job: v1\n"+ + " store ds1\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager verify-job list --output-format=json": []byte(`{"data":[{"id":"v1"},{"id":"old"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager verify-job remove old": errors.New("cannot remove old"), + "proxmox-backup-manager verify-job create v1 --store ds1": errors.New("already exists"), + } + + if err := applyPBSVerificationCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSVerificationCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager verify-job list --output-format=json", + "proxmox-backup-manager verify-job remove old", + "proxmox-backup-manager verify-job create v1 --store ds1", + "proxmox-backup-manager verify-job update v1 --store ds1", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSPruneCfgViaAPI_Create(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/prune.cfg", + "prune-job: p1\n"+ + " store ds1\n"+ + " keep-last 3\n", + 0o640, + ) + + if err := applyPBSPruneCfgViaAPI(context.Background(), logger, stageRoot, false); err != nil { + t.Fatalf("applyPBSPruneCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager prune-job create p1 --store ds1 --keep-last 3", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSPruneCfgViaAPI_StrictCleanupAndFallbackUpdate(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/prune.cfg", + "prune-job: p1\n"+ + " store ds1\n"+ + " keep-last 3\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager prune-job list --output-format=json": []byte(`{"data":[{"id":"old"},{"id":"p1"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager prune-job remove old": errors.New("cannot remove old"), + "proxmox-backup-manager prune-job create p1 --store ds1 --keep-last 3": errors.New("already exists"), + } + + if err := applyPBSPruneCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSPruneCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager prune-job list --output-format=json", + "proxmox-backup-manager prune-job remove old", + "proxmox-backup-manager prune-job create p1 --store ds1 --keep-last 3", + "proxmox-backup-manager prune-job update p1 --store ds1 --keep-last 3", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSTrafficControlCfgViaAPI_StrictCleanupAndCreate(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/traffic-control.cfg", + "traffic-control: tc1\n"+ + " rate-in 1000\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager traffic-control list --output-format=json": []byte(`{"data":[{"name":"old"},{"name":"tc1"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager traffic-control remove old": errors.New("cannot remove old"), + } + + if err := applyPBSTrafficControlCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSTrafficControlCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager traffic-control list --output-format=json", + "proxmox-backup-manager traffic-control remove old", + "proxmox-backup-manager traffic-control create tc1 --rate-in 1000", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSTrafficControlCfgViaAPI_CreateFailsThenUpdate(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/traffic-control.cfg", + "traffic-control: tc1\n"+ + " rate-in 1000\n", + 0o640, + ) + + runner.errs = map[string]error{ + "proxmox-backup-manager traffic-control create tc1 --rate-in 1000": errors.New("already exists"), + } + + if err := applyPBSTrafficControlCfgViaAPI(context.Background(), logger, stageRoot, false); err != nil { + t.Fatalf("applyPBSTrafficControlCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager traffic-control create tc1 --rate-in 1000", + "proxmox-backup-manager traffic-control update tc1 --rate-in 1000", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSNodeCfgViaAPI_UsesFirstSection(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/node.cfg", + "node: n1\n"+ + " dns1 1.1.1.1\n"+ + " foo_bar baz\n"+ + "\n"+ + "node: n2\n"+ + " dns1 9.9.9.9\n", + 0o640, + ) + + if err := applyPBSNodeCfgViaAPI(context.Background(), stageRoot); err != nil { + t.Fatalf("applyPBSNodeCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager node update --dns1 1.1.1.1 --foo-bar baz", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSNodeCfgViaAPI_NoFileAndEmptyAndError(t *testing.T) { + t.Run("no-file", func(t *testing.T) { + stageRoot, _, runner := setupPBSAPIApplyTestDeps(t) + if err := applyPBSNodeCfgViaAPI(context.Background(), stageRoot); err != nil { + t.Fatalf("applyPBSNodeCfgViaAPI error: %v", err) + } + if len(runner.calls) != 0 { + t.Fatalf("expected no calls, got %v", runner.calls) + } + }) + + t.Run("empty-sections", func(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/node.cfg", " \n# comment\n", 0o640) + if err := applyPBSNodeCfgViaAPI(context.Background(), stageRoot); err != nil { + t.Fatalf("applyPBSNodeCfgViaAPI error: %v", err) + } + if len(runner.calls) != 0 { + t.Fatalf("expected no calls, got %v", runner.calls) + } + }) + + t.Run("command-error", func(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/node.cfg", + "node: n1\n"+ + " dns1 1.1.1.1\n", + 0o640, + ) + runner.errs = map[string]error{ + "proxmox-backup-manager node update --dns1 1.1.1.1": errors.New("boom"), + } + if err := applyPBSNodeCfgViaAPI(context.Background(), stageRoot); err == nil { + t.Fatalf("expected error") + } + }) +} diff --git a/internal/orchestrator/pbs_mount_guard_test.go b/internal/orchestrator/pbs_mount_guard_test.go index f456170..a9efbc1 100644 --- a/internal/orchestrator/pbs_mount_guard_test.go +++ b/internal/orchestrator/pbs_mount_guard_test.go @@ -13,6 +13,7 @@ func TestPBSMountGuardRootForDatastorePath(t *testing.T) { {name: "mnt nested", in: "/mnt/datastore/Data1", want: "/mnt/datastore"}, {name: "mnt deep", in: "/mnt/Synology_NFS/PBS_Backup", want: "/mnt/Synology_NFS"}, {name: "media", in: "/media/USB/PBS", want: "/media/USB"}, + {name: "run media root", in: "/run/media/root", want: "/run/media/root"}, {name: "run media", in: "/run/media/root/USB/PBS", want: "/run/media/root/USB"}, {name: "not mount style", in: "/srv/pbs", want: ""}, {name: "empty", in: "", want: ""}, diff --git a/internal/orchestrator/pbs_staged_apply.go b/internal/orchestrator/pbs_staged_apply.go index 0b84e8f..265910c 100644 --- a/internal/orchestrator/pbs_staged_apply.go +++ b/internal/orchestrator/pbs_staged_apply.go @@ -12,6 +12,21 @@ import ( "github.com/tis24dev/proxsave/internal/logging" ) +// Hookable functions for testing staged PBS apply logic without touching the real system/API. +var ( + pbsStagedApplyIsRealRestoreFSFn = isRealRestoreFS + pbsStagedApplyGeteuidFn = os.Geteuid + pbsStagedApplyEnsurePBSServicesForAPIFn = ensurePBSServicesForAPI + pbsStagedApplyTrafficControlCfgViaAPIFn = applyPBSTrafficControlCfgViaAPI + pbsStagedApplyNodeCfgViaAPIFn = applyPBSNodeCfgViaAPI + pbsStagedApplyS3CfgViaAPIFn = applyPBSS3CfgViaAPI + pbsStagedApplyDatastoreCfgViaAPIFn = applyPBSDatastoreCfgViaAPI + pbsStagedApplyRemoteCfgViaAPIFn = applyPBSRemoteCfgViaAPI + pbsStagedApplySyncCfgViaAPIFn = applyPBSSyncCfgViaAPI + pbsStagedApplyVerificationCfgViaAPIFn = applyPBSVerificationCfgViaAPI + pbsStagedApplyPruneCfgViaAPIFn = applyPBSPruneCfgViaAPI +) + func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, plan *RestorePlan, stageRoot string, dryRun bool) (err error) { if plan == nil || plan.SystemType != SystemTypePBS { return nil @@ -31,11 +46,11 @@ func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, logger.Info("Dry run enabled: skipping staged PBS config apply") return nil } - if !isRealRestoreFS(restoreFS) { + if !pbsStagedApplyIsRealRestoreFSFn(restoreFS) { logger.Debug("Skipping staged PBS config apply: non-system filesystem in use") return nil } - if os.Geteuid() != 0 { + if pbsStagedApplyGeteuidFn() != 0 { logger.Warning("Skipping staged PBS config apply: requires root privileges") return nil } @@ -47,7 +62,7 @@ func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, needsAPI := plan.HasCategoryID("pbs_host") || plan.HasCategoryID("datastore_pbs") || plan.HasCategoryID("pbs_remotes") || plan.HasCategoryID("pbs_jobs") apiAvailable := false if needsAPI { - if err := ensurePBSServicesForAPI(ctx, logger); err != nil { + if err := pbsStagedApplyEnsurePBSServicesForAPIFn(ctx, logger); err != nil { if allowFileFallback { logger.Warning("PBS API apply unavailable; falling back to file-based staged apply where possible: %v", err) } else { @@ -73,14 +88,14 @@ func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, } if apiAvailable { - if err := applyPBSTrafficControlCfgViaAPI(ctx, logger, stageRoot, strict); err != nil { + if err := pbsStagedApplyTrafficControlCfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: traffic-control failed: %v", err) if allowFileFallback { logger.Warning("PBS staged apply: falling back to file-based traffic-control.cfg") _ = applyPBSConfigFileFromStage(ctx, logger, stageRoot, "etc/proxmox-backup/traffic-control.cfg") } } - if err := applyPBSNodeCfgViaAPI(ctx, stageRoot); err != nil { + if err := pbsStagedApplyNodeCfgViaAPIFn(ctx, stageRoot); err != nil { logger.Warning("PBS API apply: node config failed: %v", err) if allowFileFallback { logger.Warning("PBS staged apply: falling back to file-based node.cfg") @@ -103,14 +118,14 @@ func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, if plan.HasCategoryID("datastore_pbs") { if apiAvailable { - if err := applyPBSS3CfgViaAPI(ctx, logger, stageRoot, strict); err != nil { + if err := pbsStagedApplyS3CfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: s3.cfg failed: %v", err) if allowFileFallback { logger.Warning("PBS staged apply: falling back to file-based s3.cfg") _ = applyPBSS3CfgFromStage(ctx, logger, stageRoot) } } - if err := applyPBSDatastoreCfgViaAPI(ctx, logger, stageRoot, strict); err != nil { + if err := pbsStagedApplyDatastoreCfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: datastore.cfg failed: %v", err) if allowFileFallback { logger.Warning("PBS staged apply: falling back to file-based datastore.cfg") @@ -131,7 +146,7 @@ func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, if plan.HasCategoryID("pbs_remotes") { if apiAvailable { - if err := applyPBSRemoteCfgViaAPI(ctx, logger, stageRoot, strict); err != nil { + if err := pbsStagedApplyRemoteCfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: remote.cfg failed: %v", err) if allowFileFallback { logger.Warning("PBS staged apply: falling back to file-based remote.cfg") @@ -149,21 +164,21 @@ func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, if plan.HasCategoryID("pbs_jobs") { if apiAvailable { - if err := applyPBSSyncCfgViaAPI(ctx, logger, stageRoot, strict); err != nil { + if err := pbsStagedApplySyncCfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: sync jobs failed: %v", err) if allowFileFallback { logger.Warning("PBS staged apply: falling back to file-based job configs") _ = applyPBSJobConfigsFromStage(ctx, logger, stageRoot) } } - if err := applyPBSVerificationCfgViaAPI(ctx, logger, stageRoot, strict); err != nil { + if err := pbsStagedApplyVerificationCfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: verification jobs failed: %v", err) if allowFileFallback { logger.Warning("PBS staged apply: falling back to file-based job configs") _ = applyPBSJobConfigsFromStage(ctx, logger, stageRoot) } } - if err := applyPBSPruneCfgViaAPI(ctx, logger, stageRoot, strict); err != nil { + if err := pbsStagedApplyPruneCfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: prune jobs failed: %v", err) if allowFileFallback { logger.Warning("PBS staged apply: falling back to file-based job configs") diff --git a/internal/orchestrator/pbs_staged_apply_additional_test.go b/internal/orchestrator/pbs_staged_apply_additional_test.go new file mode 100644 index 0000000..80cb253 --- /dev/null +++ b/internal/orchestrator/pbs_staged_apply_additional_test.go @@ -0,0 +1,1031 @@ +package orchestrator + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestPBSConfigHasHeader_AcceptsAndRejectsExpectedForms(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + content string + want bool + }{ + { + name: "HeaderWithSpaceSeparatedName", + content: strings.Join([]string{ + "# comment", + "", + "remote: pbs1", + " host 10.0.0.10", + }, "\n"), + want: true, + }, + { + name: "HeaderWithInlineName", + content: strings.Join([]string{ + "remote:pbs1", + " host 10.0.0.10", + }, "\n"), + want: true, + }, + { + name: "RejectsHeaderWithoutName", + content: strings.Join([]string{ + "datastore:", + " path /mnt/datastore", + }, "\n"), + want: false, + }, + { + name: "RejectsInvalidKeyCharacters", + content: "foo.bar: baz\n", + want: false, + }, + { + name: "RejectsEmptyKey", + content: ": x\n", + want: false, + }, + { + name: "RejectsOnlyComments", + content: "# comment\n# still comment\n", + want: false, + }, + { + name: "AcceptsDashAndUnderscore", + content: "key-with-dash: v\nkey_with_underscore: v\n", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := pbsConfigHasHeader(tt.content); got != tt.want { + t.Fatalf("pbsConfigHasHeader()=%v want %v", got, tt.want) + } + }) + } +} + +func TestMaybeApplyPBSConfigsFromStage_EarlyReturns(t *testing.T) { + ctx := context.Background() + logger := newTestLogger() + + if err := maybeApplyPBSConfigsFromStage(ctx, logger, nil, "/stage", false); err != nil { + t.Fatalf("nil plan: expected nil error, got %v", err) + } + + planWrongSystem := &RestorePlan{SystemType: SystemTypePVE, NormalCategories: []Category{{ID: "pbs_host"}}} + if err := maybeApplyPBSConfigsFromStage(ctx, logger, planWrongSystem, "/stage", false); err != nil { + t.Fatalf("wrong system type: expected nil error, got %v", err) + } + + planNoCategories := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "unrelated"}}} + if err := maybeApplyPBSConfigsFromStage(ctx, logger, planNoCategories, "/stage", false); err != nil { + t.Fatalf("no pbs categories: expected nil error, got %v", err) + } + + plan := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "pbs_host"}}} + if err := maybeApplyPBSConfigsFromStage(ctx, logger, plan, " ", false); err != nil { + t.Fatalf("blank stageRoot: expected nil error, got %v", err) + } + if err := maybeApplyPBSConfigsFromStage(ctx, logger, plan, "/stage", true); err != nil { + t.Fatalf("dryRun: expected nil error, got %v", err) + } + + origFS := restoreFS + fakeFS := NewFakeFS() + t.Cleanup(func() { + restoreFS = origFS + _ = os.RemoveAll(fakeFS.Root) + }) + restoreFS = fakeFS + if err := maybeApplyPBSConfigsFromStage(ctx, logger, plan, "/stage", false); err != nil { + t.Fatalf("non-real FS: expected nil error, got %v", err) + } +} + +func TestApplyPBSConfigFileFromStage_SkipsMissingFile(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := applyPBSConfigFileFromStage(context.Background(), newTestLogger(), "/stage", "etc/proxmox-backup/s3.cfg"); err != nil { + t.Fatalf("applyPBSConfigFileFromStage: %v", err) + } + if _, err := fakeFS.Stat("/etc/proxmox-backup/s3.cfg"); err == nil { + t.Fatalf("expected s3.cfg to not be created") + } else if !os.IsNotExist(err) { + t.Fatalf("stat s3.cfg: %v", err) + } +} + +func TestApplyPBSConfigFileFromStage_SkipsInvalidHeader_LeavesTargetUnchanged(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + const destPath = "/etc/proxmox-backup/remote.cfg" + existing := "remote: old\n host 1.2.3.4\n" + if err := fakeFS.WriteFile(destPath, []byte(existing), 0o640); err != nil { + t.Fatalf("write existing remote.cfg: %v", err) + } + + stageRoot := "/stage" + staged := "this is not a PBS config file\n(no section header)\n" + if err := fakeFS.WriteFile(filepath.Join(stageRoot, "etc/proxmox-backup/remote.cfg"), []byte(staged), 0o640); err != nil { + t.Fatalf("write staged remote.cfg: %v", err) + } + + if err := applyPBSConfigFileFromStage(context.Background(), newTestLogger(), stageRoot, "etc/proxmox-backup/remote.cfg"); err != nil { + t.Fatalf("applyPBSConfigFileFromStage: %v", err) + } + + got, err := fakeFS.ReadFile(destPath) + if err != nil { + t.Fatalf("read dest remote.cfg: %v", err) + } + if string(got) != existing { + t.Fatalf("dest remote.cfg changed unexpectedly: got=%q want=%q", string(got), existing) + } +} + +func TestApplyPBSConfigFileFromStage_ReturnsErrorOnAtomicWriteFailure(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Unix(123, 456)} + + stageRoot := "/stage" + rel := "etc/proxmox-backup/s3.cfg" + staged := "s3: r1\n bucket test\n" + if err := fakeFS.WriteFile(filepath.Join(stageRoot, rel), []byte(staged), 0o640); err != nil { + t.Fatalf("write staged s3.cfg: %v", err) + } + + destPath := filepath.Join(string(os.PathSeparator), filepath.FromSlash(rel)) + tmpPath := fmt.Sprintf("%s.proxsave.tmp.%d", destPath, nowRestore().UnixNano()) + fakeFS.OpenFileErr[filepath.Clean(tmpPath)] = errors.New("forced OpenFile error") + + if err := applyPBSConfigFileFromStage(context.Background(), newTestLogger(), stageRoot, rel); err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestApplyPBSS3CfgFromStage_WritesS3Cfg(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + content := "s3: r1\n bucket test\n" + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/s3.cfg", []byte(content), 0o640); err != nil { + t.Fatalf("write staged s3.cfg: %v", err) + } + + if err := applyPBSS3CfgFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSS3CfgFromStage: %v", err) + } + + if _, err := fakeFS.Stat("/etc/proxmox-backup/s3.cfg"); err != nil { + t.Fatalf("expected s3.cfg to exist: %v", err) + } +} + +func TestApplyPBSConfigFileFromStage_PropagatesReadError(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + rel := "etc/proxmox-backup/remote.cfg" + if err := fakeFS.MkdirAll(filepath.Join(stageRoot, rel), 0o755); err != nil { + t.Fatalf("mkdir staged remote.cfg dir: %v", err) + } + + if err := applyPBSConfigFileFromStage(context.Background(), newTestLogger(), stageRoot, rel); err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestLoadPBSDatastoreCfgFromInventory_FallsBackToDatastoreList(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + inventory := `{"datastores":[{"name":"DS1","path":"/mnt/ds1","comment":"primary"},{"name":"DS2","path":"/mnt/ds2","comment":""}]}` + if err := fakeFS.WriteFile(stageRoot+"/var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json", []byte(inventory), 0o640); err != nil { + t.Fatalf("write inventory: %v", err) + } + + content, src, err := loadPBSDatastoreCfgFromInventory(stageRoot) + if err != nil { + t.Fatalf("loadPBSDatastoreCfgFromInventory: %v", err) + } + if src != "pbs_datastore_inventory.json.datastores" { + t.Fatalf("src=%q", src) + } + + blocks, err := parsePBSDatastoreCfgBlocks(content) + if err != nil { + t.Fatalf("parsePBSDatastoreCfgBlocks: %v", err) + } + if len(blocks) != 2 { + t.Fatalf("expected 2 blocks, got %d", len(blocks)) + } + paths := map[string]string{} + for _, b := range blocks { + paths[b.Name] = b.Path + } + if paths["DS1"] != "/mnt/ds1" { + t.Fatalf("DS1 path=%q", paths["DS1"]) + } + if paths["DS2"] != "/mnt/ds2" { + t.Fatalf("DS2 path=%q", paths["DS2"]) + } + if !strings.Contains(content, "comment primary") { + t.Fatalf("expected DS1 comment in generated content") + } +} + +func TestLoadPBSDatastoreCfgFromInventory_PropagatesErrors(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + inventoryPath := stageRoot + "/var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json" + + if err := fakeFS.WriteFile(inventoryPath, []byte(" \n"), 0o640); err != nil { + t.Fatalf("write empty inventory: %v", err) + } + if _, _, err := loadPBSDatastoreCfgFromInventory(stageRoot); err == nil { + t.Fatalf("expected error for empty inventory") + } + + if err := fakeFS.WriteFile(inventoryPath, []byte("not-json"), 0o640); err != nil { + t.Fatalf("write invalid inventory: %v", err) + } + if _, _, err := loadPBSDatastoreCfgFromInventory(stageRoot); err == nil { + t.Fatalf("expected error for invalid JSON") + } + + if err := fakeFS.WriteFile(inventoryPath, []byte(`{"datastores":[{"name":"","path":"","comment":""}]}`), 0o640); err != nil { + t.Fatalf("write unusable inventory: %v", err) + } + if _, _, err := loadPBSDatastoreCfgFromInventory(stageRoot); err == nil { + t.Fatalf("expected error for unusable inventory") + } +} + +func TestDetectPBSDatastoreCfgDuplicateKeys_DetectsDuplicateKeys(t *testing.T) { + t.Parallel() + + blocks := []pbsDatastoreBlock{{ + Name: "DS1", + Lines: []string{ + "datastore: DS1", + "# comment", + "", + " path /mnt/a", + " path /mnt/b", + }, + }} + if reason := detectPBSDatastoreCfgDuplicateKeys(blocks); reason == "" { + t.Fatalf("expected duplicate key detection") + } +} + +func TestDetectPBSDatastoreCfgDuplicateKeys_AllowsUniqueKeys(t *testing.T) { + t.Parallel() + + blocks := []pbsDatastoreBlock{{ + Name: "DS1", + Lines: []string{ + "datastore: DS1", + " comment one", + " path /mnt/a", + }, + }} + if reason := detectPBSDatastoreCfgDuplicateKeys(blocks); reason != "" { + t.Fatalf("expected no duplicates, got %q", reason) + } +} + +func TestParsePBSDatastoreCfgBlocks_IgnoresGarbageAndHandlesMissingNames(t *testing.T) { + t.Parallel() + + content := strings.Join([]string{ + "path /should/be/ignored", + "datastore:", + " path /also/ignored", + "datastore: DS1", + "# keep comment", + "path /mnt/ds1", + "", + "datastore: DS2:", + " path /mnt/ds2", + "", + }, "\n") + + blocks, err := parsePBSDatastoreCfgBlocks(content) + if err != nil { + t.Fatalf("parsePBSDatastoreCfgBlocks: %v", err) + } + if len(blocks) != 2 { + t.Fatalf("expected 2 blocks, got %d", len(blocks)) + } + if blocks[0].Name != "DS1" || blocks[0].Path != "/mnt/ds1" { + t.Fatalf("block[0]=%+v", blocks[0]) + } + if blocks[1].Name != "DS2" || blocks[1].Path != "/mnt/ds2" { + t.Fatalf("block[1]=%+v", blocks[1]) + } + if gotLines := strings.Join(blocks[0].Lines, "\n"); !strings.Contains(gotLines, "# keep comment") { + t.Fatalf("expected DS1 block to retain comment line; got=%q", gotLines) + } +} + +func TestParsePBSDatastoreCfgBlocks_DropsEmptyNamedBlocks(t *testing.T) { + t.Parallel() + + content := strings.Join([]string{ + "datastore: :", + " path /mnt/ignored", + "", + "datastore: DS1", + " path /mnt/ds1", + "", + }, "\n") + + blocks, err := parsePBSDatastoreCfgBlocks(content) + if err != nil { + t.Fatalf("parsePBSDatastoreCfgBlocks: %v", err) + } + if len(blocks) != 1 { + t.Fatalf("expected 1 block, got %d", len(blocks)) + } + if blocks[0].Name != "DS1" { + t.Fatalf("block[0].Name=%q", blocks[0].Name) + } +} + +func TestShouldApplyPBSDatastoreBlock_CoversCommonBranches(t *testing.T) { + t.Parallel() + + if ok, reason := shouldApplyPBSDatastoreBlock(pbsDatastoreBlock{Name: "ds", Path: "/"}, newTestLogger()); ok || !strings.Contains(reason, "invalid") { + t.Fatalf("expected invalid path rejection, got ok=%v reason=%q", ok, reason) + } + + dsDir := t.TempDir() + if err := os.MkdirAll(filepath.Join(dsDir, ".chunks"), 0o755); err != nil { + t.Fatalf("mkdir .chunks: %v", err) + } + if err := os.WriteFile(filepath.Join(dsDir, ".chunks", "c1"), []byte("x"), 0o644); err != nil { + t.Fatalf("write chunk: %v", err) + } + if ok, reason := shouldApplyPBSDatastoreBlock(pbsDatastoreBlock{Name: "ds", Path: dsDir}, newTestLogger()); !ok { + t.Fatalf("expected hasData datastore to be applied, got ok=false reason=%q", reason) + } + + tooLong := "/" + strings.Repeat("a", 5000) + if ok, reason := shouldApplyPBSDatastoreBlock(pbsDatastoreBlock{Name: "ds", Path: tooLong}, newTestLogger()); ok || !strings.Contains(reason, "inspection failed") { + t.Fatalf("expected inspection failure, got ok=%v reason=%q", ok, reason) + } +} + +func TestWriteDeferredPBSDatastoreCfg_EmptyInputIsNoop(t *testing.T) { + t.Parallel() + + if path, err := writeDeferredPBSDatastoreCfg(nil); err != nil { + t.Fatalf("err=%v", err) + } else if path != "" { + t.Fatalf("expected empty path, got %q", path) + } +} + +func TestWriteDeferredPBSDatastoreCfg_WritesFile(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 1, 2, 3, 4, 5, 0, time.UTC)} + + blocks := []pbsDatastoreBlock{{ + Name: "DS1", + Path: "/mnt/ds1", + Lines: []string{"datastore: DS1", " path /mnt/ds1"}, + }} + + path, err := writeDeferredPBSDatastoreCfg(blocks) + if err != nil { + t.Fatalf("writeDeferredPBSDatastoreCfg: %v", err) + } + if path == "" { + t.Fatalf("expected non-empty path") + } + + raw, err := fakeFS.ReadFile(path) + if err != nil { + t.Fatalf("read deferred file: %v", err) + } + if !strings.Contains(string(raw), "datastore: DS1") { + t.Fatalf("unexpected deferred content: %q", string(raw)) + } +} + +func TestWriteDeferredPBSDatastoreCfg_PropagatesMkdirError(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + fakeFS.MkdirAllErr = errors.New("forced mkdir error") + + _, err := writeDeferredPBSDatastoreCfg([]pbsDatastoreBlock{{Name: "DS1", Path: "/mnt/ds1", Lines: []string{"datastore: DS1", " path /mnt/ds1"}}}) + if err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestWriteDeferredPBSDatastoreCfg_PropagatesWriteError(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + fakeFS.WriteErr = errors.New("forced write error") + + _, err := writeDeferredPBSDatastoreCfg([]pbsDatastoreBlock{{Name: "DS1", Path: "/mnt/ds1", Lines: []string{"datastore: DS1", " path /mnt/ds1"}}}) + if err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestWriteDeferredPBSDatastoreCfg_MultipleBlocksAddsSeparator(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 1, 2, 3, 4, 6, 0, time.UTC)} + + blocks := []pbsDatastoreBlock{ + {Name: "DS1", Lines: []string{"datastore: DS1", " path /mnt/ds1"}}, + {Name: "DS2", Lines: []string{"datastore: DS2", " path /mnt/ds2"}}, + } + + path, err := writeDeferredPBSDatastoreCfg(blocks) + if err != nil { + t.Fatalf("writeDeferredPBSDatastoreCfg: %v", err) + } + + raw, err := fakeFS.ReadFile(path) + if err != nil { + t.Fatalf("read deferred file: %v", err) + } + if !strings.Contains(string(raw), "datastore: DS1\n path /mnt/ds1\n\ndatastore: DS2\n path /mnt/ds2") { + t.Fatalf("expected blank line separator between blocks; got=%q", string(raw)) + } +} + +func TestApplyPBSDatastoreCfgFromStage_SkipsMissingStagedFile(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := applyPBSDatastoreCfgFromStage(context.Background(), newTestLogger(), "/stage"); err != nil { + t.Fatalf("applyPBSDatastoreCfgFromStage: %v", err) + } + if _, err := fakeFS.Stat("/etc/proxmox-backup/datastore.cfg"); err == nil { + t.Fatalf("expected datastore.cfg not to be created") + } +} + +func TestApplyPBSDatastoreCfgFromStage_RemovesTargetWhenStagedEmpty(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.WriteFile("/etc/proxmox-backup/datastore.cfg", []byte("datastore: old\n path /mnt/old\n"), 0o640); err != nil { + t.Fatalf("write existing datastore.cfg: %v", err) + } + if err := fakeFS.WriteFile("/stage/etc/proxmox-backup/datastore.cfg", []byte(" \n"), 0o640); err != nil { + t.Fatalf("write staged datastore.cfg: %v", err) + } + + if err := applyPBSDatastoreCfgFromStage(context.Background(), newTestLogger(), "/stage"); err != nil { + t.Fatalf("applyPBSDatastoreCfgFromStage: %v", err) + } + if _, err := fakeFS.Stat("/etc/proxmox-backup/datastore.cfg"); err == nil { + t.Fatalf("expected datastore.cfg removed") + } +} + +func TestApplyPBSDatastoreCfgFromStage_DefersUnsafeAndAppliesSafe(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 3, 4, 5, 6, 0, time.UTC)} + + safeDir := t.TempDir() + unsafeDir := t.TempDir() + if err := os.WriteFile(filepath.Join(unsafeDir, "unexpected"), []byte("x"), 0o644); err != nil { + t.Fatalf("write unexpected file: %v", err) + } + + stageRoot := "/stage" + staged := strings.Join([]string{ + "datastore: Safe", + fmt.Sprintf("path %s", safeDir), + "", + "datastore: Unsafe", + fmt.Sprintf("path %s", unsafeDir), + "", + }, "\n") + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/datastore.cfg", []byte(staged), 0o640); err != nil { + t.Fatalf("write staged datastore.cfg: %v", err) + } + + if err := applyPBSDatastoreCfgFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSDatastoreCfgFromStage: %v", err) + } + + out, err := fakeFS.ReadFile("/etc/proxmox-backup/datastore.cfg") + if err != nil { + t.Fatalf("read applied datastore.cfg: %v", err) + } + if !strings.Contains(string(out), "datastore: Safe") { + t.Fatalf("expected Safe datastore in output: %q", string(out)) + } + if strings.Contains(string(out), "datastore: Unsafe") { + t.Fatalf("did not expect Unsafe datastore in output: %q", string(out)) + } + + entries, err := fakeFS.ReadDir("/tmp/proxsave") + if err != nil { + t.Fatalf("readdir deferred dir: %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 deferred file, got %d", len(entries)) + } + deferredPath := filepath.Join("/tmp/proxsave", entries[0].Name()) + deferred, err := fakeFS.ReadFile(deferredPath) + if err != nil { + t.Fatalf("read deferred file: %v", err) + } + if !strings.Contains(string(deferred), "datastore: Unsafe") { + t.Fatalf("expected Unsafe datastore deferred: %q", string(deferred)) + } +} + +func TestApplyPBSDatastoreCfgFromStage_AllDeferredLeavesTargetUnchanged(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 3, 4, 5, 7, 0, time.UTC)} + + existing := "datastore: Existing\n path /mnt/existing\n" + if err := fakeFS.WriteFile("/etc/proxmox-backup/datastore.cfg", []byte(existing), 0o640); err != nil { + t.Fatalf("write existing datastore.cfg: %v", err) + } + + unsafeDir := t.TempDir() + if err := os.WriteFile(filepath.Join(unsafeDir, "unexpected"), []byte("x"), 0o644); err != nil { + t.Fatalf("write unexpected file: %v", err) + } + + stageRoot := "/stage" + staged := strings.Join([]string{ + "datastore: UnsafeOnly", + fmt.Sprintf("path %s", unsafeDir), + "", + }, "\n") + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/datastore.cfg", []byte(staged), 0o640); err != nil { + t.Fatalf("write staged datastore.cfg: %v", err) + } + + if err := applyPBSDatastoreCfgFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSDatastoreCfgFromStage: %v", err) + } + + got, err := fakeFS.ReadFile("/etc/proxmox-backup/datastore.cfg") + if err != nil { + t.Fatalf("read datastore.cfg: %v", err) + } + if string(got) != existing { + t.Fatalf("datastore.cfg changed unexpectedly: got=%q want=%q", string(got), existing) + } +} + +func TestApplyPBSDatastoreCfgFromStage_DuplicateKeysWithoutInventoryLeavesTargetUnchanged(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + existing := "datastore: Existing\n path /mnt/existing\n" + if err := fakeFS.WriteFile("/etc/proxmox-backup/datastore.cfg", []byte(existing), 0o640); err != nil { + t.Fatalf("write existing datastore.cfg: %v", err) + } + + stageRoot := "/stage" + staged := strings.Join([]string{ + "datastore: Broken", + " path /mnt/a", + " path /mnt/b", + "", + }, "\n") + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/datastore.cfg", []byte(staged), 0o640); err != nil { + t.Fatalf("write staged datastore.cfg: %v", err) + } + + if err := applyPBSDatastoreCfgFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSDatastoreCfgFromStage: %v", err) + } + + got, err := fakeFS.ReadFile("/etc/proxmox-backup/datastore.cfg") + if err != nil { + t.Fatalf("read datastore.cfg: %v", err) + } + if string(got) != existing { + t.Fatalf("datastore.cfg changed unexpectedly: got=%q want=%q", string(got), existing) + } +} + +func TestApplyPBSDatastoreCfgFromStage_SkipsWhenNoBlocksDetected(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/datastore.cfg", []byte("# only comments\n\n# nothing else\n"), 0o640); err != nil { + t.Fatalf("write staged datastore.cfg: %v", err) + } + + if err := applyPBSDatastoreCfgFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSDatastoreCfgFromStage: %v", err) + } + if _, err := fakeFS.Stat("/etc/proxmox-backup/datastore.cfg"); err == nil { + t.Fatalf("expected datastore.cfg not to be created") + } +} + +func TestApplyPBSDatastoreCfgFromStage_PropagatesReadError(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + if err := fakeFS.MkdirAll(stageRoot+"/etc/proxmox-backup/datastore.cfg", 0o755); err != nil { + t.Fatalf("mkdir staged datastore.cfg dir: %v", err) + } + + if err := applyPBSDatastoreCfgFromStage(context.Background(), newTestLogger(), stageRoot); err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestApplyPBSDatastoreCfgFromStage_ContinuesWhenDeferredWriteFails(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Unix(10, 0)} + + safeDir := t.TempDir() + unsafeDir := t.TempDir() + if err := os.WriteFile(filepath.Join(unsafeDir, "unexpected"), []byte("x"), 0o644); err != nil { + t.Fatalf("write unexpected file: %v", err) + } + + stageRoot := "/stage" + staged := strings.Join([]string{ + "datastore: Safe", + fmt.Sprintf("path %s", safeDir), + "", + "datastore: Unsafe", + fmt.Sprintf("path %s", unsafeDir), + "", + }, "\n") + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/datastore.cfg", []byte(staged), 0o640); err != nil { + t.Fatalf("write staged datastore.cfg: %v", err) + } + + // Fail deferred file writes but still allow atomic apply. + fakeFS.WriteErr = errors.New("forced write error") + + if err := applyPBSDatastoreCfgFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSDatastoreCfgFromStage: %v", err) + } + + out, err := fakeFS.ReadFile("/etc/proxmox-backup/datastore.cfg") + if err != nil { + t.Fatalf("read applied datastore.cfg: %v", err) + } + if !strings.Contains(string(out), "datastore: Safe") || strings.Contains(string(out), "datastore: Unsafe") { + t.Fatalf("unexpected apply result: %q", string(out)) + } + + if entries, err := fakeFS.ReadDir("/tmp/proxsave"); err != nil { + t.Fatalf("readdir /tmp/proxsave: %v", err) + } else if len(entries) != 0 { + t.Fatalf("expected no deferred files due to forced write error, got %d", len(entries)) + } +} + +func TestApplyPBSDatastoreCfgFromStage_ReturnsErrorOnAtomicWriteFailure(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Unix(10, 0)} + + safeDir := t.TempDir() + stageRoot := "/stage" + staged := strings.Join([]string{ + "datastore: DS1", + fmt.Sprintf("path %s", safeDir), + "", + }, "\n") + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/datastore.cfg", []byte(staged), 0o640); err != nil { + t.Fatalf("write staged datastore.cfg: %v", err) + } + + dest := "/etc/proxmox-backup/datastore.cfg" + tmp := fmt.Sprintf("%s.proxsave.tmp.%d", dest, nowRestore().UnixNano()) + fakeFS.OpenFileErr[filepath.Clean(tmp)] = errors.New("forced OpenFile error") + + if err := applyPBSDatastoreCfgFromStage(context.Background(), newTestLogger(), stageRoot); err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestApplyPBSJobConfigsFromStage_WritesAllJobConfigs(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + for name, content := range map[string]string{ + "sync.cfg": "sync: job1\n remote r1\n", + "verification.cfg": "verification: v1\n datastore ds1\n", + "prune.cfg": "prune: p1\n keep-last 1\n", + } { + if err := fakeFS.WriteFile(filepath.Join(stageRoot, "etc/proxmox-backup", name), []byte(content), 0o640); err != nil { + t.Fatalf("write staged %s: %v", name, err) + } + } + + if err := applyPBSJobConfigsFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSJobConfigsFromStage: %v", err) + } + + for _, name := range []string{"sync.cfg", "verification.cfg", "prune.cfg"} { + dest := filepath.Join("/etc/proxmox-backup", name) + if _, err := fakeFS.Stat(dest); err != nil { + t.Fatalf("expected %s to exist: %v", dest, err) + } + } +} + +func TestApplyPBSJobConfigsFromStage_ContinuesOnApplyErrors(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Unix(10, 0)} + + stageRoot := "/stage" + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/sync.cfg", []byte("sync: job1\n remote r1\n"), 0o640); err != nil { + t.Fatalf("write staged sync.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/verification.cfg", []byte("verification: v1\n datastore ds1\n"), 0o640); err != nil { + t.Fatalf("write staged verification.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/prune.cfg", []byte("prune: p1\n keep-last 1\n"), 0o640); err != nil { + t.Fatalf("write staged prune.cfg: %v", err) + } + + relFail := "etc/proxmox-backup/verification.cfg" + destFail := filepath.Join(string(os.PathSeparator), filepath.FromSlash(relFail)) + tmpFail := fmt.Sprintf("%s.proxsave.tmp.%d", destFail, nowRestore().UnixNano()) + fakeFS.OpenFileErr[filepath.Clean(tmpFail)] = errors.New("forced OpenFile error") + + if err := applyPBSJobConfigsFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSJobConfigsFromStage: %v", err) + } + + if _, err := fakeFS.Stat("/etc/proxmox-backup/sync.cfg"); err != nil { + t.Fatalf("expected sync.cfg to exist: %v", err) + } + if _, err := fakeFS.Stat("/etc/proxmox-backup/prune.cfg"); err != nil { + t.Fatalf("expected prune.cfg to exist: %v", err) + } + if _, err := fakeFS.Stat("/etc/proxmox-backup/verification.cfg"); err == nil { + t.Fatalf("expected verification.cfg not to be created due to forced error") + } +} + +func TestApplyPBSTapeConfigsFromStage_WritesConfigsAndSensitiveKeys(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + for name, content := range map[string]string{ + "tape.cfg": "drive: d1\n path /dev/nst0\n", + "tape-job.cfg": "tape-job: job1\n drive d1\n", + "media-pool.cfg": "media-pool: pool1\n retention 30\n", + } { + if err := fakeFS.WriteFile(filepath.Join(stageRoot, "etc/proxmox-backup", name), []byte(content), 0o640); err != nil { + t.Fatalf("write staged %s: %v", name, err) + } + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/tape-encryption-keys.json", []byte(`{"keys":[{"fingerprint":"abc"}]}`), 0o640); err != nil { + t.Fatalf("write staged tape-encryption-keys.json: %v", err) + } + + if err := applyPBSTapeConfigsFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSTapeConfigsFromStage: %v", err) + } + + for _, name := range []string{"tape.cfg", "tape-job.cfg", "media-pool.cfg"} { + dest := filepath.Join("/etc/proxmox-backup", name) + if _, err := fakeFS.Stat(dest); err != nil { + t.Fatalf("expected %s to exist: %v", dest, err) + } + } + + if info, err := fakeFS.Stat("/etc/proxmox-backup/tape-encryption-keys.json"); err != nil { + t.Fatalf("stat tape-encryption-keys.json: %v", err) + } else if info.Mode().Perm() != 0o600 { + t.Fatalf("tape-encryption-keys.json mode=%#o want %#o", info.Mode().Perm(), 0o600) + } +} + +func TestApplyPBSTapeConfigsFromStage_ContinuesOnErrors(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + // Force applyPBSConfigFileFromStage and applySensitiveFileFromStage errors (ReadFile on directory). + if err := fakeFS.MkdirAll(stageRoot+"/etc/proxmox-backup/tape.cfg", 0o755); err != nil { + t.Fatalf("mkdir staged tape.cfg dir: %v", err) + } + if err := fakeFS.MkdirAll(stageRoot+"/etc/proxmox-backup/tape-encryption-keys.json", 0o755); err != nil { + t.Fatalf("mkdir staged tape-encryption-keys.json dir: %v", err) + } + + if err := applyPBSTapeConfigsFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSTapeConfigsFromStage: %v", err) + } +} + +func TestRemoveIfExists_IgnoresMissingAndRemovesExisting(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := removeIfExists("/etc/proxmox-backup/missing.cfg"); err != nil { + t.Fatalf("removeIfExists missing: %v", err) + } + + if err := fakeFS.WriteFile("/etc/proxmox-backup/existing.cfg", []byte("x"), 0o640); err != nil { + t.Fatalf("write existing.cfg: %v", err) + } + if err := removeIfExists("/etc/proxmox-backup/existing.cfg"); err != nil { + t.Fatalf("removeIfExists existing: %v", err) + } + if _, err := fakeFS.Stat("/etc/proxmox-backup/existing.cfg"); err == nil { + t.Fatalf("expected existing.cfg removed") + } +} + +func TestRemoveIfExists_PropagatesNonExistErrors(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + // Make Remove fail with a non-ENOENT error. + if err := fakeFS.MkdirAll("/etc/proxmox-backup/dir", 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := fakeFS.WriteFile("/etc/proxmox-backup/dir/file", []byte("x"), 0o640); err != nil { + t.Fatalf("write: %v", err) + } + + if err := removeIfExists("/etc/proxmox-backup/dir"); err == nil { + t.Fatalf("expected error, got nil") + } +} diff --git a/internal/orchestrator/pbs_staged_apply_maybeapply_test.go b/internal/orchestrator/pbs_staged_apply_maybeapply_test.go new file mode 100644 index 0000000..3c4b5d1 --- /dev/null +++ b/internal/orchestrator/pbs_staged_apply_maybeapply_test.go @@ -0,0 +1,407 @@ +package orchestrator + +import ( + "context" + "errors" + "os" + "testing" + + "github.com/tis24dev/proxsave/internal/logging" +) + +func TestMaybeApplyPBSConfigsFromStage_SkipsWhenNonRoot(t *testing.T) { + origFS := restoreFS + origIsReal := pbsStagedApplyIsRealRestoreFSFn + origGeteuid := pbsStagedApplyGeteuidFn + t.Cleanup(func() { + restoreFS = origFS + pbsStagedApplyIsRealRestoreFSFn = origIsReal + pbsStagedApplyGeteuidFn = origGeteuid + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + pbsStagedApplyIsRealRestoreFSFn = func(FS) bool { return true } + pbsStagedApplyGeteuidFn = func() int { return 1000 } + + stageRoot := "/stage" + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/acme/accounts.cfg", []byte("account: a1\n foo bar\n"), 0o640); err != nil { + t.Fatalf("write staged accounts.cfg: %v", err) + } + + plan := &RestorePlan{ + SystemType: SystemTypePBS, + PBSRestoreBehavior: PBSRestoreBehaviorClean, + NormalCategories: []Category{{ID: "pbs_host"}}, + } + if err := maybeApplyPBSConfigsFromStage(context.Background(), newTestLogger(), plan, stageRoot, false); err != nil { + t.Fatalf("maybeApplyPBSConfigsFromStage: %v", err) + } + + if _, err := fakeFS.Stat("/etc/proxmox-backup/acme/accounts.cfg"); err == nil { + t.Fatalf("expected no writes when non-root") + } +} + +func TestMaybeApplyPBSConfigsFromStage_CleanMode_ApiUnavailableFallsBackToFiles(t *testing.T) { + origFS := restoreFS + origIsReal := pbsStagedApplyIsRealRestoreFSFn + origGeteuid := pbsStagedApplyGeteuidFn + origEnsure := pbsStagedApplyEnsurePBSServicesForAPIFn + t.Cleanup(func() { + restoreFS = origFS + pbsStagedApplyIsRealRestoreFSFn = origIsReal + pbsStagedApplyGeteuidFn = origGeteuid + pbsStagedApplyEnsurePBSServicesForAPIFn = origEnsure + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + pbsStagedApplyIsRealRestoreFSFn = func(FS) bool { return true } + pbsStagedApplyGeteuidFn = func() int { return 0 } + pbsStagedApplyEnsurePBSServicesForAPIFn = func(context.Context, *logging.Logger) error { + return errors.New("forced API unavailable") + } + + stageRoot := "/stage" + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/acme/accounts.cfg", []byte("account: a1\n foo bar\n"), 0o640); err != nil { + t.Fatalf("write staged accounts.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/traffic-control.cfg", []byte("traffic-control: tc1\n rate 10mbit\n"), 0o640); err != nil { + t.Fatalf("write staged traffic-control.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/node.cfg", []byte("node: n1\n description test\n"), 0o640); err != nil { + t.Fatalf("write staged node.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/s3.cfg", []byte("s3: r1\n bucket test\n"), 0o640); err != nil { + t.Fatalf("write staged s3.cfg: %v", err) + } + + safeDir := t.TempDir() + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/datastore.cfg", []byte("datastore: DS1\npath "+safeDir+"\n"), 0o640); err != nil { + t.Fatalf("write staged datastore.cfg: %v", err) + } + + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/remote.cfg", []byte("remote: r1\n host 10.0.0.10\n"), 0o640); err != nil { + t.Fatalf("write staged remote.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/sync.cfg", []byte("sync: job1\n remote r1\n"), 0o640); err != nil { + t.Fatalf("write staged sync.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/verification.cfg", []byte("verification: v1\n datastore DS1\n"), 0o640); err != nil { + t.Fatalf("write staged verification.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/prune.cfg", []byte("prune: p1\n keep-last 1\n"), 0o640); err != nil { + t.Fatalf("write staged prune.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/tape.cfg", []byte("drive: d1\n path /dev/nst0\n"), 0o640); err != nil { + t.Fatalf("write staged tape.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/tape-job.cfg", []byte("tape-job: job1\n drive d1\n"), 0o640); err != nil { + t.Fatalf("write staged tape-job.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/media-pool.cfg", []byte("media-pool: pool1\n retention 30\n"), 0o640); err != nil { + t.Fatalf("write staged media-pool.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/tape-encryption-keys.json", []byte(`{"keys":[{"fingerprint":"abc"}]}`), 0o640); err != nil { + t.Fatalf("write staged tape-encryption-keys.json: %v", err) + } + + plan := &RestorePlan{ + SystemType: SystemTypePBS, + PBSRestoreBehavior: PBSRestoreBehaviorClean, + NormalCategories: []Category{ + {ID: "pbs_host"}, + {ID: "datastore_pbs"}, + {ID: "pbs_remotes"}, + {ID: "pbs_jobs"}, + {ID: "pbs_tape"}, + }, + } + + if err := maybeApplyPBSConfigsFromStage(context.Background(), newTestLogger(), plan, stageRoot, false); err != nil { + t.Fatalf("maybeApplyPBSConfigsFromStage: %v", err) + } + + for _, path := range []string{ + "/etc/proxmox-backup/acme/accounts.cfg", + "/etc/proxmox-backup/traffic-control.cfg", + "/etc/proxmox-backup/node.cfg", + "/etc/proxmox-backup/s3.cfg", + "/etc/proxmox-backup/datastore.cfg", + "/etc/proxmox-backup/remote.cfg", + "/etc/proxmox-backup/sync.cfg", + "/etc/proxmox-backup/verification.cfg", + "/etc/proxmox-backup/prune.cfg", + "/etc/proxmox-backup/tape-encryption-keys.json", + } { + if _, err := fakeFS.Stat(path); err != nil { + t.Fatalf("expected %s to exist: %v", path, err) + } + } +} + +func TestMaybeApplyPBSConfigsFromStage_MergeMode_ApiUnavailableSkipsApiCategories(t *testing.T) { + origFS := restoreFS + origIsReal := pbsStagedApplyIsRealRestoreFSFn + origGeteuid := pbsStagedApplyGeteuidFn + origEnsure := pbsStagedApplyEnsurePBSServicesForAPIFn + t.Cleanup(func() { + restoreFS = origFS + pbsStagedApplyIsRealRestoreFSFn = origIsReal + pbsStagedApplyGeteuidFn = origGeteuid + pbsStagedApplyEnsurePBSServicesForAPIFn = origEnsure + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + pbsStagedApplyIsRealRestoreFSFn = func(FS) bool { return true } + pbsStagedApplyGeteuidFn = func() int { return 0 } + pbsStagedApplyEnsurePBSServicesForAPIFn = func(context.Context, *logging.Logger) error { + return errors.New("forced API unavailable") + } + + stageRoot := "/stage" + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/acme/accounts.cfg", []byte("account: a1\n foo bar\n"), 0o640); err != nil { + t.Fatalf("write staged accounts.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/node.cfg", []byte("node: n1\n description test\n"), 0o640); err != nil { + t.Fatalf("write staged node.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/s3.cfg", []byte("s3: r1\n bucket test\n"), 0o640); err != nil { + t.Fatalf("write staged s3.cfg: %v", err) + } + + plan := &RestorePlan{ + SystemType: SystemTypePBS, + PBSRestoreBehavior: PBSRestoreBehaviorMerge, + NormalCategories: []Category{ + {ID: "pbs_host"}, + {ID: "datastore_pbs"}, + {ID: "pbs_remotes"}, + {ID: "pbs_jobs"}, + }, + } + + if err := maybeApplyPBSConfigsFromStage(context.Background(), newTestLogger(), plan, stageRoot, false); err != nil { + t.Fatalf("maybeApplyPBSConfigsFromStage: %v", err) + } + + if _, err := fakeFS.Stat("/etc/proxmox-backup/acme/accounts.cfg"); err != nil { + t.Fatalf("expected accounts.cfg to exist: %v", err) + } + + // Merge mode requires API for these categories, so they must not be file-applied. + for _, path := range []string{ + "/etc/proxmox-backup/node.cfg", + "/etc/proxmox-backup/s3.cfg", + "/etc/proxmox-backup/datastore.cfg", + "/etc/proxmox-backup/remote.cfg", + "/etc/proxmox-backup/sync.cfg", + } { + if _, err := fakeFS.Stat(path); err == nil { + t.Fatalf("did not expect %s to be applied in merge mode without API", path) + } + } +} + +func TestMaybeApplyPBSConfigsFromStage_ApiErrorsTriggerFallbackOnlyInCleanMode(t *testing.T) { + origFS := restoreFS + origIsReal := pbsStagedApplyIsRealRestoreFSFn + origGeteuid := pbsStagedApplyGeteuidFn + origEnsure := pbsStagedApplyEnsurePBSServicesForAPIFn + origTraffic := pbsStagedApplyTrafficControlCfgViaAPIFn + origNode := pbsStagedApplyNodeCfgViaAPIFn + origS3 := pbsStagedApplyS3CfgViaAPIFn + origDS := pbsStagedApplyDatastoreCfgViaAPIFn + origRemote := pbsStagedApplyRemoteCfgViaAPIFn + origSync := pbsStagedApplySyncCfgViaAPIFn + origVerify := pbsStagedApplyVerificationCfgViaAPIFn + origPrune := pbsStagedApplyPruneCfgViaAPIFn + t.Cleanup(func() { + restoreFS = origFS + pbsStagedApplyIsRealRestoreFSFn = origIsReal + pbsStagedApplyGeteuidFn = origGeteuid + pbsStagedApplyEnsurePBSServicesForAPIFn = origEnsure + pbsStagedApplyTrafficControlCfgViaAPIFn = origTraffic + pbsStagedApplyNodeCfgViaAPIFn = origNode + pbsStagedApplyS3CfgViaAPIFn = origS3 + pbsStagedApplyDatastoreCfgViaAPIFn = origDS + pbsStagedApplyRemoteCfgViaAPIFn = origRemote + pbsStagedApplySyncCfgViaAPIFn = origSync + pbsStagedApplyVerificationCfgViaAPIFn = origVerify + pbsStagedApplyPruneCfgViaAPIFn = origPrune + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + pbsStagedApplyIsRealRestoreFSFn = func(FS) bool { return true } + pbsStagedApplyGeteuidFn = func() int { return 0 } + pbsStagedApplyEnsurePBSServicesForAPIFn = func(context.Context, *logging.Logger) error { return nil } + + var strictArgsClean []bool + var strictArgsMerge []bool + strictSink := func(strict bool) { + strictArgsClean = append(strictArgsClean, strict) + } + pbsStagedApplyTrafficControlCfgViaAPIFn = func(_ context.Context, _ *logging.Logger, _ string, strict bool) error { + strictSink(strict) + return errors.New("forced API error") + } + pbsStagedApplyNodeCfgViaAPIFn = func(context.Context, string) error { return errors.New("forced API error") } + pbsStagedApplyS3CfgViaAPIFn = func(_ context.Context, _ *logging.Logger, _ string, strict bool) error { + strictSink(strict) + return errors.New("forced API error") + } + pbsStagedApplyDatastoreCfgViaAPIFn = func(_ context.Context, _ *logging.Logger, _ string, strict bool) error { + strictSink(strict) + return errors.New("forced API error") + } + pbsStagedApplyRemoteCfgViaAPIFn = func(_ context.Context, _ *logging.Logger, _ string, strict bool) error { + strictSink(strict) + return errors.New("forced API error") + } + pbsStagedApplySyncCfgViaAPIFn = func(_ context.Context, _ *logging.Logger, _ string, strict bool) error { + strictSink(strict) + return errors.New("forced API error") + } + pbsStagedApplyVerificationCfgViaAPIFn = func(_ context.Context, _ *logging.Logger, _ string, strict bool) error { + strictSink(strict) + return errors.New("forced API error") + } + pbsStagedApplyPruneCfgViaAPIFn = func(_ context.Context, _ *logging.Logger, _ string, strict bool) error { + strictSink(strict) + return errors.New("forced API error") + } + + stageRoot := "/stage" + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/acme/accounts.cfg", []byte("account: a1\n foo bar\n"), 0o640); err != nil { + t.Fatalf("write staged accounts.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/traffic-control.cfg", []byte("traffic-control: tc1\n rate 10mbit\n"), 0o640); err != nil { + t.Fatalf("write staged traffic-control.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/node.cfg", []byte("node: n1\n description test\n"), 0o640); err != nil { + t.Fatalf("write staged node.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/s3.cfg", []byte("s3: r1\n bucket test\n"), 0o640); err != nil { + t.Fatalf("write staged s3.cfg: %v", err) + } + + safeDir := t.TempDir() + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/datastore.cfg", []byte("datastore: DS1\npath "+safeDir+"\n"), 0o640); err != nil { + t.Fatalf("write staged datastore.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/remote.cfg", []byte("remote: r1\n host 10.0.0.10\n"), 0o640); err != nil { + t.Fatalf("write staged remote.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/sync.cfg", []byte("sync: job1\n remote r1\n"), 0o640); err != nil { + t.Fatalf("write staged sync.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/verification.cfg", []byte("verification: v1\n datastore DS1\n"), 0o640); err != nil { + t.Fatalf("write staged verification.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/prune.cfg", []byte("prune: p1\n keep-last 1\n"), 0o640); err != nil { + t.Fatalf("write staged prune.cfg: %v", err) + } + + planClean := &RestorePlan{ + SystemType: SystemTypePBS, + PBSRestoreBehavior: PBSRestoreBehaviorClean, + NormalCategories: []Category{ + {ID: "pbs_host"}, + {ID: "datastore_pbs"}, + {ID: "pbs_remotes"}, + {ID: "pbs_jobs"}, + }, + } + if err := maybeApplyPBSConfigsFromStage(context.Background(), newTestLogger(), planClean, stageRoot, false); err != nil { + t.Fatalf("maybeApplyPBSConfigsFromStage clean: %v", err) + } + + for _, path := range []string{ + "/etc/proxmox-backup/traffic-control.cfg", + "/etc/proxmox-backup/node.cfg", + "/etc/proxmox-backup/s3.cfg", + "/etc/proxmox-backup/datastore.cfg", + "/etc/proxmox-backup/remote.cfg", + "/etc/proxmox-backup/sync.cfg", + } { + if _, err := fakeFS.Stat(path); err != nil { + t.Fatalf("expected %s to exist in clean fallback mode: %v", path, err) + } + } + + if len(strictArgsClean) == 0 { + t.Fatalf("expected strict API calls in clean mode") + } + for _, strict := range strictArgsClean { + if !strict { + t.Fatalf("expected strict=true in clean mode, got false") + } + } + + // In merge mode, the same API failures must not trigger file-based fallbacks. + strictSink = func(strict bool) { + strictArgsMerge = append(strictArgsMerge, strict) + } + fakeFS2 := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS2.Root) }) + restoreFS = fakeFS2 + if err := fakeFS2.WriteFile(stageRoot+"/etc/proxmox-backup/acme/accounts.cfg", []byte("account: a1\n foo bar\n"), 0o640); err != nil { + t.Fatalf("write staged accounts.cfg (merge): %v", err) + } + if err := fakeFS2.WriteFile(stageRoot+"/etc/proxmox-backup/node.cfg", []byte("node: n1\n description test\n"), 0o640); err != nil { + t.Fatalf("write staged node.cfg (merge): %v", err) + } + if err := fakeFS2.WriteFile(stageRoot+"/etc/proxmox-backup/s3.cfg", []byte("s3: r1\n bucket test\n"), 0o640); err != nil { + t.Fatalf("write staged s3.cfg (merge): %v", err) + } + if err := fakeFS2.WriteFile(stageRoot+"/etc/proxmox-backup/remote.cfg", []byte("remote: r1\n host 10.0.0.10\n"), 0o640); err != nil { + t.Fatalf("write staged remote.cfg (merge): %v", err) + } + if err := fakeFS2.WriteFile(stageRoot+"/etc/proxmox-backup/sync.cfg", []byte("sync: job1\n remote r1\n"), 0o640); err != nil { + t.Fatalf("write staged sync.cfg (merge): %v", err) + } + + planMerge := &RestorePlan{ + SystemType: SystemTypePBS, + PBSRestoreBehavior: PBSRestoreBehaviorMerge, + NormalCategories: []Category{ + {ID: "pbs_host"}, + {ID: "datastore_pbs"}, + {ID: "pbs_remotes"}, + {ID: "pbs_jobs"}, + }, + } + if err := maybeApplyPBSConfigsFromStage(context.Background(), newTestLogger(), planMerge, stageRoot, false); err != nil { + t.Fatalf("maybeApplyPBSConfigsFromStage merge: %v", err) + } + for _, path := range []string{ + "/etc/proxmox-backup/node.cfg", + "/etc/proxmox-backup/s3.cfg", + "/etc/proxmox-backup/remote.cfg", + "/etc/proxmox-backup/sync.cfg", + } { + if _, err := fakeFS2.Stat(path); err == nil { + t.Fatalf("did not expect %s to be file-applied in merge mode", path) + } + } + + if len(strictArgsMerge) == 0 { + t.Fatalf("expected strict API calls in merge mode") + } + for _, strict := range strictArgsMerge { + if strict { + t.Fatalf("expected strict=false in merge mode, got true") + } + } +} diff --git a/internal/orchestrator/restore_access_control_ui.go b/internal/orchestrator/restore_access_control_ui.go index 7dee769..f4dbe86 100644 --- a/internal/orchestrator/restore_access_control_ui.go +++ b/internal/orchestrator/restore_access_control_ui.go @@ -17,6 +17,16 @@ const defaultAccessControlRollbackTimeout = 180 * time.Second var ErrAccessControlApplyNotCommitted = errors.New("access control changes not committed") +var ( + accessControlApplyGeteuid = os.Geteuid + accessControlIsMounted = isMounted + accessControlIsRealRestoreFS = isRealRestoreFS + + accessControlArmRollback = armAccessControlRollback + accessControlDisarmRollback = disarmAccessControlRollback + accessControlApplyFromStage = applyPVEAccessControlFromStage +) + type AccessControlApplyNotCommittedError struct { RollbackLog string RollbackMarker string @@ -155,7 +165,7 @@ func maybeApplyPVEAccessControlFromClusterBackupWithUI( stageRoot string, dryRun bool, ) (err error) { - if plan == nil || plan.SystemType != SystemTypePVE || !plan.HasCategoryID("pve_access_control") || !plan.ClusterBackup || plan.NeedsClusterRestore { + if plan == nil || plan.SystemType != SystemTypePVE || !plan.HasCategoryID("pve_access_control") || !plan.ClusterBackup { return nil } @@ -165,7 +175,7 @@ func maybeApplyPVEAccessControlFromClusterBackupWithUI( if ui == nil { return fmt.Errorf("restore UI not available") } - if !isRealRestoreFS(restoreFS) { + if !accessControlIsRealRestoreFS(restoreFS) { logger.Debug("Skipping PVE access control apply (cluster backup): non-system filesystem in use") return nil } @@ -173,7 +183,7 @@ func maybeApplyPVEAccessControlFromClusterBackupWithUI( logger.Info("Dry run enabled: skipping PVE access control apply (cluster backup)") return nil } - if os.Geteuid() != 0 { + if accessControlApplyGeteuid() != 0 { logger.Warning("Skipping PVE access control apply (cluster backup): requires root privileges") return nil } @@ -198,7 +208,7 @@ func maybeApplyPVEAccessControlFromClusterBackupWithUI( } etcPVE := "/etc/pve" - mounted, mountErr := isMounted(etcPVE) + mounted, mountErr := accessControlIsMounted(etcPVE) if mountErr != nil { logger.Warning("PVE access control apply: unable to check pmxcfs mount (%s): %v", etcPVE, mountErr) } @@ -279,14 +289,14 @@ func maybeApplyPVEAccessControlFromClusterBackupWithUI( if rollbackPath != "" { logger.Info("") logger.Info("Arming access control rollback timer (%ds)...", int(defaultAccessControlRollbackTimeout.Seconds())) - rollbackHandle, err = armAccessControlRollback(ctx, logger, rollbackPath, defaultAccessControlRollbackTimeout, "/tmp/proxsave") + rollbackHandle, err = accessControlArmRollback(ctx, logger, rollbackPath, defaultAccessControlRollbackTimeout, "/tmp/proxsave") if err != nil { return fmt.Errorf("arm access control rollback: %w", err) } logger.Info("Access control rollback log: %s", rollbackHandle.logPath) } - if err := applyPVEAccessControlFromStage(ctx, logger, stageRoot); err != nil { + if err := accessControlApplyFromStage(ctx, logger, stageRoot); err != nil { return err } @@ -295,7 +305,7 @@ func maybeApplyPVEAccessControlFromClusterBackupWithUI( return nil } - remaining := rollbackHandle.remaining(time.Now()) + remaining := rollbackHandle.remaining(nowRestore()) if remaining <= 0 { return buildAccessControlApplyNotCommittedError(rollbackHandle) } @@ -317,7 +327,7 @@ func maybeApplyPVEAccessControlFromClusterBackupWithUI( } if commit { - disarmAccessControlRollback(ctx, logger, rollbackHandle) + accessControlDisarmRollback(ctx, logger, rollbackHandle) logger.Info("Access control changes committed.") return nil } @@ -353,7 +363,7 @@ func armAccessControlRollback(ctx context.Context, logger *logging.Logger, backu markerPath: filepath.Join(baseDir, fmt.Sprintf("access_control_rollback_pending_%s", timestamp)), scriptPath: filepath.Join(baseDir, fmt.Sprintf("access_control_rollback_%s.sh", timestamp)), logPath: filepath.Join(baseDir, fmt.Sprintf("access_control_rollback_%s.log", timestamp)), - armedAt: time.Now(), + armedAt: nowRestore(), timeout: timeout, } diff --git a/internal/orchestrator/restore_access_control_ui_additional_test.go b/internal/orchestrator/restore_access_control_ui_additional_test.go new file mode 100644 index 0000000..079a35e --- /dev/null +++ b/internal/orchestrator/restore_access_control_ui_additional_test.go @@ -0,0 +1,847 @@ +package orchestrator + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/tis24dev/proxsave/internal/input" + "github.com/tis24dev/proxsave/internal/logging" +) + +func TestAccessControlApplyNotCommittedError_UnwrapAndMessage(t *testing.T) { + var e *AccessControlApplyNotCommittedError + if e.Error() != ErrAccessControlApplyNotCommitted.Error() { + t.Fatalf("nil receiver Error()=%q want %q", e.Error(), ErrAccessControlApplyNotCommitted.Error()) + } + if !errors.Is(error(e), ErrAccessControlApplyNotCommitted) { + t.Fatalf("expected errors.Is(..., ErrAccessControlApplyNotCommitted) to be true") + } + if errors.Unwrap(error(e)) != ErrAccessControlApplyNotCommitted { + t.Fatalf("expected Unwrap to return ErrAccessControlApplyNotCommitted") + } + + e2 := &AccessControlApplyNotCommittedError{} + if e2.Error() != ErrAccessControlApplyNotCommitted.Error() { + t.Fatalf("Error()=%q want %q", e2.Error(), ErrAccessControlApplyNotCommitted.Error()) + } +} + +func TestAccessControlRollbackHandle_Remaining(t *testing.T) { + var h *accessControlRollbackHandle + if got := h.remaining(time.Now()); got != 0 { + t.Fatalf("nil handle remaining=%s want 0", got) + } + + handle := &accessControlRollbackHandle{ + armedAt: time.Unix(100, 0), + timeout: 10 * time.Second, + } + if got := handle.remaining(time.Unix(105, 0)); got != 5*time.Second { + t.Fatalf("remaining=%s want %s", got, 5*time.Second) + } + if got := handle.remaining(time.Unix(999, 0)); got != 0 { + t.Fatalf("remaining=%s want 0", got) + } +} + +func TestBuildAccessControlApplyNotCommittedError_PopulatesFields(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if e := buildAccessControlApplyNotCommittedError(nil); e == nil { + t.Fatalf("expected error struct, got nil") + } else if e.RollbackArmed || e.RollbackMarker != "" || e.RollbackLog != "" || !e.RollbackDeadline.IsZero() { + t.Fatalf("unexpected fields for nil handle: %#v", e) + } + + armedAt := time.Unix(10, 0) + handle := &accessControlRollbackHandle{ + markerPath: " /tmp/ac.marker \n", + logPath: " /tmp/ac.log\t", + armedAt: armedAt, + timeout: 3 * time.Second, + } + if err := fakeFS.AddFile("/tmp/ac.marker", []byte("pending\n")); err != nil { + t.Fatalf("add marker: %v", err) + } + + e := buildAccessControlApplyNotCommittedError(handle) + if e.RollbackMarker != "/tmp/ac.marker" { + t.Fatalf("RollbackMarker=%q", e.RollbackMarker) + } + if e.RollbackLog != "/tmp/ac.log" { + t.Fatalf("RollbackLog=%q", e.RollbackLog) + } + if !e.RollbackArmed { + t.Fatalf("expected RollbackArmed=true") + } + if !e.RollbackDeadline.Equal(armedAt.Add(3 * time.Second)) { + t.Fatalf("RollbackDeadline=%s want %s", e.RollbackDeadline, armedAt.Add(3*time.Second)) + } +} + +func TestStageHasPVEAccessControlConfig_DetectsFilesAndErrors(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + ok, err := stageHasPVEAccessControlConfig(" ") + if err != nil || ok { + t.Fatalf("expected ok=false err=nil for empty stageRoot, got ok=%v err=%v", ok, err) + } + + ok, err = stageHasPVEAccessControlConfig("/stage-empty") + if err != nil || ok { + t.Fatalf("expected ok=false err=nil for missing files, got ok=%v err=%v", ok, err) + } + + if err := fakeFS.AddDir("/stage-dir/etc/pve/user.cfg"); err != nil { + t.Fatalf("add staged dir: %v", err) + } + ok, err = stageHasPVEAccessControlConfig("/stage-dir") + if err != nil || ok { + t.Fatalf("expected ok=false err=nil when candidates are dirs, got ok=%v err=%v", ok, err) + } + + if err := fakeFS.AddFile("/stage-hit/etc/pve/user.cfg", []byte("x")); err != nil { + t.Fatalf("add staged file: %v", err) + } + ok, err = stageHasPVEAccessControlConfig("/stage-hit") + if err != nil || !ok { + t.Fatalf("expected ok=true err=nil when stage has access control files, got ok=%v err=%v", ok, err) + } + + fakeFS.StatErrors[filepath.Clean("/stage-err/etc/pve/user.cfg")] = fmt.Errorf("boom") + ok, err = stageHasPVEAccessControlConfig("/stage-err") + if err == nil || ok { + t.Fatalf("expected error, got ok=%v err=%v", ok, err) + } +} + +func TestBuildAccessControlRollbackScript_QuotesPaths(t *testing.T) { + script := buildAccessControlRollbackScript("/tmp/marker path", "/tmp/backup's.tar.gz", "/tmp/log path") + if !strings.Contains(script, "MARKER='/tmp/marker path'") { + t.Fatalf("expected MARKER to be quoted, got script:\n%s", script) + } + if !strings.Contains(script, "LOG='/tmp/log path'") { + t.Fatalf("expected LOG to be quoted, got script:\n%s", script) + } + if !strings.Contains(script, "BACKUP='/tmp/backup'\\''s.tar.gz'") { + t.Fatalf("expected BACKUP to escape single quotes, got script:\n%s", script) + } + if !strings.HasSuffix(script, "\n") { + t.Fatalf("expected script to end with newline") + } +} + +func TestArmAccessControlRollback_SystemdAndBackgroundPaths(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + fakeTime := &FakeTime{Current: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC)} + restoreTime = fakeTime + + logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) + + t.Run("rejects invalid args", func(t *testing.T) { + if _, err := armAccessControlRollback(context.Background(), logger, "", 1*time.Second, "/tmp/proxsave"); err == nil { + t.Fatalf("expected error for empty backupPath") + } + if _, err := armAccessControlRollback(context.Background(), logger, "/backup.tgz", 0, "/tmp/proxsave"); err == nil { + t.Fatalf("expected error for invalid timeout") + } + }) + + t.Run("uses systemd-run when available", func(t *testing.T) { + binDir := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", binDir); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + writeExecutable(t, binDir, "systemd-run") + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + handle, err := armAccessControlRollback(context.Background(), logger, "/backup.tgz", 2*time.Second, "/tmp/proxsave") + if err != nil { + t.Fatalf("armAccessControlRollback error: %v", err) + } + if handle == nil || handle.unitName == "" { + t.Fatalf("expected systemd unit name, got %#v", handle) + } + if got := fakeCmd.CallsList(); len(got) != 1 || !strings.HasPrefix(got[0], "systemd-run --unit=proxsave-access-control-rollback-20200102_030405") { + t.Fatalf("unexpected calls: %#v", got) + } + if data, err := fakeFS.ReadFile(handle.markerPath); err != nil || string(data) != "pending\n" { + t.Fatalf("marker read err=%v data=%q", err, string(data)) + } + }) + + t.Run("falls back to background timer on systemd-run failure", func(t *testing.T) { + binDir := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", binDir); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + writeExecutable(t, binDir, "systemd-run") + + fakeCmd := &FakeCommandRunner{ + Errors: map[string]error{}, + } + restoreCmd = fakeCmd + + timestamp := fakeTime.Current.Format("20060102_150405") + scriptPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("access_control_rollback_%s.sh", timestamp)) + systemdKey := "systemd-run --unit=proxsave-access-control-rollback-" + timestamp + " --on-active=2s /bin/sh " + scriptPath + fakeCmd.Errors[systemdKey] = fmt.Errorf("fail") + + handle, err := armAccessControlRollback(context.Background(), logger, "/backup.tgz", 2*time.Second, "/tmp/proxsave") + if err != nil { + t.Fatalf("armAccessControlRollback error: %v", err) + } + if handle == nil { + t.Fatalf("expected handle") + } + if handle.unitName != "" { + t.Fatalf("expected unitName cleared after systemd-run failure, got %q", handle.unitName) + } + + cmd := fmt.Sprintf("nohup sh -c 'sleep %d; /bin/sh %s' >/dev/null 2>&1 &", 2, scriptPath) + wantBackground := "sh -c " + cmd + calls := fakeCmd.CallsList() + if len(calls) != 2 || calls[1] != wantBackground { + t.Fatalf("unexpected calls: %#v", calls) + } + }) + + t.Run("background timer failure returns error", func(t *testing.T) { + emptyBin := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", emptyBin); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + + fakeCmd := &FakeCommandRunner{ + Errors: map[string]error{}, + } + restoreCmd = fakeCmd + + timestamp := fakeTime.Current.Format("20060102_150405") + scriptPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("access_control_rollback_%s.sh", timestamp)) + cmd := fmt.Sprintf("nohup sh -c 'sleep %d; /bin/sh %s' >/dev/null 2>&1 &", 1, scriptPath) + backgroundKey := "sh -c " + cmd + fakeCmd.Errors[backgroundKey] = fmt.Errorf("boom") + + if _, err := armAccessControlRollback(context.Background(), logger, "/backup.tgz", 1*time.Second, "/tmp/proxsave"); err == nil { + t.Fatalf("expected error") + } + }) +} + +func TestArmAccessControlRollback_DefaultWorkDirAndMinTimeout(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + fakeTime := &FakeTime{Current: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC)} + restoreTime = fakeTime + + emptyBin := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", emptyBin); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + handle, err := armAccessControlRollback(context.Background(), newTestLogger(), "/backup.tgz", 500*time.Millisecond, " ") + if err != nil { + t.Fatalf("armAccessControlRollback error: %v", err) + } + if handle == nil || handle.workDir != "/tmp/proxsave" { + t.Fatalf("unexpected handle: %#v", handle) + } + if data, err := fakeFS.ReadFile(handle.markerPath); err != nil || string(data) != "pending\n" { + t.Fatalf("marker err=%v data=%q", err, string(data)) + } + calls := fakeCmd.CallsList() + if len(calls) != 1 { + t.Fatalf("unexpected calls: %#v", calls) + } + if !strings.Contains(calls[0], "sleep 1; /bin/sh") { + t.Fatalf("expected timeoutSeconds to clamp to 1, got call=%q", calls[0]) + } +} + +func TestArmAccessControlRollback_ReturnsErrorOnMkdirAllFailure(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + fakeFS.MkdirAllErr = fmt.Errorf("disk full") + restoreFS = fakeFS + + if _, err := armAccessControlRollback(context.Background(), newTestLogger(), "/backup.tgz", 1*time.Second, "/tmp/proxsave"); err == nil { + t.Fatalf("expected error") + } +} + +func TestArmAccessControlRollback_ReturnsErrorOnMarkerWriteFailure(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + fakeFS.WriteErr = fmt.Errorf("disk full") + restoreFS = fakeFS + + if _, err := armAccessControlRollback(context.Background(), newTestLogger(), "/backup.tgz", 1*time.Second, "/tmp/proxsave"); err == nil || !strings.Contains(err.Error(), "write rollback marker") { + t.Fatalf("expected write rollback marker error, got %v", err) + } +} + +func TestArmAccessControlRollback_ReturnsErrorOnScriptWriteFailure(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + base := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(base.Root) }) + + fakeTime := &FakeTime{Current: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC)} + restoreTime = fakeTime + timestamp := fakeTime.Current.Format("20060102_150405") + scriptPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("access_control_rollback_%s.sh", timestamp)) + + restoreFS = writeFileFailFS{FS: base, failPath: scriptPath, err: fmt.Errorf("disk full")} + + if _, err := armAccessControlRollback(context.Background(), newTestLogger(), "/backup.tgz", 1*time.Second, "/tmp/proxsave"); err == nil || !strings.Contains(err.Error(), "write rollback script") { + t.Fatalf("expected write rollback script error, got %v", err) + } +} + +func TestDisarmAccessControlRollback_RemovesMarkerScriptAndStopsTimer(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + binDir := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", binDir); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + writeExecutable(t, binDir, "systemctl") + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + handle := &accessControlRollbackHandle{ + markerPath: "/tmp/proxsave/ac.marker", + scriptPath: "/tmp/proxsave/ac.sh", + unitName: "proxsave-access-control-rollback-test", + } + if err := fakeFS.AddFile(handle.markerPath, []byte("pending\n")); err != nil { + t.Fatalf("add marker: %v", err) + } + if err := fakeFS.AddFile(handle.scriptPath, []byte("#!/bin/sh\n")); err != nil { + t.Fatalf("add script: %v", err) + } + + disarmAccessControlRollback(context.Background(), newTestLogger(), handle) + + if _, err := fakeFS.Stat(handle.markerPath); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected marker removed; stat err=%v", err) + } + if _, err := fakeFS.Stat(handle.scriptPath); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected script removed; stat err=%v", err) + } + + timerUnit := handle.unitName + ".timer" + want1 := "systemctl stop " + timerUnit + want2 := "systemctl reset-failed " + handle.unitName + ".service " + timerUnit + calls := fakeCmd.CallsList() + if len(calls) != 2 || calls[0] != want1 || calls[1] != want2 { + t.Fatalf("unexpected calls: %#v", calls) + } +} + +func TestMaybeApplyPVEAccessControlFromClusterBackupWithUI_CoversUserFlows(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + origGeteuid := accessControlApplyGeteuid + origMounted := accessControlIsMounted + origRealFS := accessControlIsRealRestoreFS + origArm := accessControlArmRollback + origDisarm := accessControlDisarmRollback + origApply := accessControlApplyFromStage + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + accessControlApplyGeteuid = origGeteuid + accessControlIsMounted = origMounted + accessControlIsRealRestoreFS = origRealFS + accessControlArmRollback = origArm + accessControlDisarmRollback = origDisarm + accessControlApplyFromStage = origApply + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + restoreTime = &FakeTime{Current: time.Unix(100, 0)} + restoreCmd = &FakeCommandRunner{} + + accessControlIsRealRestoreFS = func(fs FS) bool { return true } + accessControlApplyGeteuid = func() int { return 0 } + accessControlIsMounted = func(path string) (bool, error) { return true, nil } + + plan := &RestorePlan{ + SystemType: SystemTypePVE, + ClusterBackup: true, + NormalCategories: []Category{{ID: "pve_access_control"}}, + } + stageWithAC := "/stage-ac" + if err := fakeFS.AddFile(stageWithAC+"/etc/pve/user.cfg", []byte("x")); err != nil { + t.Fatalf("add staged user.cfg: %v", err) + } + stageWithoutAC := "/stage-empty" + logger := newTestLogger() + + t.Run("errors when ui missing", func(t *testing.T) { + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), nil, logger, plan, nil, nil, stageWithAC, false) + if err == nil { + t.Fatalf("expected error") + } + }) + + t.Run("skips when /etc/pve not mounted", func(t *testing.T) { + accessControlIsMounted = func(path string) (bool, error) { return false, nil } + t.Cleanup(func() { accessControlIsMounted = func(path string) (bool, error) { return true, nil } }) + + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, nil, stageWithAC, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("skips when stage has no access control files", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, nil, stageWithoutAC, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("user skips apply", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{{ok: false}}, + } + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, nil, stageWithAC, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("missing rollback backup declines full rollback", func(t *testing.T) { + accessControlArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*accessControlRollbackHandle, error) { + t.Fatalf("unexpected rollback arm") + return nil, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: false}, // Skip full rollback + }, + } + safety := &SafetyBackupResult{BackupPath: "/backups/full.tgz"} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, safety, nil, stageWithAC, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("proceed without rollback applies and returns without commit prompt", func(t *testing.T) { + accessControlArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*accessControlRollbackHandle, error) { + t.Fatalf("unexpected rollback arm") + return nil, nil + } + called := false + accessControlApplyFromStage = func(ctx context.Context, logger *logging.Logger, stageRoot string) error { + called = true + return nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: true}, // Proceed without rollback + }, + } + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, nil, stageWithAC, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + if !called { + t.Fatalf("expected access control apply to be invoked") + } + }) + + t.Run("commit keeps changes and disarms rollback", func(t *testing.T) { + markerPath := "/tmp/proxsave/ac.marker" + scriptPath := "/tmp/proxsave/ac.sh" + accessControlArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*accessControlRollbackHandle, error) { + handle := &accessControlRollbackHandle{ + markerPath: markerPath, + scriptPath: scriptPath, + logPath: "/tmp/proxsave/ac.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + _ = restoreFS.WriteFile(scriptPath, []byte("#!/bin/sh\n"), 0o640) + return handle, nil + } + accessControlApplyFromStage = func(ctx context.Context, logger *logging.Logger, stageRoot string) error { return nil } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: true}, // Commit + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/access-control.tgz"} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, rollback, stageWithAC, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + if _, err := fakeFS.Stat(markerPath); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected rollback marker removed; stat err=%v", err) + } + if _, err := fakeFS.Stat(scriptPath); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected rollback script removed; stat err=%v", err) + } + }) + + t.Run("rollback requested returns typed error with marker armed", func(t *testing.T) { + markerPath := "/tmp/proxsave/ac.marker" + armedAt := nowRestore() + accessControlArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*accessControlRollbackHandle, error) { + handle := &accessControlRollbackHandle{ + markerPath: markerPath, + scriptPath: "/tmp/proxsave/ac.sh", + logPath: "/tmp/proxsave/ac.log", + armedAt: armedAt, + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + accessControlApplyFromStage = func(ctx context.Context, logger *logging.Logger, stageRoot string) error { return nil } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: false}, // Rollback + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/access-control.tgz"} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, rollback, stageWithAC, false) + if err == nil { + t.Fatalf("expected error") + } + if !errors.Is(err, ErrAccessControlApplyNotCommitted) { + t.Fatalf("expected ErrAccessControlApplyNotCommitted, got %v", err) + } + var typed *AccessControlApplyNotCommittedError + if !errors.As(err, &typed) || typed == nil { + t.Fatalf("expected AccessControlApplyNotCommittedError, got %T", err) + } + if !typed.RollbackArmed || typed.RollbackMarker != markerPath || typed.RollbackLog != "/tmp/proxsave/ac.log" { + t.Fatalf("unexpected error fields: %#v", typed) + } + if typed.RollbackDeadline.IsZero() || !typed.RollbackDeadline.Equal(armedAt.Add(defaultAccessControlRollbackTimeout)) { + t.Fatalf("unexpected RollbackDeadline=%s", typed.RollbackDeadline) + } + if _, err := fakeFS.Stat(markerPath); err != nil { + t.Fatalf("expected marker to still exist, stat err=%v", err) + } + }) + + t.Run("commit prompt abort returns abort error", func(t *testing.T) { + markerPath := "/tmp/proxsave/ac.marker" + accessControlArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*accessControlRollbackHandle, error) { + handle := &accessControlRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/ac.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + accessControlApplyFromStage = func(ctx context.Context, logger *logging.Logger, stageRoot string) error { return nil } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {err: input.ErrInputAborted}, // Abort at commit prompt + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/access-control.tgz"} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, rollback, stageWithAC, false) + if err == nil || !errors.Is(err, input.ErrInputAborted) { + t.Fatalf("expected abort error, got %v", err) + } + }) + + t.Run("commit prompt failure returns typed error", func(t *testing.T) { + markerPath := "/tmp/proxsave/ac.marker" + accessControlArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*accessControlRollbackHandle, error) { + handle := &accessControlRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/ac.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + accessControlApplyFromStage = func(ctx context.Context, logger *logging.Logger, stageRoot string) error { return nil } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {err: fmt.Errorf("boom")}, // Commit prompt fails + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/access-control.tgz"} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, rollback, stageWithAC, false) + if err == nil { + t.Fatalf("expected error") + } + if !errors.Is(err, ErrAccessControlApplyNotCommitted) { + t.Fatalf("expected ErrAccessControlApplyNotCommitted, got %v", err) + } + if _, err := fakeFS.Stat(markerPath); err != nil { + t.Fatalf("expected marker to still exist, stat err=%v", err) + } + }) + + t.Run("remaining timeout returns typed error without commit prompt", func(t *testing.T) { + markerPath := "/tmp/proxsave/ac.marker" + accessControlArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*accessControlRollbackHandle, error) { + handle := &accessControlRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/ac.log", + armedAt: nowRestore().Add(-timeout - time.Second), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + accessControlApplyFromStage = func(ctx context.Context, logger *logging.Logger, stageRoot string) error { return nil } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{{ok: true}}, // Apply now only + } + rollback := &SafetyBackupResult{BackupPath: "/backups/access-control.tgz"} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, rollback, stageWithAC, false) + if err == nil || !errors.Is(err, ErrAccessControlApplyNotCommitted) { + t.Fatalf("expected ErrAccessControlApplyNotCommitted, got %v", err) + } + }) + + t.Run("dry run skips apply", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, nil, stageWithAC, true) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("non-root skips apply", func(t *testing.T) { + accessControlApplyGeteuid = func() int { return 1000 } + t.Cleanup(func() { accessControlApplyGeteuid = func() int { return 0 } }) + + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, nil, stageWithAC, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) +} + +func TestMaybeApplyAccessControlWithUI_BranchCoverage(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + origGeteuid := accessControlApplyGeteuid + origMounted := accessControlIsMounted + origRealFS := accessControlIsRealRestoreFS + origArm := accessControlArmRollback + origApply := accessControlApplyFromStage + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + accessControlApplyGeteuid = origGeteuid + accessControlIsMounted = origMounted + accessControlIsRealRestoreFS = origRealFS + accessControlArmRollback = origArm + accessControlApplyFromStage = origApply + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Unix(100, 0)} + restoreCmd = &FakeCommandRunner{} + + accessControlIsRealRestoreFS = func(fs FS) bool { return true } + accessControlApplyGeteuid = func() int { return 0 } + accessControlIsMounted = func(path string) (bool, error) { return true, nil } + + stageRoot := "/stage" + if err := fakeFS.AddFile(stageRoot+"/etc/pve/user.cfg", []byte("x")); err != nil { + t.Fatalf("add staged user.cfg: %v", err) + } + logger := newTestLogger() + + t.Run("nil plan returns nil", func(t *testing.T) { + if err := maybeApplyAccessControlWithUI(context.Background(), &fakeRestoreWorkflowUI{}, logger, nil, nil, nil, stageRoot, false); err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("empty stageRoot skips", func(t *testing.T) { + plan := &RestorePlan{NormalCategories: []Category{{ID: "pve_access_control"}}} + if err := maybeApplyAccessControlWithUI(context.Background(), &fakeRestoreWorkflowUI{}, logger, plan, nil, nil, " ", false); err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("no relevant categories returns nil", func(t *testing.T) { + plan := &RestorePlan{NormalCategories: []Category{{ID: "pve_firewall"}}} + if err := maybeApplyAccessControlWithUI(context.Background(), &fakeRestoreWorkflowUI{}, logger, plan, nil, nil, stageRoot, false); err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("errors when ui missing", func(t *testing.T) { + plan := &RestorePlan{ + SystemType: SystemTypePVE, + ClusterBackup: true, + NormalCategories: []Category{{ID: "pve_access_control"}}, + } + if err := maybeApplyAccessControlWithUI(context.Background(), nil, logger, plan, nil, nil, stageRoot, false); err == nil { + t.Fatalf("expected error") + } + }) + + t.Run("cluster backup path is used", func(t *testing.T) { + markerPath := "/tmp/proxsave/ac.marker" + accessControlArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*accessControlRollbackHandle, error) { + handle := &accessControlRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/ac.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + accessControlApplyFromStage = func(ctx context.Context, logger *logging.Logger, stageRoot string) error { return nil } + + plan := &RestorePlan{ + SystemType: SystemTypePVE, + ClusterBackup: true, + NormalCategories: []Category{{ID: "pve_access_control"}}, + } + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: false}, // Rollback + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/access-control.tgz"} + err := maybeApplyAccessControlWithUI(context.Background(), ui, logger, plan, nil, rollback, stageRoot, false) + if err == nil || !errors.Is(err, ErrAccessControlApplyNotCommitted) { + t.Fatalf("expected ErrAccessControlApplyNotCommitted, got %v", err) + } + }) + + t.Run("default path uses stage apply", func(t *testing.T) { + plan := &RestorePlan{ + SystemType: SystemTypePBS, + NormalCategories: []Category{{ID: "pbs_access_control"}}, + } + if err := maybeApplyAccessControlWithUI(context.Background(), &fakeRestoreWorkflowUI{}, logger, plan, nil, nil, stageRoot, false); err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) +} + diff --git a/internal/orchestrator/restore_firewall.go b/internal/orchestrator/restore_firewall.go index 64c7419..eeb5daa 100644 --- a/internal/orchestrator/restore_firewall.go +++ b/internal/orchestrator/restore_firewall.go @@ -17,6 +17,18 @@ const defaultFirewallRollbackTimeout = 180 * time.Second var ErrFirewallApplyNotCommitted = errors.New("firewall configuration not committed") +var ( + firewallApplyGeteuid = os.Geteuid + firewallHostname = os.Hostname + firewallIsMounted = isMounted + firewallIsRealRestoreFS = isRealRestoreFS + + firewallArmRollback = armFirewallRollback + firewallDisarmRollback = disarmFirewallRollback + firewallApplyFromStage = applyPVEFirewallFromStage + firewallRestartService = restartPVEFirewallService +) + type FirewallApplyNotCommittedError struct { RollbackLog string RollbackMarker string @@ -99,7 +111,7 @@ func maybeApplyPVEFirewallWithUI( if ui == nil { return fmt.Errorf("restore UI not available") } - if !isRealRestoreFS(restoreFS) { + if !firewallIsRealRestoreFS(restoreFS) { logger.Debug("Skipping PVE firewall restore: non-system filesystem in use") return nil } @@ -107,7 +119,7 @@ func maybeApplyPVEFirewallWithUI( logger.Info("Dry run enabled: skipping PVE firewall restore") return nil } - if os.Geteuid() != 0 { + if firewallApplyGeteuid() != 0 { logger.Warning("Skipping PVE firewall restore: requires root privileges") return nil } @@ -123,7 +135,7 @@ func maybeApplyPVEFirewallWithUI( } etcPVE := "/etc/pve" - mounted, mountErr := isMounted(etcPVE) + mounted, mountErr := firewallIsMounted(etcPVE) if mountErr != nil { logger.Warning("PVE firewall restore: unable to check pmxcfs mount (%s): %v", etcPVE, mountErr) } @@ -214,26 +226,26 @@ func maybeApplyPVEFirewallWithUI( if rollbackPath != "" { logger.Info("") logger.Info("Arming firewall rollback timer (%ds)...", int(defaultFirewallRollbackTimeout.Seconds())) - rollbackHandle, err = armFirewallRollback(ctx, logger, rollbackPath, defaultFirewallRollbackTimeout, "/tmp/proxsave") + rollbackHandle, err = firewallArmRollback(ctx, logger, rollbackPath, defaultFirewallRollbackTimeout, "/tmp/proxsave") if err != nil { return fmt.Errorf("arm firewall rollback: %w", err) } logger.Info("Firewall rollback log: %s", rollbackHandle.logPath) } - applied, err := applyPVEFirewallFromStage(logger, stageRoot) + applied, err := firewallApplyFromStage(logger, stageRoot) if err != nil { return err } if len(applied) == 0 { logger.Info("PVE firewall restore: no changes applied (stage contained no firewall entries)") if rollbackHandle != nil { - disarmFirewallRollback(ctx, logger, rollbackHandle) + firewallDisarmRollback(ctx, logger, rollbackHandle) } return nil } - if err := restartPVEFirewallService(ctx); err != nil { + if err := firewallRestartService(ctx); err != nil { logger.Warning("PVE firewall restore: reload/restart failed: %v", err) } @@ -242,7 +254,7 @@ func maybeApplyPVEFirewallWithUI( return nil } - remaining := rollbackHandle.remaining(time.Now()) + remaining := rollbackHandle.remaining(nowRestore()) if remaining <= 0 { return buildFirewallApplyNotCommittedError(rollbackHandle) } @@ -264,7 +276,7 @@ func maybeApplyPVEFirewallWithUI( } if commit { - disarmFirewallRollback(ctx, logger, rollbackHandle) + firewallDisarmRollback(ctx, logger, rollbackHandle) logger.Info("Firewall changes committed.") return nil } @@ -309,7 +321,7 @@ func applyPVEFirewallFromStage(logger *logging.Logger, stageRoot string) (applie return applied, err } if ok { - currentNode, _ := os.Hostname() + currentNode, _ := firewallHostname() currentNode = shortHost(currentNode) if strings.TrimSpace(currentNode) == "" { currentNode = "localhost" @@ -331,7 +343,7 @@ func applyPVEFirewallFromStage(logger *logging.Logger, stageRoot string) (applie } func selectStageHostFirewall(logger *logging.Logger, stageRoot string) (path string, sourceNode string, ok bool, err error) { - currentNode, _ := os.Hostname() + currentNode, _ := firewallHostname() currentNode = shortHost(currentNode) if strings.TrimSpace(currentNode) == "" { currentNode = "localhost" @@ -430,7 +442,7 @@ func armFirewallRollback(ctx context.Context, logger *logging.Logger, backupPath markerPath: filepath.Join(baseDir, fmt.Sprintf("firewall_rollback_pending_%s", timestamp)), scriptPath: filepath.Join(baseDir, fmt.Sprintf("firewall_rollback_%s.sh", timestamp)), logPath: filepath.Join(baseDir, fmt.Sprintf("firewall_rollback_%s.log", timestamp)), - armedAt: time.Now(), + armedAt: nowRestore(), timeout: timeout, } diff --git a/internal/orchestrator/restore_firewall_additional_test.go b/internal/orchestrator/restore_firewall_additional_test.go new file mode 100644 index 0000000..525c269 --- /dev/null +++ b/internal/orchestrator/restore_firewall_additional_test.go @@ -0,0 +1,2146 @@ +package orchestrator + +import ( + "context" + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "github.com/tis24dev/proxsave/internal/input" + "github.com/tis24dev/proxsave/internal/logging" +) + +type readDirFailFS struct { + FS + failPath string + err error +} + +func (f readDirFailFS) ReadDir(path string) ([]os.DirEntry, error) { + if filepath.Clean(path) == filepath.Clean(f.failPath) { + return nil, f.err + } + return f.FS.ReadDir(path) +} + +type readDirOverrideFS struct { + FS + overridePath string + entries []os.DirEntry + err error +} + +func (f readDirOverrideFS) ReadDir(path string) ([]os.DirEntry, error) { + if filepath.Clean(path) == filepath.Clean(f.overridePath) { + if f.err != nil { + return nil, f.err + } + return f.entries, nil + } + return f.FS.ReadDir(path) +} + +type statFailFS struct { + FS + failPath string + err error +} + +func (f statFailFS) Stat(path string) (os.FileInfo, error) { + if filepath.Clean(path) == filepath.Clean(f.failPath) { + return nil, f.err + } + return f.FS.Stat(path) +} + +type readFileFailFS struct { + FS + failPath string + err error +} + +func (f readFileFailFS) ReadFile(path string) ([]byte, error) { + if filepath.Clean(path) == filepath.Clean(f.failPath) { + return nil, f.err + } + return f.FS.ReadFile(path) +} + +type writeFileFailFS struct { + FS + failPath string + err error +} + +func (f writeFileFailFS) WriteFile(path string, data []byte, perm os.FileMode) error { + if filepath.Clean(path) == filepath.Clean(f.failPath) { + return f.err + } + return f.FS.WriteFile(path, data, perm) +} + +type removeFailFS struct { + FS + failPath string + err error +} + +func (f removeFailFS) Remove(path string) error { + if filepath.Clean(path) == filepath.Clean(f.failPath) { + return f.err + } + return f.FS.Remove(path) +} + +type readlinkFailFS struct { + FS + failPath string + err error +} + +func (f readlinkFailFS) Readlink(path string) (string, error) { + if filepath.Clean(path) == filepath.Clean(f.failPath) { + return "", f.err + } + return f.FS.Readlink(path) +} + +type badInfoDirEntry struct { + name string +} + +func (e badInfoDirEntry) Name() string { return e.name } +func (e badInfoDirEntry) IsDir() bool { return false } +func (e badInfoDirEntry) Type() fs.FileMode { return 0 } +func (e badInfoDirEntry) Info() (fs.FileInfo, error) { return nil, fmt.Errorf("boom") } + +type symlinkFailFS struct { + FS + failNewname string + err error +} + +func (f symlinkFailFS) Symlink(oldname, newname string) error { + if filepath.Clean(newname) == filepath.Clean(f.failNewname) { + return f.err + } + return f.FS.Symlink(oldname, newname) +} + +type statFailOnNthFS struct { + FS + path string + failOn int + calls int + err error +} + +func (f *statFailOnNthFS) Stat(path string) (os.FileInfo, error) { + if filepath.Clean(path) == filepath.Clean(f.path) { + f.calls++ + if f.calls >= f.failOn { + return nil, f.err + } + } + return f.FS.Stat(path) +} + +type multiReadDirFS struct { + FS + entries map[string][]os.DirEntry + errors map[string]error +} + +func (f multiReadDirFS) ReadDir(path string) ([]os.DirEntry, error) { + clean := filepath.Clean(path) + if f.errors != nil { + if err, ok := f.errors[clean]; ok { + return nil, err + } + } + if f.entries != nil { + if entries, ok := f.entries[clean]; ok { + return entries, nil + } + } + return f.FS.ReadDir(path) +} + +type staticFileInfo struct { + name string + mode fs.FileMode +} + +func (i staticFileInfo) Name() string { return i.name } +func (i staticFileInfo) Size() int64 { return 0 } +func (i staticFileInfo) Mode() fs.FileMode { return i.mode } +func (i staticFileInfo) ModTime() time.Time { return time.Time{} } +func (i staticFileInfo) IsDir() bool { return i.mode.IsDir() } +func (i staticFileInfo) Sys() any { return nil } + +type staticDirEntry struct { + name string + mode fs.FileMode +} + +func (e staticDirEntry) Name() string { return e.name } +func (e staticDirEntry) IsDir() bool { return e.mode.IsDir() } +func (e staticDirEntry) Type() fs.FileMode { return e.mode } +func (e staticDirEntry) Info() (fs.FileInfo, error) { + return staticFileInfo{name: e.name, mode: e.mode}, nil +} + +type scriptedConfirmAction struct { + ok bool + err error +} + +type scriptedRestoreWorkflowUI struct { + *fakeRestoreWorkflowUI + script []scriptedConfirmAction + calls int +} + +func (s *scriptedRestoreWorkflowUI) ConfirmAction(ctx context.Context, title, message, yesLabel, noLabel string, timeout time.Duration, defaultYes bool) (bool, error) { + if s.calls >= len(s.script) { + return false, fmt.Errorf("unexpected ConfirmAction call %d (title=%q)", s.calls+1, strings.TrimSpace(title)) + } + action := s.script[s.calls] + s.calls++ + return action.ok, action.err +} + +func writeExecutable(t *testing.T, dir, name string) string { + t.Helper() + path := filepath.Join(dir, name) + if err := os.WriteFile(path, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { + t.Fatalf("write executable %s: %v", name, err) + } + return path +} + +func TestFirewallApplyNotCommittedError_UnwrapAndMessage(t *testing.T) { + var e *FirewallApplyNotCommittedError + if e.Error() != ErrFirewallApplyNotCommitted.Error() { + t.Fatalf("nil receiver Error()=%q want %q", e.Error(), ErrFirewallApplyNotCommitted.Error()) + } + if !errors.Is(error(e), ErrFirewallApplyNotCommitted) { + t.Fatalf("expected errors.Is(..., ErrFirewallApplyNotCommitted) to be true") + } + if errors.Unwrap(error(e)) != ErrFirewallApplyNotCommitted { + t.Fatalf("expected Unwrap to return ErrFirewallApplyNotCommitted") + } + + e2 := &FirewallApplyNotCommittedError{} + if e2.Error() != ErrFirewallApplyNotCommitted.Error() { + t.Fatalf("Error()=%q want %q", e2.Error(), ErrFirewallApplyNotCommitted.Error()) + } +} + +func TestFirewallRollbackHandle_Remaining(t *testing.T) { + var h *firewallRollbackHandle + if got := h.remaining(time.Now()); got != 0 { + t.Fatalf("nil handle remaining=%s want 0", got) + } + + handle := &firewallRollbackHandle{ + armedAt: time.Unix(100, 0), + timeout: 10 * time.Second, + } + if got := handle.remaining(time.Unix(105, 0)); got != 5*time.Second { + t.Fatalf("remaining=%s want %s", got, 5*time.Second) + } + if got := handle.remaining(time.Unix(999, 0)); got != 0 { + t.Fatalf("remaining=%s want 0", got) + } +} + +func TestBuildFirewallApplyNotCommittedError_PopulatesFields(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if e := buildFirewallApplyNotCommittedError(nil); e == nil { + t.Fatalf("expected error struct, got nil") + } else if e.RollbackArmed || e.RollbackMarker != "" || e.RollbackLog != "" || !e.RollbackDeadline.IsZero() { + t.Fatalf("unexpected fields for nil handle: %#v", e) + } + + armedAt := time.Unix(10, 0) + handle := &firewallRollbackHandle{ + markerPath: " /tmp/fw.marker \n", + logPath: " /tmp/fw.log\t", + armedAt: armedAt, + timeout: 3 * time.Second, + } + if err := fakeFS.AddFile("/tmp/fw.marker", []byte("pending\n")); err != nil { + t.Fatalf("add marker: %v", err) + } + + e := buildFirewallApplyNotCommittedError(handle) + if e.RollbackMarker != "/tmp/fw.marker" { + t.Fatalf("RollbackMarker=%q", e.RollbackMarker) + } + if e.RollbackLog != "/tmp/fw.log" { + t.Fatalf("RollbackLog=%q", e.RollbackLog) + } + if !e.RollbackArmed { + t.Fatalf("expected RollbackArmed=true") + } + if !e.RollbackDeadline.Equal(armedAt.Add(3 * time.Second)) { + t.Fatalf("RollbackDeadline=%s want %s", e.RollbackDeadline, armedAt.Add(3*time.Second)) + } +} + +func TestBuildFirewallRollbackScript_QuotesPaths(t *testing.T) { + script := buildFirewallRollbackScript("/tmp/marker path", "/tmp/backup's.tar.gz", "/tmp/log path") + if !strings.Contains(script, "MARKER='/tmp/marker path'") { + t.Fatalf("expected MARKER to be quoted, got script:\n%s", script) + } + if !strings.Contains(script, "LOG='/tmp/log path'") { + t.Fatalf("expected LOG to be quoted, got script:\n%s", script) + } + if !strings.Contains(script, "BACKUP='/tmp/backup'\\''s.tar.gz'") { + t.Fatalf("expected BACKUP to escape single quotes, got script:\n%s", script) + } + if !strings.HasSuffix(script, "\n") { + t.Fatalf("expected script to end with newline") + } +} + +func TestCopyFileExact_ReturnsFalseWhenSourceMissingOrDir(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + ok, err := copyFileExact("/missing", "/dest") + if err != nil || ok { + t.Fatalf("copyFileExact missing ok=%v err=%v want ok=false err=nil", ok, err) + } + + if err := fakeFS.AddDir("/srcdir"); err != nil { + t.Fatalf("add dir: %v", err) + } + ok, err = copyFileExact("/srcdir", "/dest") + if err != nil || ok { + t.Fatalf("copyFileExact dir ok=%v err=%v want ok=false err=nil", ok, err) + } +} + +func TestCopyFileExact_PropagatesAtomicWriteFailure(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + fakeTime := &FakeTime{Current: time.Unix(0, 12345)} + restoreTime = fakeTime + + if err := fakeFS.AddFile("/src/file", []byte("data\n")); err != nil { + t.Fatalf("add src: %v", err) + } + + dest := "/dest/file" + tmpPath := dest + ".proxsave.tmp." + strconv.FormatInt(fakeTime.Current.UnixNano(), 10) + fakeFS.OpenFileErr[filepath.Clean(tmpPath)] = fmt.Errorf("open tmp denied") + + _, err := copyFileExact("/src/file", dest) + if err == nil { + t.Fatalf("expected error") + } +} + +func TestSyncDirExact_CopiesSymlinks(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.AddFile("/stage/target", []byte("x")); err != nil { + t.Fatalf("add target: %v", err) + } + if err := fakeFS.Symlink("/stage/target", "/stage/link"); err != nil { + t.Fatalf("add symlink: %v", err) + } + + applied, err := syncDirExact("/stage", "/dest") + if err != nil { + t.Fatalf("syncDirExact error: %v", err) + } + + destTarget, err := fakeFS.Readlink("/dest/link") + if err != nil { + t.Fatalf("read dest symlink: %v", err) + } + if strings.TrimSpace(destTarget) == "" { + t.Fatalf("expected non-empty symlink target") + } + found := false + for _, p := range applied { + if p == "/dest/link" { + found = true + break + } + } + if !found { + t.Fatalf("expected /dest/link to be reported as applied, got %#v", applied) + } +} + +func TestSelectStageHostFirewall_ErrorsOnReadDirFailure(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + base := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(base.Root) }) + restoreFS = readDirFailFS{FS: base, failPath: "/stage/etc/pve/nodes", err: fmt.Errorf("boom")} + + _, _, _, err := selectStageHostFirewall(newTestLogger(), "/stage") + if err == nil || !strings.Contains(err.Error(), "readdir") { + t.Fatalf("expected readdir error, got %v", err) + } +} + +func TestSelectStageHostFirewall_PicksCurrentNodeWhenPresent(t *testing.T) { + origFS := restoreFS + origHostname := firewallHostname + t.Cleanup(func() { + restoreFS = origFS + firewallHostname = origHostname + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + firewallHostname = func() (string, error) { return "node1.example", nil } + + stageRoot := "/stage" + if err := fakeFS.AddFile(stageRoot+"/etc/pve/nodes/node1/host.fw", []byte("a")); err != nil { + t.Fatalf("add node1 host.fw: %v", err) + } + if err := fakeFS.AddFile(stageRoot+"/etc/pve/nodes/other/host.fw", []byte("b")); err != nil { + t.Fatalf("add other host.fw: %v", err) + } + if err := fakeFS.AddFile(stageRoot+"/etc/pve/nodes/notadir", []byte("c")); err != nil { + t.Fatalf("add notadir: %v", err) + } + + path, sourceNode, ok, err := selectStageHostFirewall(newTestLogger(), stageRoot) + if err != nil { + t.Fatalf("selectStageHostFirewall error: %v", err) + } + if !ok { + t.Fatalf("expected ok=true") + } + if sourceNode != "node1" { + t.Fatalf("sourceNode=%q want %q", sourceNode, "node1") + } + if !strings.HasSuffix(path, "/stage/etc/pve/nodes/node1/host.fw") { + t.Fatalf("unexpected path: %q", path) + } +} + +func TestSelectStageHostFirewall_SkipsWhenMultipleCandidatesNoneMatches(t *testing.T) { + origFS := restoreFS + origHostname := firewallHostname + t.Cleanup(func() { + restoreFS = origFS + firewallHostname = origHostname + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + firewallHostname = func() (string, error) { return "current", nil } + + stageRoot := "/stage" + for _, node := range []string{"nodeA", "nodeB"} { + if err := fakeFS.AddFile(stageRoot+"/etc/pve/nodes/"+node+"/host.fw", []byte("x")); err != nil { + t.Fatalf("add host.fw for %s: %v", node, err) + } + } + + path, sourceNode, ok, err := selectStageHostFirewall(newTestLogger(), stageRoot) + if err != nil { + t.Fatalf("selectStageHostFirewall error: %v", err) + } + if ok || path != "" || sourceNode != "" { + t.Fatalf("expected skip, got ok=%v path=%q source=%q", ok, path, sourceNode) + } +} + +func TestRestartPVEFirewallService_CommandFallbacks(t *testing.T) { + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + + binDir := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", binDir); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + + writeExecutable(t, binDir, "systemctl") + writeExecutable(t, binDir, "pve-firewall") + + t.Run("try-restart ok", func(t *testing.T) { + fake := &FakeCommandRunner{} + restoreCmd = fake + + if err := restartPVEFirewallService(context.Background()); err != nil { + t.Fatalf("restartPVEFirewallService error: %v", err) + } + if got := fake.CallsList(); len(got) != 1 || got[0] != "systemctl try-restart pve-firewall" { + t.Fatalf("unexpected calls: %#v", got) + } + }) + + t.Run("restart ok", func(t *testing.T) { + fake := &FakeCommandRunner{ + Errors: map[string]error{ + "systemctl try-restart pve-firewall": fmt.Errorf("fail"), + }, + } + restoreCmd = fake + + if err := restartPVEFirewallService(context.Background()); err != nil { + t.Fatalf("restartPVEFirewallService error: %v", err) + } + if got := fake.CallsList(); len(got) != 2 || got[0] != "systemctl try-restart pve-firewall" || got[1] != "systemctl restart pve-firewall" { + t.Fatalf("unexpected calls: %#v", got) + } + }) + + t.Run("fallback to pve-firewall", func(t *testing.T) { + fake := &FakeCommandRunner{ + Errors: map[string]error{ + "systemctl try-restart pve-firewall": fmt.Errorf("fail"), + "systemctl restart pve-firewall": fmt.Errorf("fail"), + }, + } + restoreCmd = fake + + if err := restartPVEFirewallService(context.Background()); err != nil { + t.Fatalf("restartPVEFirewallService error: %v", err) + } + calls := fake.CallsList() + if len(calls) != 3 { + t.Fatalf("unexpected calls: %#v", calls) + } + if calls[2] != "pve-firewall restart" { + t.Fatalf("expected fallback call, got %#v", calls) + } + }) + + t.Run("no commands available", func(t *testing.T) { + fake := &FakeCommandRunner{} + restoreCmd = fake + + emptyBin := t.TempDir() + if err := os.Setenv("PATH", emptyBin); err != nil { + t.Fatalf("set PATH: %v", err) + } + + if err := restartPVEFirewallService(context.Background()); err == nil { + t.Fatalf("expected error") + } + }) +} + +func TestArmFirewallRollback_SystemdAndBackgroundPaths(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + fakeTime := &FakeTime{Current: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC)} + restoreTime = fakeTime + + logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) + + t.Run("rejects invalid args", func(t *testing.T) { + if _, err := armFirewallRollback(context.Background(), logger, "", 1*time.Second, "/tmp/proxsave"); err == nil { + t.Fatalf("expected error for empty backupPath") + } + if _, err := armFirewallRollback(context.Background(), logger, "/backup.tgz", 0, "/tmp/proxsave"); err == nil { + t.Fatalf("expected error for invalid timeout") + } + }) + + t.Run("uses systemd-run when available", func(t *testing.T) { + binDir := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", binDir); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + writeExecutable(t, binDir, "systemd-run") + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + handle, err := armFirewallRollback(context.Background(), logger, "/backup.tgz", 2*time.Second, "/tmp/proxsave") + if err != nil { + t.Fatalf("armFirewallRollback error: %v", err) + } + if handle == nil || handle.unitName == "" { + t.Fatalf("expected systemd unit name, got %#v", handle) + } + if got := fakeCmd.CallsList(); len(got) != 1 || !strings.HasPrefix(got[0], "systemd-run --unit=proxsave-firewall-rollback-20200102_030405") { + t.Fatalf("unexpected calls: %#v", got) + } + if data, err := fakeFS.ReadFile(handle.markerPath); err != nil || string(data) != "pending\n" { + t.Fatalf("marker read err=%v data=%q", err, string(data)) + } + }) + + t.Run("falls back to background timer on systemd-run failure", func(t *testing.T) { + binDir := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", binDir); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + writeExecutable(t, binDir, "systemd-run") + + fakeCmd := &FakeCommandRunner{ + Errors: map[string]error{}, + } + restoreCmd = fakeCmd + + timestamp := fakeTime.Current.Format("20060102_150405") + scriptPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("firewall_rollback_%s.sh", timestamp)) + systemdKey := "systemd-run --unit=proxsave-firewall-rollback-" + timestamp + " --on-active=2s /bin/sh " + scriptPath + fakeCmd.Errors[systemdKey] = fmt.Errorf("fail") + + handle, err := armFirewallRollback(context.Background(), logger, "/backup.tgz", 2*time.Second, "/tmp/proxsave") + if err != nil { + t.Fatalf("armFirewallRollback error: %v", err) + } + if handle == nil { + t.Fatalf("expected handle") + } + if handle.unitName != "" { + t.Fatalf("expected unitName cleared after systemd-run failure, got %q", handle.unitName) + } + + cmd := fmt.Sprintf("nohup sh -c 'sleep %d; /bin/sh %s' >/dev/null 2>&1 &", 2, scriptPath) + wantBackground := "sh -c " + cmd + calls := fakeCmd.CallsList() + if len(calls) != 2 || calls[1] != wantBackground { + t.Fatalf("unexpected calls: %#v", calls) + } + }) + + t.Run("background timer failure returns error", func(t *testing.T) { + emptyBin := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", emptyBin); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + + fakeCmd := &FakeCommandRunner{ + Errors: map[string]error{}, + } + restoreCmd = fakeCmd + + timestamp := fakeTime.Current.Format("20060102_150405") + scriptPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("firewall_rollback_%s.sh", timestamp)) + cmd := fmt.Sprintf("nohup sh -c 'sleep %d; /bin/sh %s' >/dev/null 2>&1 &", 1, scriptPath) + backgroundKey := "sh -c " + cmd + fakeCmd.Errors[backgroundKey] = fmt.Errorf("boom") + + if _, err := armFirewallRollback(context.Background(), logger, "/backup.tgz", 1*time.Second, "/tmp/proxsave"); err == nil { + t.Fatalf("expected error") + } + }) +} + +func TestDisarmFirewallRollback_RemovesMarkerAndStopsTimer(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + binDir := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", binDir); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + writeExecutable(t, binDir, "systemctl") + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + handle := &firewallRollbackHandle{ + markerPath: "/tmp/proxsave/fw.marker", + unitName: "proxsave-firewall-rollback-test", + } + if err := fakeFS.AddFile(handle.markerPath, []byte("pending\n")); err != nil { + t.Fatalf("add marker: %v", err) + } + + disarmFirewallRollback(context.Background(), newTestLogger(), handle) + + if _, err := fakeFS.Stat(handle.markerPath); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected marker removed; stat err=%v", err) + } + + timerUnit := handle.unitName + ".timer" + want1 := "systemctl stop " + timerUnit + want2 := "systemctl reset-failed " + handle.unitName + ".service " + timerUnit + calls := fakeCmd.CallsList() + if len(calls) != 2 || calls[0] != want1 || calls[1] != want2 { + t.Fatalf("unexpected calls: %#v", calls) + } +} + +func TestMaybeApplyPVEFirewallWithUI_CoversUserFlows(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + origGeteuid := firewallApplyGeteuid + origMounted := firewallIsMounted + origRealFS := firewallIsRealRestoreFS + origArm := firewallArmRollback + origDisarm := firewallDisarmRollback + origApply := firewallApplyFromStage + origRestart := firewallRestartService + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + firewallApplyGeteuid = origGeteuid + firewallIsMounted = origMounted + firewallIsRealRestoreFS = origRealFS + firewallArmRollback = origArm + firewallDisarmRollback = origDisarm + firewallApplyFromStage = origApply + firewallRestartService = origRestart + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + restoreTime = &FakeTime{Current: time.Unix(100, 0)} + restoreCmd = &FakeCommandRunner{} + + firewallIsRealRestoreFS = func(fs FS) bool { return true } + firewallApplyGeteuid = func() int { return 0 } + firewallIsMounted = func(path string) (bool, error) { return true, nil } + firewallRestartService = func(ctx context.Context) error { return nil } + + plan := &RestorePlan{ + SystemType: SystemTypePVE, + NormalCategories: []Category{{ID: "pve_firewall"}}, + } + stageRoot := "/stage" + logger := newTestLogger() + + t.Run("errors when ui missing", func(t *testing.T) { + err := maybeApplyPVEFirewallWithUI(context.Background(), nil, logger, plan, nil, nil, stageRoot, false) + if err == nil { + t.Fatalf("expected error") + } + }) + + t.Run("skips when /etc/pve not mounted", func(t *testing.T) { + firewallIsMounted = func(path string) (bool, error) { return false, nil } + t.Cleanup(func() { firewallIsMounted = func(path string) (bool, error) { return true, nil } }) + + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, stageRoot, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("skips when stage has no firewall data", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, stageRoot, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("user skips apply", func(t *testing.T) { + if err := fakeFS.AddDir(stageRoot + "/etc/pve/nodes"); err != nil { + t.Fatalf("add stage nodes: %v", err) + } + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{{ok: false}}, + } + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, stageRoot, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("missing rollback backup declines full rollback", func(t *testing.T) { + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + t.Fatalf("unexpected rollback arm") + return nil, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: false}, // Skip full rollback + }, + } + safety := &SafetyBackupResult{BackupPath: "/backups/full.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, safety, nil, stageRoot, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("full rollback accepted but no changes applied disarms rollback", func(t *testing.T) { + markerPath := "/tmp/proxsave/fw.marker" + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + handle := &firewallRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/fw.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { return nil, nil } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: true}, // Proceed with full rollback + }, + } + safety := &SafetyBackupResult{BackupPath: "/backups/full.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, safety, nil, stageRoot, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + if _, err := fakeFS.Stat(markerPath); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected rollback marker removed; stat err=%v", err) + } + }) + + t.Run("proceed without rollback applies and returns without commit prompt", func(t *testing.T) { + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + t.Fatalf("unexpected rollback arm") + return nil, nil + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/firewall/cluster.fw"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: true}, // Proceed without rollback + }, + } + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, stageRoot, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("commit keeps changes and disarms rollback", func(t *testing.T) { + markerPath := "/tmp/proxsave/fw.marker" + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + handle := &firewallRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/fw.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/firewall/cluster.fw"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: true}, // Commit + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/firewall.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, rollback, stageRoot, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + if _, err := fakeFS.Stat(markerPath); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected rollback marker removed; stat err=%v", err) + } + }) + + t.Run("rollback requested returns typed error with marker armed", func(t *testing.T) { + markerPath := "/tmp/proxsave/fw.marker" + armedAt := nowRestore() + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + handle := &firewallRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/fw.log", + armedAt: armedAt, + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/firewall/cluster.fw"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: false}, // Rollback + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/firewall.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, rollback, stageRoot, false) + if err == nil { + t.Fatalf("expected error") + } + if !errors.Is(err, ErrFirewallApplyNotCommitted) { + t.Fatalf("expected ErrFirewallApplyNotCommitted, got %v", err) + } + var typed *FirewallApplyNotCommittedError + if !errors.As(err, &typed) || typed == nil { + t.Fatalf("expected FirewallApplyNotCommittedError, got %T", err) + } + if !typed.RollbackArmed || typed.RollbackMarker != markerPath || typed.RollbackLog != "/tmp/proxsave/fw.log" { + t.Fatalf("unexpected error fields: %#v", typed) + } + if typed.RollbackDeadline.IsZero() || !typed.RollbackDeadline.Equal(armedAt.Add(defaultFirewallRollbackTimeout)) { + t.Fatalf("unexpected RollbackDeadline=%s", typed.RollbackDeadline) + } + if _, err := fakeFS.Stat(markerPath); err != nil { + t.Fatalf("expected marker to still exist, stat err=%v", err) + } + }) + + t.Run("commit prompt abort returns abort error", func(t *testing.T) { + markerPath := "/tmp/proxsave/fw.marker" + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + handle := &firewallRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/fw.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/firewall/cluster.fw"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {err: input.ErrInputAborted}, // Abort at commit prompt + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/firewall.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, rollback, stageRoot, false) + if err == nil || !errors.Is(err, input.ErrInputAborted) { + t.Fatalf("expected abort error, got %v", err) + } + }) + + t.Run("commit prompt failure returns typed error", func(t *testing.T) { + markerPath := "/tmp/proxsave/fw.marker" + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + handle := &firewallRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/fw.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/firewall/cluster.fw"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {err: fmt.Errorf("boom")}, // Commit prompt fails + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/firewall.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, rollback, stageRoot, false) + if err == nil { + t.Fatalf("expected error") + } + if !errors.Is(err, ErrFirewallApplyNotCommitted) { + t.Fatalf("expected ErrFirewallApplyNotCommitted, got %v", err) + } + if _, err := fakeFS.Stat(markerPath); err != nil { + t.Fatalf("expected marker to still exist, stat err=%v", err) + } + }) + + t.Run("remaining timeout returns typed error without commit prompt", func(t *testing.T) { + markerPath := "/tmp/proxsave/fw.marker" + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + handle := &firewallRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/fw.log", + armedAt: nowRestore().Add(-timeout - time.Second), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/firewall/cluster.fw"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{{ok: true}}, // Apply now only + } + rollback := &SafetyBackupResult{BackupPath: "/backups/firewall.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, rollback, stageRoot, false) + if err == nil || !errors.Is(err, ErrFirewallApplyNotCommitted) { + t.Fatalf("expected ErrFirewallApplyNotCommitted, got %v", err) + } + }) + + t.Run("dry run skips apply", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, stageRoot, true) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("non-root skips apply", func(t *testing.T) { + firewallApplyGeteuid = func() int { return 1000 } + t.Cleanup(func() { firewallApplyGeteuid = func() int { return 0 } }) + + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, stageRoot, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) +} + +func TestMaybeApplyPVEFirewallWithUI_AdditionalBranches(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origCmd := restoreCmd + origGeteuid := firewallApplyGeteuid + origMounted := firewallIsMounted + origRealFS := firewallIsRealRestoreFS + origArm := firewallArmRollback + origDisarm := firewallDisarmRollback + origApply := firewallApplyFromStage + origRestart := firewallRestartService + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + restoreCmd = origCmd + firewallApplyGeteuid = origGeteuid + firewallIsMounted = origMounted + firewallIsRealRestoreFS = origRealFS + firewallArmRollback = origArm + firewallDisarmRollback = origDisarm + firewallApplyFromStage = origApply + firewallRestartService = origRestart + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Unix(100, 0)} + restoreCmd = &FakeCommandRunner{} + + firewallIsRealRestoreFS = func(fs FS) bool { return true } + firewallApplyGeteuid = func() int { return 0 } + firewallIsMounted = func(path string) (bool, error) { return true, nil } + firewallRestartService = func(ctx context.Context) error { return nil } + + logger := newTestLogger() + plan := &RestorePlan{SystemType: SystemTypePVE, NormalCategories: []Category{{ID: "pve_firewall"}}} + + t.Run("plan nil returns nil", func(t *testing.T) { + err := maybeApplyPVEFirewallWithUI(context.Background(), nil, logger, nil, nil, nil, "/stage", false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("wrong system type returns nil", func(t *testing.T) { + p := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "pve_firewall"}}} + err := maybeApplyPVEFirewallWithUI(context.Background(), nil, logger, p, nil, nil, "/stage", false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("missing category returns nil", func(t *testing.T) { + p := &RestorePlan{SystemType: SystemTypePVE, NormalCategories: []Category{{ID: "network"}}} + err := maybeApplyPVEFirewallWithUI(context.Background(), nil, logger, p, nil, nil, "/stage", false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("non-real filesystem skips apply", func(t *testing.T) { + firewallIsRealRestoreFS = func(fs FS) bool { return false } + t.Cleanup(func() { firewallIsRealRestoreFS = func(fs FS) bool { return true } }) + + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, "/stage", false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("blank stage root skips apply", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, " ", false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("cluster restore skips apply", func(t *testing.T) { + p := &RestorePlan{SystemType: SystemTypePVE, NeedsClusterRestore: true, NormalCategories: []Category{{ID: "pve_firewall"}}} + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, p, nil, nil, "/stage", false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("mount check warning path", func(t *testing.T) { + firewallIsMounted = func(path string) (bool, error) { return true, fmt.Errorf("boom") } + t.Cleanup(func() { firewallIsMounted = func(path string) (bool, error) { return true, nil } }) + + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, "/stage", false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("apply prompt error returns error", func(t *testing.T) { + if err := fakeFS.AddDir("/stage/etc/pve/nodes"); err != nil { + t.Fatalf("add stage nodes: %v", err) + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{{err: fmt.Errorf("input fail")}}, + } + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, "/stage", false) + if err == nil { + t.Fatalf("expected error") + } + }) + + t.Run("no rollback declined returns nil", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: false}, // Skip apply (no rollback) + }, + } + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, "/stage", false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("no rollback prompt error returns error", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {err: fmt.Errorf("boom")}, // No rollback prompt fails + }, + } + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, "/stage", false) + if err == nil { + t.Fatalf("expected error") + } + }) + + t.Run("full rollback prompt error returns error", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {err: fmt.Errorf("boom")}, // Full rollback prompt fails + }, + } + safety := &SafetyBackupResult{BackupPath: "/backups/full.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, safety, nil, "/stage", false) + if err == nil { + t.Fatalf("expected error") + } + }) + + t.Run("arm rollback error returns wrapped error", func(t *testing.T) { + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + return nil, fmt.Errorf("arm failed") + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + t.Fatalf("unexpected apply") + return nil, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{{ok: true}}, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/firewall.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, rollback, "/stage", false) + if err == nil || !strings.Contains(err.Error(), "arm firewall rollback") { + t.Fatalf("expected wrapped arm error, got %v", err) + } + }) + + t.Run("apply from stage error returns error", func(t *testing.T) { + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return nil, fmt.Errorf("apply failed") + } + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: true}, // Proceed without rollback + }, + } + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, "/stage", false) + if err == nil || !strings.Contains(err.Error(), "apply failed") { + t.Fatalf("expected apply error, got %v", err) + } + }) + + t.Run("restart failure logs warning but continues", func(t *testing.T) { + markerPath := "/tmp/proxsave/fw.marker" + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + handle := &firewallRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/fw.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/firewall/cluster.fw"}, nil + } + firewallRestartService = func(ctx context.Context) error { return fmt.Errorf("restart failed") } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: true}, // Commit + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/firewall.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, rollback, "/stage", false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + if _, err := fakeFS.Stat(markerPath); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected rollback marker removed; stat err=%v", err) + } + }) + + t.Run("commit context canceled returns canceled error", func(t *testing.T) { + markerPath := "/tmp/proxsave/fw.marker" + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + handle := &firewallRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/fw.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/firewall/cluster.fw"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {err: context.Canceled}, // Commit prompt canceled + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/firewall.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, rollback, "/stage", false) + if err == nil || !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled error, got %v", err) + } + }) +} + +func TestApplyPVEFirewallFromStage_CoversAdditionalPaths(t *testing.T) { + origFS := restoreFS + origHostname := firewallHostname + t.Cleanup(func() { + restoreFS = origFS + firewallHostname = origHostname + }) + + t.Run("blank stage root", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + applied, err := applyPVEFirewallFromStage(newTestLogger(), " ") + if err != nil || len(applied) != 0 { + t.Fatalf("expected nil, got applied=%#v err=%v", applied, err) + } + }) + + t.Run("staged firewall as file", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + if err := fakeFS.AddFile(stageRoot+"/etc/pve/firewall", []byte("fw\n")); err != nil { + t.Fatalf("add staged firewall file: %v", err) + } + + applied, err := applyPVEFirewallFromStage(newTestLogger(), stageRoot) + if err != nil { + t.Fatalf("applyPVEFirewallFromStage error: %v", err) + } + found := false + for _, p := range applied { + if p == "/etc/pve/firewall" { + found = true + break + } + } + if !found { + t.Fatalf("expected /etc/pve/firewall in applied paths, got %#v", applied) + } + if got, err := fakeFS.ReadFile("/etc/pve/firewall"); err != nil || string(got) != "fw\n" { + t.Fatalf("dest firewall err=%v data=%q", err, string(got)) + } + }) + + t.Run("firewall stat error returns error", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + stageFirewall := filepath.Join(stageRoot, "etc", "pve", "firewall") + fakeFS.StatErrors[filepath.Clean(stageFirewall)] = fmt.Errorf("boom") + + _, err := applyPVEFirewallFromStage(newTestLogger(), stageRoot) + if err == nil || !strings.Contains(err.Error(), "stat staged firewall config") { + t.Fatalf("expected stat staged firewall config error, got %v", err) + } + }) + + t.Run("host fw selection error bubbles", func(t *testing.T) { + base := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(base.Root) }) + restoreFS = readDirFailFS{FS: base, failPath: "/stage/etc/pve/nodes", err: fmt.Errorf("boom")} + + _, err := applyPVEFirewallFromStage(newTestLogger(), "/stage") + if err == nil || !strings.Contains(err.Error(), "readdir") { + t.Fatalf("expected readdir error, got %v", err) + } + }) + + t.Run("defaults to localhost when hostname empty", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + firewallHostname = func() (string, error) { return " ", nil } + + stageRoot := "/stage" + if err := fakeFS.AddFile(stageRoot+"/etc/pve/nodes/node1/host.fw", []byte("host\n")); err != nil { + t.Fatalf("add staged host.fw: %v", err) + } + + applied, err := applyPVEFirewallFromStage(newTestLogger(), stageRoot) + if err != nil { + t.Fatalf("applyPVEFirewallFromStage error: %v", err) + } + found := false + for _, p := range applied { + if p == "/etc/pve/nodes/localhost/host.fw" { + found = true + break + } + } + if !found { + t.Fatalf("expected mapped host.fw applied, got %#v", applied) + } + if got, err := fakeFS.ReadFile("/etc/pve/nodes/localhost/host.fw"); err != nil || string(got) != "host\n" { + t.Fatalf("dest host.fw err=%v data=%q", err, string(got)) + } + }) +} + +func TestSelectStageHostFirewall_EmptyCases(t *testing.T) { + origFS := restoreFS + origHostname := firewallHostname + t.Cleanup(func() { + restoreFS = origFS + firewallHostname = origHostname + }) + + firewallHostname = func() (string, error) { return "node1", nil } + + t.Run("nodes directory missing", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + path, node, ok, err := selectStageHostFirewall(newTestLogger(), "/stage") + if err != nil || ok || path != "" || node != "" { + t.Fatalf("expected no selection, got ok=%v path=%q node=%q err=%v", ok, path, node, err) + } + }) + + t.Run("no host.fw candidates", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.AddDir("/stage/etc/pve/nodes/node1"); err != nil { + t.Fatalf("add nodes dir: %v", err) + } + + path, node, ok, err := selectStageHostFirewall(newTestLogger(), "/stage") + if err != nil || ok || path != "" || node != "" { + t.Fatalf("expected no selection, got ok=%v path=%q node=%q err=%v", ok, path, node, err) + } + }) +} + +func TestArmFirewallRollback_DefaultWorkDirAndMinTimeout(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + fakeTime := &FakeTime{Current: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC)} + restoreTime = fakeTime + + emptyBin := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", emptyBin); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + handle, err := armFirewallRollback(context.Background(), newTestLogger(), "/backup.tgz", 500*time.Millisecond, " ") + if err != nil { + t.Fatalf("armFirewallRollback error: %v", err) + } + if handle == nil || handle.workDir != "/tmp/proxsave" { + t.Fatalf("unexpected handle: %#v", handle) + } + if data, err := fakeFS.ReadFile(handle.markerPath); err != nil || string(data) != "pending\n" { + t.Fatalf("marker err=%v data=%q", err, string(data)) + } + calls := fakeCmd.CallsList() + if len(calls) != 1 { + t.Fatalf("unexpected calls: %#v", calls) + } + if !strings.Contains(calls[0], "sleep 1; /bin/sh") { + t.Fatalf("expected timeoutSeconds to clamp to 1, got call=%q", calls[0]) + } +} + +func TestArmFirewallRollback_ReturnsErrorOnMkdirAllFailure(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + fakeFS.MkdirAllErr = fmt.Errorf("disk full") + restoreFS = fakeFS + + if _, err := armFirewallRollback(context.Background(), newTestLogger(), "/backup.tgz", 1*time.Second, "/tmp/proxsave"); err == nil { + t.Fatalf("expected error") + } +} + +func TestArmFirewallRollback_ReturnsErrorOnMarkerWriteFailure(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + fakeFS.WriteErr = fmt.Errorf("disk full") + restoreFS = fakeFS + + if _, err := armFirewallRollback(context.Background(), newTestLogger(), "/backup.tgz", 1*time.Second, "/tmp/proxsave"); err == nil || !strings.Contains(err.Error(), "write rollback marker") { + t.Fatalf("expected write rollback marker error, got %v", err) + } +} + +func TestArmFirewallRollback_ReturnsErrorOnScriptWriteFailure(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + base := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(base.Root) }) + + fakeTime := &FakeTime{Current: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC)} + restoreTime = fakeTime + timestamp := fakeTime.Current.Format("20060102_150405") + scriptPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("firewall_rollback_%s.sh", timestamp)) + + restoreFS = writeFileFailFS{FS: base, failPath: scriptPath, err: fmt.Errorf("disk full")} + + if _, err := armFirewallRollback(context.Background(), newTestLogger(), "/backup.tgz", 1*time.Second, "/tmp/proxsave"); err == nil || !strings.Contains(err.Error(), "write rollback script") { + t.Fatalf("expected write rollback script error, got %v", err) + } +} + +func TestDisarmFirewallRollback_MissingMarkerAndNoSystemctl(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + emptyBin := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", emptyBin); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + handle := &firewallRollbackHandle{ + markerPath: "/tmp/proxsave/missing.marker", + unitName: "unit", + } + disarmFirewallRollback(context.Background(), newTestLogger(), handle) + + if calls := fakeCmd.CallsList(); len(calls) != 0 { + t.Fatalf("expected no systemctl calls, got %#v", calls) + } +} + +func TestDisarmFirewallRollback_ContinuesOnMarkerRemoveError(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + base := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(base.Root) }) + if err := base.AddFile("/tmp/proxsave/fw.marker", []byte("pending\n")); err != nil { + t.Fatalf("add marker: %v", err) + } + restoreFS = removeFailFS{FS: base, failPath: "/tmp/proxsave/fw.marker", err: fmt.Errorf("perm")} + + handle := &firewallRollbackHandle{ + markerPath: "/tmp/proxsave/fw.marker", + } + disarmFirewallRollback(context.Background(), newTestLogger(), handle) +} + +func TestCopyFileExact_PropagatesStatAndReadFailures(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + base := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(base.Root) }) + if err := base.AddFile("/src/file", []byte("x")); err != nil { + t.Fatalf("add src: %v", err) + } + + restoreFS = statFailFS{FS: base, failPath: "/src/file", err: fmt.Errorf("boom")} + if _, err := copyFileExact("/src/file", "/dest/file"); err == nil { + t.Fatalf("expected stat error") + } + + restoreFS = readFileFailFS{FS: base, failPath: "/src/file", err: os.ErrNotExist} + ok, err := copyFileExact("/src/file", "/dest/file") + if err != nil || ok { + t.Fatalf("expected ok=false err=nil for readfile not exist, got ok=%v err=%v", ok, err) + } + + restoreFS = readFileFailFS{FS: base, failPath: "/src/file", err: fmt.Errorf("read boom")} + if _, err := copyFileExact("/src/file", "/dest/file"); err == nil { + t.Fatalf("expected read error") + } +} + +func TestSyncDirExact_CoversErrorPathsAndPrune(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + t.Run("source missing or not a dir returns nil", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if applied, err := syncDirExact("/missing", "/dest"); err != nil || len(applied) != 0 { + t.Fatalf("expected nil, got applied=%#v err=%v", applied, err) + } + if err := fakeFS.AddFile("/stagefile", []byte("x")); err != nil { + t.Fatalf("add stagefile: %v", err) + } + if applied, err := syncDirExact("/stagefile", "/dest"); err != nil || len(applied) != 0 { + t.Fatalf("expected nil, got applied=%#v err=%v", applied, err) + } + }) + + t.Run("dest exists as file returns error", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + if err := fakeFS.AddFile("/dest", []byte("x")); err != nil { + t.Fatalf("add dest file: %v", err) + } + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "ensure") { + t.Fatalf("expected ensure error, got %v", err) + } + }) + + t.Run("readDir error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + restoreFS = readDirFailFS{FS: fakeFS, failPath: "/stage", err: fmt.Errorf("boom")} + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "readdir") { + t.Fatalf("expected readdir error, got %v", err) + } + }) + + t.Run("entry.Info error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + + restoreFS = readDirOverrideFS{ + FS: fakeFS, + overridePath: "/stage", + entries: []os.DirEntry{badInfoDirEntry{name: "bad"}}, + } + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "stat /stage/bad") { + t.Fatalf("expected entry.Info error, got %v", err) + } + }) + + t.Run("readlink error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.AddFile("/stage/target", []byte("x")); err != nil { + t.Fatalf("add target: %v", err) + } + if err := fakeFS.Symlink("/stage/target", "/stage/link"); err != nil { + t.Fatalf("add symlink: %v", err) + } + + restoreFS = readlinkFailFS{FS: fakeFS, failPath: "/stage/link", err: fmt.Errorf("boom")} + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "readlink") { + t.Fatalf("expected readlink error, got %v", err) + } + }) + + t.Run("prune remove failure returns error", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + + if err := fakeFS.AddFile("/stage/keep", []byte("x")); err != nil { + t.Fatalf("add keep: %v", err) + } + if err := fakeFS.AddFile("/dest/remove", []byte("x")); err != nil { + t.Fatalf("add extraneous: %v", err) + } + + restoreFS = removeFailFS{FS: fakeFS, failPath: "/dest/remove", err: fmt.Errorf("perm")} + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "remove") { + t.Fatalf("expected remove error, got %v", err) + } + }) + + t.Run("prunes empty dirs not present in stage", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.AddFile("/stage/keep", []byte("x")); err != nil { + t.Fatalf("add keep: %v", err) + } + if err := fakeFS.AddDir("/dest/oldDir"); err != nil { + t.Fatalf("add oldDir: %v", err) + } + + if _, err := syncDirExact("/stage", "/dest"); err != nil { + t.Fatalf("syncDirExact error: %v", err) + } + if _, err := fakeFS.Stat("/dest/oldDir"); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected oldDir removed; stat err=%v", err) + } + }) +} + +func TestDisarmFirewallRollback_NilHandleAndEmptyPaths(t *testing.T) { + disarmFirewallRollback(context.Background(), newTestLogger(), nil) + + handle := &firewallRollbackHandle{ + markerPath: " ", + unitName: " ", + } + disarmFirewallRollback(context.Background(), newTestLogger(), handle) +} + +func TestSelectStageHostFirewall_IgnoresNilAndBlankEntries(t *testing.T) { + origFS := restoreFS + origHostname := firewallHostname + t.Cleanup(func() { + restoreFS = origFS + firewallHostname = origHostname + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + + firewallHostname = func() (string, error) { return "node1", nil } + + stageNodes := "/stage/etc/pve/nodes" + if err := fakeFS.AddDir(stageNodes); err != nil { + t.Fatalf("add stage nodes: %v", err) + } + if err := fakeFS.AddDir(stageNodes + "/ "); err != nil { + t.Fatalf("add blank node dir: %v", err) + } + if err := fakeFS.AddFile(stageNodes+"/node1/host.fw", []byte("host\n")); err != nil { + t.Fatalf("add host.fw: %v", err) + } + + entries, err := fakeFS.ReadDir(stageNodes) + if err != nil { + t.Fatalf("readDir stage nodes: %v", err) + } + entries = append([]os.DirEntry{nil}, entries...) + + restoreFS = readDirOverrideFS{ + FS: fakeFS, + overridePath: stageNodes, + entries: entries, + } + + path, node, ok, err := selectStageHostFirewall(newTestLogger(), "/stage") + if err != nil { + t.Fatalf("selectStageHostFirewall error: %v", err) + } + if !ok || node != "node1" || !strings.HasSuffix(path, "/stage/etc/pve/nodes/node1/host.fw") { + t.Fatalf("unexpected selection ok=%v node=%q path=%q", ok, node, path) + } +} + +func TestApplyPVEFirewallFromStage_PropagatesSyncAndCopyErrors(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origHostname := firewallHostname + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + firewallHostname = origHostname + }) + + t.Run("syncDirExact failure bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + if err := fakeFS.AddFile(stageRoot+"/etc/pve/firewall/cluster.fw", []byte("x")); err != nil { + t.Fatalf("add staged firewall: %v", err) + } + if err := fakeFS.AddFile("/etc/pve/firewall", []byte("not a dir")); err != nil { + t.Fatalf("add dest firewall file: %v", err) + } + + _, err := applyPVEFirewallFromStage(newTestLogger(), stageRoot) + if err == nil || !strings.Contains(err.Error(), "ensure /etc/pve/firewall") { + t.Fatalf("expected ensure error, got %v", err) + } + }) + + t.Run("firewall file copy failure bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + fakeTime := &FakeTime{Current: time.Unix(0, 12345)} + restoreTime = fakeTime + + stageRoot := "/stage" + if err := fakeFS.AddFile(stageRoot+"/etc/pve/firewall", []byte("fw\n")); err != nil { + t.Fatalf("add staged firewall file: %v", err) + } + + tmpPath := "/etc/pve/firewall.proxsave.tmp." + strconv.FormatInt(fakeTime.Current.UnixNano(), 10) + fakeFS.OpenFileErr[filepath.Clean(tmpPath)] = fmt.Errorf("open tmp denied") + + _, err := applyPVEFirewallFromStage(newTestLogger(), stageRoot) + if err == nil || !strings.Contains(err.Error(), "write /etc/pve/firewall") { + t.Fatalf("expected write error, got %v", err) + } + }) + + t.Run("host.fw copy failure bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + fakeTime := &FakeTime{Current: time.Unix(0, 12345)} + restoreTime = fakeTime + + firewallHostname = func() (string, error) { return "current", nil } + + stageRoot := "/stage" + if err := fakeFS.AddFile(stageRoot+"/etc/pve/nodes/other/host.fw", []byte("host\n")); err != nil { + t.Fatalf("add staged host.fw: %v", err) + } + + destHostFW := "/etc/pve/nodes/current/host.fw" + tmpPath := destHostFW + ".proxsave.tmp." + strconv.FormatInt(fakeTime.Current.UnixNano(), 10) + fakeFS.OpenFileErr[filepath.Clean(tmpPath)] = fmt.Errorf("open tmp denied") + + _, err := applyPVEFirewallFromStage(newTestLogger(), stageRoot) + if err == nil || !strings.Contains(err.Error(), "write "+destHostFW) { + t.Fatalf("expected host.fw write error, got %v", err) + } + }) +} + +func TestSyncDirExact_AdditionalEdgeCases(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + t.Run("stat error bubbles", func(t *testing.T) { + base := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(base.Root) }) + restoreFS = statFailFS{FS: base, failPath: "/stage", err: fmt.Errorf("boom")} + + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "stat /stage") { + t.Fatalf("expected stat error, got %v", err) + } + }) + + t.Run("walkStage ignores disappeared dir readDir not-exist", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + + restoreFS = readDirOverrideFS{ + FS: fakeFS, + overridePath: "/stage", + entries: []os.DirEntry{ + staticDirEntry{name: "sub", mode: fs.ModeDir}, + }, + } + + if _, err := syncDirExact("/stage", "/dest"); err != nil { + t.Fatalf("syncDirExact error: %v", err) + } + if info, err := fakeFS.Stat("/dest/sub"); err != nil || !info.IsDir() { + t.Fatalf("expected /dest/sub directory, err=%v info=%v", err, info) + } + }) + + t.Run("walkStage skips nil/blank/dot entries", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + + restoreFS = readDirOverrideFS{ + FS: fakeFS, + overridePath: "/stage", + entries: []os.DirEntry{ + nil, + staticDirEntry{name: " ", mode: 0}, + staticDirEntry{name: ".", mode: 0}, + }, + } + + if applied, err := syncDirExact("/stage", "/dest"); err != nil || len(applied) != 0 { + t.Fatalf("expected nil, got applied=%#v err=%v", applied, err) + } + }) + + t.Run("symlink parent ensure error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + if err := fakeFS.Symlink("/stage/target", "/stage/link"); err != nil { + t.Fatalf("add stage symlink: %v", err) + } + if err := fakeFS.AddDir("/dest"); err != nil { + t.Fatalf("add dest dir: %v", err) + } + + restoreFS = &statFailOnNthFS{FS: fakeFS, path: "/dest", failOn: 2, err: fmt.Errorf("boom")} + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "ensure /dest") { + t.Fatalf("expected ensure error, got %v", err) + } + }) + + t.Run("symlink creation error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + if err := fakeFS.Symlink("/stage/target", "/stage/link"); err != nil { + t.Fatalf("add stage symlink: %v", err) + } + + restoreFS = symlinkFailFS{FS: fakeFS, failNewname: "/dest/link", err: fmt.Errorf("boom")} + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "symlink /dest/link") { + t.Fatalf("expected symlink error, got %v", err) + } + }) + + t.Run("directory ensure error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.AddDir("/stage/sub"); err != nil { + t.Fatalf("add stage subdir: %v", err) + } + if err := fakeFS.AddFile("/dest/sub", []byte("not a dir")); err != nil { + t.Fatalf("add dest file: %v", err) + } + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "ensure /dest/sub") { + t.Fatalf("expected ensure /dest/sub error, got %v", err) + } + }) + + t.Run("directory recursion error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.AddDir("/stage/sub"); err != nil { + t.Fatalf("add stage subdir: %v", err) + } + restoreFS = readDirFailFS{FS: fakeFS, failPath: "/stage/sub", err: fmt.Errorf("boom")} + + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "readdir /stage/sub") { + t.Fatalf("expected recursion readdir error, got %v", err) + } + }) + + t.Run("copy file error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.AddFile("/stage/file", []byte("x")); err != nil { + t.Fatalf("add stage file: %v", err) + } + restoreFS = readFileFailFS{FS: fakeFS, failPath: "/stage/file", err: fmt.Errorf("read boom")} + + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "read /stage/file") { + t.Fatalf("expected read error, got %v", err) + } + }) + + t.Run("pruneDest readDir error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + restoreFS = readDirFailFS{FS: fakeFS, failPath: "/dest", err: fmt.Errorf("boom")} + + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "readdir /dest") { + t.Fatalf("expected pruneDest readdir error, got %v", err) + } + }) + + t.Run("pruneDest ignores not-exist", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + restoreFS = readDirFailFS{FS: fakeFS, failPath: "/dest", err: os.ErrNotExist} + + if _, err := syncDirExact("/stage", "/dest"); err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("pruneDest skips nil/blank/dot entries", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + + restoreFS = multiReadDirFS{ + FS: fakeFS, + entries: map[string][]os.DirEntry{ + "/dest": { + nil, + staticDirEntry{name: " ", mode: 0}, + staticDirEntry{name: ".", mode: 0}, + }, + }, + } + + if _, err := syncDirExact("/stage", "/dest"); err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("pruneDest entry.Info error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + + restoreFS = multiReadDirFS{ + FS: fakeFS, + entries: map[string][]os.DirEntry{ + "/dest": {badInfoDirEntry{name: "bad"}}, + }, + } + + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "stat /dest/bad") { + t.Fatalf("expected pruneDest info error, got %v", err) + } + }) + + t.Run("pruneDest recursion error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + + restoreFS = multiReadDirFS{ + FS: fakeFS, + entries: map[string][]os.DirEntry{ + "/dest": {staticDirEntry{name: "sub", mode: fs.ModeDir}}, + "/dest/sub": nil, + }, + errors: map[string]error{ + "/dest/sub": fmt.Errorf("boom"), + }, + } + + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "readdir /dest/sub") { + t.Fatalf("expected pruneDest recursion error, got %v", err) + } + }) +} diff --git a/internal/tui/wizard/install.go b/internal/tui/wizard/install.go index f1999ca..66484b9 100644 --- a/internal/tui/wizard/install.go +++ b/internal/tui/wizard/install.go @@ -1,6 +1,7 @@ package wizard import ( + "bufio" "context" "errors" "fmt" @@ -18,6 +19,19 @@ import ( "github.com/tis24dev/proxsave/pkg/utils" ) +type installWizardPrefill struct { + SecondaryEnabled bool + SecondaryPath string + SecondaryLogPath string + CloudEnabled bool + CloudRemote string + CloudLogPath string + FirewallEnabled bool + TelegramEnabled bool + EmailEnabled bool + EncryptionEnabled bool +} + // InstallWizardData holds the collected installation data type InstallWizardData struct { BaseDir string @@ -52,7 +66,7 @@ var ( ) // RunInstallWizard runs the TUI-based installation wizard -func RunInstallWizard(ctx context.Context, configPath string, baseDir string, buildSig string) (*InstallWizardData, error) { +func RunInstallWizard(ctx context.Context, configPath string, baseDir string, buildSig string, baseTemplate string) (*InstallWizardData, error) { defaultFirewallRules := false data := &InstallWizardData{ BaseDir: baseDir, @@ -64,6 +78,8 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu app := tui.NewApp() + prefill := deriveInstallWizardPrefill(baseTemplate) + // Build the form form := components.NewForm(app) @@ -94,7 +110,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu var dropdownOpen bool // Secondary Storage section - var secondaryEnabled bool + secondaryEnabled := prefill.SecondaryEnabled var secondaryPathField, secondaryLogField *tview.InputField secondaryDropdown := tview.NewDropDown(). @@ -110,7 +126,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu } dropdownOpen = false }). - SetCurrentOption(0) + SetCurrentOption(boolToOptionIndex(secondaryEnabled)) secondaryDropdown.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyEnter { @@ -134,18 +150,24 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu SetLabel(" └─ Secondary Backup Path"). SetText("/mnt/secondary-backup"). SetFieldWidth(40) - secondaryPathField.SetDisabled(true) + if prefill.SecondaryPath != "" { + secondaryPathField.SetText(prefill.SecondaryPath) + } + secondaryPathField.SetDisabled(!secondaryEnabled) form.Form.AddFormItem(secondaryPathField) secondaryLogField = tview.NewInputField(). SetLabel(" └─ Secondary Log Path"). SetText("/mnt/secondary-backup/logs"). SetFieldWidth(40) - secondaryLogField.SetDisabled(true) + if prefill.SecondaryLogPath != "" { + secondaryLogField.SetText(prefill.SecondaryLogPath) + } + secondaryLogField.SetDisabled(!secondaryEnabled) form.Form.AddFormItem(secondaryLogField) // Cloud Storage section - var cloudEnabled bool + cloudEnabled := prefill.CloudEnabled var rcloneBackupField, rcloneLogField *tview.InputField cloudDropdown := tview.NewDropDown(). @@ -160,7 +182,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu } dropdownOpen = false }). - SetCurrentOption(0) + SetCurrentOption(boolToOptionIndex(cloudEnabled)) cloudDropdown.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyEnter { @@ -174,7 +196,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu form.Form.AddFormItem(cloudDropdown) cloudHint := tview.NewInputField(). - SetLabel(" Tip: remotename:path (via 'rclone config'), e.g. myremote:pbs-backups"). + SetLabel(" Tip: remote name (via 'rclone config'), e.g. myremote (or myremote:path)"). SetFieldWidth(0). SetText("") cloudHint.SetDisabled(true) @@ -184,25 +206,31 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu SetLabel(" └─ Rclone Backup Remote"). SetText("myremote:pbs-backups"). SetFieldWidth(40) - rcloneBackupField.SetDisabled(true) + if prefill.CloudRemote != "" { + rcloneBackupField.SetText(prefill.CloudRemote) + } + rcloneBackupField.SetDisabled(!cloudEnabled) form.Form.AddFormItem(rcloneBackupField) rcloneLogField = tview.NewInputField(). - SetLabel(" └─ Rclone Log Remote"). + SetLabel(" └─ Rclone Log Path"). SetText("myremote:pbs-logs"). SetFieldWidth(40) - rcloneLogField.SetDisabled(true) + if prefill.CloudLogPath != "" { + rcloneLogField.SetText(prefill.CloudLogPath) + } + rcloneLogField.SetDisabled(!cloudEnabled) form.Form.AddFormItem(rcloneLogField) // Firewall rules backup (system collection) - firewallEnabled := false + firewallEnabled := prefill.FirewallEnabled firewallDropdown := tview.NewDropDown(). SetLabel("Backup Firewall Rules (iptables/nftables)"). SetOptions([]string{"No", "Yes"}, func(option string, index int) { firewallEnabled = (option == "Yes") dropdownOpen = false }). - SetCurrentOption(0) + SetCurrentOption(boolToOptionIndex(firewallEnabled)) firewallDropdown.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyEnter { @@ -216,7 +244,8 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu form.Form.AddFormItem(firewallDropdown) // Notifications (header + two toggles) - var telegramEnabled, emailEnabled bool + telegramEnabled := prefill.TelegramEnabled + emailEnabled := prefill.EmailEnabled notificationHeader := tview.NewInputField(). SetLabel("Notifications"). SetFieldWidth(0). @@ -230,7 +259,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu telegramEnabled = (option == "Yes") dropdownOpen = false }). - SetCurrentOption(0) + SetCurrentOption(boolToOptionIndex(telegramEnabled)) telegramDropdown.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyEnter { dropdownOpen = !dropdownOpen @@ -247,7 +276,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu emailEnabled = (option == "Yes") dropdownOpen = false }). - SetCurrentOption(0) + SetCurrentOption(boolToOptionIndex(emailEnabled)) emailDropdown.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyEnter { dropdownOpen = !dropdownOpen @@ -264,7 +293,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu SetOptions([]string{"No", "Yes"}, func(option string, index int) { dropdownOpen = false }). - SetCurrentOption(0) + SetCurrentOption(boolToOptionIndex(prefill.EncryptionEnabled)) encryptionDropdown.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyEnter { @@ -312,15 +341,15 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu data.EnableCloudStorage = cloudEnabled if cloudEnabled { - data.RcloneBackupRemote = rcloneBackupField.GetText() - data.RcloneLogRemote = rcloneLogField.GetText() + data.RcloneBackupRemote = strings.TrimSpace(rcloneBackupField.GetText()) + data.RcloneLogRemote = strings.TrimSpace(rcloneLogField.GetText()) - // Validate rclone remotes - if !strings.Contains(data.RcloneBackupRemote, ":") { - return fmt.Errorf("rclone backup remote must be in format 'remote:path'") + // Validate rclone inputs (allow both "remote" and "remote:path", logs can also be path-only) + if data.RcloneBackupRemote == "" { + return fmt.Errorf("rclone backup remote cannot be empty") } - if !strings.Contains(data.RcloneLogRemote, ":") { - return fmt.Errorf("rclone log remote must be in format 'remote:path'") + if data.RcloneLogRemote == "" { + return fmt.Errorf("rclone log path cannot be empty") } } @@ -454,6 +483,11 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu // If baseTemplate is empty, the embedded default template is used. func ApplyInstallData(baseTemplate string, data *InstallWizardData) (string, error) { template := baseTemplate + editingExisting := strings.TrimSpace(baseTemplate) != "" + existingValues := map[string]string{} + if editingExisting { + existingValues = parseEnvTemplate(baseTemplate) + } if strings.TrimSpace(template) == "" { template = config.DefaultEnvTemplate() } @@ -497,15 +531,23 @@ func ApplyInstallData(baseTemplate string, data *InstallWizardData) (string, err // Apply notifications if data.NotificationMode == "telegram" || data.NotificationMode == "both" { template = setEnvValue(template, "TELEGRAM_ENABLED", "true") - template = setEnvValue(template, "BOT_TELEGRAM_TYPE", "centralized") + // Preserve existing telegram mode when editing an existing config. + if !editingExisting || strings.TrimSpace(existingValues["BOT_TELEGRAM_TYPE"]) == "" { + template = setEnvValue(template, "BOT_TELEGRAM_TYPE", "centralized") + } } else { template = setEnvValue(template, "TELEGRAM_ENABLED", "false") } if data.NotificationMode == "email" || data.NotificationMode == "both" { template = setEnvValue(template, "EMAIL_ENABLED", "true") - template = setEnvValue(template, "EMAIL_DELIVERY_METHOD", "relay") - template = setEnvValue(template, "EMAIL_FALLBACK_SENDMAIL", "true") + // Preserve existing delivery preferences when editing an existing config. + if !editingExisting || strings.TrimSpace(existingValues["EMAIL_DELIVERY_METHOD"]) == "" { + template = setEnvValue(template, "EMAIL_DELIVERY_METHOD", "relay") + } + if !editingExisting || strings.TrimSpace(existingValues["EMAIL_FALLBACK_SENDMAIL"]) == "" { + template = setEnvValue(template, "EMAIL_FALLBACK_SENDMAIL", "true") + } } else { template = setEnvValue(template, "EMAIL_ENABLED", "false") } @@ -560,6 +602,87 @@ func unsetEnvValue(template, key string) string { return strings.Join(out, "\n") } +func boolToOptionIndex(value bool) int { + if value { + return 1 + } + return 0 +} + +func deriveInstallWizardPrefill(baseTemplate string) installWizardPrefill { + out := installWizardPrefill{} + if strings.TrimSpace(baseTemplate) == "" { + return out + } + values := parseEnvTemplate(baseTemplate) + + out.SecondaryEnabled = readTemplateBool(values, "SECONDARY_ENABLED", "ENABLE_SECONDARY_BACKUP") + out.SecondaryPath = readTemplateString(values, "SECONDARY_PATH", "SECONDARY_BACKUP_PATH") + out.SecondaryLogPath = readTemplateString(values, "SECONDARY_LOG_PATH") + + out.CloudEnabled = readTemplateBool(values, "CLOUD_ENABLED", "ENABLE_CLOUD_BACKUP") + out.CloudRemote = readTemplateString(values, "CLOUD_REMOTE", "RCLONE_REMOTE") + out.CloudLogPath = readTemplateString(values, "CLOUD_LOG_PATH") + + out.FirewallEnabled = readTemplateBool(values, "BACKUP_FIREWALL_RULES") + + out.TelegramEnabled = readTemplateBool(values, "TELEGRAM_ENABLED") + out.EmailEnabled = readTemplateBool(values, "EMAIL_ENABLED") + + out.EncryptionEnabled = readTemplateBool(values, "ENCRYPT_ARCHIVE") + + return out +} + +func parseEnvTemplate(template string) map[string]string { + values := make(map[string]string) + + scanner := bufio.NewScanner(strings.NewReader(template)) + for scanner.Scan() { + line := strings.TrimRight(scanner.Text(), "\r") + trimmed := strings.TrimSpace(line) + if utils.IsComment(trimmed) { + continue + } + + key, value, ok := utils.SplitKeyValue(line) + if !ok { + continue + } + if fields := strings.Fields(key); len(fields) >= 2 && fields[0] == "export" { + key = fields[1] + } + key = strings.ToUpper(strings.TrimSpace(key)) + if key == "" { + continue + } + values[key] = strings.TrimSpace(value) + } + + return values +} + +func readTemplateString(values map[string]string, keys ...string) string { + for _, key := range keys { + key = strings.ToUpper(strings.TrimSpace(key)) + if key == "" { + continue + } + if val, ok := values[key]; ok { + return strings.TrimSpace(val) + } + } + return "" +} + +func readTemplateBool(values map[string]string, keys ...string) bool { + raw := readTemplateString(values, keys...) + if strings.TrimSpace(raw) == "" { + return false + } + return utils.ParseBool(raw) +} + // CheckExistingConfig checks if config file exists and asks how to proceed func CheckExistingConfig(configPath string, buildSig string) (ExistingConfigAction, error) { if _, err := os.Stat(configPath); err == nil { diff --git a/internal/tui/wizard/post_install_audit_core.go b/internal/tui/wizard/post_install_audit_core.go new file mode 100644 index 0000000..5d9a76a --- /dev/null +++ b/internal/tui/wizard/post_install_audit_core.go @@ -0,0 +1,265 @@ +package wizard + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "regexp" + "sort" + "strings" + "sync" + "time" + + "github.com/tis24dev/proxsave/internal/config" +) + +// PostInstallAuditSuggestion represents an optional feature that appears to be enabled +// but not configured/detected on this system. Users can disable the feature by setting +// the corresponding KEY=false in backup.env. +type PostInstallAuditSuggestion struct { + Key string + Messages []string +} + +var ( + postInstallAuditDisableHintRe = regexp.MustCompile(`(?i)\bset\s+([A-Z0-9_]+)=false\b`) + postInstallAuditANSISGRRe = regexp.MustCompile(`\x1b\[[0-9;]*m`) + + postInstallAuditAllowedKeysOnce sync.Once + postInstallAuditAllowedKeys map[string]struct{} + + postInstallAuditRunner = runPostInstallAuditDryRun +) + +func postInstallAuditAllowedKeysSet() map[string]struct{} { + postInstallAuditAllowedKeysOnce.Do(func() { + allowed := make(map[string]struct{}) + values := parseEnvTemplate(config.DefaultEnvTemplate()) + for key := range values { + key = strings.ToUpper(strings.TrimSpace(key)) + if strings.HasPrefix(key, "BACKUP_") { + allowed[key] = struct{}{} + } + } + postInstallAuditAllowedKeys = allowed + }) + return postInstallAuditAllowedKeys +} + +func runPostInstallAuditDryRun(ctx context.Context, execPath, configPath string) (output string, exitCode int, err error) { + // Run a dry-run with warning-level logs to keep output minimal while still capturing + // all actionable "set KEY=false" hints. + cmd := exec.CommandContext(ctx, execPath, + "--dry-run", + "--log-level", "warning", + "--config", configPath, + ) + out, runErr := cmd.CombinedOutput() + if runErr == nil { + return string(out), 0, nil + } + + var exitErr *exec.ExitError + if errors.As(runErr, &exitErr) { + // Non-zero exit codes are expected when warnings are emitted (exit code 1). + return string(out), exitErr.ExitCode(), nil + } + return string(out), -1, fmt.Errorf("post-install audit dry-run failed: %w", runErr) +} + +// CollectPostInstallDisableSuggestions runs a proxsave dry-run and extracts actionable +// "set KEY=false" hints from the resulting warnings/errors. It only returns keys that: +// - exist in the embedded template, and +// - start with "BACKUP_", and +// - are currently enabled (true) in the provided config file. +func CollectPostInstallDisableSuggestions(ctx context.Context, execPath, configPath string) ([]PostInstallAuditSuggestion, error) { + if strings.TrimSpace(execPath) == "" { + return nil, fmt.Errorf("exec path cannot be empty") + } + if strings.TrimSpace(configPath) == "" { + return nil, fmt.Errorf("config path cannot be empty") + } + + configBytes, err := os.ReadFile(configPath) + if err != nil { + return nil, fmt.Errorf("read configuration for audit: %w", err) + } + configValues := parseEnvTemplate(string(configBytes)) + allowed := postInstallAuditAllowedKeysSet() + + auditCtx, cancel := context.WithTimeout(ctx, 90*time.Second) + defer cancel() + + output, _, err := postInstallAuditRunner(auditCtx, execPath, configPath) + if err != nil { + return nil, err + } + + issueLines := extractIssueLinesFromProxsaveOutput(output) + return extractDisableSuggestionsFromIssueLines(issueLines, allowed, configValues), nil +} + +func extractIssueLinesFromProxsaveOutput(output string) []string { + lines := splitNormalizedLines(output) + + // Prefer the end-of-run summary, which is clean (no ANSI) and deduplicated. + const header = "WARNINGS/ERRORS DURING RUN" + inSummary := false + issues := make([]string, 0, 16) + + for _, line := range lines { + if strings.Contains(line, header) { + inSummary = true + continue + } + if !inSummary { + continue + } + trimmed := strings.TrimSpace(line) + if trimmed == "" { + continue + } + if strings.HasPrefix(trimmed, "===========================================") { + break + } + issues = append(issues, trimmed) + } + + if len(issues) > 0 { + return issues + } + + // Fallback: scan the entire output and keep only actionable lines. This is less + // robust because the live log output may contain ANSI codes. + for _, line := range lines { + clean := stripANSI(strings.TrimSpace(line)) + if clean == "" { + continue + } + if postInstallAuditDisableHintRe.MatchString(clean) { + issues = append(issues, clean) + } + } + return issues +} + +func extractDisableSuggestionsFromIssueLines(issueLines []string, allowed map[string]struct{}, configValues map[string]string) []PostInstallAuditSuggestion { + if len(issueLines) == 0 { + return nil + } + if allowed == nil { + allowed = map[string]struct{}{} + } + if configValues == nil { + configValues = map[string]string{} + } + + type msgSet map[string]struct{} + byKey := make(map[string]msgSet) + + for _, raw := range issueLines { + line := strings.TrimSpace(stripANSI(raw)) + if line == "" { + continue + } + + matches := postInstallAuditDisableHintRe.FindAllStringSubmatch(line, -1) + if len(matches) == 0 { + continue + } + + msg := normalizeIssueMessage(line) + if msg == "" { + msg = line + } + + for _, m := range matches { + if len(m) < 2 { + continue + } + key := strings.ToUpper(strings.TrimSpace(m[1])) + if !strings.HasPrefix(key, "BACKUP_") { + continue + } + if _, ok := allowed[key]; !ok { + continue + } + // Only suggest disabling keys that are currently enabled in the config. + if !readTemplateBool(configValues, key) { + continue + } + if _, ok := byKey[key]; !ok { + byKey[key] = make(msgSet) + } + byKey[key][msg] = struct{}{} + } + } + + if len(byKey) == 0 { + return nil + } + + keys := make([]string, 0, len(byKey)) + for key := range byKey { + keys = append(keys, key) + } + sort.Strings(keys) + + out := make([]PostInstallAuditSuggestion, 0, len(keys)) + for _, key := range keys { + msgs := make([]string, 0, len(byKey[key])) + for msg := range byKey[key] { + msgs = append(msgs, msg) + } + sort.Strings(msgs) + out = append(out, PostInstallAuditSuggestion{ + Key: key, + Messages: msgs, + }) + } + return out +} + +func normalizeIssueMessage(line string) string { + line = strings.TrimSpace(stripANSI(line)) + if line == "" { + return "" + } + // Prefer to remove "[timestamp] LEVEL" prefix when present. + if strings.HasPrefix(line, "[") { + if idx := strings.Index(line, "]"); idx >= 0 { + rest := strings.TrimSpace(line[idx+1:]) + fields := strings.Fields(rest) + if len(fields) >= 2 { + level := fields[0] + rest = strings.TrimSpace(rest[len(level):]) + if rest != "" { + return rest + } + } + } + } + return line +} + +func splitNormalizedLines(s string) []string { + if strings.TrimSpace(s) == "" { + return nil + } + // Normalize CRLF and split. + s = strings.ReplaceAll(s, "\r\n", "\n") + s = strings.ReplaceAll(s, "\r", "\n") + return strings.Split(s, "\n") +} + +func stripANSI(s string) string { + // Best-effort removal of common ANSI SGR sequences. + // Example: "\x1b[33mWARNING\x1b[0m" + const esc = "\x1b[" + if !strings.Contains(s, esc) { + return s + } + return postInstallAuditANSISGRRe.ReplaceAllString(s, "") +} diff --git a/internal/tui/wizard/post_install_audit_core_test.go b/internal/tui/wizard/post_install_audit_core_test.go new file mode 100644 index 0000000..471efac --- /dev/null +++ b/internal/tui/wizard/post_install_audit_core_test.go @@ -0,0 +1,119 @@ +package wizard + +import ( + "context" + "os" + "path/filepath" + "reflect" + "strings" + "testing" +) + +func TestExtractIssueLinesFromProxsaveOutput_UsesSummaryBlock(t *testing.T) { + output := strings.Join([]string{ + "[2026-02-24 09:12:55] \x1b[33mWARNING\x1b[0m Live warning (colored)", + "===========================================", + "WARNINGS/ERRORS DURING RUN (warnings=1 errors=0)", + "", + "[2026-02-24 09:12:55] WARNING Corosync authkey: not configured. If unused, set BACKUP_CLUSTER_CONFIG=false to disable.", + "===========================================", + "", + }, "\n") + + issues := extractIssueLinesFromProxsaveOutput(output) + if len(issues) != 1 { + t.Fatalf("issues len=%d, want 1: %#v", len(issues), issues) + } + if !strings.Contains(issues[0], "set BACKUP_CLUSTER_CONFIG=false") { + t.Fatalf("unexpected issue line: %q", issues[0]) + } +} + +func TestExtractDisableSuggestionsFromIssueLines_FiltersAllowedAndEnabled(t *testing.T) { + allowed := map[string]struct{}{ + "BACKUP_CLUSTER_CONFIG": {}, + "BACKUP_ZFS_CONFIG": {}, + } + configValues := map[string]string{ + "BACKUP_CLUSTER_CONFIG": "true", + "BACKUP_ZFS_CONFIG": "true", + "BACKUP_CEPH_CONFIG": "false", + } + lines := []string{ + "[2026-02-24 09:12:55] WARNING Corosync authkey: not configured. If unused, set BACKUP_CLUSTER_CONFIG=false to disable.", + "Skipping ZFS collection: not detected. Set BACKUP_ZFS_CONFIG=false to disable.", + "[2026-02-24 09:12:55] WARNING Something else. Set NOT_BACKUP_VAR=false to disable.", + "[2026-02-24 09:12:55] WARNING Ceph not detected. If unused, set BACKUP_CEPH_CONFIG=false to disable.", + } + + got := extractDisableSuggestionsFromIssueLines(lines, allowed, configValues) + wantKeys := []string{"BACKUP_CLUSTER_CONFIG", "BACKUP_ZFS_CONFIG"} + if len(got) != len(wantKeys) { + t.Fatalf("got %d suggestions, want %d: %#v", len(got), len(wantKeys), got) + } + for i, key := range wantKeys { + if got[i].Key != key { + t.Fatalf("suggestion[%d].Key=%q, want %q", i, got[i].Key, key) + } + } +} + +func TestCollectPostInstallDisableSuggestions_UsesRunnerAndConfigFilter(t *testing.T) { + tmp := t.TempDir() + configPath := filepath.Join(tmp, "backup.env") + if err := os.WriteFile(configPath, []byte("BACKUP_CLUSTER_CONFIG=true\nBACKUP_ZFS_CONFIG=false\n"), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + origRunner := postInstallAuditRunner + t.Cleanup(func() { postInstallAuditRunner = origRunner }) + + postInstallAuditRunner = func(ctx context.Context, execPath, cfgPath string) (string, int, error) { + return strings.Join([]string{ + "===========================================", + "WARNINGS/ERRORS DURING RUN (warnings=1 errors=0)", + "", + "[2026-02-24 09:12:55] WARNING Corosync authkey: not configured. If unused, set BACKUP_CLUSTER_CONFIG=false to disable.", + "===========================================", + }, "\n"), 1, nil + } + + suggestions, err := CollectPostInstallDisableSuggestions(context.Background(), "/fake/proxsave", configPath) + if err != nil { + t.Fatalf("CollectPostInstallDisableSuggestions error: %v", err) + } + if len(suggestions) != 1 { + t.Fatalf("got %d suggestions, want 1: %#v", len(suggestions), suggestions) + } + if suggestions[0].Key != "BACKUP_CLUSTER_CONFIG" { + t.Fatalf("key=%q, want BACKUP_CLUSTER_CONFIG", suggestions[0].Key) + } + if len(suggestions[0].Messages) != 1 || !strings.Contains(suggestions[0].Messages[0], "Corosync authkey") { + t.Fatalf("unexpected messages: %#v", suggestions[0].Messages) + } +} + +func TestNormalizeIssueMessage_RemovesTimestampAndLevel(t *testing.T) { + got := normalizeIssueMessage("[2026-02-24 09:12:55] WARNING hello world") + want := "hello world" + if got != want { + t.Fatalf("got %q, want %q", got, want) + } +} + +func TestStripANSI_RemovesSGRCodes(t *testing.T) { + in := "\x1b[33mWARNING\x1b[0m hello" + got := stripANSI(in) + want := "WARNING hello" + if got != want { + t.Fatalf("got %q, want %q", got, want) + } +} + +func TestSplitNormalizedLines_NormalizesCRLF(t *testing.T) { + got := splitNormalizedLines("a\r\nb\r\n") + want := []string{"a", "b", ""} + if !reflect.DeepEqual(got, want) { + t.Fatalf("got %#v, want %#v", got, want) + } +} diff --git a/internal/tui/wizard/post_install_audit_tui.go b/internal/tui/wizard/post_install_audit_tui.go new file mode 100644 index 0000000..b46587d --- /dev/null +++ b/internal/tui/wizard/post_install_audit_tui.go @@ -0,0 +1,389 @@ +package wizard + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "sync" + + "github.com/gdamore/tcell/v2" + "github.com/rivo/tview" + + "github.com/tis24dev/proxsave/internal/tui" +) + +var ( + postInstallAuditWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { + return app.SetRoot(root, true).SetFocus(focus).Run() + } +) + +type PostInstallAuditResult struct { + // Ran indicates whether the user chose to run the post-install check. + Ran bool + // Suggestions contains the disable suggestions extracted from the dry-run output. + Suggestions []PostInstallAuditSuggestion + // AppliedKeys contains the keys written as KEY=false into backup.env. + AppliedKeys []string + // CollectErr is set when the dry-run/suggestion collection failed. + CollectErr error +} + +// RunPostInstallAuditWizard runs an optional post-installation check that: +// 1. runs proxsave --dry-run +// 2. extracts actionable "set KEY=false" hints from warnings +// 3. lets the user disable unused BACKUP_* collectors in backup.env +// +// It returns the audit result. Errors are returned only for unexpected failures +// (e.g., UI setup issues). +func RunPostInstallAuditWizard(ctx context.Context, execPath, configPath, buildSig string) (result PostInstallAuditResult, err error) { + app := tui.NewApp() + + titleText := tview.NewTextView(). + SetText("ProxSave - Post-install Check\n\n" + + "Detect optional components that are enabled but not configured on this node.\n" + + "This helps reduce WARNING noise and exit code 1 runs when features are unused.\n"). + SetTextColor(tui.ProxmoxLight). + SetDynamicColors(true) + titleText.SetBorder(false) + + nav := tview.NewTextView(). + SetText("[yellow]Navigation:[white] ↑↓ to move | ENTER/SPACE to toggle | ←→ on buttons | ENTER to select"). + SetTextColor(tcell.ColorWhite). + SetDynamicColors(true). + SetTextAlign(tview.AlignCenter) + nav.SetBorder(false) + + separator := tview.NewTextView(). + SetText(strings.Repeat("─", 80)). + SetTextColor(tui.ProxmoxOrange) + separator.SetBorder(false) + + configPathText := tview.NewTextView(). + SetText(fmt.Sprintf("[yellow]Configuration file:[white] %s", configPath)). + SetTextColor(tcell.ColorWhite). + SetDynamicColors(true). + SetTextAlign(tview.AlignCenter) + configPathText.SetBorder(false) + + buildSigText := tview.NewTextView(). + SetText(fmt.Sprintf("[yellow]Build Signature:[white] %s", buildSig)). + SetTextColor(tcell.ColorWhite). + SetDynamicColors(true). + SetTextAlign(tview.AlignCenter) + buildSigText.SetBorder(false) + + pages := tview.NewPages() + + confirmRun := false + var mu sync.Mutex + var collectedSuggestions []PostInstallAuditSuggestion + var collectErr error + applied := []string{} + confirm := tview.NewModal(). + SetText("Run the post-install check now?\n\n" + + "ProxSave will execute a dry-run and collect WARNING messages that include a hint like:\n" + + " set BACKUP_CLUSTER_CONFIG=false to disable\n\n" + + "You can then choose which optional components to disable.\n"). + AddButtons([]string{"Run check", "Skip"}). + SetDoneFunc(func(_ int, label string) { + confirmRun = (label == "Run check") + if !confirmRun { + app.Stop() + return + } + pages.SwitchToPage("running") + go func() { + suggestions, suggestionErr := CollectPostInstallDisableSuggestions(ctx, execPath, configPath) + app.QueueUpdateDraw(func() { + mu.Lock() + collectedSuggestions = suggestions + collectErr = suggestionErr + mu.Unlock() + if suggestionErr != nil { + showAuditDoneModal(app, pages, "Post-install check failed:\n\n"+suggestionErr.Error()) + return + } + if len(suggestions) == 0 { + showAuditDoneModal(app, pages, "No unused components detected.\n\nNo changes are required.") + return + } + showAuditReview(app, pages, configPath, suggestions, &applied) + }) + }() + }) + + confirm.SetBorder(true). + SetTitle(" Post-install Check "). + SetTitleAlign(tview.AlignCenter). + SetTitleColor(tui.ProxmoxOrange). + SetBorderColor(tui.ProxmoxOrange). + SetBackgroundColor(tcell.ColorBlack) + + running := tview.NewTextView(). + SetText("Running dry-run...\n\nPlease wait. This may take a minute."). + SetTextColor(tcell.ColorWhite). + SetTextAlign(tview.AlignCenter) + running.SetBorder(true). + SetTitle(" Running "). + SetTitleAlign(tview.AlignCenter). + SetTitleColor(tui.WarningYellow). + SetBorderColor(tui.WarningYellow). + SetBackgroundColor(tcell.ColorBlack) + + pages.AddPage("confirm", confirm, true, true) + pages.AddPage("running", running, true, false) + + layout := tview.NewFlex(). + SetDirection(tview.FlexRow). + AddItem(titleText, 5, 0, false). + AddItem(nav, 2, 0, false). + AddItem(separator, 1, 0, false). + AddItem(pages, 0, 1, true). + AddItem(configPathText, 1, 0, false). + AddItem(buildSigText, 1, 0, false) + + layout.SetBorder(true). + SetTitle(" ProxSave "). + SetTitleAlign(tview.AlignCenter). + SetTitleColor(tui.ProxmoxOrange). + SetBorderColor(tui.ProxmoxOrange). + SetBackgroundColor(tcell.ColorBlack) + + if runErr := postInstallAuditWizardRunner(app, layout, confirm); runErr != nil { + return PostInstallAuditResult{}, runErr + } + + result.Ran = confirmRun + mu.Lock() + result.Suggestions = collectedSuggestions + result.CollectErr = collectErr + mu.Unlock() + result.AppliedKeys = applied + return result, nil +} + +func showAuditDoneModal(app *tui.App, pages *tview.Pages, message string) { + done := tview.NewModal(). + SetText(message). + AddButtons([]string{"Continue"}). + SetDoneFunc(func(_ int, _ string) { + app.Stop() + }) + done.SetBorder(true). + SetTitle(" Post-install Check "). + SetTitleAlign(tview.AlignCenter). + SetTitleColor(tui.ProxmoxOrange). + SetBorderColor(tui.ProxmoxOrange). + SetBackgroundColor(tcell.ColorBlack) + + pages.AddAndSwitchToPage("done", done, true) +} + +func showAuditReview(app *tui.App, pages *tview.Pages, configPath string, suggestions []PostInstallAuditSuggestion, applied *[]string) { + if applied == nil { + tmp := []string{} + applied = &tmp + } + + selected := make(map[string]bool, len(suggestions)) + list := tview.NewList(). + ShowSecondaryText(false) + // We render checkbox markers like "[x]" which would otherwise be interpreted + // as style tags by tview and get stripped. + list.SetUseStyleTags(false, false) + + details := tview.NewTextView(). + SetDynamicColors(true). + SetWrap(true). + SetTextAlign(tview.AlignLeft) + details.SetBorder(true). + SetTitle(" Details "). + SetTitleAlign(tview.AlignCenter). + SetTitleColor(tui.ProxmoxOrange). + SetBorderColor(tui.ProxmoxOrange) + + updateListItem := func(index int) { + if index < 0 || index >= len(suggestions) { + return + } + key := suggestions[index].Key + marker := "[ ]" + if selected[key] { + marker = "[x]" + } + list.SetItemText(index, fmt.Sprintf("%s %s", marker, key), "") + } + + updateDetails := func(index int) { + if index < 0 || index >= len(suggestions) { + details.SetText("") + return + } + s := suggestions[index] + var b strings.Builder + b.WriteString("[yellow]Detected warnings:[white]\n\n") + for _, msg := range s.Messages { + b.WriteString("- ") + b.WriteString(msg) + b.WriteString("\n") + } + b.WriteString("\n") + b.WriteString(fmt.Sprintf("If you don’t use this feature, set [yellow]%s=false[white] to disable.\n", s.Key)) + details.SetText(b.String()) + } + + toggle := func(index int) { + if index < 0 || index >= len(suggestions) { + return + } + key := suggestions[index].Key + selected[key] = !selected[key] + updateListItem(index) + updateDetails(index) + } + + for i, s := range suggestions { + selected[s.Key] = false + list.AddItem("", "", 0, nil) + updateListItem(i) + } + + list.SetChangedFunc(func(index int, _ string, _ string, _ rune) { + updateDetails(index) + }) + list.SetSelectedFunc(func(index int, _ string, _ string, _ rune) { + toggle(index) + }) + list.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + switch event.Key() { + case tcell.KeyEnter: + // Let SetSelectedFunc handle it. + return event + } + if event.Rune() == ' ' { + toggle(list.GetCurrentItem()) + return nil + } + return event + }) + + if len(suggestions) > 0 { + updateDetails(0) + } + + buttons := tview.NewForm(). + AddButton("Disable selected", func() { + keys := make([]string, 0, len(suggestions)) + for _, s := range suggestions { + if selected[s.Key] { + keys = append(keys, s.Key) + } + } + sort.Strings(keys) + if len(keys) == 0 { + showAuditDoneModal(app, pages, "No changes selected.\n\nNothing was modified.") + return + } + if err := applyAuditDisables(configPath, keys); err != nil { + showAuditDoneModal(app, pages, "Failed to update configuration:\n\n"+err.Error()) + return + } + *applied = keys + showAuditDoneModal(app, pages, fmt.Sprintf("Configuration updated successfully.\n\nDisabled %d feature(s).", len(keys))) + }). + AddButton("Disable all", func() { + for i := range suggestions { + selected[suggestions[i].Key] = true + updateListItem(i) + } + updateDetails(list.GetCurrentItem()) + }). + AddButton("Skip", func() { + app.Stop() + }) + + buttons.SetBorder(true). + SetTitle(" Actions "). + SetTitleAlign(tview.AlignCenter). + SetTitleColor(tui.ProxmoxOrange). + SetBorderColor(tui.ProxmoxOrange). + SetBackgroundColor(tcell.ColorBlack) + + list.SetBorder(true). + SetTitle(" Suggestions "). + SetTitleAlign(tview.AlignCenter). + SetTitleColor(tui.ProxmoxOrange). + SetBorderColor(tui.ProxmoxOrange) + + mid := tview.NewFlex(). + AddItem(list, 0, 1, true). + AddItem(details, 0, 2, false) + + review := tview.NewFlex(). + SetDirection(tview.FlexRow). + AddItem(tview.NewTextView(). + SetText("Select which features to disable. This only changes backup.env flags.\n"). + SetTextColor(tcell.ColorWhite), 2, 0, false). + AddItem(mid, 0, 1, true). + AddItem(buttons, 7, 0, false) + + review.SetBorder(true). + SetTitle(" Review & Disable "). + SetTitleAlign(tview.AlignCenter). + SetTitleColor(tui.ProxmoxOrange). + SetBorderColor(tui.ProxmoxOrange). + SetBackgroundColor(tcell.ColorBlack) + + pages.AddAndSwitchToPage("review", review, true) + app.SetFocus(list) +} + +func applyAuditDisables(configPath string, keys []string) error { + if strings.TrimSpace(configPath) == "" { + return fmt.Errorf("config path cannot be empty") + } + if len(keys) == 0 { + return nil + } + + contentBytes, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("read configuration: %w", err) + } + content := string(contentBytes) + for _, key := range keys { + key = strings.ToUpper(strings.TrimSpace(key)) + if key == "" { + continue + } + content = setEnvValue(content, key, "false") + } + + tmpPath := configPath + ".tmp.audit" + if err := writeConfigFileAtomic(configPath, tmpPath, content); err != nil { + return err + } + return nil +} + +func writeConfigFileAtomic(configPath, tmpPath, content string) error { + dir := filepath.Dir(strings.TrimSpace(configPath)) + if dir == "" || dir == "." { + return fmt.Errorf("invalid configuration path: %q", configPath) + } + if err := os.MkdirAll(dir, 0o700); err != nil { + return fmt.Errorf("failed to create configuration directory: %w", err) + } + if err := os.WriteFile(tmpPath, []byte(content), 0o600); err != nil { + return fmt.Errorf("failed to write configuration file: %w", err) + } + if err := os.Rename(tmpPath, configPath); err != nil { + _ = os.Remove(tmpPath) + return fmt.Errorf("failed to finalize configuration file: %w", err) + } + return nil +}