From f62f3d600132a3e4e95138164b9ac98531997810 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 22 Feb 2026 18:51:14 +0100 Subject: [PATCH 01/12] Bump filippo.io/edwards25519 from 1.1.0 to 1.1.1 (#160) Bumps [filippo.io/edwards25519](https://github.com/FiloSottile/edwards25519) from 1.1.0 to 1.1.1. - [Commits](https://github.com/FiloSottile/edwards25519/compare/v1.1.0...v1.1.1) --- updated-dependencies: - dependency-name: filippo.io/edwards25519 dependency-version: 1.1.1 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 717ff99..4a36bb8 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( ) require ( - filippo.io/edwards25519 v1.1.0 // indirect + filippo.io/edwards25519 v1.1.1 // indirect filippo.io/hpke v0.4.0 // indirect github.com/gdamore/encoding v1.0.1 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect diff --git a/go.sum b/go.sum index ab99f6d..3ce3e36 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ c2sp.org/CCTV/age v0.0.0-20251208015420-e9274a7bdbfd h1:ZLsPO6WdZ5zatV4UfVpr7oAw c2sp.org/CCTV/age v0.0.0-20251208015420-e9274a7bdbfd/go.mod h1:SrHC2C7r5GkDk8R+NFVzYy/sdj0Ypg9htaPXQq5Cqeo= filippo.io/age v1.3.1 h1:hbzdQOJkuaMEpRCLSN1/C5DX74RPcNCk6oqhKMXmZi0= filippo.io/age v1.3.1/go.mod h1:EZorDTYUxt836i3zdori5IJX/v2Lj6kWFU0cfh6C0D4= -filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= -filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw= +filippo.io/edwards25519 v1.1.1/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= filippo.io/hpke v0.4.0 h1:p575VVQ6ted4pL+it6M00V/f2qTZITO0zgmdKCkd5+A= filippo.io/hpke v0.4.0/go.mod h1:EmAN849/P3qdeK+PCMkDpDm83vRHM5cDipBJ8xbQLVY= github.com/gdamore/encoding v1.0.1 h1:YzKZckdBL6jVt2Gc+5p82qhrGiqMdG/eNs6Wy0u3Uhw= From f0fe6e99abc42a963ecc3482841cff1a707dbe7b Mon Sep 17 00:00:00 2001 From: Damiano <71268257+tis24dev@users.noreply.github.com> Date: Mon, 23 Feb 2026 00:01:50 +0100 Subject: [PATCH 02/12] Remove ENABLE_GO_BACKUP flag and legacy wrappers Remove legacy Go-pipeline compatibility and related dead code. Deleted the prefilter-manual command and removed references to ENABLE_GO_BACKUP from configs and docs. Dropped Config.EnableGoBackup and its tests, cleaned up proxsave logging that referenced the flag. Consolidated bundle creation by removing the package-level createBundle wrapper and updating callers to use Orchestrator.createBundle; removed several legacy/compat helper functions in identity and orchestrator and adjusted unit tests to call the new helpers (encodeProtectedServerIDWithMACs, collectMACCandidates, etc.). Miscellaneous test cleanup: removed obsolete fake FS/test helpers no longer needed. These changes simplify code paths and eliminate obsolete compatibility layers. --- cmd/prefilter-manual/main.go | 59 -------------------- cmd/proxsave/main.go | 7 --- docs/CONFIGURATION.md | 5 -- docs/EXAMPLES.md | 1 - internal/config/config.go | 2 - internal/config/config_test.go | 27 --------- internal/config/templates/backup.env | 1 - internal/identity/identity.go | 19 ------- internal/identity/identity_test.go | 42 ++++++-------- internal/orchestrator/bundle_test.go | 35 ------------ internal/orchestrator/decrypt_test.go | 37 ------------ internal/orchestrator/decrypt_workflow_ui.go | 3 +- internal/orchestrator/orchestrator.go | 6 -- 13 files changed, 20 insertions(+), 224 deletions(-) delete mode 100644 cmd/prefilter-manual/main.go diff --git a/cmd/prefilter-manual/main.go b/cmd/prefilter-manual/main.go deleted file mode 100644 index 6f0d1a8..0000000 --- a/cmd/prefilter-manual/main.go +++ /dev/null @@ -1,59 +0,0 @@ -package main - -import ( - "context" - "flag" - "os" - "path/filepath" - "strings" - - "github.com/tis24dev/proxsave/internal/backup" - "github.com/tis24dev/proxsave/internal/logging" - "github.com/tis24dev/proxsave/internal/types" -) - -func parseLogLevel(raw string) types.LogLevel { - switch strings.ToLower(strings.TrimSpace(raw)) { - case "debug": - return types.LogLevelDebug - case "info", "": - return types.LogLevelInfo - case "warning", "warn": - return types.LogLevelWarning - case "error": - return types.LogLevelError - default: - return types.LogLevelInfo - } -} - -func main() { - var ( - root string - maxSize int64 - levelLabel string - ) - - flag.StringVar(&root, "root", "/tmp/test_prefilter", "Root directory to run prefilter on") - flag.Int64Var(&maxSize, "max-size", 8*1024*1024, "Max file size (bytes) to prefilter") - flag.StringVar(&levelLabel, "log-level", "info", "Log level: debug|info|warn|error") - flag.Parse() - - root = filepath.Clean(strings.TrimSpace(root)) - if root == "" || root == "." { - root = string(os.PathSeparator) - } - - logger := logging.New(parseLogLevel(levelLabel), false) - logger.SetOutput(os.Stdout) - - cfg := backup.OptimizationConfig{ - EnablePrefilter: true, - PrefilterMaxFileSizeBytes: maxSize, - } - - if err := backup.ApplyOptimizations(context.Background(), logger, root, cfg); err != nil { - logger.Error("Prefilter failed: %v", err) - os.Exit(1) - } -} diff --git a/cmd/proxsave/main.go b/cmd/proxsave/main.go index 1d51c69..185e7b6 100644 --- a/cmd/proxsave/main.go +++ b/cmd/proxsave/main.go @@ -1302,13 +1302,6 @@ func run() int { fmt.Println() - if !cfg.EnableGoBackup && !args.Support { - logging.Warning("ENABLE_GO_BACKUP=false is ignored; the Go backup pipeline is always used.") - } else { - logging.Debug("Go backup pipeline enabled") - } - fmt.Println() - // Storage info logging.Info("Storage configuration:") logging.Info(" Primary: %s", formatStorageLabel(cfg.BackupPath, localFS)) diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index fddead4..87e655b 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -46,9 +46,6 @@ Complete reference for all 200+ configuration variables in `configs/backup.env`. # Enable/disable backup system BACKUP_ENABLED=true # true | false -# Enable Go pipeline (vs legacy Bash) -ENABLE_GO_BACKUP=true # true | false - # Colored output in terminal USE_COLOR=true # true | false @@ -922,8 +919,6 @@ METRICS_ENABLED=false # true | false METRICS_PATH=${BASE_DIR}/metrics # Empty = /var/lib/prometheus/node-exporter ``` -> ℹ️ Metrics export is available only for the Go pipeline (`ENABLE_GO_BACKUP=true`). - **Output**: Creates `proxmox_backup.prom` in `METRICS_PATH` with: - Backup duration and start/end timestamps - Archive size and raw bytes collected diff --git a/docs/EXAMPLES.md b/docs/EXAMPLES.md index 0a53278..3f8a6eb 100644 --- a/docs/EXAMPLES.md +++ b/docs/EXAMPLES.md @@ -866,7 +866,6 @@ CLOUD_LOG_PATH= # configs/backup.env SYSTEM_ROOT_PREFIX=/mnt/snapshot-root # points to the alternate root BACKUP_ENABLED=true -ENABLE_GO_BACKUP=true # /etc, /var, /root, /home are resolved under the prefix ``` diff --git a/internal/config/config.go b/internal/config/config.go index 1bea4f2..bb35388 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -33,7 +33,6 @@ type Config struct { DebugLevel types.LogLevel UseColor bool ColorizeStepLogs bool - EnableGoBackup bool ProfilingEnabled bool BaseDir string DryRun bool @@ -423,7 +422,6 @@ func (c *Config) parseOptimizationSettings() { } func (c *Config) parseSecuritySettings() { - c.EnableGoBackup = c.getBoolWithFallback([]string{"ENABLE_GO_BACKUP", "ENABLE_GO_PIPELINE"}, true) c.DisableNetworkPreflight = c.getBool("DISABLE_NETWORK_PREFLIGHT", false) // Base directory diff --git a/internal/config/config_test.go b/internal/config/config_test.go index eb26f81..3fb7ab0 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -427,38 +427,11 @@ func TestConfigDefaults(t *testing.T) { t.Errorf("Default LocalRetentionDays = %d; want 7", cfg.LocalRetentionDays) } - if !cfg.EnableGoBackup { - t.Error("Expected default EnableGoBackup to be true") - } - if cfg.BaseDir != "/defaults/base" { t.Errorf("Default BaseDir = %q; want %q", cfg.BaseDir, "/defaults/base") } } -func TestEnableGoBackupFlag(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "go_pipeline.env") - - content := `ENABLE_GO_BACKUP=false -` - if err := os.WriteFile(configPath, []byte(content), 0644); err != nil { - t.Fatalf("Failed to create test config: %v", err) - } - - cleanup := setBaseDirEnv(t, "/flag/base") - defer cleanup() - - cfg, err := LoadConfig(configPath) - if err != nil { - t.Fatalf("LoadConfig() error = %v", err) - } - - if cfg.EnableGoBackup { - t.Error("Expected EnableGoBackup to be false when explicitly disabled") - } -} - func TestLoadConfigBaseDirFromConfig(t *testing.T) { tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, "base_dir.env") diff --git a/internal/config/templates/backup.env b/internal/config/templates/backup.env index 9c3ec39..2dbf810 100644 --- a/internal/config/templates/backup.env +++ b/internal/config/templates/backup.env @@ -7,7 +7,6 @@ # General settings # ---------------------------------------------------------------------- BACKUP_ENABLED=true -ENABLE_GO_BACKUP=true PROFILING_ENABLED=true # Enable CPU/heap profiling (pprof) for Go pipeline USE_COLOR=true COLORIZE_STEP_LOGS=true # Highlight "Step N/8" lines (requires USE_COLOR=true) diff --git a/internal/identity/identity.go b/internal/identity/identity.go index 0e43740..87ae582 100644 --- a/internal/identity/identity.go +++ b/internal/identity/identity.go @@ -170,11 +170,6 @@ func collectMACCandidates(logger *logging.Logger) ([]macCandidate, []string) { return candidates, macs } -func collectMACAddresses() []string { - _, macs := collectMACCandidates(nil) - return macs -} - func selectPreferredMAC(candidates []macCandidate) (string, string) { var best *macCandidate for i := range candidates { @@ -469,10 +464,6 @@ func buildSystemData(macs []string, logger *logging.Logger) string { return builder.String() } -func encodeProtectedServerID(serverID, primaryMAC string, logger *logging.Logger) (string, error) { - return encodeProtectedServerIDWithMACs(serverID, []string{primaryMAC}, primaryMAC, logger) -} - func encodeProtectedServerIDWithMACs(serverID string, macs []string, primaryMAC string, logger *logging.Logger) (string, error) { logDebug(logger, "Identity: encodeProtectedServerID: start (serverID=%s primaryMAC=%s)", serverID, primaryMAC) keyField := buildIdentityKeyField(macs, primaryMAC, logger) @@ -574,16 +565,6 @@ func decodeProtectedServerID(fileContent, primaryMAC string, logger *logging.Log return serverID, matchedByMAC, nil } -func generateSystemKey(primaryMAC string, logger *logging.Logger) string { - machineID := readMachineID(logger) - hostnamePart := readHostnamePart(logger) - - macPart := strings.ReplaceAll(primaryMAC, ":", "") - key := computeSystemKey(machineID, hostnamePart, macPart) - logDebug(logger, "Identity: generateSystemKey: systemKey=%s", key) - return key -} - func buildIdentityKeyField(macs []string, primaryMAC string, logger *logging.Logger) string { machineID := readMachineID(logger) hostnamePart := readHostnamePart(logger) diff --git a/internal/identity/identity_test.go b/internal/identity/identity_test.go index d7a6354..8271696 100644 --- a/internal/identity/identity_test.go +++ b/internal/identity/identity_test.go @@ -15,11 +15,15 @@ import ( "github.com/tis24dev/proxsave/internal/types" ) +func encodeProtectedServerIDForTest(serverID, primaryMAC string, logger *logging.Logger) (string, error) { + return encodeProtectedServerIDWithMACs(serverID, []string{primaryMAC}, primaryMAC, logger) +} + func TestEncodeDecodeProtectedServerIDRoundTrip(t *testing.T) { const serverID = "1234567890123456" const mac = "aa:bb:cc:dd:ee:ff" - content, err := encodeProtectedServerID(serverID, mac, nil) + content, err := encodeProtectedServerIDForTest(serverID, mac, nil) if err != nil { t.Fatalf("encodeProtectedServerID() error = %v", err) } @@ -57,7 +61,7 @@ func TestDecodeProtectedServerIDAcceptsDifferentMACOnSameHost(t *testing.T) { } const serverID = "1111222233334444" - content, err := encodeProtectedServerID(serverID, "aa:bb:cc:dd:ee:ff", nil) + content, err := encodeProtectedServerIDForTest(serverID, "aa:bb:cc:dd:ee:ff", nil) if err != nil { t.Fatalf("encodeProtectedServerID() error = %v", err) } @@ -95,7 +99,7 @@ func TestDecodeProtectedServerIDRejectsDifferentHost(t *testing.T) { } const serverID = "1111222233334444" - content, err := encodeProtectedServerID(serverID, "aa:bb:cc:dd:ee:ff", nil) + content, err := encodeProtectedServerIDForTest(serverID, "aa:bb:cc:dd:ee:ff", nil) if err != nil { t.Fatalf("encodeProtectedServerID() error = %v", err) } @@ -147,7 +151,7 @@ func TestDecodeProtectedServerIDDetectsCorruptedData(t *testing.T) { const serverID = "5555666677778888" const mac = "aa:aa:aa:aa:aa:aa" - content, err := encodeProtectedServerID(serverID, mac, nil) + content, err := encodeProtectedServerIDForTest(serverID, mac, nil) if err != nil { t.Fatalf("encodeProtectedServerID() error = %v", err) } @@ -204,14 +208,14 @@ func TestDetectUsesExistingIdentityFile(t *testing.T) { } identityPath := filepath.Join(identityDir, identityFileName) - macs := collectMACAddresses() + _, macs := collectMACCandidates(nil) if len(macs) == 0 { t.Skip("no non-loopback MACs available on this system") } primary := macs[0] const serverID = "1234567890123456" - content, err := encodeProtectedServerID(serverID, primary, nil) + content, err := encodeProtectedServerIDForTest(serverID, primary, nil) if err != nil { t.Fatalf("encodeProtectedServerID() error = %v", err) } @@ -354,21 +358,8 @@ func TestFallbackServerIDFormat(t *testing.T) { } } -func TestGenerateSystemKeyStableAndLength(t *testing.T) { - const mac = "aa:bb:cc:dd:ee:ff" - k1 := generateSystemKey(mac, nil) - k2 := generateSystemKey(mac, nil) - - if len(k1) != 16 { - t.Fatalf("generateSystemKey length = %d, want 16", len(k1)) - } - if k1 != k2 { - t.Fatalf("generateSystemKey should be stable, got %q and %q", k1, k2) - } -} - func TestCollectMACAddressesSortedAndUnique(t *testing.T) { - macs := collectMACAddresses() + _, macs := collectMACCandidates(nil) for i := 0; i < len(macs); i++ { if macs[i] == "" { t.Fatalf("unexpected empty MAC at index %d", i) @@ -413,7 +404,7 @@ func TestDecodeProtectedServerIDInvalidPayloadFormat(t *testing.T) { func TestDecodeProtectedServerIDInvalidServerIDFormat(t *testing.T) { const mac = "aa:bb:cc:dd:ee:ff" - content, err := encodeProtectedServerID("AAAAAAAAAAAAAAAA", mac, nil) + content, err := encodeProtectedServerIDForTest("AAAAAAAAAAAAAAAA", mac, nil) if err != nil { t.Fatalf("encodeProtectedServerID() error = %v", err) } @@ -446,7 +437,7 @@ func TestLoadServerIDWithEmptyMACSlice(t *testing.T) { path := filepath.Join(dir, "identity.conf") const serverID = "1234567890123456" - content, err := encodeProtectedServerID(serverID, "", nil) + content, err := encodeProtectedServerIDForTest(serverID, "", nil) if err != nil { t.Fatalf("encodeProtectedServerID() error = %v", err) } @@ -487,7 +478,10 @@ func TestLoadServerIDFailsAllMACs(t *testing.T) { } func encodeProtectedServerIDLegacy(serverID, primaryMAC string) (string, error) { - systemKey := generateSystemKey(primaryMAC, nil) + machineID := readMachineID(nil) + hostnamePart := readHostnamePart(nil) + macPart := strings.ReplaceAll(primaryMAC, ":", "") + systemKey := computeSystemKey(machineID, hostnamePart, macPart) timestamp := time.Unix(1700000000, 0).Unix() data := fmt.Sprintf("%s:%d:%s", serverID, timestamp, systemKey[:systemKeyPrefixLength]) checksum := sha256.Sum256([]byte(data)) @@ -1588,7 +1582,7 @@ func TestDecodeProtectedServerIDWithEmptyMAC(t *testing.T) { } const serverID = "1234567890123456" - content, err := encodeProtectedServerID(serverID, "aa:bb:cc:dd:ee:ff", nil) + content, err := encodeProtectedServerIDForTest(serverID, "aa:bb:cc:dd:ee:ff", nil) if err != nil { t.Fatalf("encodeProtectedServerID() error = %v", err) } diff --git a/internal/orchestrator/bundle_test.go b/internal/orchestrator/bundle_test.go index 9462b4b..280060e 100644 --- a/internal/orchestrator/bundle_test.go +++ b/internal/orchestrator/bundle_test.go @@ -121,41 +121,6 @@ func TestCreateBundle_CreatesValidTarArchive(t *testing.T) { } } -func TestLegacyCreateBundleWrapper_DelegatesToMethod(t *testing.T) { - logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) - tempDir := t.TempDir() - archive := filepath.Join(tempDir, "backup.tar") - - // Minimal associated files required by createBundle - required := map[string]string{ - "": "archive-content", - ".sha256": "checksum", - ".metadata": "metadata-json", - } - for suffix, content := range required { - if err := os.WriteFile(archive+suffix, []byte(content), 0o640); err != nil { - t.Fatalf("write %s: %v", suffix, err) - } - } - - ctx := context.Background() - - // Call legacy wrapper - bundlePath, err := createBundle(ctx, logger, archive) - if err != nil { - t.Fatalf("legacy createBundle returned error: %v", err) - } - - expectedPath := archive + ".bundle.tar" - if bundlePath != expectedPath { - t.Fatalf("bundle path = %s, want %s", bundlePath, expectedPath) - } - - if _, err := os.Stat(bundlePath); err != nil { - t.Fatalf("expected bundle file to exist, got %v", err) - } -} - func TestRemoveAssociatedFiles_RemovesAll(t *testing.T) { logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) tempDir := t.TempDir() diff --git a/internal/orchestrator/decrypt_test.go b/internal/orchestrator/decrypt_test.go index ae17bd2..84fcee1 100644 --- a/internal/orchestrator/decrypt_test.go +++ b/internal/orchestrator/decrypt_test.go @@ -4250,22 +4250,6 @@ exit 0 os.Setenv("PATH", tmp+":"+origPath) defer os.Setenv("PATH", origPath) - // Create a filesystem wrapper that allows download but fails MkdirAll for tempRoot - type fakeMkdirAllFailOnTempRoot struct { - osFS - } - fake := &struct { - osFS - mkdirCalls int - }{} - - // Use osFS with a hook to fail on the second MkdirAll (tempRoot creation) - type osFSWithMkdirHook struct { - osFS - mkdirCalls int - } - hookFS := &osFSWithMkdirHook{} - orig := restoreFS // Use regular osFS - the download will work, then MkdirAll for /tmp/proxsave should succeed // but we can trigger error by making /tmp/proxsave unwritable after download @@ -4289,27 +4273,6 @@ exit 0 } // If download succeeds and extraction succeeds, that's fine - we've tested the path _ = err - _ = fake - _ = hookFS -} - -// fakeChecksumFailFS wraps osFS to make the plain archive unreadable after extraction -// This triggers GenerateChecksum error (lines 670-673) -type fakeChecksumFailFS struct { - osFS - extractDone bool -} - -func (f *fakeChecksumFailFS) OpenFile(path string, flag int, perm os.FileMode) (*os.File, error) { - file, err := os.OpenFile(path, flag, perm) - if err != nil { - return nil, err - } - // After extracting, make the archive unreadable for checksum - if f.extractDone && strings.Contains(path, "proxmox-decrypt") && strings.HasSuffix(path, ".tar.xz") { - os.Chmod(path, 0o000) - } - return file, nil } // fakeStatThenRemoveFS removes the file after stat succeeds diff --git a/internal/orchestrator/decrypt_workflow_ui.go b/internal/orchestrator/decrypt_workflow_ui.go index a57d45a..2ae37d7 100644 --- a/internal/orchestrator/decrypt_workflow_ui.go +++ b/internal/orchestrator/decrypt_workflow_ui.go @@ -361,7 +361,8 @@ func runDecryptWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l } logger.Info("Creating decrypted bundle...") - bundlePath, err := createBundle(ctx, logger, tempArchivePath) + o := &Orchestrator{logger: logger, fs: osFS{}} + bundlePath, err := o.createBundle(ctx, tempArchivePath) if err != nil { return err } diff --git a/internal/orchestrator/orchestrator.go b/internal/orchestrator/orchestrator.go index c629435..4eb71f7 100644 --- a/internal/orchestrator/orchestrator.go +++ b/internal/orchestrator/orchestrator.go @@ -1192,12 +1192,6 @@ func (o *Orchestrator) removeAssociatedFiles(archivePath string) error { return nil } -// Legacy compatibility wrapper for callers that used the package-level createBundle function. -func createBundle(ctx context.Context, logger *logging.Logger, archivePath string) (string, error) { - o := &Orchestrator{logger: logger, fs: osFS{}, clock: realTimeProvider{}} - return o.createBundle(ctx, archivePath) -} - // encryptArchive was replaced by streaming encryption inside the archiver. // SaveStatsReport writes a JSON report with backup statistics to the log directory. From 7caa1823e3f466cd5af7430f7657d0fbdd832c90 Mon Sep 17 00:00:00 2001 From: Damiano <71268257+tis24dev@users.noreply.github.com> Date: Mon, 23 Feb 2026 01:34:46 +0100 Subject: [PATCH 03/12] Make mount guard functions mockable, add tests Extract direct OS/syscall/fstab calls in mount guard into package-level function variables (e.g. mountGuardGeteuid, mountGuardReadFile, mountGuardMkdirAll, mountGuardSysMount, mountGuardSysUnmount, mountGuardFstabMountpointsSet, mountGuardIsPathOnRootFilesystem, mountGuardParsePBSDatastoreCfg) and update usages to call those variables. This makes the mount-guard logic easily mockable for unit tests. Add extensive tests in internal/orchestrator/mount_guard_more_test.go covering guardDirForTarget, isMounted (mountinfo/proc mounts fallback and error combinations), guardMountPoint behaviors (mkdir, bind, remount, unmount, context handling), and many flows for maybeApplyPBSDatastoreMountGuards including parsing, fstab fallback, mount attempts and timeout handling. Also adjust an existing test case in pbs_mount_guard_test.go to include a /run/media root scenario and remove a redundant check in pbsMountGuardRootForDatastorePath. These changes improve test coverage and reliability without changing runtime behavior. --- internal/orchestrator/mount_guard.go | 43 +- .../orchestrator/mount_guard_more_test.go | 896 ++++++++++++++++++ internal/orchestrator/pbs_mount_guard_test.go | 1 + 3 files changed, 923 insertions(+), 17 deletions(-) create mode 100644 internal/orchestrator/mount_guard_more_test.go diff --git a/internal/orchestrator/mount_guard.go b/internal/orchestrator/mount_guard.go index 037811d..f4ee262 100644 --- a/internal/orchestrator/mount_guard.go +++ b/internal/orchestrator/mount_guard.go @@ -18,6 +18,18 @@ import ( const mountGuardBaseDir = "/var/lib/proxsave/guards" const mountGuardMountAttemptTimeout = 10 * time.Second +var ( + mountGuardGeteuid = os.Geteuid + mountGuardReadFile = os.ReadFile + mountGuardMkdirAll = os.MkdirAll + mountGuardReadDir = os.ReadDir + mountGuardSysMount = syscall.Mount + mountGuardSysUnmount = syscall.Unmount + mountGuardFstabMountpointsSet = fstabMountpointsSet + mountGuardIsPathOnRootFilesystem = isPathOnRootFilesystem + mountGuardParsePBSDatastoreCfg = parsePBSDatastoreCfgBlocks +) + func maybeApplyPBSDatastoreMountGuards(ctx context.Context, logger *logging.Logger, plan *RestorePlan, stageRoot, destRoot string, dryRun bool) error { if plan == nil || plan.SystemType != SystemTypePBS || !plan.HasCategoryID("datastore_pbs") { return nil @@ -44,7 +56,7 @@ func maybeApplyPBSDatastoreMountGuards(ctx context.Context, logger *logging.Logg } return nil } - if os.Geteuid() != 0 { + if mountGuardGeteuid() != 0 { if logger != nil { logger.Warning("Skipping PBS mount guards: requires root privileges") } @@ -64,7 +76,7 @@ func maybeApplyPBSDatastoreMountGuards(ctx context.Context, logger *logging.Logg } normalized, _ := normalizePBSDatastoreCfgContent(string(data)) - blocks, err := parsePBSDatastoreCfgBlocks(normalized) + blocks, err := mountGuardParsePBSDatastoreCfg(normalized) if err != nil { return err } @@ -75,7 +87,7 @@ func maybeApplyPBSDatastoreMountGuards(ctx context.Context, logger *logging.Logg var fstabMounts map[string]struct{} var mountpointCandidates []string currentFstab := filepath.Join(destRoot, "etc", "fstab") - if mounts, err := fstabMountpointsSet(currentFstab); err != nil { + if mounts, err := mountGuardFstabMountpointsSet(currentFstab); err != nil { if logger != nil { logger.Warning("PBS mount guard: unable to parse current fstab %s: %v (continuing without fstab cross-check)", currentFstab, err) } @@ -123,14 +135,14 @@ func maybeApplyPBSDatastoreMountGuards(ctx context.Context, logger *logging.Logg } } - if err := os.MkdirAll(guardTarget, 0o755); err != nil { + if err := mountGuardMkdirAll(guardTarget, 0o755); err != nil { if logger != nil { logger.Warning("PBS mount guard: unable to create mountpoint directory %s: %v", guardTarget, err) } continue } - onRootFS, _, devErr := isPathOnRootFilesystem(guardTarget) + onRootFS, _, devErr := mountGuardIsPathOnRootFilesystem(guardTarget) if devErr != nil { if logger != nil { logger.Warning("PBS mount guard: unable to determine filesystem device for %s: %v", guardTarget, devErr) @@ -158,7 +170,7 @@ func maybeApplyPBSDatastoreMountGuards(ctx context.Context, logger *logging.Logg out, attemptErr := restoreCmd.Run(mountCtx, "mount", guardTarget) cancel() if attemptErr == nil { - onRootFSNow, _, devErrNow := isPathOnRootFilesystem(guardTarget) + onRootFSNow, _, devErrNow := mountGuardIsPathOnRootFilesystem(guardTarget) if devErrNow == nil && !onRootFSNow { if logger != nil { logger.Info("PBS mount guard: mountpoint %s is now mounted (mount attempt succeeded)", guardTarget) @@ -209,7 +221,7 @@ func maybeApplyPBSDatastoreMountGuards(ctx context.Context, logger *logging.Logg protected[guardTarget] = struct{}{} if logger != nil { - if entries, err := os.ReadDir(guardTarget); err == nil && len(entries) > 0 { + if entries, err := mountGuardReadDir(guardTarget); err == nil && len(entries) > 0 { logger.Warning("PBS mount guard: guard mount point %s is not empty (entries=%d)", guardTarget, len(entries)) } logger.Warning("PBS mount guard: %s resolves to root filesystem (mount missing?) — bind-mounted a read-only guard to prevent writes until storage is available", guardTarget) @@ -241,22 +253,22 @@ func guardMountPoint(ctx context.Context, guardTarget string) error { } guardDir := guardDirForTarget(target) - if err := os.MkdirAll(guardDir, 0o755); err != nil { + if err := mountGuardMkdirAll(guardDir, 0o755); err != nil { return fmt.Errorf("mkdir guard dir: %w", err) } - if err := os.MkdirAll(target, 0o755); err != nil { + if err := mountGuardMkdirAll(target, 0o755); err != nil { return fmt.Errorf("mkdir target: %w", err) } // Bind mount guard directory over the mountpoint to avoid writes to the underlying rootfs path. - if err := syscall.Mount(guardDir, target, "", syscall.MS_BIND, ""); err != nil { + if err := mountGuardSysMount(guardDir, target, "", syscall.MS_BIND, ""); err != nil { return fmt.Errorf("bind mount guard: %w", err) } // Make the bind mount read-only to ensure PBS cannot write backup data to the guard directory. remountFlags := uintptr(syscall.MS_BIND | syscall.MS_REMOUNT | syscall.MS_RDONLY | syscall.MS_NODEV | syscall.MS_NOSUID | syscall.MS_NOEXEC) - if err := syscall.Mount("", target, "", remountFlags, ""); err != nil { - _ = syscall.Unmount(target, 0) + if err := mountGuardSysMount("", target, "", remountFlags, ""); err != nil { + _ = mountGuardSysUnmount(target, 0) return fmt.Errorf("remount guard read-only: %w", err) } @@ -274,7 +286,7 @@ func guardDirForTarget(target string) string { } func isMounted(path string) (bool, error) { - data, err := os.ReadFile("/proc/self/mountinfo") + data, err := mountGuardReadFile("/proc/self/mountinfo") if err == nil { return isMountedFromMountinfo(string(data), path), nil } @@ -315,7 +327,7 @@ func isMountedFromMountinfo(mountinfo, path string) bool { } func isMountedFromProcMounts(path string) (bool, error) { - data, err := os.ReadFile("/proc/mounts") + data, err := mountGuardReadFile("/proc/mounts") if err != nil { return false, err } @@ -408,9 +420,6 @@ func pbsMountGuardRootForDatastorePath(path string) string { case strings.HasPrefix(p, "/run/media/"): rest := strings.TrimPrefix(p, "/run/media/") parts := splitPath(rest) - if len(parts) == 0 { - return "" - } if len(parts) == 1 { return filepath.Join("/run/media", parts[0]) } diff --git a/internal/orchestrator/mount_guard_more_test.go b/internal/orchestrator/mount_guard_more_test.go new file mode 100644 index 0000000..4110908 --- /dev/null +++ b/internal/orchestrator/mount_guard_more_test.go @@ -0,0 +1,896 @@ +package orchestrator + +import ( + "context" + "crypto/sha256" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "syscall" + "testing" + "time" +) + +func TestGuardDirForTarget(t *testing.T) { + t.Parallel() + + target := "/mnt/datastore" + sum := sha256.Sum256([]byte(target)) + id := fmt.Sprintf("%x", sum[:8]) + want := filepath.Join(mountGuardBaseDir, fmt.Sprintf("%s-%s", filepath.Base(target), id)) + if got := guardDirForTarget(target); got != want { + t.Fatalf("guardDirForTarget(%q)=%q want %q", target, got, want) + } + + rootTarget := "/" + sum = sha256.Sum256([]byte(rootTarget)) + id = fmt.Sprintf("%x", sum[:8]) + want = filepath.Join(mountGuardBaseDir, fmt.Sprintf("%s-%s", "guard", id)) + if got := guardDirForTarget(rootTarget); got != want { + t.Fatalf("guardDirForTarget(%q)=%q want %q", rootTarget, got, want) + } +} + +func TestIsMountedFromMountinfo(t *testing.T) { + t.Parallel() + + mountinfo := strings.Join([]string{ + "36 25 0:32 / / rw,relatime - ext4 /dev/sda1 rw", + `37 36 0:33 / /mnt/pbs\040datastore rw,relatime - ext4 /dev/sdb1 rw`, + "bad line", + "", + }, "\n") + + if got := isMountedFromMountinfo(mountinfo, "/"); !got { + t.Fatalf("expected / to be mounted") + } + if got := isMountedFromMountinfo(mountinfo, "/mnt/pbs datastore"); !got { + t.Fatalf("expected escaped mountpoint to match") + } + if got := isMountedFromMountinfo(mountinfo, "/not-mounted"); got { + t.Fatalf("expected /not-mounted to be unmounted") + } + if got := isMountedFromMountinfo(mountinfo, ""); got { + t.Fatalf("expected empty path to be unmounted") + } +} + +func TestFstabMountpointsSet(t *testing.T) { + tmp := filepath.Join(t.TempDir(), "fstab") + content := strings.Join([]string{ + "# comment", + "UUID=abc / ext4 defaults 0 1", + "/dev/sdb1 /mnt/data/ ext4 defaults 0 2", + "/dev/sdc1 /mnt/data2 ext4 defaults 0 2 # inline comment", + "/dev/sdd1 . ext4 defaults 0 0", + "invalidline", + "", + }, "\n") + if err := os.WriteFile(tmp, []byte(content), 0o600); err != nil { + t.Fatalf("write temp fstab: %v", err) + } + + mps, err := fstabMountpointsSet(tmp) + if err != nil { + t.Fatalf("fstabMountpointsSet error: %v", err) + } + + for _, mp := range []string{"/", "/mnt/data", "/mnt/data2"} { + if _, ok := mps[mp]; !ok { + t.Fatalf("expected mountpoint %s to be present", mp) + } + } + if _, ok := mps["."]; ok { + t.Fatalf("expected dot mountpoint to be skipped") + } +} + +func TestFstabMountpointsSet_Error(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + restoreFS = NewFakeFS() + + if _, err := fstabMountpointsSet("/does-not-exist"); err == nil { + t.Fatalf("expected error") + } +} + +func TestSplitPathAndMountRootWithPrefix(t *testing.T) { + t.Parallel() + + if got := splitPath("a//b/ /c/"); strings.Join(got, ",") != "a,b,c" { + t.Fatalf("splitPath unexpected: %#v", got) + } + if got := mountRootWithPrefix("/mnt/datastore/Data1", "/mnt/"); got != "/mnt/datastore" { + t.Fatalf("mountRootWithPrefix got %q want %q", got, "/mnt/datastore") + } + if got := mountRootWithPrefix("/mnt/", "/mnt/"); got != "" { + t.Fatalf("mountRootWithPrefix(/mnt/)=%q want empty", got) + } +} + +func TestSortByLengthDesc(t *testing.T) { + t.Parallel() + + items := []string{"a", "abc", "ab"} + sortByLengthDesc(items) + if len(items) != 3 { + t.Fatalf("unexpected len: %d", len(items)) + } + if !(len(items[0]) >= len(items[1]) && len(items[1]) >= len(items[2])) { + t.Fatalf("expected non-increasing lengths, got %#v", items) + } +} + +func TestFirstFstabMountpointMatch(t *testing.T) { + t.Parallel() + + mountpoints := []string{"/mnt/storage/pbs", "/mnt/storage", "/"} + if got := firstFstabMountpointMatch("/mnt/storage/pbs/ds1/data", mountpoints); got != "/mnt/storage/pbs" { + t.Fatalf("firstFstabMountpointMatch got %q want %q", got, "/mnt/storage/pbs") + } + if got := firstFstabMountpointMatch(" ", mountpoints); got != "" { + t.Fatalf("firstFstabMountpointMatch empty got %q want empty", got) + } +} + +func TestIsMounted_Variants(t *testing.T) { + origReadFile := mountGuardReadFile + t.Cleanup(func() { mountGuardReadFile = origReadFile }) + + t.Run("prefers mountinfo", func(t *testing.T) { + mountGuardReadFile = func(path string) ([]byte, error) { + if path != "/proc/self/mountinfo" { + t.Fatalf("unexpected read path: %s", path) + } + return []byte("1 2 3:4 / /mnt/target rw - ext4 /dev/sda1 rw\n"), nil + } + mounted, err := isMounted("/mnt/target") + if err != nil { + t.Fatalf("isMounted error: %v", err) + } + if !mounted { + t.Fatalf("expected mounted") + } + }) + + t.Run("falls back to proc mounts", func(t *testing.T) { + mountGuardReadFile = func(path string) ([]byte, error) { + switch path { + case "/proc/self/mountinfo": + return nil, os.ErrNotExist + case "/proc/mounts": + return []byte("/dev/sda1 /mnt/target ext4 rw 0 0\n"), nil + default: + t.Fatalf("unexpected read path: %s", path) + return nil, nil + } + } + mounted, err := isMounted("/mnt/target") + if err != nil { + t.Fatalf("isMounted error: %v", err) + } + if !mounted { + t.Fatalf("expected mounted") + } + }) + + t.Run("reports mounts error when mountinfo missing", func(t *testing.T) { + wantErr := errors.New("mounts read failed") + mountGuardReadFile = func(path string) ([]byte, error) { + switch path { + case "/proc/self/mountinfo": + return nil, os.ErrNotExist + case "/proc/mounts": + return nil, wantErr + default: + t.Fatalf("unexpected read path: %s", path) + return nil, nil + } + } + _, err := isMounted("/mnt/target") + if !errors.Is(err, wantErr) { + t.Fatalf("expected mounts error, got %v", err) + } + }) + + t.Run("includes both errors when mountinfo read fails", func(t *testing.T) { + mountErr := errors.New("mountinfo boom") + mountsErr := errors.New("mounts boom") + mountGuardReadFile = func(path string) ([]byte, error) { + switch path { + case "/proc/self/mountinfo": + return nil, mountErr + case "/proc/mounts": + return nil, mountsErr + default: + t.Fatalf("unexpected read path: %s", path) + return nil, nil + } + } + _, err := isMounted("/mnt/target") + if err == nil || !strings.Contains(err.Error(), "mountinfo boom") || !strings.Contains(err.Error(), "mounts boom") { + t.Fatalf("expected combined error, got %v", err) + } + }) +} + +func TestIsMountedFromProcMounts_Parsing(t *testing.T) { + origReadFile := mountGuardReadFile + t.Cleanup(func() { mountGuardReadFile = origReadFile }) + + mountGuardReadFile = func(path string) ([]byte, error) { + if path != "/proc/mounts" { + t.Fatalf("unexpected read path: %s", path) + } + return []byte(strings.Join([]string{ + "", + "invalid", + "/dev/sda1 /mnt/other ext4 rw 0 0", + "", + }, "\n")), nil + } + + mounted, err := isMountedFromProcMounts("/mnt/target") + if err != nil { + t.Fatalf("isMountedFromProcMounts error: %v", err) + } + if mounted { + t.Fatalf("expected unmounted") + } + + mounted, err = isMountedFromProcMounts(" ") + if err != nil { + t.Fatalf("isMountedFromProcMounts empty target error: %v", err) + } + if mounted { + t.Fatalf("expected empty target to be unmounted") + } +} + +func TestGuardMountPoint(t *testing.T) { + origReadFile := mountGuardReadFile + origMkdirAll := mountGuardMkdirAll + origMount := mountGuardSysMount + origUnmount := mountGuardSysUnmount + t.Cleanup(func() { + mountGuardReadFile = origReadFile + mountGuardMkdirAll = origMkdirAll + mountGuardSysMount = origMount + mountGuardSysUnmount = origUnmount + }) + + t.Run("rejects invalid target", func(t *testing.T) { + if err := guardMountPoint(context.Background(), "/"); err == nil { + t.Fatalf("expected error") + } + }) + + t.Run("nil context uses background", func(t *testing.T) { + mountGuardReadFile = func(string) ([]byte, error) { + return []byte("1 2 3:4 / / rw - ext4 /dev/sda1 rw\n"), nil + } + mountGuardMkdirAll = func(string, os.FileMode) error { return nil } + mountGuardSysMount = func(string, string, string, uintptr, string) error { return nil } + mountGuardSysUnmount = func(string, int) error { + t.Fatalf("unexpected unmount call") + return nil + } + + if err := guardMountPoint(nil, "/mnt/nilctx"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("mount status check error", func(t *testing.T) { + mountGuardReadFile = func(string) ([]byte, error) { return nil, errors.New("read failed") } + if err := guardMountPoint(context.Background(), "/mnt/statuserr"); err == nil || !strings.Contains(err.Error(), "check mount status") { + t.Fatalf("expected status check error, got %v", err) + } + }) + + t.Run("mkdir guard dir failure", func(t *testing.T) { + target := "/mnt/mkdir-guard-dir-fail" + guardDir := guardDirForTarget(target) + + mountGuardReadFile = func(string) ([]byte, error) { + return []byte("1 2 3:4 / / rw - ext4 /dev/sda1 rw\n"), nil + } + mountGuardMkdirAll = func(path string, _ os.FileMode) error { + if filepath.Clean(path) == filepath.Clean(guardDir) { + return errors.New("mkdir guard dir failed") + } + return nil + } + mountGuardSysMount = func(string, string, string, uintptr, string) error { + t.Fatalf("unexpected mount call") + return nil + } + + if err := guardMountPoint(context.Background(), target); err == nil || !strings.Contains(err.Error(), "mkdir guard dir") { + t.Fatalf("expected mkdir guard dir error, got %v", err) + } + }) + + t.Run("mkdir target failure", func(t *testing.T) { + target := "/mnt/mkdir-target-fail" + guardDir := guardDirForTarget(target) + + mountGuardReadFile = func(string) ([]byte, error) { + return []byte("1 2 3:4 / / rw - ext4 /dev/sda1 rw\n"), nil + } + mountGuardMkdirAll = func(path string, _ os.FileMode) error { + switch filepath.Clean(path) { + case filepath.Clean(guardDir): + return nil + case filepath.Clean(target): + return errors.New("mkdir target failed") + default: + return nil + } + } + mountGuardSysMount = func(string, string, string, uintptr, string) error { + t.Fatalf("unexpected mount call") + return nil + } + + if err := guardMountPoint(context.Background(), target); err == nil || !strings.Contains(err.Error(), "mkdir target") { + t.Fatalf("expected mkdir target error, got %v", err) + } + }) + + t.Run("returns nil when already mounted", func(t *testing.T) { + mountGuardReadFile = func(path string) ([]byte, error) { + if path != "/proc/self/mountinfo" { + return nil, os.ErrNotExist + } + return []byte("1 2 3:4 / /mnt/already rw - ext4 /dev/sda1 rw\n"), nil + } + mountGuardMkdirAll = func(string, os.FileMode) error { + t.Fatalf("unexpected mkdir call") + return nil + } + mountGuardSysMount = func(string, string, string, uintptr, string) error { + t.Fatalf("unexpected mount call") + return nil + } + + if err := guardMountPoint(context.Background(), "/mnt/already"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("bind mount failure", func(t *testing.T) { + mountGuardReadFile = func(path string) ([]byte, error) { + return []byte("1 2 3:4 / / rw - ext4 /dev/sda1 rw\n"), nil + } + mountGuardMkdirAll = func(string, os.FileMode) error { return nil } + mountGuardSysMount = func(_, _, _ string, flags uintptr, _ string) error { + if flags == syscall.MS_BIND { + return syscall.EPERM + } + return nil + } + mountGuardSysUnmount = func(string, int) error { + t.Fatalf("unexpected unmount call") + return nil + } + + if err := guardMountPoint(context.Background(), "/mnt/failbind"); err == nil || !strings.Contains(err.Error(), "bind mount guard") { + t.Fatalf("expected bind mount error, got %v", err) + } + }) + + t.Run("remount failure unmounts", func(t *testing.T) { + mountGuardReadFile = func(path string) ([]byte, error) { + return []byte("1 2 3:4 / / rw - ext4 /dev/sda1 rw\n"), nil + } + mountGuardMkdirAll = func(string, os.FileMode) error { return nil } + + mountCalls := 0 + mountGuardSysMount = func(_, _, _ string, _ uintptr, _ string) error { + mountCalls++ + if mountCalls == 2 { + return syscall.EPERM + } + return nil + } + + unmountCalls := 0 + mountGuardSysUnmount = func(target string, flags int) error { + unmountCalls++ + if target != "/mnt/failremount" || flags != 0 { + t.Fatalf("unexpected unmount args: target=%s flags=%d", target, flags) + } + return nil + } + + if err := guardMountPoint(context.Background(), "/mnt/failremount"); err == nil || !strings.Contains(err.Error(), "remount guard read-only") { + t.Fatalf("expected remount error, got %v", err) + } + if unmountCalls != 1 { + t.Fatalf("expected 1 unmount call, got %d", unmountCalls) + } + }) + + t.Run("context canceled", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := guardMountPoint(ctx, "/mnt/ctx"); !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled, got %v", err) + } + }) +} + +type mountGuardCommandCall struct { + name string + args []string +} + +type mountGuardCommandRunner struct { + run func(ctx context.Context, name string, args ...string) ([]byte, error) + calls []mountGuardCommandCall +} + +func (f *mountGuardCommandRunner) Run(ctx context.Context, name string, args ...string) ([]byte, error) { + f.calls = append(f.calls, mountGuardCommandCall{name: name, args: append([]string{}, args...)}) + if f.run != nil { + return f.run(ctx, name, args...) + } + return nil, nil +} + +func TestMaybeApplyPBSDatastoreMountGuards_EarlyReturns(t *testing.T) { + logger := newTestLogger() + ctx := context.Background() + plan := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "datastore_pbs"}}} + + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, nil, "/stage", "/", false); err != nil { + t.Fatalf("nil plan error: %v", err) + } + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, &RestorePlan{SystemType: SystemTypePVE}, "/stage", "/", false); err != nil { + t.Fatalf("wrong system type error: %v", err) + } + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, &RestorePlan{SystemType: SystemTypePBS}, "/stage", "/", false); err != nil { + t.Fatalf("missing category error: %v", err) + } + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, "", "/", false); err != nil { + t.Fatalf("empty stageRoot error: %v", err) + } + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, "/stage", "/not-root", false); err != nil { + t.Fatalf("destRoot not root error: %v", err) + } + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, "/stage", "/", true); err != nil { + t.Fatalf("dryRun error: %v", err) + } + + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + restoreFS = NewFakeFS() + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, "/stage", "/", false); err != nil { + t.Fatalf("non-real restoreFS error: %v", err) + } + + origGeteuid := mountGuardGeteuid + t.Cleanup(func() { mountGuardGeteuid = origGeteuid }) + mountGuardGeteuid = func() int { return 1 } + restoreFS = origFS + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, "/stage", "/", false); err != nil { + t.Fatalf("non-root user error: %v", err) + } +} + +func TestMaybeApplyPBSDatastoreMountGuards_StagedDatastoreCfgHandling(t *testing.T) { + logger := newTestLogger() + ctx := context.Background() + plan := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "datastore_pbs"}}} + + origGeteuid := mountGuardGeteuid + t.Cleanup(func() { mountGuardGeteuid = origGeteuid }) + mountGuardGeteuid = func() int { return 0 } + + stageRoot := t.TempDir() + + // Missing file => no-op. + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, stageRoot, "/", false); err != nil { + t.Fatalf("missing staged file error: %v", err) + } + + // Non-file error should propagate. + stagePath := filepath.Join(stageRoot, "etc/proxmox-backup/datastore.cfg") + if err := os.MkdirAll(stagePath, 0o755); err != nil { + t.Fatalf("mkdir staged path: %v", err) + } + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, stageRoot, "/", false); err == nil || !strings.Contains(err.Error(), "read staged datastore.cfg") { + t.Fatalf("expected read staged error, got %v", err) + } + + // Empty file => no-op. + if err := os.RemoveAll(filepath.Dir(stagePath)); err != nil { + t.Fatalf("cleanup staged dir: %v", err) + } + if err := os.MkdirAll(filepath.Dir(stagePath), 0o755); err != nil { + t.Fatalf("mkdir staged dir: %v", err) + } + if err := os.WriteFile(stagePath, []byte(" \n\t"), 0o600); err != nil { + t.Fatalf("write staged file: %v", err) + } + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, stageRoot, "/", false); err != nil { + t.Fatalf("empty staged content error: %v", err) + } +} + +func TestMaybeApplyPBSDatastoreMountGuards_NoBlocks(t *testing.T) { + logger := newTestLogger() + ctx := context.Background() + plan := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "datastore_pbs"}}} + + origGeteuid := mountGuardGeteuid + t.Cleanup(func() { mountGuardGeteuid = origGeteuid }) + mountGuardGeteuid = func() int { return 0 } + + stageRoot := t.TempDir() + stagePath := filepath.Join(stageRoot, "etc/proxmox-backup/datastore.cfg") + if err := os.MkdirAll(filepath.Dir(stagePath), 0o755); err != nil { + t.Fatalf("mkdir staged dir: %v", err) + } + if err := os.WriteFile(stagePath, []byte("# comment only\n"), 0o600); err != nil { + t.Fatalf("write datastore.cfg: %v", err) + } + + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, stageRoot, "/", false); err != nil { + t.Fatalf("maybeApplyPBSDatastoreMountGuards error: %v", err) + } +} + +func TestMaybeApplyPBSDatastoreMountGuards_FstabParseErrorContinues(t *testing.T) { + logger := newTestLogger() + ctx := context.Background() + plan := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "datastore_pbs"}}} + + origGeteuid := mountGuardGeteuid + origFstab := mountGuardFstabMountpointsSet + origMkdirAll := mountGuardMkdirAll + origRootFS := mountGuardIsPathOnRootFilesystem + t.Cleanup(func() { + mountGuardGeteuid = origGeteuid + mountGuardFstabMountpointsSet = origFstab + mountGuardMkdirAll = origMkdirAll + mountGuardIsPathOnRootFilesystem = origRootFS + }) + + mountGuardGeteuid = func() int { return 0 } + mountGuardFstabMountpointsSet = func(string) (map[string]struct{}, error) { + return nil, errors.New("fstab parse failed") + } + mountGuardMkdirAll = func(string, os.FileMode) error { return nil } + mountGuardIsPathOnRootFilesystem = func(path string) (bool, string, error) { + return false, filepath.Clean(path), nil + } + + stageRoot := t.TempDir() + stagePath := filepath.Join(stageRoot, "etc/proxmox-backup/datastore.cfg") + if err := os.MkdirAll(filepath.Dir(stagePath), 0o755); err != nil { + t.Fatalf("mkdir staged dir: %v", err) + } + if err := os.WriteFile(stagePath, []byte("datastore: ds\n path /mnt/test/store\n"), 0o600); err != nil { + t.Fatalf("write datastore.cfg: %v", err) + } + + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, stageRoot, "/", false); err != nil { + t.Fatalf("maybeApplyPBSDatastoreMountGuards error: %v", err) + } +} + +func TestMaybeApplyPBSDatastoreMountGuards_ParseBlocksError(t *testing.T) { + logger := newTestLogger() + ctx := context.Background() + plan := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "datastore_pbs"}}} + + origGeteuid := mountGuardGeteuid + origParse := mountGuardParsePBSDatastoreCfg + t.Cleanup(func() { + mountGuardGeteuid = origGeteuid + mountGuardParsePBSDatastoreCfg = origParse + }) + + mountGuardGeteuid = func() int { return 0 } + wantErr := errors.New("parse blocks failed") + mountGuardParsePBSDatastoreCfg = func(string) ([]pbsDatastoreBlock, error) { + return nil, wantErr + } + + stageRoot := t.TempDir() + stagePath := filepath.Join(stageRoot, "etc/proxmox-backup/datastore.cfg") + if err := os.MkdirAll(filepath.Dir(stagePath), 0o755); err != nil { + t.Fatalf("mkdir staged dir: %v", err) + } + if err := os.WriteFile(stagePath, []byte("datastore: ds\n path /mnt/test/store\n"), 0o600); err != nil { + t.Fatalf("write datastore.cfg: %v", err) + } + + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, stageRoot, "/", false); !errors.Is(err, wantErr) { + t.Fatalf("expected parse error %v, got %v", wantErr, err) + } +} + +func TestMaybeApplyPBSDatastoreMountGuards_FullFlow(t *testing.T) { + logger := newTestLogger() + ctx := context.Background() + plan := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "datastore_pbs"}}} + + origGeteuid := mountGuardGeteuid + origReadFile := mountGuardReadFile + origMkdirAll := mountGuardMkdirAll + origReadDir := mountGuardReadDir + origMount := mountGuardSysMount + origUnmount := mountGuardSysUnmount + origFstab := mountGuardFstabMountpointsSet + origRootFS := mountGuardIsPathOnRootFilesystem + origCmd := restoreCmd + t.Cleanup(func() { + mountGuardGeteuid = origGeteuid + mountGuardReadFile = origReadFile + mountGuardMkdirAll = origMkdirAll + mountGuardReadDir = origReadDir + mountGuardSysMount = origMount + mountGuardSysUnmount = origUnmount + mountGuardFstabMountpointsSet = origFstab + mountGuardIsPathOnRootFilesystem = origRootFS + restoreCmd = origCmd + }) + + mountGuardGeteuid = func() int { return 0 } + + stageRoot := t.TempDir() + stagePath := filepath.Join(stageRoot, "etc/proxmox-backup/datastore.cfg") + if err := os.MkdirAll(filepath.Dir(stagePath), 0o755); err != nil { + t.Fatalf("mkdir staged dir: %v", err) + } + cfg := strings.Join([]string{ + "datastore: ds-chattrsuccess", + " path /mnt/chattrsuccess/store", + "datastore: ds-invalid", + " path /", + "datastore: ds-nomountstyle", + " path /srv/pbs", + "datastore: ds-storage", + " path /mnt/storage/pbs/ds1/data", + "datastore: ds-media-skip-fstab", + " path /media/USB/PBS", + "datastore: ds-mkdirerr", + " path /mnt/mkdirerr/store", + "datastore: ds-deverr", + " path /mnt/deverr/store", + "datastore: ds-notroot", + " path /mnt/notroot/store", + "datastore: ds-mounted", + " path /mnt/mounted/store", + "datastore: ds-mountok", + " path /mnt/mountok/store", + "datastore: ds-mountok2", + " path /mnt/mountok2/store", + "datastore: ds-chattrfail", + " path /mnt/chattrfail/store", + "datastore: ds-guardok", + " path /mnt/guardok/store", + "datastore: ds-guarddup", + " path /mnt/guardok/other", + "", + }, "\n") + if err := os.WriteFile(stagePath, []byte(cfg), 0o600); err != nil { + t.Fatalf("write datastore.cfg: %v", err) + } + + mountGuardFstabMountpointsSet = func(string) (map[string]struct{}, error) { + return map[string]struct{}{ + "/": {}, + "/srv": {}, + "/mnt/storage": {}, + "/mnt/storage/pbs": {}, + "/mnt/mkdirerr": {}, + "/mnt/deverr": {}, + "/mnt/notroot": {}, + "/mnt/mounted": {}, + "/mnt/mountok": {}, + "/mnt/mountok2": {}, + "/mnt/chattrfail": {}, + "/mnt/chattrsuccess": {}, + "/mnt/guardok": {}, + }, nil + } + + mountGuardMkdirAll = func(path string, _ os.FileMode) error { + if filepath.Clean(path) == "/mnt/mkdirerr" { + return errors.New("mkdir denied") + } + return nil + } + + rootCalls := make(map[string]int) + mountGuardIsPathOnRootFilesystem = func(path string) (bool, string, error) { + path = filepath.Clean(path) + rootCalls[path]++ + switch path { + case "/mnt/deverr": + return false, path, errors.New("stat failed") + case "/mnt/notroot": + return false, path, nil + case "/mnt/mountok": + if rootCalls[path] == 1 { + return true, path, nil + } + return false, path, nil + default: + return true, path, nil + } + } + + mountedTargets := map[string]struct{}{ + "/mnt/mounted": {}, + } + mountinfoReads := 0 + mountsReads := 0 + buildMountinfo := func() string { + var b strings.Builder + for mp := range mountedTargets { + b.WriteString(fmt.Sprintf("1 2 3:4 / %s rw - ext4 /dev/sda1 rw\n", mp)) + } + return b.String() + } + buildProcMounts := func() string { + var b strings.Builder + for mp := range mountedTargets { + b.WriteString(fmt.Sprintf("/dev/sda1 %s ext4 rw 0 0\n", mp)) + } + return b.String() + } + mountGuardReadFile = func(path string) ([]byte, error) { + switch path { + case "/proc/self/mountinfo": + mountinfoReads++ + if mountinfoReads == 1 { + return nil, errors.New("mountinfo read failed") + } + return []byte(buildMountinfo()), nil + case "/proc/mounts": + mountsReads++ + if mountsReads == 1 { + return nil, errors.New("mounts read failed") + } + return []byte(buildProcMounts()), nil + default: + return nil, fmt.Errorf("unexpected read: %s", path) + } + } + + mountGuardReadDir = func(path string) ([]os.DirEntry, error) { + if filepath.Clean(path) == "/mnt/guardok" { + return []os.DirEntry{&fakeDirEntry{name: "nonempty"}}, nil + } + return nil, os.ErrNotExist + } + + mountGuardSysMount = func(_, target, _ string, _ uintptr, _ string) error { + switch filepath.Clean(target) { + case "/mnt/chattrfail", "/mnt/chattrsuccess": + return syscall.EPERM + default: + return nil + } + } + mountGuardSysUnmount = func(string, int) error { return nil } + + cmd := &mountGuardCommandRunner{} + cmd.run = func(_ context.Context, name string, args ...string) ([]byte, error) { + switch name { + case "mount": + if len(args) != 1 { + return nil, fmt.Errorf("unexpected mount args: %v", args) + } + target := filepath.Clean(args[0]) + switch target { + case "/mnt/mountok", "/mnt/mountok2": + if target == "/mnt/mountok2" { + mountedTargets[target] = struct{}{} + } + return nil, nil + case "/mnt/chattrfail": + return []byte(" \n\t"), errors.New("mount failed") + default: + return []byte("mount: failed"), errors.New("mount failed") + } + case "chattr": + if len(args) != 2 || args[0] != "+i" { + return nil, fmt.Errorf("unexpected chattr args: %v", args) + } + target := filepath.Clean(args[1]) + if target == "/mnt/chattrfail" { + return nil, errors.New("chattr failed") + } + return nil, nil + default: + return nil, fmt.Errorf("unexpected command: %s", name) + } + } + restoreCmd = cmd + + if err := maybeApplyPBSDatastoreMountGuards(ctx, logger, plan, stageRoot, "/", false); err != nil { + t.Fatalf("maybeApplyPBSDatastoreMountGuards error: %v", err) + } + + // Ensure the longest fstab mountpoint match wins (/mnt/storage/pbs instead of /mnt/storage). + foundStoragePBS := false + for _, c := range cmd.calls { + if c.name == "mount" && len(c.args) == 1 && filepath.Clean(c.args[0]) == "/mnt/storage/pbs" { + foundStoragePBS = true + break + } + } + if !foundStoragePBS { + t.Fatalf("expected mount attempt for /mnt/storage/pbs, calls=%#v", cmd.calls) + } +} + +func TestMaybeApplyPBSDatastoreMountGuards_MountAttemptTimeout(t *testing.T) { + logger := newTestLogger() + baseCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second)) + t.Cleanup(cancel) + plan := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "datastore_pbs"}}} + + origGeteuid := mountGuardGeteuid + origReadFile := mountGuardReadFile + origMkdirAll := mountGuardMkdirAll + origMount := mountGuardSysMount + origUnmount := mountGuardSysUnmount + origFstab := mountGuardFstabMountpointsSet + origRootFS := mountGuardIsPathOnRootFilesystem + origCmd := restoreCmd + t.Cleanup(func() { + mountGuardGeteuid = origGeteuid + mountGuardReadFile = origReadFile + mountGuardMkdirAll = origMkdirAll + mountGuardSysMount = origMount + mountGuardSysUnmount = origUnmount + mountGuardFstabMountpointsSet = origFstab + mountGuardIsPathOnRootFilesystem = origRootFS + restoreCmd = origCmd + }) + + mountGuardGeteuid = func() int { return 0 } + mountGuardReadFile = func(string) ([]byte, error) { return []byte(""), nil } + mountGuardMkdirAll = func(string, os.FileMode) error { return nil } + mountGuardSysMount = func(string, string, string, uintptr, string) error { return nil } + mountGuardSysUnmount = func(string, int) error { return nil } + mountGuardIsPathOnRootFilesystem = func(path string) (bool, string, error) { return true, filepath.Clean(path), nil } + mountGuardFstabMountpointsSet = func(string) (map[string]struct{}, error) { + return map[string]struct{}{"/mnt/timeout": {}}, nil + } + + stageRoot := t.TempDir() + stagePath := filepath.Join(stageRoot, "etc/proxmox-backup/datastore.cfg") + if err := os.MkdirAll(filepath.Dir(stagePath), 0o755); err != nil { + t.Fatalf("mkdir staged dir: %v", err) + } + if err := os.WriteFile(stagePath, []byte("datastore: ds\n path /mnt/timeout/store\n"), 0o600); err != nil { + t.Fatalf("write datastore.cfg: %v", err) + } + + restoreCmd = &mountGuardCommandRunner{ + run: func(ctx context.Context, name string, args ...string) ([]byte, error) { + if name == "mount" { + return nil, ctx.Err() + } + if name == "chattr" { + return nil, ctx.Err() + } + return nil, fmt.Errorf("unexpected command: %s", name) + }, + } + + if err := maybeApplyPBSDatastoreMountGuards(baseCtx, logger, plan, stageRoot, "/", false); err != nil { + t.Fatalf("maybeApplyPBSDatastoreMountGuards error: %v", err) + } +} diff --git a/internal/orchestrator/pbs_mount_guard_test.go b/internal/orchestrator/pbs_mount_guard_test.go index f456170..a9efbc1 100644 --- a/internal/orchestrator/pbs_mount_guard_test.go +++ b/internal/orchestrator/pbs_mount_guard_test.go @@ -13,6 +13,7 @@ func TestPBSMountGuardRootForDatastorePath(t *testing.T) { {name: "mnt nested", in: "/mnt/datastore/Data1", want: "/mnt/datastore"}, {name: "mnt deep", in: "/mnt/Synology_NFS/PBS_Backup", want: "/mnt/Synology_NFS"}, {name: "media", in: "/media/USB/PBS", want: "/media/USB"}, + {name: "run media root", in: "/run/media/root", want: "/run/media/root"}, {name: "run media", in: "/run/media/root/USB/PBS", want: "/run/media/root/USB"}, {name: "not mount style", in: "/srv/pbs", want: ""}, {name: "empty", in: "", want: ""}, From e5774d13f71565f979b1499c12953ef5824d3797 Mon Sep 17 00:00:00 2001 From: Damiano <71268257+tis24dev@users.noreply.github.com> Date: Mon, 23 Feb 2026 15:02:32 +0100 Subject: [PATCH 04/12] Inject Geteuid for PBS API to enable tests Introduce a pbsAPIApplyGeteuid variable (defaulting to os.Geteuid) and use it for the root-privilege check in ensurePBSServicesForAPI to allow overriding in tests. Add a comprehensive test suite (internal/orchestrator/pbs_api_apply_test.go) that exercises PBS API apply functions, error paths, and service checks using fake filesystem and command runner mocks. --- internal/orchestrator/pbs_api_apply.go | 4 +- internal/orchestrator/pbs_api_apply_test.go | 1276 +++++++++++++++++++ 2 files changed, 1279 insertions(+), 1 deletion(-) create mode 100644 internal/orchestrator/pbs_api_apply_test.go diff --git a/internal/orchestrator/pbs_api_apply.go b/internal/orchestrator/pbs_api_apply.go index 830bcc4..9851866 100644 --- a/internal/orchestrator/pbs_api_apply.go +++ b/internal/orchestrator/pbs_api_apply.go @@ -12,6 +12,8 @@ import ( "github.com/tis24dev/proxsave/internal/logging" ) +var pbsAPIApplyGeteuid = os.Geteuid + func normalizeProxmoxCfgKey(key string) string { key = strings.ToLower(strings.TrimSpace(key)) key = strings.ReplaceAll(key, "_", "-") @@ -176,7 +178,7 @@ func ensurePBSServicesForAPI(ctx context.Context, logger *logging.Logger) error if !isRealRestoreFS(restoreFS) { return fmt.Errorf("non-system filesystem in use") } - if os.Geteuid() != 0 { + if pbsAPIApplyGeteuid() != 0 { return fmt.Errorf("requires root privileges") } diff --git a/internal/orchestrator/pbs_api_apply_test.go b/internal/orchestrator/pbs_api_apply_test.go new file mode 100644 index 0000000..f9b3b20 --- /dev/null +++ b/internal/orchestrator/pbs_api_apply_test.go @@ -0,0 +1,1276 @@ +package orchestrator + +import ( + "context" + "errors" + "os" + "reflect" + "strings" + "testing" + + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/types" +) + +func setupPBSAPIApplyTestDeps(t *testing.T) (stageRoot string, fs *FakeFS, runner *fakeCommandRunner) { + t.Helper() + + origCmd := restoreCmd + origFS := restoreFS + t.Cleanup(func() { + restoreCmd = origCmd + restoreFS = origFS + }) + + fs = NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fs.Root) }) + restoreFS = fs + + runner = &fakeCommandRunner{} + restoreCmd = runner + + return "/stage", fs, runner +} + +func writeStageFile(t *testing.T, fs *FakeFS, stageRoot, relPath, content string, perm os.FileMode) { + t.Helper() + if err := fs.WriteFile(stageRoot+"/"+relPath, []byte(content), perm); err != nil { + t.Fatalf("write staged %s: %v", relPath, err) + } +} + +func TestNormalizeProxmoxCfgKey(t *testing.T) { + tests := []struct { + in string + want string + }{ + {in: " Foo_Bar ", want: "foo-bar"}, + {in: "dns1", want: "dns1"}, + {in: "", want: ""}, + {in: " ", want: ""}, + } + for _, tt := range tests { + if got := normalizeProxmoxCfgKey(tt.in); got != tt.want { + t.Fatalf("normalizeProxmoxCfgKey(%q)=%q want %q", tt.in, got, tt.want) + } + } +} + +func TestBuildProxmoxManagerFlags_SkipsAndNormalizes(t *testing.T) { + entries := []proxmoxNotificationEntry{ + {Key: "HOST", Value: " pbs.example "}, + {Key: "digest", Value: "abc"}, + {Key: "name", Value: "ignored"}, + {Key: "foo_bar", Value: "baz"}, + {Key: "skip_me", Value: "nope"}, + {Key: "", Value: "x"}, + } + got := buildProxmoxManagerFlags(entries, "SKIP_ME") + want := []string{"--host", "pbs.example", "--foo-bar", "baz"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("flags=%v want %v", got, want) + } +} + +func TestBuildProxmoxManagerFlags_EmptyReturnsNil(t *testing.T) { + if got := buildProxmoxManagerFlags(nil); got != nil { + t.Fatalf("expected nil, got %v", got) + } + if got := buildProxmoxManagerFlags([]proxmoxNotificationEntry{}); got != nil { + t.Fatalf("expected nil, got %v", got) + } +} + +func TestPopEntryValue_FirstMatchOnly(t *testing.T) { + entries := []proxmoxNotificationEntry{ + {Key: "Foo_Bar", Value: " first "}, + {Key: "foo-bar", Value: "second"}, + {Key: "x", Value: "y"}, + } + value, remaining, ok := popEntryValue(entries, "foo-bar") + if !ok || value != "first" { + t.Fatalf("ok=%v value=%q want ok=true value=first", ok, value) + } + wantRemaining := []proxmoxNotificationEntry{ + {Key: "foo-bar", Value: "second"}, + {Key: "x", Value: "y"}, + } + if !reflect.DeepEqual(remaining, wantRemaining) { + t.Fatalf("remaining=%v want %v", remaining, wantRemaining) + } +} + +func TestPopEntryValue_NoKeysOrEntries(t *testing.T) { + value, remaining, ok := popEntryValue(nil, "k") + if ok || value != "" || remaining != nil { + t.Fatalf("ok=%v value=%q remaining=%v want ok=false value=\"\" remaining=nil", ok, value, remaining) + } + + entries := []proxmoxNotificationEntry{{Key: "k", Value: "v"}} + value, remaining, ok = popEntryValue(entries) + if ok || value != "" || !reflect.DeepEqual(remaining, entries) { + t.Fatalf("ok=%v value=%q remaining=%v want ok=false value=\"\" remaining=%v", ok, value, remaining, entries) + } +} + +func TestRunPBSManagerRedacted_RedactsFlagsAndIndexes(t *testing.T) { + orig := restoreCmd + t.Cleanup(func() { restoreCmd = orig }) + + runner := &fakeCommandRunner{ + outputs: map[string][]byte{}, + errs: map[string]error{}, + } + restoreCmd = runner + + args := []string{"remote", "create", "r1", "--password", "secret123", "--token", "token-value-123"} + key := "proxmox-backup-manager " + strings.Join(args, " ") + runner.outputs[key] = []byte("boom") + runner.errs[key] = errors.New("exit 1") + + out, err := runPBSManagerRedacted(context.Background(), args, []string{"--password"}, []int{6}) + if string(out) != "boom" { + t.Fatalf("out=%q want %q", string(out), "boom") + } + if err == nil { + t.Fatalf("expected error") + } + msg := err.Error() + if strings.Contains(msg, "secret123") || strings.Contains(msg, "token-value-123") { + t.Fatalf("expected secrets to be redacted, got: %s", msg) + } + if !strings.Contains(msg, "") { + t.Fatalf("expected in error, got: %s", msg) + } +} + +func TestUnwrapPBSJSONData(t *testing.T) { + if got := unwrapPBSJSONData([]byte(" ")); got != nil { + t.Fatalf("expected nil for empty input, got %q", string(got)) + } + + got := string(unwrapPBSJSONData([]byte(" not-json "))) + if got != "not-json" { + t.Fatalf("got %q want %q", got, "not-json") + } + + got = string(unwrapPBSJSONData([]byte(`{"data":[{"id":"a"}]}`))) + if got != `[{"id":"a"}]` { + t.Fatalf("got %q want %q", got, `[{"id":"a"}]`) + } + + got = string(unwrapPBSJSONData([]byte(`{"foo":"bar"}`))) + if got != `{"foo":"bar"}` { + t.Fatalf("got %q want %q", got, `{"foo":"bar"}`) + } +} + +func TestParsePBSListIDs(t *testing.T) { + if _, err := parsePBSListIDs([]byte(`[{"id":"a"}]`)); err == nil { + t.Fatalf("expected error for missing candidate keys") + } + + ids, err := parsePBSListIDs([]byte(" "), "id") + if err != nil || ids != nil { + t.Fatalf("ids=%v err=%v want nil,nil", ids, err) + } + + ids, err = parsePBSListIDs([]byte(`{"data":[{"id":"b"},{"id":"a"},{"id":"a"}]}`), "id") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !reflect.DeepEqual(ids, []string{"a", "b"}) { + t.Fatalf("ids=%v want %v", ids, []string{"a", "b"}) + } + + _, err = parsePBSListIDs([]byte(`{"data":[{"id":123,"name":""}]}`), "id", "name") + if err == nil || !strings.Contains(err.Error(), "failed to parse PBS list row 0") { + t.Fatalf("expected row parse error, got: %v", err) + } + + ids, err = parsePBSListIDs([]byte(`[{"name":"x"}]`), "id", "name") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !reflect.DeepEqual(ids, []string{"x"}) { + t.Fatalf("ids=%v want %v", ids, []string{"x"}) + } + + ids, err = parsePBSListIDs([]byte(`{"data":[{"id":"a"}]}`), " id ", " ", "name") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !reflect.DeepEqual(ids, []string{"a"}) { + t.Fatalf("ids=%v want %v", ids, []string{"a"}) + } + + if _, err := parsePBSListIDs([]byte(`{"data":{}}`), "id"); err == nil { + t.Fatalf("expected unmarshal error for non-array data") + } +} + +func TestPBSAPIApply_ReadStageFileOptionalErrors(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + + type applyFunc struct { + name string + relPath string + apply func(context.Context, *logging.Logger, string, bool) error + } + + for _, tt := range []applyFunc{ + {name: "remote", relPath: "etc/proxmox-backup/remote.cfg", apply: applyPBSRemoteCfgViaAPI}, + {name: "s3", relPath: "etc/proxmox-backup/s3.cfg", apply: applyPBSS3CfgViaAPI}, + {name: "datastore", relPath: "etc/proxmox-backup/datastore.cfg", apply: applyPBSDatastoreCfgViaAPI}, + {name: "sync", relPath: "etc/proxmox-backup/sync.cfg", apply: applyPBSSyncCfgViaAPI}, + {name: "verification", relPath: "etc/proxmox-backup/verification.cfg", apply: applyPBSVerificationCfgViaAPI}, + {name: "prune", relPath: "etc/proxmox-backup/prune.cfg", apply: applyPBSPruneCfgViaAPI}, + {name: "traffic-control", relPath: "etc/proxmox-backup/traffic-control.cfg", apply: applyPBSTrafficControlCfgViaAPI}, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + stageRoot, fs, _ := setupPBSAPIApplyTestDeps(t) + if err := fs.AddDir(stageRoot + "/" + tt.relPath); err != nil { + t.Fatalf("create staged dir: %v", err) + } + err := tt.apply(context.Background(), logger, stageRoot, false) + if err == nil || !strings.Contains(err.Error(), "read staged") { + t.Fatalf("expected read staged error, got: %v", err) + } + }) + } + + t.Run("node", func(t *testing.T) { + stageRoot, fs, _ := setupPBSAPIApplyTestDeps(t) + if err := fs.AddDir(stageRoot + "/etc/proxmox-backup/node.cfg"); err != nil { + t.Fatalf("create staged dir: %v", err) + } + err := applyPBSNodeCfgViaAPI(context.Background(), stageRoot) + if err == nil || !strings.Contains(err.Error(), "read staged") { + t.Fatalf("expected read staged error, got: %v", err) + } + }) +} + +func TestPBSAPIApply_NoFileBranches(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + + type applyFunc struct { + name string + apply func(context.Context, *logging.Logger, string, bool) error + } + + for _, tt := range []applyFunc{ + {name: "datastore", apply: applyPBSDatastoreCfgViaAPI}, + {name: "sync", apply: applyPBSSyncCfgViaAPI}, + {name: "verification", apply: applyPBSVerificationCfgViaAPI}, + {name: "prune", apply: applyPBSPruneCfgViaAPI}, + {name: "traffic-control", apply: applyPBSTrafficControlCfgViaAPI}, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + stageRoot, _, runner := setupPBSAPIApplyTestDeps(t) + if err := tt.apply(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(runner.calls) != 0 { + t.Fatalf("expected no calls, got %v", runner.calls) + } + }) + } +} + +func TestPBSAPIApply_StrictListCommandErrors(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + + tests := []struct { + name string + relPath string + content string + listCmd string + apply func(context.Context, *logging.Logger, string, bool) error + }{ + { + name: "remote", + relPath: "etc/proxmox-backup/remote.cfg", + content: "remote: r1\n host pbs.example\n", + listCmd: "proxmox-backup-manager remote list --output-format=json", + apply: applyPBSRemoteCfgViaAPI, + }, + { + name: "s3", + relPath: "etc/proxmox-backup/s3.cfg", + content: "s3: e1\n endpoint https://s3.example\n", + listCmd: "proxmox-backup-manager s3 endpoint list --output-format=json", + apply: applyPBSS3CfgViaAPI, + }, + { + name: "sync", + relPath: "etc/proxmox-backup/sync.cfg", + content: "sync-job: job1\n remote r1\n store ds1\n", + listCmd: "proxmox-backup-manager sync-job list --output-format=json", + apply: applyPBSSyncCfgViaAPI, + }, + { + name: "verification", + relPath: "etc/proxmox-backup/verification.cfg", + content: "verify-job: v1\n store ds1\n", + listCmd: "proxmox-backup-manager verify-job list --output-format=json", + apply: applyPBSVerificationCfgViaAPI, + }, + { + name: "prune", + relPath: "etc/proxmox-backup/prune.cfg", + content: "prune-job: p1\n store ds1\n", + listCmd: "proxmox-backup-manager prune-job list --output-format=json", + apply: applyPBSPruneCfgViaAPI, + }, + { + name: "traffic-control", + relPath: "etc/proxmox-backup/traffic-control.cfg", + content: "traffic-control: tc1\n rate-in 1000\n", + listCmd: "proxmox-backup-manager traffic-control list --output-format=json", + apply: applyPBSTrafficControlCfgViaAPI, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + writeStageFile(t, fs, stageRoot, tt.relPath, tt.content, 0o640) + runner.errs = map[string]error{tt.listCmd: errors.New("boom")} + if err := tt.apply(context.Background(), logger, stageRoot, true); err == nil { + t.Fatalf("expected error") + } + }) + } +} + +func TestPBSAPIApply_StrictListParseErrors(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + + tests := []struct { + name string + relPath string + content string + listCmd string + wantErrPrefix string + apply func(context.Context, *logging.Logger, string, bool) error + }{ + { + name: "remote", + relPath: "etc/proxmox-backup/remote.cfg", + content: "remote: r1\n host pbs.example\n", + listCmd: "proxmox-backup-manager remote list --output-format=json", + wantErrPrefix: "parse remote list:", + apply: applyPBSRemoteCfgViaAPI, + }, + { + name: "s3", + relPath: "etc/proxmox-backup/s3.cfg", + content: "s3: e1\n endpoint https://s3.example\n", + listCmd: "proxmox-backup-manager s3 endpoint list --output-format=json", + wantErrPrefix: "parse s3 endpoint list:", + apply: applyPBSS3CfgViaAPI, + }, + { + name: "sync", + relPath: "etc/proxmox-backup/sync.cfg", + content: "sync-job: job1\n remote r1\n store ds1\n", + listCmd: "proxmox-backup-manager sync-job list --output-format=json", + wantErrPrefix: "parse sync-job list:", + apply: applyPBSSyncCfgViaAPI, + }, + { + name: "verification", + relPath: "etc/proxmox-backup/verification.cfg", + content: "verify-job: v1\n store ds1\n", + listCmd: "proxmox-backup-manager verify-job list --output-format=json", + wantErrPrefix: "parse verify-job list:", + apply: applyPBSVerificationCfgViaAPI, + }, + { + name: "prune", + relPath: "etc/proxmox-backup/prune.cfg", + content: "prune-job: p1\n store ds1\n", + listCmd: "proxmox-backup-manager prune-job list --output-format=json", + wantErrPrefix: "parse prune-job list:", + apply: applyPBSPruneCfgViaAPI, + }, + { + name: "traffic-control", + relPath: "etc/proxmox-backup/traffic-control.cfg", + content: "traffic-control: tc1\n rate-in 1000\n", + listCmd: "proxmox-backup-manager traffic-control list --output-format=json", + wantErrPrefix: "parse traffic-control list:", + apply: applyPBSTrafficControlCfgViaAPI, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + writeStageFile(t, fs, stageRoot, tt.relPath, tt.content, 0o640) + runner.outputs = map[string][]byte{tt.listCmd: []byte("not-json")} + if err := tt.apply(context.Background(), logger, stageRoot, true); err == nil || !strings.Contains(err.Error(), tt.wantErrPrefix) { + t.Fatalf("expected %q error, got: %v", tt.wantErrPrefix, err) + } + }) + } +} + +func TestPBSAPIApply_CreateUpdateBothFailErrors(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + + tests := []struct { + name string + relPath string + content string + create string + update string + apply func(context.Context, *logging.Logger, string, bool) error + }{ + { + name: "sync", + relPath: "etc/proxmox-backup/sync.cfg", + content: "sync-job: job1\n remote r1\n store ds1\n", + create: "proxmox-backup-manager sync-job create job1 --remote r1 --store ds1", + update: "proxmox-backup-manager sync-job update job1 --remote r1 --store ds1", + apply: applyPBSSyncCfgViaAPI, + }, + { + name: "verification", + relPath: "etc/proxmox-backup/verification.cfg", + content: "verify-job: v1\n store ds1\n", + create: "proxmox-backup-manager verify-job create v1 --store ds1", + update: "proxmox-backup-manager verify-job update v1 --store ds1", + apply: applyPBSVerificationCfgViaAPI, + }, + { + name: "prune", + relPath: "etc/proxmox-backup/prune.cfg", + content: "prune-job: p1\n store ds1\n keep-last 3\n", + create: "proxmox-backup-manager prune-job create p1 --store ds1 --keep-last 3", + update: "proxmox-backup-manager prune-job update p1 --store ds1 --keep-last 3", + apply: applyPBSPruneCfgViaAPI, + }, + { + name: "traffic-control", + relPath: "etc/proxmox-backup/traffic-control.cfg", + content: "traffic-control: tc1\n rate-in 1000\n", + create: "proxmox-backup-manager traffic-control create tc1 --rate-in 1000", + update: "proxmox-backup-manager traffic-control update tc1 --rate-in 1000", + apply: applyPBSTrafficControlCfgViaAPI, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + writeStageFile(t, fs, stageRoot, tt.relPath, tt.content, 0o640) + runner.errs = map[string]error{ + tt.create: errors.New("create failed"), + tt.update: errors.New("update failed"), + } + err := tt.apply(context.Background(), logger, stageRoot, false) + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "create failed") || !strings.Contains(err.Error(), "update failed") { + t.Fatalf("expected create/update errors, got: %s", err.Error()) + } + }) + } +} + +func TestEnsurePBSServicesForAPI(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + + t.Run("non-system-fs", func(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + restoreFS = NewFakeFS() + if err := ensurePBSServicesForAPI(context.Background(), logger); err == nil || !strings.Contains(err.Error(), "non-system filesystem") { + t.Fatalf("expected non-system filesystem error, got: %v", err) + } + }) + + t.Run("non-root", func(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origGeteuid := pbsAPIApplyGeteuid + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + pbsAPIApplyGeteuid = origGeteuid + }) + restoreFS = osFS{} + restoreCmd = &fakeCommandRunner{} + pbsAPIApplyGeteuid = func() int { return 1000 } + + if err := ensurePBSServicesForAPI(context.Background(), logger); err == nil || !strings.Contains(err.Error(), "requires root privileges") { + t.Fatalf("expected root privileges error, got: %v", err) + } + }) + + t.Run("pbs-manager-missing", func(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origGeteuid := pbsAPIApplyGeteuid + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + pbsAPIApplyGeteuid = origGeteuid + }) + restoreFS = osFS{} + + runner := &fakeCommandRunner{ + errs: map[string]error{ + "proxmox-backup-manager version": errors.New("not found"), + }, + } + restoreCmd = runner + pbsAPIApplyGeteuid = func() int { return 0 } + + if err := ensurePBSServicesForAPI(context.Background(), logger); err == nil || !strings.Contains(err.Error(), "proxmox-backup-manager not available") { + t.Fatalf("expected pbs manager missing error, got: %v", err) + } + }) + + t.Run("start-services-fails", func(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origGeteuid := pbsAPIApplyGeteuid + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + pbsAPIApplyGeteuid = origGeteuid + }) + restoreFS = osFS{} + + runner := &fakeCommandRunner{ + errs: map[string]error{ + "which systemctl": errors.New("no systemctl"), + }, + } + restoreCmd = runner + pbsAPIApplyGeteuid = func() int { return 0 } + + if err := ensurePBSServicesForAPI(context.Background(), logger); err == nil || !strings.Contains(err.Error(), "systemctl not available") { + t.Fatalf("expected systemctl not available error, got: %v", err) + } + }) + + t.Run("ok", func(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origGeteuid := pbsAPIApplyGeteuid + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + pbsAPIApplyGeteuid = origGeteuid + }) + restoreFS = osFS{} + restoreCmd = &fakeCommandRunner{} + pbsAPIApplyGeteuid = func() int { return 0 } + + if err := ensurePBSServicesForAPI(context.Background(), nil); err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) +} + +func TestApplyPBSRemoteCfgViaAPI_StrictCleanupAndCreate(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/remote.cfg", + "remote: r1\n"+ + " HOST pbs1.example\n"+ + " password secret1\n"+ + " foo_bar baz\n"+ + " digest abc\n"+ + " name ignored\n"+ + "\n"+ + "remote: r2\n"+ + " host pbs2.example\n"+ + " username admin\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager remote list --output-format=json": []byte(`{"data":[{"id":"r1"},{"id":"old"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager remote remove old": errors.New("cannot remove old"), + } + + if err := applyPBSRemoteCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSRemoteCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager remote list --output-format=json", + "proxmox-backup-manager remote remove old", + "proxmox-backup-manager remote create r1 --host pbs1.example --password secret1 --foo-bar baz", + "proxmox-backup-manager remote create r2 --host pbs2.example --username admin", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSRemoteCfgViaAPI_NoFileNoCalls(t *testing.T) { + stageRoot, _, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + if err := applyPBSRemoteCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSRemoteCfgViaAPI error: %v", err) + } + if len(runner.calls) != 0 { + t.Fatalf("expected no calls, got %v", runner.calls) + } +} + +func TestApplyPBSRemoteCfgViaAPI_CreateFailsThenUpdate(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/remote.cfg", + "remote: r1\n"+ + " host pbs.example\n"+ + " password secret1\n", + 0o640, + ) + + runner.errs = map[string]error{ + "proxmox-backup-manager remote create r1 --host pbs.example --password secret1": errors.New("already exists"), + } + + if err := applyPBSRemoteCfgViaAPI(context.Background(), logger, stageRoot, false); err != nil { + t.Fatalf("applyPBSRemoteCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager remote create r1 --host pbs.example --password secret1", + "proxmox-backup-manager remote update r1 --host pbs.example --password secret1", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSRemoteCfgViaAPI_RedactsPasswordOnFailure(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/remote.cfg", + "remote: r1\n"+ + " host pbs.example\n"+ + " password secret1\n", + 0o640, + ) + + runner.errs = map[string]error{ + "proxmox-backup-manager remote create r1 --host pbs.example --password secret1": errors.New("create failed"), + "proxmox-backup-manager remote update r1 --host pbs.example --password secret1": errors.New("update failed"), + } + + err := applyPBSRemoteCfgViaAPI(context.Background(), logger, stageRoot, false) + if err == nil { + t.Fatalf("expected error") + } + if strings.Contains(err.Error(), "secret1") { + t.Fatalf("expected password to be redacted, got: %s", err.Error()) + } + if !strings.Contains(err.Error(), "") { + t.Fatalf("expected in error, got: %s", err.Error()) + } +} + +func TestApplyPBSS3CfgViaAPI_CreateUpdateAndStrictCleanup(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/s3.cfg", + "s3: e1\n"+ + " endpoint https://s3.example\n"+ + " access-key access1\n"+ + " secret-key secret1\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager s3 endpoint list --output-format=json": []byte(`{"data":[{"id":"e1"},{"id":"old"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager s3 endpoint remove old": errors.New("cannot remove old"), + "proxmox-backup-manager s3 endpoint create e1 --endpoint https://s3.example --access-key access1 --secret-key secret1": errors.New("already exists"), + } + + if err := applyPBSS3CfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSS3CfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager s3 endpoint list --output-format=json", + "proxmox-backup-manager s3 endpoint remove old", + "proxmox-backup-manager s3 endpoint create e1 --endpoint https://s3.example --access-key access1 --secret-key secret1", + "proxmox-backup-manager s3 endpoint update e1 --endpoint https://s3.example --access-key access1 --secret-key secret1", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSS3CfgViaAPI_NoFileNoCalls(t *testing.T) { + stageRoot, _, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + if err := applyPBSS3CfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSS3CfgViaAPI error: %v", err) + } + if len(runner.calls) != 0 { + t.Fatalf("expected no calls, got %v", runner.calls) + } +} + +func TestApplyPBSS3CfgViaAPI_RedactsKeysOnFailure(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/s3.cfg", + "s3: e1\n"+ + " endpoint https://s3.example\n"+ + " access-key access1\n"+ + " secret-key secret1\n", + 0o640, + ) + + runner.errs = map[string]error{ + "proxmox-backup-manager s3 endpoint create e1 --endpoint https://s3.example --access-key access1 --secret-key secret1": errors.New("create failed"), + "proxmox-backup-manager s3 endpoint update e1 --endpoint https://s3.example --access-key access1 --secret-key secret1": errors.New("update failed"), + } + + err := applyPBSS3CfgViaAPI(context.Background(), logger, stageRoot, false) + if err == nil { + t.Fatalf("expected error") + } + if strings.Contains(err.Error(), "access1") || strings.Contains(err.Error(), "secret1") { + t.Fatalf("expected keys to be redacted, got: %s", err.Error()) + } + if !strings.Contains(err.Error(), "") { + t.Fatalf("expected in error, got: %s", err.Error()) + } +} + +func TestApplyPBSDatastoreCfgViaAPI_StrictFullFlow(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/datastore.cfg", + "datastore: ds1\n"+ + " path /new1\n"+ + " comment c1\n"+ + "\n"+ + "datastore: ds2\n"+ + " comment missing-path\n"+ + "\n"+ + "datastore: ds3\n"+ + " path /p3\n"+ + " comment c3\n"+ + "\n"+ + "datastore: ds4\n"+ + " path /p4\n"+ + " comment c4\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager datastore list --output-format=json": []byte(`{"data":[{"name":"ds1","path":"/old1"},{"name":"ds3","path":"/p3"},{"name":"ds-old","path":"/old"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager datastore create ds4 /p4 --comment c4": errors.New("already exists"), + } + + if err := applyPBSDatastoreCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSDatastoreCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager datastore list --output-format=json", + "proxmox-backup-manager datastore remove ds-old", + "proxmox-backup-manager datastore remove ds1", + "proxmox-backup-manager datastore create ds1 /new1 --comment c1", + "proxmox-backup-manager datastore update ds3 --comment c3", + "proxmox-backup-manager datastore create ds4 /p4 --comment c4", + "proxmox-backup-manager datastore update ds4 --comment c4", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSDatastoreCfgViaAPI_CurrentPathsFallbacksAndStrictRemoveWarn(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/datastore.cfg", + "datastore: ds1\n"+ + " path /p1\n"+ + " comment c1\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager datastore list --output-format=json": []byte( + `{"data":[` + + `{"store":"store1","path":"/ps"},` + + `{"id":"id1","path":"/pi"},` + + `{"path":"/no-id"},` + + `{"name":"ds1","path":"/p1"}` + + `]}`, + ), + } + runner.errs = map[string]error{ + "proxmox-backup-manager datastore remove id1": errors.New("cannot remove id1"), + } + + if err := applyPBSDatastoreCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSDatastoreCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager datastore list --output-format=json", + "proxmox-backup-manager datastore remove id1", + "proxmox-backup-manager datastore remove store1", + "proxmox-backup-manager datastore update ds1 --comment c1", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSDatastoreCfgViaAPI_StrictPathMismatchRemoveFails(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/datastore.cfg", + "datastore: ds1\n"+ + " path /new\n"+ + " comment c1\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager datastore list --output-format=json": []byte(`{"data":[{"name":"ds1","path":"/old"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager datastore remove ds1": errors.New("cannot remove ds1"), + } + + err := applyPBSDatastoreCfgViaAPI(context.Background(), logger, stageRoot, true) + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "path mismatch") || !strings.Contains(err.Error(), "remove failed") { + t.Fatalf("unexpected error: %s", err.Error()) + } +} + +func TestApplyPBSDatastoreCfgViaAPI_StrictPathMismatchRecreateFails(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/datastore.cfg", + "datastore: ds1\n"+ + " path /new\n"+ + " comment c1\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager datastore list --output-format=json": []byte(`{"data":[{"name":"ds1","path":"/old"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager datastore create ds1 /new --comment c1": errors.New("create failed"), + } + + err := applyPBSDatastoreCfgViaAPI(context.Background(), logger, stageRoot, true) + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "recreate after path mismatch failed") { + t.Fatalf("unexpected error: %s", err.Error()) + } +} + +func TestApplyPBSDatastoreCfgViaAPI_UpdateFails(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/datastore.cfg", + "datastore: ds1\n"+ + " path /p1\n"+ + " comment c1\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager datastore list --output-format=json": []byte(`{"data":[{"name":"ds1","path":"/p1"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager datastore update ds1 --comment c1": errors.New("update failed"), + } + + err := applyPBSDatastoreCfgViaAPI(context.Background(), logger, stageRoot, false) + if err == nil || !strings.Contains(err.Error(), "update failed") { + t.Fatalf("expected update failed error, got: %v", err) + } +} + +func TestApplyPBSDatastoreCfgViaAPI_CreateAndUpdateFail(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/datastore.cfg", + "datastore: ds1\n"+ + " path /p1\n"+ + " comment c1\n", + 0o640, + ) + + runner.errs = map[string]error{ + "proxmox-backup-manager datastore create ds1 /p1 --comment c1": errors.New("create failed"), + "proxmox-backup-manager datastore update ds1 --comment c1": errors.New("update failed"), + } + + err := applyPBSDatastoreCfgViaAPI(context.Background(), logger, stageRoot, false) + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "create failed") || !strings.Contains(err.Error(), "update failed") { + t.Fatalf("expected create/update errors, got: %s", err.Error()) + } +} + +func TestApplyPBSDatastoreCfgViaAPI_NonStrictPathMismatchKeepsPath(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/datastore.cfg", + "datastore: ds1\n"+ + " path /new1\n"+ + " comment c1\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager datastore list --output-format=json": []byte(`{"data":[{"name":"ds1","path":"/old1"}]}`), + } + + if err := applyPBSDatastoreCfgViaAPI(context.Background(), logger, stageRoot, false); err != nil { + t.Fatalf("applyPBSDatastoreCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager datastore list --output-format=json", + "proxmox-backup-manager datastore update ds1 --comment c1", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSSyncCfgViaAPI_Create(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/sync.cfg", + "sync-job: job1\n"+ + " remote r1\n"+ + " store ds1\n", + 0o640, + ) + + if err := applyPBSSyncCfgViaAPI(context.Background(), logger, stageRoot, false); err != nil { + t.Fatalf("applyPBSSyncCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager sync-job create job1 --remote r1 --store ds1", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSSyncCfgViaAPI_StrictCleanupAndFallbackUpdate(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/sync.cfg", + "sync-job: job1\n"+ + " remote r1\n"+ + " store ds1\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager sync-job list --output-format=json": []byte(`{"data":[{"id":"job1"},{"id":"old"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager sync-job remove old": errors.New("cannot remove old"), + "proxmox-backup-manager sync-job create job1 --remote r1 --store ds1": errors.New("already exists"), + } + + if err := applyPBSSyncCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSSyncCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager sync-job list --output-format=json", + "proxmox-backup-manager sync-job remove old", + "proxmox-backup-manager sync-job create job1 --remote r1 --store ds1", + "proxmox-backup-manager sync-job update job1 --remote r1 --store ds1", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSVerificationCfgViaAPI_Create(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/verification.cfg", + "verify-job: v1\n"+ + " store ds1\n", + 0o640, + ) + + if err := applyPBSVerificationCfgViaAPI(context.Background(), logger, stageRoot, false); err != nil { + t.Fatalf("applyPBSVerificationCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager verify-job create v1 --store ds1", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSVerificationCfgViaAPI_StrictCleanupAndFallbackUpdate(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/verification.cfg", + "verify-job: v1\n"+ + " store ds1\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager verify-job list --output-format=json": []byte(`{"data":[{"id":"v1"},{"id":"old"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager verify-job remove old": errors.New("cannot remove old"), + "proxmox-backup-manager verify-job create v1 --store ds1": errors.New("already exists"), + } + + if err := applyPBSVerificationCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSVerificationCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager verify-job list --output-format=json", + "proxmox-backup-manager verify-job remove old", + "proxmox-backup-manager verify-job create v1 --store ds1", + "proxmox-backup-manager verify-job update v1 --store ds1", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSPruneCfgViaAPI_Create(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/prune.cfg", + "prune-job: p1\n"+ + " store ds1\n"+ + " keep-last 3\n", + 0o640, + ) + + if err := applyPBSPruneCfgViaAPI(context.Background(), logger, stageRoot, false); err != nil { + t.Fatalf("applyPBSPruneCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager prune-job create p1 --store ds1 --keep-last 3", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSPruneCfgViaAPI_StrictCleanupAndFallbackUpdate(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/prune.cfg", + "prune-job: p1\n"+ + " store ds1\n"+ + " keep-last 3\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager prune-job list --output-format=json": []byte(`{"data":[{"id":"old"},{"id":"p1"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager prune-job remove old": errors.New("cannot remove old"), + "proxmox-backup-manager prune-job create p1 --store ds1 --keep-last 3": errors.New("already exists"), + } + + if err := applyPBSPruneCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSPruneCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager prune-job list --output-format=json", + "proxmox-backup-manager prune-job remove old", + "proxmox-backup-manager prune-job create p1 --store ds1 --keep-last 3", + "proxmox-backup-manager prune-job update p1 --store ds1 --keep-last 3", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSTrafficControlCfgViaAPI_StrictCleanupAndCreate(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/traffic-control.cfg", + "traffic-control: tc1\n"+ + " rate-in 1000\n", + 0o640, + ) + + runner.outputs = map[string][]byte{ + "proxmox-backup-manager traffic-control list --output-format=json": []byte(`{"data":[{"name":"old"},{"name":"tc1"}]}`), + } + runner.errs = map[string]error{ + "proxmox-backup-manager traffic-control remove old": errors.New("cannot remove old"), + } + + if err := applyPBSTrafficControlCfgViaAPI(context.Background(), logger, stageRoot, true); err != nil { + t.Fatalf("applyPBSTrafficControlCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager traffic-control list --output-format=json", + "proxmox-backup-manager traffic-control remove old", + "proxmox-backup-manager traffic-control create tc1 --rate-in 1000", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSTrafficControlCfgViaAPI_CreateFailsThenUpdate(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + logger := logging.New(types.LogLevelDebug, false) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/traffic-control.cfg", + "traffic-control: tc1\n"+ + " rate-in 1000\n", + 0o640, + ) + + runner.errs = map[string]error{ + "proxmox-backup-manager traffic-control create tc1 --rate-in 1000": errors.New("already exists"), + } + + if err := applyPBSTrafficControlCfgViaAPI(context.Background(), logger, stageRoot, false); err != nil { + t.Fatalf("applyPBSTrafficControlCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager traffic-control create tc1 --rate-in 1000", + "proxmox-backup-manager traffic-control update tc1 --rate-in 1000", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSNodeCfgViaAPI_UsesFirstSection(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/node.cfg", + "node: n1\n"+ + " dns1 1.1.1.1\n"+ + " foo_bar baz\n"+ + "\n"+ + "node: n2\n"+ + " dns1 9.9.9.9\n", + 0o640, + ) + + if err := applyPBSNodeCfgViaAPI(context.Background(), stageRoot); err != nil { + t.Fatalf("applyPBSNodeCfgViaAPI error: %v", err) + } + + want := []string{ + "proxmox-backup-manager node update --dns1 1.1.1.1 --foo-bar baz", + } + if !reflect.DeepEqual(runner.calls, want) { + t.Fatalf("calls=%v want %v", runner.calls, want) + } +} + +func TestApplyPBSNodeCfgViaAPI_NoFileAndEmptyAndError(t *testing.T) { + t.Run("no-file", func(t *testing.T) { + stageRoot, _, runner := setupPBSAPIApplyTestDeps(t) + if err := applyPBSNodeCfgViaAPI(context.Background(), stageRoot); err != nil { + t.Fatalf("applyPBSNodeCfgViaAPI error: %v", err) + } + if len(runner.calls) != 0 { + t.Fatalf("expected no calls, got %v", runner.calls) + } + }) + + t.Run("empty-sections", func(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/node.cfg", " \n# comment\n", 0o640) + if err := applyPBSNodeCfgViaAPI(context.Background(), stageRoot); err != nil { + t.Fatalf("applyPBSNodeCfgViaAPI error: %v", err) + } + if len(runner.calls) != 0 { + t.Fatalf("expected no calls, got %v", runner.calls) + } + }) + + t.Run("command-error", func(t *testing.T) { + stageRoot, fs, runner := setupPBSAPIApplyTestDeps(t) + writeStageFile(t, fs, stageRoot, "etc/proxmox-backup/node.cfg", + "node: n1\n"+ + " dns1 1.1.1.1\n", + 0o640, + ) + runner.errs = map[string]error{ + "proxmox-backup-manager node update --dns1 1.1.1.1": errors.New("boom"), + } + if err := applyPBSNodeCfgViaAPI(context.Background(), stageRoot); err == nil { + t.Fatalf("expected error") + } + }) +} From e0495ab2d21b72b91e2ddaa6b55997b41b0fe7ed Mon Sep 17 00:00:00 2001 From: Damiano <71268257+tis24dev@users.noreply.github.com> Date: Mon, 23 Feb 2026 16:25:09 +0100 Subject: [PATCH 05/12] Make firewall restore testable; add tests Introduce function-level variables to allow dependency injection for firewall restore (hostname, geteuid, mount checks, real FS check, rollback arm/disarm, apply from stage, restart service). Replace direct calls (os.Geteuid, os.Hostname, isMounted, isRealRestoreFS, time.Now) with the injectable variants (firewallApplyGeteuid, firewallHostname, firewallIsMounted, firewallIsRealRestoreFS, firewallArmRollback, firewallDisarmRollback, firewallApplyFromStage, firewallRestartService, nowRestore) to improve testability. Also add extensive unit tests in internal/orchestrator/restore_firewall_additional_test.go that exercise many branches of the firewall apply/rollback flow (arm/disarm behavior, marker handling, command fallbacks, symlink/file operations, prompts and error conditions). These changes enable robust testing of firewall restore logic without changing runtime behavior. --- internal/orchestrator/restore_firewall.go | 36 +- .../restore_firewall_additional_test.go | 2146 +++++++++++++++++ 2 files changed, 2170 insertions(+), 12 deletions(-) create mode 100644 internal/orchestrator/restore_firewall_additional_test.go diff --git a/internal/orchestrator/restore_firewall.go b/internal/orchestrator/restore_firewall.go index 64c7419..eeb5daa 100644 --- a/internal/orchestrator/restore_firewall.go +++ b/internal/orchestrator/restore_firewall.go @@ -17,6 +17,18 @@ const defaultFirewallRollbackTimeout = 180 * time.Second var ErrFirewallApplyNotCommitted = errors.New("firewall configuration not committed") +var ( + firewallApplyGeteuid = os.Geteuid + firewallHostname = os.Hostname + firewallIsMounted = isMounted + firewallIsRealRestoreFS = isRealRestoreFS + + firewallArmRollback = armFirewallRollback + firewallDisarmRollback = disarmFirewallRollback + firewallApplyFromStage = applyPVEFirewallFromStage + firewallRestartService = restartPVEFirewallService +) + type FirewallApplyNotCommittedError struct { RollbackLog string RollbackMarker string @@ -99,7 +111,7 @@ func maybeApplyPVEFirewallWithUI( if ui == nil { return fmt.Errorf("restore UI not available") } - if !isRealRestoreFS(restoreFS) { + if !firewallIsRealRestoreFS(restoreFS) { logger.Debug("Skipping PVE firewall restore: non-system filesystem in use") return nil } @@ -107,7 +119,7 @@ func maybeApplyPVEFirewallWithUI( logger.Info("Dry run enabled: skipping PVE firewall restore") return nil } - if os.Geteuid() != 0 { + if firewallApplyGeteuid() != 0 { logger.Warning("Skipping PVE firewall restore: requires root privileges") return nil } @@ -123,7 +135,7 @@ func maybeApplyPVEFirewallWithUI( } etcPVE := "/etc/pve" - mounted, mountErr := isMounted(etcPVE) + mounted, mountErr := firewallIsMounted(etcPVE) if mountErr != nil { logger.Warning("PVE firewall restore: unable to check pmxcfs mount (%s): %v", etcPVE, mountErr) } @@ -214,26 +226,26 @@ func maybeApplyPVEFirewallWithUI( if rollbackPath != "" { logger.Info("") logger.Info("Arming firewall rollback timer (%ds)...", int(defaultFirewallRollbackTimeout.Seconds())) - rollbackHandle, err = armFirewallRollback(ctx, logger, rollbackPath, defaultFirewallRollbackTimeout, "/tmp/proxsave") + rollbackHandle, err = firewallArmRollback(ctx, logger, rollbackPath, defaultFirewallRollbackTimeout, "/tmp/proxsave") if err != nil { return fmt.Errorf("arm firewall rollback: %w", err) } logger.Info("Firewall rollback log: %s", rollbackHandle.logPath) } - applied, err := applyPVEFirewallFromStage(logger, stageRoot) + applied, err := firewallApplyFromStage(logger, stageRoot) if err != nil { return err } if len(applied) == 0 { logger.Info("PVE firewall restore: no changes applied (stage contained no firewall entries)") if rollbackHandle != nil { - disarmFirewallRollback(ctx, logger, rollbackHandle) + firewallDisarmRollback(ctx, logger, rollbackHandle) } return nil } - if err := restartPVEFirewallService(ctx); err != nil { + if err := firewallRestartService(ctx); err != nil { logger.Warning("PVE firewall restore: reload/restart failed: %v", err) } @@ -242,7 +254,7 @@ func maybeApplyPVEFirewallWithUI( return nil } - remaining := rollbackHandle.remaining(time.Now()) + remaining := rollbackHandle.remaining(nowRestore()) if remaining <= 0 { return buildFirewallApplyNotCommittedError(rollbackHandle) } @@ -264,7 +276,7 @@ func maybeApplyPVEFirewallWithUI( } if commit { - disarmFirewallRollback(ctx, logger, rollbackHandle) + firewallDisarmRollback(ctx, logger, rollbackHandle) logger.Info("Firewall changes committed.") return nil } @@ -309,7 +321,7 @@ func applyPVEFirewallFromStage(logger *logging.Logger, stageRoot string) (applie return applied, err } if ok { - currentNode, _ := os.Hostname() + currentNode, _ := firewallHostname() currentNode = shortHost(currentNode) if strings.TrimSpace(currentNode) == "" { currentNode = "localhost" @@ -331,7 +343,7 @@ func applyPVEFirewallFromStage(logger *logging.Logger, stageRoot string) (applie } func selectStageHostFirewall(logger *logging.Logger, stageRoot string) (path string, sourceNode string, ok bool, err error) { - currentNode, _ := os.Hostname() + currentNode, _ := firewallHostname() currentNode = shortHost(currentNode) if strings.TrimSpace(currentNode) == "" { currentNode = "localhost" @@ -430,7 +442,7 @@ func armFirewallRollback(ctx context.Context, logger *logging.Logger, backupPath markerPath: filepath.Join(baseDir, fmt.Sprintf("firewall_rollback_pending_%s", timestamp)), scriptPath: filepath.Join(baseDir, fmt.Sprintf("firewall_rollback_%s.sh", timestamp)), logPath: filepath.Join(baseDir, fmt.Sprintf("firewall_rollback_%s.log", timestamp)), - armedAt: time.Now(), + armedAt: nowRestore(), timeout: timeout, } diff --git a/internal/orchestrator/restore_firewall_additional_test.go b/internal/orchestrator/restore_firewall_additional_test.go new file mode 100644 index 0000000..525c269 --- /dev/null +++ b/internal/orchestrator/restore_firewall_additional_test.go @@ -0,0 +1,2146 @@ +package orchestrator + +import ( + "context" + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "github.com/tis24dev/proxsave/internal/input" + "github.com/tis24dev/proxsave/internal/logging" +) + +type readDirFailFS struct { + FS + failPath string + err error +} + +func (f readDirFailFS) ReadDir(path string) ([]os.DirEntry, error) { + if filepath.Clean(path) == filepath.Clean(f.failPath) { + return nil, f.err + } + return f.FS.ReadDir(path) +} + +type readDirOverrideFS struct { + FS + overridePath string + entries []os.DirEntry + err error +} + +func (f readDirOverrideFS) ReadDir(path string) ([]os.DirEntry, error) { + if filepath.Clean(path) == filepath.Clean(f.overridePath) { + if f.err != nil { + return nil, f.err + } + return f.entries, nil + } + return f.FS.ReadDir(path) +} + +type statFailFS struct { + FS + failPath string + err error +} + +func (f statFailFS) Stat(path string) (os.FileInfo, error) { + if filepath.Clean(path) == filepath.Clean(f.failPath) { + return nil, f.err + } + return f.FS.Stat(path) +} + +type readFileFailFS struct { + FS + failPath string + err error +} + +func (f readFileFailFS) ReadFile(path string) ([]byte, error) { + if filepath.Clean(path) == filepath.Clean(f.failPath) { + return nil, f.err + } + return f.FS.ReadFile(path) +} + +type writeFileFailFS struct { + FS + failPath string + err error +} + +func (f writeFileFailFS) WriteFile(path string, data []byte, perm os.FileMode) error { + if filepath.Clean(path) == filepath.Clean(f.failPath) { + return f.err + } + return f.FS.WriteFile(path, data, perm) +} + +type removeFailFS struct { + FS + failPath string + err error +} + +func (f removeFailFS) Remove(path string) error { + if filepath.Clean(path) == filepath.Clean(f.failPath) { + return f.err + } + return f.FS.Remove(path) +} + +type readlinkFailFS struct { + FS + failPath string + err error +} + +func (f readlinkFailFS) Readlink(path string) (string, error) { + if filepath.Clean(path) == filepath.Clean(f.failPath) { + return "", f.err + } + return f.FS.Readlink(path) +} + +type badInfoDirEntry struct { + name string +} + +func (e badInfoDirEntry) Name() string { return e.name } +func (e badInfoDirEntry) IsDir() bool { return false } +func (e badInfoDirEntry) Type() fs.FileMode { return 0 } +func (e badInfoDirEntry) Info() (fs.FileInfo, error) { return nil, fmt.Errorf("boom") } + +type symlinkFailFS struct { + FS + failNewname string + err error +} + +func (f symlinkFailFS) Symlink(oldname, newname string) error { + if filepath.Clean(newname) == filepath.Clean(f.failNewname) { + return f.err + } + return f.FS.Symlink(oldname, newname) +} + +type statFailOnNthFS struct { + FS + path string + failOn int + calls int + err error +} + +func (f *statFailOnNthFS) Stat(path string) (os.FileInfo, error) { + if filepath.Clean(path) == filepath.Clean(f.path) { + f.calls++ + if f.calls >= f.failOn { + return nil, f.err + } + } + return f.FS.Stat(path) +} + +type multiReadDirFS struct { + FS + entries map[string][]os.DirEntry + errors map[string]error +} + +func (f multiReadDirFS) ReadDir(path string) ([]os.DirEntry, error) { + clean := filepath.Clean(path) + if f.errors != nil { + if err, ok := f.errors[clean]; ok { + return nil, err + } + } + if f.entries != nil { + if entries, ok := f.entries[clean]; ok { + return entries, nil + } + } + return f.FS.ReadDir(path) +} + +type staticFileInfo struct { + name string + mode fs.FileMode +} + +func (i staticFileInfo) Name() string { return i.name } +func (i staticFileInfo) Size() int64 { return 0 } +func (i staticFileInfo) Mode() fs.FileMode { return i.mode } +func (i staticFileInfo) ModTime() time.Time { return time.Time{} } +func (i staticFileInfo) IsDir() bool { return i.mode.IsDir() } +func (i staticFileInfo) Sys() any { return nil } + +type staticDirEntry struct { + name string + mode fs.FileMode +} + +func (e staticDirEntry) Name() string { return e.name } +func (e staticDirEntry) IsDir() bool { return e.mode.IsDir() } +func (e staticDirEntry) Type() fs.FileMode { return e.mode } +func (e staticDirEntry) Info() (fs.FileInfo, error) { + return staticFileInfo{name: e.name, mode: e.mode}, nil +} + +type scriptedConfirmAction struct { + ok bool + err error +} + +type scriptedRestoreWorkflowUI struct { + *fakeRestoreWorkflowUI + script []scriptedConfirmAction + calls int +} + +func (s *scriptedRestoreWorkflowUI) ConfirmAction(ctx context.Context, title, message, yesLabel, noLabel string, timeout time.Duration, defaultYes bool) (bool, error) { + if s.calls >= len(s.script) { + return false, fmt.Errorf("unexpected ConfirmAction call %d (title=%q)", s.calls+1, strings.TrimSpace(title)) + } + action := s.script[s.calls] + s.calls++ + return action.ok, action.err +} + +func writeExecutable(t *testing.T, dir, name string) string { + t.Helper() + path := filepath.Join(dir, name) + if err := os.WriteFile(path, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { + t.Fatalf("write executable %s: %v", name, err) + } + return path +} + +func TestFirewallApplyNotCommittedError_UnwrapAndMessage(t *testing.T) { + var e *FirewallApplyNotCommittedError + if e.Error() != ErrFirewallApplyNotCommitted.Error() { + t.Fatalf("nil receiver Error()=%q want %q", e.Error(), ErrFirewallApplyNotCommitted.Error()) + } + if !errors.Is(error(e), ErrFirewallApplyNotCommitted) { + t.Fatalf("expected errors.Is(..., ErrFirewallApplyNotCommitted) to be true") + } + if errors.Unwrap(error(e)) != ErrFirewallApplyNotCommitted { + t.Fatalf("expected Unwrap to return ErrFirewallApplyNotCommitted") + } + + e2 := &FirewallApplyNotCommittedError{} + if e2.Error() != ErrFirewallApplyNotCommitted.Error() { + t.Fatalf("Error()=%q want %q", e2.Error(), ErrFirewallApplyNotCommitted.Error()) + } +} + +func TestFirewallRollbackHandle_Remaining(t *testing.T) { + var h *firewallRollbackHandle + if got := h.remaining(time.Now()); got != 0 { + t.Fatalf("nil handle remaining=%s want 0", got) + } + + handle := &firewallRollbackHandle{ + armedAt: time.Unix(100, 0), + timeout: 10 * time.Second, + } + if got := handle.remaining(time.Unix(105, 0)); got != 5*time.Second { + t.Fatalf("remaining=%s want %s", got, 5*time.Second) + } + if got := handle.remaining(time.Unix(999, 0)); got != 0 { + t.Fatalf("remaining=%s want 0", got) + } +} + +func TestBuildFirewallApplyNotCommittedError_PopulatesFields(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if e := buildFirewallApplyNotCommittedError(nil); e == nil { + t.Fatalf("expected error struct, got nil") + } else if e.RollbackArmed || e.RollbackMarker != "" || e.RollbackLog != "" || !e.RollbackDeadline.IsZero() { + t.Fatalf("unexpected fields for nil handle: %#v", e) + } + + armedAt := time.Unix(10, 0) + handle := &firewallRollbackHandle{ + markerPath: " /tmp/fw.marker \n", + logPath: " /tmp/fw.log\t", + armedAt: armedAt, + timeout: 3 * time.Second, + } + if err := fakeFS.AddFile("/tmp/fw.marker", []byte("pending\n")); err != nil { + t.Fatalf("add marker: %v", err) + } + + e := buildFirewallApplyNotCommittedError(handle) + if e.RollbackMarker != "/tmp/fw.marker" { + t.Fatalf("RollbackMarker=%q", e.RollbackMarker) + } + if e.RollbackLog != "/tmp/fw.log" { + t.Fatalf("RollbackLog=%q", e.RollbackLog) + } + if !e.RollbackArmed { + t.Fatalf("expected RollbackArmed=true") + } + if !e.RollbackDeadline.Equal(armedAt.Add(3 * time.Second)) { + t.Fatalf("RollbackDeadline=%s want %s", e.RollbackDeadline, armedAt.Add(3*time.Second)) + } +} + +func TestBuildFirewallRollbackScript_QuotesPaths(t *testing.T) { + script := buildFirewallRollbackScript("/tmp/marker path", "/tmp/backup's.tar.gz", "/tmp/log path") + if !strings.Contains(script, "MARKER='/tmp/marker path'") { + t.Fatalf("expected MARKER to be quoted, got script:\n%s", script) + } + if !strings.Contains(script, "LOG='/tmp/log path'") { + t.Fatalf("expected LOG to be quoted, got script:\n%s", script) + } + if !strings.Contains(script, "BACKUP='/tmp/backup'\\''s.tar.gz'") { + t.Fatalf("expected BACKUP to escape single quotes, got script:\n%s", script) + } + if !strings.HasSuffix(script, "\n") { + t.Fatalf("expected script to end with newline") + } +} + +func TestCopyFileExact_ReturnsFalseWhenSourceMissingOrDir(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + ok, err := copyFileExact("/missing", "/dest") + if err != nil || ok { + t.Fatalf("copyFileExact missing ok=%v err=%v want ok=false err=nil", ok, err) + } + + if err := fakeFS.AddDir("/srcdir"); err != nil { + t.Fatalf("add dir: %v", err) + } + ok, err = copyFileExact("/srcdir", "/dest") + if err != nil || ok { + t.Fatalf("copyFileExact dir ok=%v err=%v want ok=false err=nil", ok, err) + } +} + +func TestCopyFileExact_PropagatesAtomicWriteFailure(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + fakeTime := &FakeTime{Current: time.Unix(0, 12345)} + restoreTime = fakeTime + + if err := fakeFS.AddFile("/src/file", []byte("data\n")); err != nil { + t.Fatalf("add src: %v", err) + } + + dest := "/dest/file" + tmpPath := dest + ".proxsave.tmp." + strconv.FormatInt(fakeTime.Current.UnixNano(), 10) + fakeFS.OpenFileErr[filepath.Clean(tmpPath)] = fmt.Errorf("open tmp denied") + + _, err := copyFileExact("/src/file", dest) + if err == nil { + t.Fatalf("expected error") + } +} + +func TestSyncDirExact_CopiesSymlinks(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.AddFile("/stage/target", []byte("x")); err != nil { + t.Fatalf("add target: %v", err) + } + if err := fakeFS.Symlink("/stage/target", "/stage/link"); err != nil { + t.Fatalf("add symlink: %v", err) + } + + applied, err := syncDirExact("/stage", "/dest") + if err != nil { + t.Fatalf("syncDirExact error: %v", err) + } + + destTarget, err := fakeFS.Readlink("/dest/link") + if err != nil { + t.Fatalf("read dest symlink: %v", err) + } + if strings.TrimSpace(destTarget) == "" { + t.Fatalf("expected non-empty symlink target") + } + found := false + for _, p := range applied { + if p == "/dest/link" { + found = true + break + } + } + if !found { + t.Fatalf("expected /dest/link to be reported as applied, got %#v", applied) + } +} + +func TestSelectStageHostFirewall_ErrorsOnReadDirFailure(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + base := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(base.Root) }) + restoreFS = readDirFailFS{FS: base, failPath: "/stage/etc/pve/nodes", err: fmt.Errorf("boom")} + + _, _, _, err := selectStageHostFirewall(newTestLogger(), "/stage") + if err == nil || !strings.Contains(err.Error(), "readdir") { + t.Fatalf("expected readdir error, got %v", err) + } +} + +func TestSelectStageHostFirewall_PicksCurrentNodeWhenPresent(t *testing.T) { + origFS := restoreFS + origHostname := firewallHostname + t.Cleanup(func() { + restoreFS = origFS + firewallHostname = origHostname + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + firewallHostname = func() (string, error) { return "node1.example", nil } + + stageRoot := "/stage" + if err := fakeFS.AddFile(stageRoot+"/etc/pve/nodes/node1/host.fw", []byte("a")); err != nil { + t.Fatalf("add node1 host.fw: %v", err) + } + if err := fakeFS.AddFile(stageRoot+"/etc/pve/nodes/other/host.fw", []byte("b")); err != nil { + t.Fatalf("add other host.fw: %v", err) + } + if err := fakeFS.AddFile(stageRoot+"/etc/pve/nodes/notadir", []byte("c")); err != nil { + t.Fatalf("add notadir: %v", err) + } + + path, sourceNode, ok, err := selectStageHostFirewall(newTestLogger(), stageRoot) + if err != nil { + t.Fatalf("selectStageHostFirewall error: %v", err) + } + if !ok { + t.Fatalf("expected ok=true") + } + if sourceNode != "node1" { + t.Fatalf("sourceNode=%q want %q", sourceNode, "node1") + } + if !strings.HasSuffix(path, "/stage/etc/pve/nodes/node1/host.fw") { + t.Fatalf("unexpected path: %q", path) + } +} + +func TestSelectStageHostFirewall_SkipsWhenMultipleCandidatesNoneMatches(t *testing.T) { + origFS := restoreFS + origHostname := firewallHostname + t.Cleanup(func() { + restoreFS = origFS + firewallHostname = origHostname + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + firewallHostname = func() (string, error) { return "current", nil } + + stageRoot := "/stage" + for _, node := range []string{"nodeA", "nodeB"} { + if err := fakeFS.AddFile(stageRoot+"/etc/pve/nodes/"+node+"/host.fw", []byte("x")); err != nil { + t.Fatalf("add host.fw for %s: %v", node, err) + } + } + + path, sourceNode, ok, err := selectStageHostFirewall(newTestLogger(), stageRoot) + if err != nil { + t.Fatalf("selectStageHostFirewall error: %v", err) + } + if ok || path != "" || sourceNode != "" { + t.Fatalf("expected skip, got ok=%v path=%q source=%q", ok, path, sourceNode) + } +} + +func TestRestartPVEFirewallService_CommandFallbacks(t *testing.T) { + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + + binDir := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", binDir); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + + writeExecutable(t, binDir, "systemctl") + writeExecutable(t, binDir, "pve-firewall") + + t.Run("try-restart ok", func(t *testing.T) { + fake := &FakeCommandRunner{} + restoreCmd = fake + + if err := restartPVEFirewallService(context.Background()); err != nil { + t.Fatalf("restartPVEFirewallService error: %v", err) + } + if got := fake.CallsList(); len(got) != 1 || got[0] != "systemctl try-restart pve-firewall" { + t.Fatalf("unexpected calls: %#v", got) + } + }) + + t.Run("restart ok", func(t *testing.T) { + fake := &FakeCommandRunner{ + Errors: map[string]error{ + "systemctl try-restart pve-firewall": fmt.Errorf("fail"), + }, + } + restoreCmd = fake + + if err := restartPVEFirewallService(context.Background()); err != nil { + t.Fatalf("restartPVEFirewallService error: %v", err) + } + if got := fake.CallsList(); len(got) != 2 || got[0] != "systemctl try-restart pve-firewall" || got[1] != "systemctl restart pve-firewall" { + t.Fatalf("unexpected calls: %#v", got) + } + }) + + t.Run("fallback to pve-firewall", func(t *testing.T) { + fake := &FakeCommandRunner{ + Errors: map[string]error{ + "systemctl try-restart pve-firewall": fmt.Errorf("fail"), + "systemctl restart pve-firewall": fmt.Errorf("fail"), + }, + } + restoreCmd = fake + + if err := restartPVEFirewallService(context.Background()); err != nil { + t.Fatalf("restartPVEFirewallService error: %v", err) + } + calls := fake.CallsList() + if len(calls) != 3 { + t.Fatalf("unexpected calls: %#v", calls) + } + if calls[2] != "pve-firewall restart" { + t.Fatalf("expected fallback call, got %#v", calls) + } + }) + + t.Run("no commands available", func(t *testing.T) { + fake := &FakeCommandRunner{} + restoreCmd = fake + + emptyBin := t.TempDir() + if err := os.Setenv("PATH", emptyBin); err != nil { + t.Fatalf("set PATH: %v", err) + } + + if err := restartPVEFirewallService(context.Background()); err == nil { + t.Fatalf("expected error") + } + }) +} + +func TestArmFirewallRollback_SystemdAndBackgroundPaths(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + fakeTime := &FakeTime{Current: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC)} + restoreTime = fakeTime + + logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) + + t.Run("rejects invalid args", func(t *testing.T) { + if _, err := armFirewallRollback(context.Background(), logger, "", 1*time.Second, "/tmp/proxsave"); err == nil { + t.Fatalf("expected error for empty backupPath") + } + if _, err := armFirewallRollback(context.Background(), logger, "/backup.tgz", 0, "/tmp/proxsave"); err == nil { + t.Fatalf("expected error for invalid timeout") + } + }) + + t.Run("uses systemd-run when available", func(t *testing.T) { + binDir := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", binDir); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + writeExecutable(t, binDir, "systemd-run") + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + handle, err := armFirewallRollback(context.Background(), logger, "/backup.tgz", 2*time.Second, "/tmp/proxsave") + if err != nil { + t.Fatalf("armFirewallRollback error: %v", err) + } + if handle == nil || handle.unitName == "" { + t.Fatalf("expected systemd unit name, got %#v", handle) + } + if got := fakeCmd.CallsList(); len(got) != 1 || !strings.HasPrefix(got[0], "systemd-run --unit=proxsave-firewall-rollback-20200102_030405") { + t.Fatalf("unexpected calls: %#v", got) + } + if data, err := fakeFS.ReadFile(handle.markerPath); err != nil || string(data) != "pending\n" { + t.Fatalf("marker read err=%v data=%q", err, string(data)) + } + }) + + t.Run("falls back to background timer on systemd-run failure", func(t *testing.T) { + binDir := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", binDir); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + writeExecutable(t, binDir, "systemd-run") + + fakeCmd := &FakeCommandRunner{ + Errors: map[string]error{}, + } + restoreCmd = fakeCmd + + timestamp := fakeTime.Current.Format("20060102_150405") + scriptPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("firewall_rollback_%s.sh", timestamp)) + systemdKey := "systemd-run --unit=proxsave-firewall-rollback-" + timestamp + " --on-active=2s /bin/sh " + scriptPath + fakeCmd.Errors[systemdKey] = fmt.Errorf("fail") + + handle, err := armFirewallRollback(context.Background(), logger, "/backup.tgz", 2*time.Second, "/tmp/proxsave") + if err != nil { + t.Fatalf("armFirewallRollback error: %v", err) + } + if handle == nil { + t.Fatalf("expected handle") + } + if handle.unitName != "" { + t.Fatalf("expected unitName cleared after systemd-run failure, got %q", handle.unitName) + } + + cmd := fmt.Sprintf("nohup sh -c 'sleep %d; /bin/sh %s' >/dev/null 2>&1 &", 2, scriptPath) + wantBackground := "sh -c " + cmd + calls := fakeCmd.CallsList() + if len(calls) != 2 || calls[1] != wantBackground { + t.Fatalf("unexpected calls: %#v", calls) + } + }) + + t.Run("background timer failure returns error", func(t *testing.T) { + emptyBin := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", emptyBin); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + + fakeCmd := &FakeCommandRunner{ + Errors: map[string]error{}, + } + restoreCmd = fakeCmd + + timestamp := fakeTime.Current.Format("20060102_150405") + scriptPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("firewall_rollback_%s.sh", timestamp)) + cmd := fmt.Sprintf("nohup sh -c 'sleep %d; /bin/sh %s' >/dev/null 2>&1 &", 1, scriptPath) + backgroundKey := "sh -c " + cmd + fakeCmd.Errors[backgroundKey] = fmt.Errorf("boom") + + if _, err := armFirewallRollback(context.Background(), logger, "/backup.tgz", 1*time.Second, "/tmp/proxsave"); err == nil { + t.Fatalf("expected error") + } + }) +} + +func TestDisarmFirewallRollback_RemovesMarkerAndStopsTimer(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + binDir := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", binDir); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + writeExecutable(t, binDir, "systemctl") + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + handle := &firewallRollbackHandle{ + markerPath: "/tmp/proxsave/fw.marker", + unitName: "proxsave-firewall-rollback-test", + } + if err := fakeFS.AddFile(handle.markerPath, []byte("pending\n")); err != nil { + t.Fatalf("add marker: %v", err) + } + + disarmFirewallRollback(context.Background(), newTestLogger(), handle) + + if _, err := fakeFS.Stat(handle.markerPath); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected marker removed; stat err=%v", err) + } + + timerUnit := handle.unitName + ".timer" + want1 := "systemctl stop " + timerUnit + want2 := "systemctl reset-failed " + handle.unitName + ".service " + timerUnit + calls := fakeCmd.CallsList() + if len(calls) != 2 || calls[0] != want1 || calls[1] != want2 { + t.Fatalf("unexpected calls: %#v", calls) + } +} + +func TestMaybeApplyPVEFirewallWithUI_CoversUserFlows(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + origGeteuid := firewallApplyGeteuid + origMounted := firewallIsMounted + origRealFS := firewallIsRealRestoreFS + origArm := firewallArmRollback + origDisarm := firewallDisarmRollback + origApply := firewallApplyFromStage + origRestart := firewallRestartService + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + firewallApplyGeteuid = origGeteuid + firewallIsMounted = origMounted + firewallIsRealRestoreFS = origRealFS + firewallArmRollback = origArm + firewallDisarmRollback = origDisarm + firewallApplyFromStage = origApply + firewallRestartService = origRestart + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + restoreTime = &FakeTime{Current: time.Unix(100, 0)} + restoreCmd = &FakeCommandRunner{} + + firewallIsRealRestoreFS = func(fs FS) bool { return true } + firewallApplyGeteuid = func() int { return 0 } + firewallIsMounted = func(path string) (bool, error) { return true, nil } + firewallRestartService = func(ctx context.Context) error { return nil } + + plan := &RestorePlan{ + SystemType: SystemTypePVE, + NormalCategories: []Category{{ID: "pve_firewall"}}, + } + stageRoot := "/stage" + logger := newTestLogger() + + t.Run("errors when ui missing", func(t *testing.T) { + err := maybeApplyPVEFirewallWithUI(context.Background(), nil, logger, plan, nil, nil, stageRoot, false) + if err == nil { + t.Fatalf("expected error") + } + }) + + t.Run("skips when /etc/pve not mounted", func(t *testing.T) { + firewallIsMounted = func(path string) (bool, error) { return false, nil } + t.Cleanup(func() { firewallIsMounted = func(path string) (bool, error) { return true, nil } }) + + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, stageRoot, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("skips when stage has no firewall data", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, stageRoot, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("user skips apply", func(t *testing.T) { + if err := fakeFS.AddDir(stageRoot + "/etc/pve/nodes"); err != nil { + t.Fatalf("add stage nodes: %v", err) + } + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{{ok: false}}, + } + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, stageRoot, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("missing rollback backup declines full rollback", func(t *testing.T) { + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + t.Fatalf("unexpected rollback arm") + return nil, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: false}, // Skip full rollback + }, + } + safety := &SafetyBackupResult{BackupPath: "/backups/full.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, safety, nil, stageRoot, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("full rollback accepted but no changes applied disarms rollback", func(t *testing.T) { + markerPath := "/tmp/proxsave/fw.marker" + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + handle := &firewallRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/fw.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { return nil, nil } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: true}, // Proceed with full rollback + }, + } + safety := &SafetyBackupResult{BackupPath: "/backups/full.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, safety, nil, stageRoot, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + if _, err := fakeFS.Stat(markerPath); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected rollback marker removed; stat err=%v", err) + } + }) + + t.Run("proceed without rollback applies and returns without commit prompt", func(t *testing.T) { + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + t.Fatalf("unexpected rollback arm") + return nil, nil + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/firewall/cluster.fw"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: true}, // Proceed without rollback + }, + } + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, stageRoot, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("commit keeps changes and disarms rollback", func(t *testing.T) { + markerPath := "/tmp/proxsave/fw.marker" + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + handle := &firewallRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/fw.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/firewall/cluster.fw"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: true}, // Commit + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/firewall.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, rollback, stageRoot, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + if _, err := fakeFS.Stat(markerPath); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected rollback marker removed; stat err=%v", err) + } + }) + + t.Run("rollback requested returns typed error with marker armed", func(t *testing.T) { + markerPath := "/tmp/proxsave/fw.marker" + armedAt := nowRestore() + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + handle := &firewallRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/fw.log", + armedAt: armedAt, + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/firewall/cluster.fw"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: false}, // Rollback + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/firewall.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, rollback, stageRoot, false) + if err == nil { + t.Fatalf("expected error") + } + if !errors.Is(err, ErrFirewallApplyNotCommitted) { + t.Fatalf("expected ErrFirewallApplyNotCommitted, got %v", err) + } + var typed *FirewallApplyNotCommittedError + if !errors.As(err, &typed) || typed == nil { + t.Fatalf("expected FirewallApplyNotCommittedError, got %T", err) + } + if !typed.RollbackArmed || typed.RollbackMarker != markerPath || typed.RollbackLog != "/tmp/proxsave/fw.log" { + t.Fatalf("unexpected error fields: %#v", typed) + } + if typed.RollbackDeadline.IsZero() || !typed.RollbackDeadline.Equal(armedAt.Add(defaultFirewallRollbackTimeout)) { + t.Fatalf("unexpected RollbackDeadline=%s", typed.RollbackDeadline) + } + if _, err := fakeFS.Stat(markerPath); err != nil { + t.Fatalf("expected marker to still exist, stat err=%v", err) + } + }) + + t.Run("commit prompt abort returns abort error", func(t *testing.T) { + markerPath := "/tmp/proxsave/fw.marker" + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + handle := &firewallRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/fw.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/firewall/cluster.fw"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {err: input.ErrInputAborted}, // Abort at commit prompt + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/firewall.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, rollback, stageRoot, false) + if err == nil || !errors.Is(err, input.ErrInputAborted) { + t.Fatalf("expected abort error, got %v", err) + } + }) + + t.Run("commit prompt failure returns typed error", func(t *testing.T) { + markerPath := "/tmp/proxsave/fw.marker" + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + handle := &firewallRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/fw.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/firewall/cluster.fw"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {err: fmt.Errorf("boom")}, // Commit prompt fails + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/firewall.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, rollback, stageRoot, false) + if err == nil { + t.Fatalf("expected error") + } + if !errors.Is(err, ErrFirewallApplyNotCommitted) { + t.Fatalf("expected ErrFirewallApplyNotCommitted, got %v", err) + } + if _, err := fakeFS.Stat(markerPath); err != nil { + t.Fatalf("expected marker to still exist, stat err=%v", err) + } + }) + + t.Run("remaining timeout returns typed error without commit prompt", func(t *testing.T) { + markerPath := "/tmp/proxsave/fw.marker" + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + handle := &firewallRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/fw.log", + armedAt: nowRestore().Add(-timeout - time.Second), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/firewall/cluster.fw"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{{ok: true}}, // Apply now only + } + rollback := &SafetyBackupResult{BackupPath: "/backups/firewall.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, rollback, stageRoot, false) + if err == nil || !errors.Is(err, ErrFirewallApplyNotCommitted) { + t.Fatalf("expected ErrFirewallApplyNotCommitted, got %v", err) + } + }) + + t.Run("dry run skips apply", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, stageRoot, true) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("non-root skips apply", func(t *testing.T) { + firewallApplyGeteuid = func() int { return 1000 } + t.Cleanup(func() { firewallApplyGeteuid = func() int { return 0 } }) + + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, stageRoot, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) +} + +func TestMaybeApplyPVEFirewallWithUI_AdditionalBranches(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origCmd := restoreCmd + origGeteuid := firewallApplyGeteuid + origMounted := firewallIsMounted + origRealFS := firewallIsRealRestoreFS + origArm := firewallArmRollback + origDisarm := firewallDisarmRollback + origApply := firewallApplyFromStage + origRestart := firewallRestartService + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + restoreCmd = origCmd + firewallApplyGeteuid = origGeteuid + firewallIsMounted = origMounted + firewallIsRealRestoreFS = origRealFS + firewallArmRollback = origArm + firewallDisarmRollback = origDisarm + firewallApplyFromStage = origApply + firewallRestartService = origRestart + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Unix(100, 0)} + restoreCmd = &FakeCommandRunner{} + + firewallIsRealRestoreFS = func(fs FS) bool { return true } + firewallApplyGeteuid = func() int { return 0 } + firewallIsMounted = func(path string) (bool, error) { return true, nil } + firewallRestartService = func(ctx context.Context) error { return nil } + + logger := newTestLogger() + plan := &RestorePlan{SystemType: SystemTypePVE, NormalCategories: []Category{{ID: "pve_firewall"}}} + + t.Run("plan nil returns nil", func(t *testing.T) { + err := maybeApplyPVEFirewallWithUI(context.Background(), nil, logger, nil, nil, nil, "/stage", false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("wrong system type returns nil", func(t *testing.T) { + p := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "pve_firewall"}}} + err := maybeApplyPVEFirewallWithUI(context.Background(), nil, logger, p, nil, nil, "/stage", false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("missing category returns nil", func(t *testing.T) { + p := &RestorePlan{SystemType: SystemTypePVE, NormalCategories: []Category{{ID: "network"}}} + err := maybeApplyPVEFirewallWithUI(context.Background(), nil, logger, p, nil, nil, "/stage", false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("non-real filesystem skips apply", func(t *testing.T) { + firewallIsRealRestoreFS = func(fs FS) bool { return false } + t.Cleanup(func() { firewallIsRealRestoreFS = func(fs FS) bool { return true } }) + + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, "/stage", false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("blank stage root skips apply", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, " ", false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("cluster restore skips apply", func(t *testing.T) { + p := &RestorePlan{SystemType: SystemTypePVE, NeedsClusterRestore: true, NormalCategories: []Category{{ID: "pve_firewall"}}} + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, p, nil, nil, "/stage", false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("mount check warning path", func(t *testing.T) { + firewallIsMounted = func(path string) (bool, error) { return true, fmt.Errorf("boom") } + t.Cleanup(func() { firewallIsMounted = func(path string) (bool, error) { return true, nil } }) + + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, "/stage", false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("apply prompt error returns error", func(t *testing.T) { + if err := fakeFS.AddDir("/stage/etc/pve/nodes"); err != nil { + t.Fatalf("add stage nodes: %v", err) + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{{err: fmt.Errorf("input fail")}}, + } + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, "/stage", false) + if err == nil { + t.Fatalf("expected error") + } + }) + + t.Run("no rollback declined returns nil", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: false}, // Skip apply (no rollback) + }, + } + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, "/stage", false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("no rollback prompt error returns error", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {err: fmt.Errorf("boom")}, // No rollback prompt fails + }, + } + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, "/stage", false) + if err == nil { + t.Fatalf("expected error") + } + }) + + t.Run("full rollback prompt error returns error", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {err: fmt.Errorf("boom")}, // Full rollback prompt fails + }, + } + safety := &SafetyBackupResult{BackupPath: "/backups/full.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, safety, nil, "/stage", false) + if err == nil { + t.Fatalf("expected error") + } + }) + + t.Run("arm rollback error returns wrapped error", func(t *testing.T) { + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + return nil, fmt.Errorf("arm failed") + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + t.Fatalf("unexpected apply") + return nil, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{{ok: true}}, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/firewall.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, rollback, "/stage", false) + if err == nil || !strings.Contains(err.Error(), "arm firewall rollback") { + t.Fatalf("expected wrapped arm error, got %v", err) + } + }) + + t.Run("apply from stage error returns error", func(t *testing.T) { + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return nil, fmt.Errorf("apply failed") + } + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: true}, // Proceed without rollback + }, + } + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, nil, "/stage", false) + if err == nil || !strings.Contains(err.Error(), "apply failed") { + t.Fatalf("expected apply error, got %v", err) + } + }) + + t.Run("restart failure logs warning but continues", func(t *testing.T) { + markerPath := "/tmp/proxsave/fw.marker" + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + handle := &firewallRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/fw.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/firewall/cluster.fw"}, nil + } + firewallRestartService = func(ctx context.Context) error { return fmt.Errorf("restart failed") } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: true}, // Commit + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/firewall.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, rollback, "/stage", false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + if _, err := fakeFS.Stat(markerPath); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected rollback marker removed; stat err=%v", err) + } + }) + + t.Run("commit context canceled returns canceled error", func(t *testing.T) { + markerPath := "/tmp/proxsave/fw.marker" + firewallArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*firewallRollbackHandle, error) { + handle := &firewallRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/fw.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + firewallApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { + return []string{"/etc/pve/firewall/cluster.fw"}, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {err: context.Canceled}, // Commit prompt canceled + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/firewall.tgz"} + err := maybeApplyPVEFirewallWithUI(context.Background(), ui, logger, plan, nil, rollback, "/stage", false) + if err == nil || !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled error, got %v", err) + } + }) +} + +func TestApplyPVEFirewallFromStage_CoversAdditionalPaths(t *testing.T) { + origFS := restoreFS + origHostname := firewallHostname + t.Cleanup(func() { + restoreFS = origFS + firewallHostname = origHostname + }) + + t.Run("blank stage root", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + applied, err := applyPVEFirewallFromStage(newTestLogger(), " ") + if err != nil || len(applied) != 0 { + t.Fatalf("expected nil, got applied=%#v err=%v", applied, err) + } + }) + + t.Run("staged firewall as file", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + if err := fakeFS.AddFile(stageRoot+"/etc/pve/firewall", []byte("fw\n")); err != nil { + t.Fatalf("add staged firewall file: %v", err) + } + + applied, err := applyPVEFirewallFromStage(newTestLogger(), stageRoot) + if err != nil { + t.Fatalf("applyPVEFirewallFromStage error: %v", err) + } + found := false + for _, p := range applied { + if p == "/etc/pve/firewall" { + found = true + break + } + } + if !found { + t.Fatalf("expected /etc/pve/firewall in applied paths, got %#v", applied) + } + if got, err := fakeFS.ReadFile("/etc/pve/firewall"); err != nil || string(got) != "fw\n" { + t.Fatalf("dest firewall err=%v data=%q", err, string(got)) + } + }) + + t.Run("firewall stat error returns error", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + stageFirewall := filepath.Join(stageRoot, "etc", "pve", "firewall") + fakeFS.StatErrors[filepath.Clean(stageFirewall)] = fmt.Errorf("boom") + + _, err := applyPVEFirewallFromStage(newTestLogger(), stageRoot) + if err == nil || !strings.Contains(err.Error(), "stat staged firewall config") { + t.Fatalf("expected stat staged firewall config error, got %v", err) + } + }) + + t.Run("host fw selection error bubbles", func(t *testing.T) { + base := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(base.Root) }) + restoreFS = readDirFailFS{FS: base, failPath: "/stage/etc/pve/nodes", err: fmt.Errorf("boom")} + + _, err := applyPVEFirewallFromStage(newTestLogger(), "/stage") + if err == nil || !strings.Contains(err.Error(), "readdir") { + t.Fatalf("expected readdir error, got %v", err) + } + }) + + t.Run("defaults to localhost when hostname empty", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + firewallHostname = func() (string, error) { return " ", nil } + + stageRoot := "/stage" + if err := fakeFS.AddFile(stageRoot+"/etc/pve/nodes/node1/host.fw", []byte("host\n")); err != nil { + t.Fatalf("add staged host.fw: %v", err) + } + + applied, err := applyPVEFirewallFromStage(newTestLogger(), stageRoot) + if err != nil { + t.Fatalf("applyPVEFirewallFromStage error: %v", err) + } + found := false + for _, p := range applied { + if p == "/etc/pve/nodes/localhost/host.fw" { + found = true + break + } + } + if !found { + t.Fatalf("expected mapped host.fw applied, got %#v", applied) + } + if got, err := fakeFS.ReadFile("/etc/pve/nodes/localhost/host.fw"); err != nil || string(got) != "host\n" { + t.Fatalf("dest host.fw err=%v data=%q", err, string(got)) + } + }) +} + +func TestSelectStageHostFirewall_EmptyCases(t *testing.T) { + origFS := restoreFS + origHostname := firewallHostname + t.Cleanup(func() { + restoreFS = origFS + firewallHostname = origHostname + }) + + firewallHostname = func() (string, error) { return "node1", nil } + + t.Run("nodes directory missing", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + path, node, ok, err := selectStageHostFirewall(newTestLogger(), "/stage") + if err != nil || ok || path != "" || node != "" { + t.Fatalf("expected no selection, got ok=%v path=%q node=%q err=%v", ok, path, node, err) + } + }) + + t.Run("no host.fw candidates", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.AddDir("/stage/etc/pve/nodes/node1"); err != nil { + t.Fatalf("add nodes dir: %v", err) + } + + path, node, ok, err := selectStageHostFirewall(newTestLogger(), "/stage") + if err != nil || ok || path != "" || node != "" { + t.Fatalf("expected no selection, got ok=%v path=%q node=%q err=%v", ok, path, node, err) + } + }) +} + +func TestArmFirewallRollback_DefaultWorkDirAndMinTimeout(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + fakeTime := &FakeTime{Current: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC)} + restoreTime = fakeTime + + emptyBin := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", emptyBin); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + handle, err := armFirewallRollback(context.Background(), newTestLogger(), "/backup.tgz", 500*time.Millisecond, " ") + if err != nil { + t.Fatalf("armFirewallRollback error: %v", err) + } + if handle == nil || handle.workDir != "/tmp/proxsave" { + t.Fatalf("unexpected handle: %#v", handle) + } + if data, err := fakeFS.ReadFile(handle.markerPath); err != nil || string(data) != "pending\n" { + t.Fatalf("marker err=%v data=%q", err, string(data)) + } + calls := fakeCmd.CallsList() + if len(calls) != 1 { + t.Fatalf("unexpected calls: %#v", calls) + } + if !strings.Contains(calls[0], "sleep 1; /bin/sh") { + t.Fatalf("expected timeoutSeconds to clamp to 1, got call=%q", calls[0]) + } +} + +func TestArmFirewallRollback_ReturnsErrorOnMkdirAllFailure(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + fakeFS.MkdirAllErr = fmt.Errorf("disk full") + restoreFS = fakeFS + + if _, err := armFirewallRollback(context.Background(), newTestLogger(), "/backup.tgz", 1*time.Second, "/tmp/proxsave"); err == nil { + t.Fatalf("expected error") + } +} + +func TestArmFirewallRollback_ReturnsErrorOnMarkerWriteFailure(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + fakeFS.WriteErr = fmt.Errorf("disk full") + restoreFS = fakeFS + + if _, err := armFirewallRollback(context.Background(), newTestLogger(), "/backup.tgz", 1*time.Second, "/tmp/proxsave"); err == nil || !strings.Contains(err.Error(), "write rollback marker") { + t.Fatalf("expected write rollback marker error, got %v", err) + } +} + +func TestArmFirewallRollback_ReturnsErrorOnScriptWriteFailure(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + base := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(base.Root) }) + + fakeTime := &FakeTime{Current: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC)} + restoreTime = fakeTime + timestamp := fakeTime.Current.Format("20060102_150405") + scriptPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("firewall_rollback_%s.sh", timestamp)) + + restoreFS = writeFileFailFS{FS: base, failPath: scriptPath, err: fmt.Errorf("disk full")} + + if _, err := armFirewallRollback(context.Background(), newTestLogger(), "/backup.tgz", 1*time.Second, "/tmp/proxsave"); err == nil || !strings.Contains(err.Error(), "write rollback script") { + t.Fatalf("expected write rollback script error, got %v", err) + } +} + +func TestDisarmFirewallRollback_MissingMarkerAndNoSystemctl(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + emptyBin := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", emptyBin); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + handle := &firewallRollbackHandle{ + markerPath: "/tmp/proxsave/missing.marker", + unitName: "unit", + } + disarmFirewallRollback(context.Background(), newTestLogger(), handle) + + if calls := fakeCmd.CallsList(); len(calls) != 0 { + t.Fatalf("expected no systemctl calls, got %#v", calls) + } +} + +func TestDisarmFirewallRollback_ContinuesOnMarkerRemoveError(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + base := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(base.Root) }) + if err := base.AddFile("/tmp/proxsave/fw.marker", []byte("pending\n")); err != nil { + t.Fatalf("add marker: %v", err) + } + restoreFS = removeFailFS{FS: base, failPath: "/tmp/proxsave/fw.marker", err: fmt.Errorf("perm")} + + handle := &firewallRollbackHandle{ + markerPath: "/tmp/proxsave/fw.marker", + } + disarmFirewallRollback(context.Background(), newTestLogger(), handle) +} + +func TestCopyFileExact_PropagatesStatAndReadFailures(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + base := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(base.Root) }) + if err := base.AddFile("/src/file", []byte("x")); err != nil { + t.Fatalf("add src: %v", err) + } + + restoreFS = statFailFS{FS: base, failPath: "/src/file", err: fmt.Errorf("boom")} + if _, err := copyFileExact("/src/file", "/dest/file"); err == nil { + t.Fatalf("expected stat error") + } + + restoreFS = readFileFailFS{FS: base, failPath: "/src/file", err: os.ErrNotExist} + ok, err := copyFileExact("/src/file", "/dest/file") + if err != nil || ok { + t.Fatalf("expected ok=false err=nil for readfile not exist, got ok=%v err=%v", ok, err) + } + + restoreFS = readFileFailFS{FS: base, failPath: "/src/file", err: fmt.Errorf("read boom")} + if _, err := copyFileExact("/src/file", "/dest/file"); err == nil { + t.Fatalf("expected read error") + } +} + +func TestSyncDirExact_CoversErrorPathsAndPrune(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + t.Run("source missing or not a dir returns nil", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if applied, err := syncDirExact("/missing", "/dest"); err != nil || len(applied) != 0 { + t.Fatalf("expected nil, got applied=%#v err=%v", applied, err) + } + if err := fakeFS.AddFile("/stagefile", []byte("x")); err != nil { + t.Fatalf("add stagefile: %v", err) + } + if applied, err := syncDirExact("/stagefile", "/dest"); err != nil || len(applied) != 0 { + t.Fatalf("expected nil, got applied=%#v err=%v", applied, err) + } + }) + + t.Run("dest exists as file returns error", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + if err := fakeFS.AddFile("/dest", []byte("x")); err != nil { + t.Fatalf("add dest file: %v", err) + } + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "ensure") { + t.Fatalf("expected ensure error, got %v", err) + } + }) + + t.Run("readDir error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + restoreFS = readDirFailFS{FS: fakeFS, failPath: "/stage", err: fmt.Errorf("boom")} + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "readdir") { + t.Fatalf("expected readdir error, got %v", err) + } + }) + + t.Run("entry.Info error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + + restoreFS = readDirOverrideFS{ + FS: fakeFS, + overridePath: "/stage", + entries: []os.DirEntry{badInfoDirEntry{name: "bad"}}, + } + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "stat /stage/bad") { + t.Fatalf("expected entry.Info error, got %v", err) + } + }) + + t.Run("readlink error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.AddFile("/stage/target", []byte("x")); err != nil { + t.Fatalf("add target: %v", err) + } + if err := fakeFS.Symlink("/stage/target", "/stage/link"); err != nil { + t.Fatalf("add symlink: %v", err) + } + + restoreFS = readlinkFailFS{FS: fakeFS, failPath: "/stage/link", err: fmt.Errorf("boom")} + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "readlink") { + t.Fatalf("expected readlink error, got %v", err) + } + }) + + t.Run("prune remove failure returns error", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + + if err := fakeFS.AddFile("/stage/keep", []byte("x")); err != nil { + t.Fatalf("add keep: %v", err) + } + if err := fakeFS.AddFile("/dest/remove", []byte("x")); err != nil { + t.Fatalf("add extraneous: %v", err) + } + + restoreFS = removeFailFS{FS: fakeFS, failPath: "/dest/remove", err: fmt.Errorf("perm")} + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "remove") { + t.Fatalf("expected remove error, got %v", err) + } + }) + + t.Run("prunes empty dirs not present in stage", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.AddFile("/stage/keep", []byte("x")); err != nil { + t.Fatalf("add keep: %v", err) + } + if err := fakeFS.AddDir("/dest/oldDir"); err != nil { + t.Fatalf("add oldDir: %v", err) + } + + if _, err := syncDirExact("/stage", "/dest"); err != nil { + t.Fatalf("syncDirExact error: %v", err) + } + if _, err := fakeFS.Stat("/dest/oldDir"); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected oldDir removed; stat err=%v", err) + } + }) +} + +func TestDisarmFirewallRollback_NilHandleAndEmptyPaths(t *testing.T) { + disarmFirewallRollback(context.Background(), newTestLogger(), nil) + + handle := &firewallRollbackHandle{ + markerPath: " ", + unitName: " ", + } + disarmFirewallRollback(context.Background(), newTestLogger(), handle) +} + +func TestSelectStageHostFirewall_IgnoresNilAndBlankEntries(t *testing.T) { + origFS := restoreFS + origHostname := firewallHostname + t.Cleanup(func() { + restoreFS = origFS + firewallHostname = origHostname + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + + firewallHostname = func() (string, error) { return "node1", nil } + + stageNodes := "/stage/etc/pve/nodes" + if err := fakeFS.AddDir(stageNodes); err != nil { + t.Fatalf("add stage nodes: %v", err) + } + if err := fakeFS.AddDir(stageNodes + "/ "); err != nil { + t.Fatalf("add blank node dir: %v", err) + } + if err := fakeFS.AddFile(stageNodes+"/node1/host.fw", []byte("host\n")); err != nil { + t.Fatalf("add host.fw: %v", err) + } + + entries, err := fakeFS.ReadDir(stageNodes) + if err != nil { + t.Fatalf("readDir stage nodes: %v", err) + } + entries = append([]os.DirEntry{nil}, entries...) + + restoreFS = readDirOverrideFS{ + FS: fakeFS, + overridePath: stageNodes, + entries: entries, + } + + path, node, ok, err := selectStageHostFirewall(newTestLogger(), "/stage") + if err != nil { + t.Fatalf("selectStageHostFirewall error: %v", err) + } + if !ok || node != "node1" || !strings.HasSuffix(path, "/stage/etc/pve/nodes/node1/host.fw") { + t.Fatalf("unexpected selection ok=%v node=%q path=%q", ok, node, path) + } +} + +func TestApplyPVEFirewallFromStage_PropagatesSyncAndCopyErrors(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origHostname := firewallHostname + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + firewallHostname = origHostname + }) + + t.Run("syncDirExact failure bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + if err := fakeFS.AddFile(stageRoot+"/etc/pve/firewall/cluster.fw", []byte("x")); err != nil { + t.Fatalf("add staged firewall: %v", err) + } + if err := fakeFS.AddFile("/etc/pve/firewall", []byte("not a dir")); err != nil { + t.Fatalf("add dest firewall file: %v", err) + } + + _, err := applyPVEFirewallFromStage(newTestLogger(), stageRoot) + if err == nil || !strings.Contains(err.Error(), "ensure /etc/pve/firewall") { + t.Fatalf("expected ensure error, got %v", err) + } + }) + + t.Run("firewall file copy failure bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + fakeTime := &FakeTime{Current: time.Unix(0, 12345)} + restoreTime = fakeTime + + stageRoot := "/stage" + if err := fakeFS.AddFile(stageRoot+"/etc/pve/firewall", []byte("fw\n")); err != nil { + t.Fatalf("add staged firewall file: %v", err) + } + + tmpPath := "/etc/pve/firewall.proxsave.tmp." + strconv.FormatInt(fakeTime.Current.UnixNano(), 10) + fakeFS.OpenFileErr[filepath.Clean(tmpPath)] = fmt.Errorf("open tmp denied") + + _, err := applyPVEFirewallFromStage(newTestLogger(), stageRoot) + if err == nil || !strings.Contains(err.Error(), "write /etc/pve/firewall") { + t.Fatalf("expected write error, got %v", err) + } + }) + + t.Run("host.fw copy failure bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + fakeTime := &FakeTime{Current: time.Unix(0, 12345)} + restoreTime = fakeTime + + firewallHostname = func() (string, error) { return "current", nil } + + stageRoot := "/stage" + if err := fakeFS.AddFile(stageRoot+"/etc/pve/nodes/other/host.fw", []byte("host\n")); err != nil { + t.Fatalf("add staged host.fw: %v", err) + } + + destHostFW := "/etc/pve/nodes/current/host.fw" + tmpPath := destHostFW + ".proxsave.tmp." + strconv.FormatInt(fakeTime.Current.UnixNano(), 10) + fakeFS.OpenFileErr[filepath.Clean(tmpPath)] = fmt.Errorf("open tmp denied") + + _, err := applyPVEFirewallFromStage(newTestLogger(), stageRoot) + if err == nil || !strings.Contains(err.Error(), "write "+destHostFW) { + t.Fatalf("expected host.fw write error, got %v", err) + } + }) +} + +func TestSyncDirExact_AdditionalEdgeCases(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + t.Run("stat error bubbles", func(t *testing.T) { + base := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(base.Root) }) + restoreFS = statFailFS{FS: base, failPath: "/stage", err: fmt.Errorf("boom")} + + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "stat /stage") { + t.Fatalf("expected stat error, got %v", err) + } + }) + + t.Run("walkStage ignores disappeared dir readDir not-exist", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + + restoreFS = readDirOverrideFS{ + FS: fakeFS, + overridePath: "/stage", + entries: []os.DirEntry{ + staticDirEntry{name: "sub", mode: fs.ModeDir}, + }, + } + + if _, err := syncDirExact("/stage", "/dest"); err != nil { + t.Fatalf("syncDirExact error: %v", err) + } + if info, err := fakeFS.Stat("/dest/sub"); err != nil || !info.IsDir() { + t.Fatalf("expected /dest/sub directory, err=%v info=%v", err, info) + } + }) + + t.Run("walkStage skips nil/blank/dot entries", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + + restoreFS = readDirOverrideFS{ + FS: fakeFS, + overridePath: "/stage", + entries: []os.DirEntry{ + nil, + staticDirEntry{name: " ", mode: 0}, + staticDirEntry{name: ".", mode: 0}, + }, + } + + if applied, err := syncDirExact("/stage", "/dest"); err != nil || len(applied) != 0 { + t.Fatalf("expected nil, got applied=%#v err=%v", applied, err) + } + }) + + t.Run("symlink parent ensure error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + if err := fakeFS.Symlink("/stage/target", "/stage/link"); err != nil { + t.Fatalf("add stage symlink: %v", err) + } + if err := fakeFS.AddDir("/dest"); err != nil { + t.Fatalf("add dest dir: %v", err) + } + + restoreFS = &statFailOnNthFS{FS: fakeFS, path: "/dest", failOn: 2, err: fmt.Errorf("boom")} + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "ensure /dest") { + t.Fatalf("expected ensure error, got %v", err) + } + }) + + t.Run("symlink creation error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + if err := fakeFS.Symlink("/stage/target", "/stage/link"); err != nil { + t.Fatalf("add stage symlink: %v", err) + } + + restoreFS = symlinkFailFS{FS: fakeFS, failNewname: "/dest/link", err: fmt.Errorf("boom")} + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "symlink /dest/link") { + t.Fatalf("expected symlink error, got %v", err) + } + }) + + t.Run("directory ensure error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.AddDir("/stage/sub"); err != nil { + t.Fatalf("add stage subdir: %v", err) + } + if err := fakeFS.AddFile("/dest/sub", []byte("not a dir")); err != nil { + t.Fatalf("add dest file: %v", err) + } + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "ensure /dest/sub") { + t.Fatalf("expected ensure /dest/sub error, got %v", err) + } + }) + + t.Run("directory recursion error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.AddDir("/stage/sub"); err != nil { + t.Fatalf("add stage subdir: %v", err) + } + restoreFS = readDirFailFS{FS: fakeFS, failPath: "/stage/sub", err: fmt.Errorf("boom")} + + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "readdir /stage/sub") { + t.Fatalf("expected recursion readdir error, got %v", err) + } + }) + + t.Run("copy file error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.AddFile("/stage/file", []byte("x")); err != nil { + t.Fatalf("add stage file: %v", err) + } + restoreFS = readFileFailFS{FS: fakeFS, failPath: "/stage/file", err: fmt.Errorf("read boom")} + + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "read /stage/file") { + t.Fatalf("expected read error, got %v", err) + } + }) + + t.Run("pruneDest readDir error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + restoreFS = readDirFailFS{FS: fakeFS, failPath: "/dest", err: fmt.Errorf("boom")} + + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "readdir /dest") { + t.Fatalf("expected pruneDest readdir error, got %v", err) + } + }) + + t.Run("pruneDest ignores not-exist", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + restoreFS = readDirFailFS{FS: fakeFS, failPath: "/dest", err: os.ErrNotExist} + + if _, err := syncDirExact("/stage", "/dest"); err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("pruneDest skips nil/blank/dot entries", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + + restoreFS = multiReadDirFS{ + FS: fakeFS, + entries: map[string][]os.DirEntry{ + "/dest": { + nil, + staticDirEntry{name: " ", mode: 0}, + staticDirEntry{name: ".", mode: 0}, + }, + }, + } + + if _, err := syncDirExact("/stage", "/dest"); err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("pruneDest entry.Info error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + + restoreFS = multiReadDirFS{ + FS: fakeFS, + entries: map[string][]os.DirEntry{ + "/dest": {badInfoDirEntry{name: "bad"}}, + }, + } + + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "stat /dest/bad") { + t.Fatalf("expected pruneDest info error, got %v", err) + } + }) + + t.Run("pruneDest recursion error bubbles", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + + restoreFS = multiReadDirFS{ + FS: fakeFS, + entries: map[string][]os.DirEntry{ + "/dest": {staticDirEntry{name: "sub", mode: fs.ModeDir}}, + "/dest/sub": nil, + }, + errors: map[string]error{ + "/dest/sub": fmt.Errorf("boom"), + }, + } + + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "readdir /dest/sub") { + t.Fatalf("expected pruneDest recursion error, got %v", err) + } + }) +} From 09717621cf27c3ef8a8d336b0577dca96e193cdf Mon Sep 17 00:00:00 2001 From: Damiano <71268257+tis24dev@users.noreply.github.com> Date: Tue, 24 Feb 2026 13:10:11 +0100 Subject: [PATCH 06/12] Inject access-control helpers and add tests Introduce package-level function variables (wrappers for os.Geteuid, isMounted, isRealRestoreFS, arm/disarm/apply functions) and use nowRestore() for time-based logic so the access-control apply/rollback flow can be mocked in tests. Adjusted cluster-backup check in maybeApplyPVEAccessControlFromClusterBackupWithUI and replaced direct calls with the injectable variants. Added comprehensive unit tests (restore_access_control_ui_additional_test.go) to exercise rollback arming/disarming, script generation, mounting/root checks, user prompts, timeout/commit branches and error conditions. --- .../orchestrator/restore_access_control_ui.go | 28 +- ...store_access_control_ui_additional_test.go | 847 ++++++++++++++++++ 2 files changed, 866 insertions(+), 9 deletions(-) create mode 100644 internal/orchestrator/restore_access_control_ui_additional_test.go diff --git a/internal/orchestrator/restore_access_control_ui.go b/internal/orchestrator/restore_access_control_ui.go index 7dee769..f4dbe86 100644 --- a/internal/orchestrator/restore_access_control_ui.go +++ b/internal/orchestrator/restore_access_control_ui.go @@ -17,6 +17,16 @@ const defaultAccessControlRollbackTimeout = 180 * time.Second var ErrAccessControlApplyNotCommitted = errors.New("access control changes not committed") +var ( + accessControlApplyGeteuid = os.Geteuid + accessControlIsMounted = isMounted + accessControlIsRealRestoreFS = isRealRestoreFS + + accessControlArmRollback = armAccessControlRollback + accessControlDisarmRollback = disarmAccessControlRollback + accessControlApplyFromStage = applyPVEAccessControlFromStage +) + type AccessControlApplyNotCommittedError struct { RollbackLog string RollbackMarker string @@ -155,7 +165,7 @@ func maybeApplyPVEAccessControlFromClusterBackupWithUI( stageRoot string, dryRun bool, ) (err error) { - if plan == nil || plan.SystemType != SystemTypePVE || !plan.HasCategoryID("pve_access_control") || !plan.ClusterBackup || plan.NeedsClusterRestore { + if plan == nil || plan.SystemType != SystemTypePVE || !plan.HasCategoryID("pve_access_control") || !plan.ClusterBackup { return nil } @@ -165,7 +175,7 @@ func maybeApplyPVEAccessControlFromClusterBackupWithUI( if ui == nil { return fmt.Errorf("restore UI not available") } - if !isRealRestoreFS(restoreFS) { + if !accessControlIsRealRestoreFS(restoreFS) { logger.Debug("Skipping PVE access control apply (cluster backup): non-system filesystem in use") return nil } @@ -173,7 +183,7 @@ func maybeApplyPVEAccessControlFromClusterBackupWithUI( logger.Info("Dry run enabled: skipping PVE access control apply (cluster backup)") return nil } - if os.Geteuid() != 0 { + if accessControlApplyGeteuid() != 0 { logger.Warning("Skipping PVE access control apply (cluster backup): requires root privileges") return nil } @@ -198,7 +208,7 @@ func maybeApplyPVEAccessControlFromClusterBackupWithUI( } etcPVE := "/etc/pve" - mounted, mountErr := isMounted(etcPVE) + mounted, mountErr := accessControlIsMounted(etcPVE) if mountErr != nil { logger.Warning("PVE access control apply: unable to check pmxcfs mount (%s): %v", etcPVE, mountErr) } @@ -279,14 +289,14 @@ func maybeApplyPVEAccessControlFromClusterBackupWithUI( if rollbackPath != "" { logger.Info("") logger.Info("Arming access control rollback timer (%ds)...", int(defaultAccessControlRollbackTimeout.Seconds())) - rollbackHandle, err = armAccessControlRollback(ctx, logger, rollbackPath, defaultAccessControlRollbackTimeout, "/tmp/proxsave") + rollbackHandle, err = accessControlArmRollback(ctx, logger, rollbackPath, defaultAccessControlRollbackTimeout, "/tmp/proxsave") if err != nil { return fmt.Errorf("arm access control rollback: %w", err) } logger.Info("Access control rollback log: %s", rollbackHandle.logPath) } - if err := applyPVEAccessControlFromStage(ctx, logger, stageRoot); err != nil { + if err := accessControlApplyFromStage(ctx, logger, stageRoot); err != nil { return err } @@ -295,7 +305,7 @@ func maybeApplyPVEAccessControlFromClusterBackupWithUI( return nil } - remaining := rollbackHandle.remaining(time.Now()) + remaining := rollbackHandle.remaining(nowRestore()) if remaining <= 0 { return buildAccessControlApplyNotCommittedError(rollbackHandle) } @@ -317,7 +327,7 @@ func maybeApplyPVEAccessControlFromClusterBackupWithUI( } if commit { - disarmAccessControlRollback(ctx, logger, rollbackHandle) + accessControlDisarmRollback(ctx, logger, rollbackHandle) logger.Info("Access control changes committed.") return nil } @@ -353,7 +363,7 @@ func armAccessControlRollback(ctx context.Context, logger *logging.Logger, backu markerPath: filepath.Join(baseDir, fmt.Sprintf("access_control_rollback_pending_%s", timestamp)), scriptPath: filepath.Join(baseDir, fmt.Sprintf("access_control_rollback_%s.sh", timestamp)), logPath: filepath.Join(baseDir, fmt.Sprintf("access_control_rollback_%s.log", timestamp)), - armedAt: time.Now(), + armedAt: nowRestore(), timeout: timeout, } diff --git a/internal/orchestrator/restore_access_control_ui_additional_test.go b/internal/orchestrator/restore_access_control_ui_additional_test.go new file mode 100644 index 0000000..079a35e --- /dev/null +++ b/internal/orchestrator/restore_access_control_ui_additional_test.go @@ -0,0 +1,847 @@ +package orchestrator + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/tis24dev/proxsave/internal/input" + "github.com/tis24dev/proxsave/internal/logging" +) + +func TestAccessControlApplyNotCommittedError_UnwrapAndMessage(t *testing.T) { + var e *AccessControlApplyNotCommittedError + if e.Error() != ErrAccessControlApplyNotCommitted.Error() { + t.Fatalf("nil receiver Error()=%q want %q", e.Error(), ErrAccessControlApplyNotCommitted.Error()) + } + if !errors.Is(error(e), ErrAccessControlApplyNotCommitted) { + t.Fatalf("expected errors.Is(..., ErrAccessControlApplyNotCommitted) to be true") + } + if errors.Unwrap(error(e)) != ErrAccessControlApplyNotCommitted { + t.Fatalf("expected Unwrap to return ErrAccessControlApplyNotCommitted") + } + + e2 := &AccessControlApplyNotCommittedError{} + if e2.Error() != ErrAccessControlApplyNotCommitted.Error() { + t.Fatalf("Error()=%q want %q", e2.Error(), ErrAccessControlApplyNotCommitted.Error()) + } +} + +func TestAccessControlRollbackHandle_Remaining(t *testing.T) { + var h *accessControlRollbackHandle + if got := h.remaining(time.Now()); got != 0 { + t.Fatalf("nil handle remaining=%s want 0", got) + } + + handle := &accessControlRollbackHandle{ + armedAt: time.Unix(100, 0), + timeout: 10 * time.Second, + } + if got := handle.remaining(time.Unix(105, 0)); got != 5*time.Second { + t.Fatalf("remaining=%s want %s", got, 5*time.Second) + } + if got := handle.remaining(time.Unix(999, 0)); got != 0 { + t.Fatalf("remaining=%s want 0", got) + } +} + +func TestBuildAccessControlApplyNotCommittedError_PopulatesFields(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if e := buildAccessControlApplyNotCommittedError(nil); e == nil { + t.Fatalf("expected error struct, got nil") + } else if e.RollbackArmed || e.RollbackMarker != "" || e.RollbackLog != "" || !e.RollbackDeadline.IsZero() { + t.Fatalf("unexpected fields for nil handle: %#v", e) + } + + armedAt := time.Unix(10, 0) + handle := &accessControlRollbackHandle{ + markerPath: " /tmp/ac.marker \n", + logPath: " /tmp/ac.log\t", + armedAt: armedAt, + timeout: 3 * time.Second, + } + if err := fakeFS.AddFile("/tmp/ac.marker", []byte("pending\n")); err != nil { + t.Fatalf("add marker: %v", err) + } + + e := buildAccessControlApplyNotCommittedError(handle) + if e.RollbackMarker != "/tmp/ac.marker" { + t.Fatalf("RollbackMarker=%q", e.RollbackMarker) + } + if e.RollbackLog != "/tmp/ac.log" { + t.Fatalf("RollbackLog=%q", e.RollbackLog) + } + if !e.RollbackArmed { + t.Fatalf("expected RollbackArmed=true") + } + if !e.RollbackDeadline.Equal(armedAt.Add(3 * time.Second)) { + t.Fatalf("RollbackDeadline=%s want %s", e.RollbackDeadline, armedAt.Add(3*time.Second)) + } +} + +func TestStageHasPVEAccessControlConfig_DetectsFilesAndErrors(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + ok, err := stageHasPVEAccessControlConfig(" ") + if err != nil || ok { + t.Fatalf("expected ok=false err=nil for empty stageRoot, got ok=%v err=%v", ok, err) + } + + ok, err = stageHasPVEAccessControlConfig("/stage-empty") + if err != nil || ok { + t.Fatalf("expected ok=false err=nil for missing files, got ok=%v err=%v", ok, err) + } + + if err := fakeFS.AddDir("/stage-dir/etc/pve/user.cfg"); err != nil { + t.Fatalf("add staged dir: %v", err) + } + ok, err = stageHasPVEAccessControlConfig("/stage-dir") + if err != nil || ok { + t.Fatalf("expected ok=false err=nil when candidates are dirs, got ok=%v err=%v", ok, err) + } + + if err := fakeFS.AddFile("/stage-hit/etc/pve/user.cfg", []byte("x")); err != nil { + t.Fatalf("add staged file: %v", err) + } + ok, err = stageHasPVEAccessControlConfig("/stage-hit") + if err != nil || !ok { + t.Fatalf("expected ok=true err=nil when stage has access control files, got ok=%v err=%v", ok, err) + } + + fakeFS.StatErrors[filepath.Clean("/stage-err/etc/pve/user.cfg")] = fmt.Errorf("boom") + ok, err = stageHasPVEAccessControlConfig("/stage-err") + if err == nil || ok { + t.Fatalf("expected error, got ok=%v err=%v", ok, err) + } +} + +func TestBuildAccessControlRollbackScript_QuotesPaths(t *testing.T) { + script := buildAccessControlRollbackScript("/tmp/marker path", "/tmp/backup's.tar.gz", "/tmp/log path") + if !strings.Contains(script, "MARKER='/tmp/marker path'") { + t.Fatalf("expected MARKER to be quoted, got script:\n%s", script) + } + if !strings.Contains(script, "LOG='/tmp/log path'") { + t.Fatalf("expected LOG to be quoted, got script:\n%s", script) + } + if !strings.Contains(script, "BACKUP='/tmp/backup'\\''s.tar.gz'") { + t.Fatalf("expected BACKUP to escape single quotes, got script:\n%s", script) + } + if !strings.HasSuffix(script, "\n") { + t.Fatalf("expected script to end with newline") + } +} + +func TestArmAccessControlRollback_SystemdAndBackgroundPaths(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + fakeTime := &FakeTime{Current: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC)} + restoreTime = fakeTime + + logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) + + t.Run("rejects invalid args", func(t *testing.T) { + if _, err := armAccessControlRollback(context.Background(), logger, "", 1*time.Second, "/tmp/proxsave"); err == nil { + t.Fatalf("expected error for empty backupPath") + } + if _, err := armAccessControlRollback(context.Background(), logger, "/backup.tgz", 0, "/tmp/proxsave"); err == nil { + t.Fatalf("expected error for invalid timeout") + } + }) + + t.Run("uses systemd-run when available", func(t *testing.T) { + binDir := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", binDir); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + writeExecutable(t, binDir, "systemd-run") + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + handle, err := armAccessControlRollback(context.Background(), logger, "/backup.tgz", 2*time.Second, "/tmp/proxsave") + if err != nil { + t.Fatalf("armAccessControlRollback error: %v", err) + } + if handle == nil || handle.unitName == "" { + t.Fatalf("expected systemd unit name, got %#v", handle) + } + if got := fakeCmd.CallsList(); len(got) != 1 || !strings.HasPrefix(got[0], "systemd-run --unit=proxsave-access-control-rollback-20200102_030405") { + t.Fatalf("unexpected calls: %#v", got) + } + if data, err := fakeFS.ReadFile(handle.markerPath); err != nil || string(data) != "pending\n" { + t.Fatalf("marker read err=%v data=%q", err, string(data)) + } + }) + + t.Run("falls back to background timer on systemd-run failure", func(t *testing.T) { + binDir := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", binDir); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + writeExecutable(t, binDir, "systemd-run") + + fakeCmd := &FakeCommandRunner{ + Errors: map[string]error{}, + } + restoreCmd = fakeCmd + + timestamp := fakeTime.Current.Format("20060102_150405") + scriptPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("access_control_rollback_%s.sh", timestamp)) + systemdKey := "systemd-run --unit=proxsave-access-control-rollback-" + timestamp + " --on-active=2s /bin/sh " + scriptPath + fakeCmd.Errors[systemdKey] = fmt.Errorf("fail") + + handle, err := armAccessControlRollback(context.Background(), logger, "/backup.tgz", 2*time.Second, "/tmp/proxsave") + if err != nil { + t.Fatalf("armAccessControlRollback error: %v", err) + } + if handle == nil { + t.Fatalf("expected handle") + } + if handle.unitName != "" { + t.Fatalf("expected unitName cleared after systemd-run failure, got %q", handle.unitName) + } + + cmd := fmt.Sprintf("nohup sh -c 'sleep %d; /bin/sh %s' >/dev/null 2>&1 &", 2, scriptPath) + wantBackground := "sh -c " + cmd + calls := fakeCmd.CallsList() + if len(calls) != 2 || calls[1] != wantBackground { + t.Fatalf("unexpected calls: %#v", calls) + } + }) + + t.Run("background timer failure returns error", func(t *testing.T) { + emptyBin := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", emptyBin); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + + fakeCmd := &FakeCommandRunner{ + Errors: map[string]error{}, + } + restoreCmd = fakeCmd + + timestamp := fakeTime.Current.Format("20060102_150405") + scriptPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("access_control_rollback_%s.sh", timestamp)) + cmd := fmt.Sprintf("nohup sh -c 'sleep %d; /bin/sh %s' >/dev/null 2>&1 &", 1, scriptPath) + backgroundKey := "sh -c " + cmd + fakeCmd.Errors[backgroundKey] = fmt.Errorf("boom") + + if _, err := armAccessControlRollback(context.Background(), logger, "/backup.tgz", 1*time.Second, "/tmp/proxsave"); err == nil { + t.Fatalf("expected error") + } + }) +} + +func TestArmAccessControlRollback_DefaultWorkDirAndMinTimeout(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + fakeTime := &FakeTime{Current: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC)} + restoreTime = fakeTime + + emptyBin := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", emptyBin); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + handle, err := armAccessControlRollback(context.Background(), newTestLogger(), "/backup.tgz", 500*time.Millisecond, " ") + if err != nil { + t.Fatalf("armAccessControlRollback error: %v", err) + } + if handle == nil || handle.workDir != "/tmp/proxsave" { + t.Fatalf("unexpected handle: %#v", handle) + } + if data, err := fakeFS.ReadFile(handle.markerPath); err != nil || string(data) != "pending\n" { + t.Fatalf("marker err=%v data=%q", err, string(data)) + } + calls := fakeCmd.CallsList() + if len(calls) != 1 { + t.Fatalf("unexpected calls: %#v", calls) + } + if !strings.Contains(calls[0], "sleep 1; /bin/sh") { + t.Fatalf("expected timeoutSeconds to clamp to 1, got call=%q", calls[0]) + } +} + +func TestArmAccessControlRollback_ReturnsErrorOnMkdirAllFailure(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + fakeFS.MkdirAllErr = fmt.Errorf("disk full") + restoreFS = fakeFS + + if _, err := armAccessControlRollback(context.Background(), newTestLogger(), "/backup.tgz", 1*time.Second, "/tmp/proxsave"); err == nil { + t.Fatalf("expected error") + } +} + +func TestArmAccessControlRollback_ReturnsErrorOnMarkerWriteFailure(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + fakeFS.WriteErr = fmt.Errorf("disk full") + restoreFS = fakeFS + + if _, err := armAccessControlRollback(context.Background(), newTestLogger(), "/backup.tgz", 1*time.Second, "/tmp/proxsave"); err == nil || !strings.Contains(err.Error(), "write rollback marker") { + t.Fatalf("expected write rollback marker error, got %v", err) + } +} + +func TestArmAccessControlRollback_ReturnsErrorOnScriptWriteFailure(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + base := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(base.Root) }) + + fakeTime := &FakeTime{Current: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC)} + restoreTime = fakeTime + timestamp := fakeTime.Current.Format("20060102_150405") + scriptPath := filepath.Join("/tmp/proxsave", fmt.Sprintf("access_control_rollback_%s.sh", timestamp)) + + restoreFS = writeFileFailFS{FS: base, failPath: scriptPath, err: fmt.Errorf("disk full")} + + if _, err := armAccessControlRollback(context.Background(), newTestLogger(), "/backup.tgz", 1*time.Second, "/tmp/proxsave"); err == nil || !strings.Contains(err.Error(), "write rollback script") { + t.Fatalf("expected write rollback script error, got %v", err) + } +} + +func TestDisarmAccessControlRollback_RemovesMarkerScriptAndStopsTimer(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + binDir := t.TempDir() + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", binDir); err != nil { + t.Fatalf("set PATH: %v", err) + } + t.Cleanup(func() { _ = os.Setenv("PATH", oldPath) }) + writeExecutable(t, binDir, "systemctl") + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + handle := &accessControlRollbackHandle{ + markerPath: "/tmp/proxsave/ac.marker", + scriptPath: "/tmp/proxsave/ac.sh", + unitName: "proxsave-access-control-rollback-test", + } + if err := fakeFS.AddFile(handle.markerPath, []byte("pending\n")); err != nil { + t.Fatalf("add marker: %v", err) + } + if err := fakeFS.AddFile(handle.scriptPath, []byte("#!/bin/sh\n")); err != nil { + t.Fatalf("add script: %v", err) + } + + disarmAccessControlRollback(context.Background(), newTestLogger(), handle) + + if _, err := fakeFS.Stat(handle.markerPath); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected marker removed; stat err=%v", err) + } + if _, err := fakeFS.Stat(handle.scriptPath); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected script removed; stat err=%v", err) + } + + timerUnit := handle.unitName + ".timer" + want1 := "systemctl stop " + timerUnit + want2 := "systemctl reset-failed " + handle.unitName + ".service " + timerUnit + calls := fakeCmd.CallsList() + if len(calls) != 2 || calls[0] != want1 || calls[1] != want2 { + t.Fatalf("unexpected calls: %#v", calls) + } +} + +func TestMaybeApplyPVEAccessControlFromClusterBackupWithUI_CoversUserFlows(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + origGeteuid := accessControlApplyGeteuid + origMounted := accessControlIsMounted + origRealFS := accessControlIsRealRestoreFS + origArm := accessControlArmRollback + origDisarm := accessControlDisarmRollback + origApply := accessControlApplyFromStage + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + accessControlApplyGeteuid = origGeteuid + accessControlIsMounted = origMounted + accessControlIsRealRestoreFS = origRealFS + accessControlArmRollback = origArm + accessControlDisarmRollback = origDisarm + accessControlApplyFromStage = origApply + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + restoreTime = &FakeTime{Current: time.Unix(100, 0)} + restoreCmd = &FakeCommandRunner{} + + accessControlIsRealRestoreFS = func(fs FS) bool { return true } + accessControlApplyGeteuid = func() int { return 0 } + accessControlIsMounted = func(path string) (bool, error) { return true, nil } + + plan := &RestorePlan{ + SystemType: SystemTypePVE, + ClusterBackup: true, + NormalCategories: []Category{{ID: "pve_access_control"}}, + } + stageWithAC := "/stage-ac" + if err := fakeFS.AddFile(stageWithAC+"/etc/pve/user.cfg", []byte("x")); err != nil { + t.Fatalf("add staged user.cfg: %v", err) + } + stageWithoutAC := "/stage-empty" + logger := newTestLogger() + + t.Run("errors when ui missing", func(t *testing.T) { + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), nil, logger, plan, nil, nil, stageWithAC, false) + if err == nil { + t.Fatalf("expected error") + } + }) + + t.Run("skips when /etc/pve not mounted", func(t *testing.T) { + accessControlIsMounted = func(path string) (bool, error) { return false, nil } + t.Cleanup(func() { accessControlIsMounted = func(path string) (bool, error) { return true, nil } }) + + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, nil, stageWithAC, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("skips when stage has no access control files", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, nil, stageWithoutAC, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("user skips apply", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{{ok: false}}, + } + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, nil, stageWithAC, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("missing rollback backup declines full rollback", func(t *testing.T) { + accessControlArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*accessControlRollbackHandle, error) { + t.Fatalf("unexpected rollback arm") + return nil, nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: false}, // Skip full rollback + }, + } + safety := &SafetyBackupResult{BackupPath: "/backups/full.tgz"} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, safety, nil, stageWithAC, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("proceed without rollback applies and returns without commit prompt", func(t *testing.T) { + accessControlArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*accessControlRollbackHandle, error) { + t.Fatalf("unexpected rollback arm") + return nil, nil + } + called := false + accessControlApplyFromStage = func(ctx context.Context, logger *logging.Logger, stageRoot string) error { + called = true + return nil + } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: true}, // Proceed without rollback + }, + } + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, nil, stageWithAC, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + if !called { + t.Fatalf("expected access control apply to be invoked") + } + }) + + t.Run("commit keeps changes and disarms rollback", func(t *testing.T) { + markerPath := "/tmp/proxsave/ac.marker" + scriptPath := "/tmp/proxsave/ac.sh" + accessControlArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*accessControlRollbackHandle, error) { + handle := &accessControlRollbackHandle{ + markerPath: markerPath, + scriptPath: scriptPath, + logPath: "/tmp/proxsave/ac.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + _ = restoreFS.WriteFile(scriptPath, []byte("#!/bin/sh\n"), 0o640) + return handle, nil + } + accessControlApplyFromStage = func(ctx context.Context, logger *logging.Logger, stageRoot string) error { return nil } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: true}, // Commit + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/access-control.tgz"} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, rollback, stageWithAC, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + if _, err := fakeFS.Stat(markerPath); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected rollback marker removed; stat err=%v", err) + } + if _, err := fakeFS.Stat(scriptPath); err == nil || !os.IsNotExist(err) { + t.Fatalf("expected rollback script removed; stat err=%v", err) + } + }) + + t.Run("rollback requested returns typed error with marker armed", func(t *testing.T) { + markerPath := "/tmp/proxsave/ac.marker" + armedAt := nowRestore() + accessControlArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*accessControlRollbackHandle, error) { + handle := &accessControlRollbackHandle{ + markerPath: markerPath, + scriptPath: "/tmp/proxsave/ac.sh", + logPath: "/tmp/proxsave/ac.log", + armedAt: armedAt, + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + accessControlApplyFromStage = func(ctx context.Context, logger *logging.Logger, stageRoot string) error { return nil } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: false}, // Rollback + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/access-control.tgz"} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, rollback, stageWithAC, false) + if err == nil { + t.Fatalf("expected error") + } + if !errors.Is(err, ErrAccessControlApplyNotCommitted) { + t.Fatalf("expected ErrAccessControlApplyNotCommitted, got %v", err) + } + var typed *AccessControlApplyNotCommittedError + if !errors.As(err, &typed) || typed == nil { + t.Fatalf("expected AccessControlApplyNotCommittedError, got %T", err) + } + if !typed.RollbackArmed || typed.RollbackMarker != markerPath || typed.RollbackLog != "/tmp/proxsave/ac.log" { + t.Fatalf("unexpected error fields: %#v", typed) + } + if typed.RollbackDeadline.IsZero() || !typed.RollbackDeadline.Equal(armedAt.Add(defaultAccessControlRollbackTimeout)) { + t.Fatalf("unexpected RollbackDeadline=%s", typed.RollbackDeadline) + } + if _, err := fakeFS.Stat(markerPath); err != nil { + t.Fatalf("expected marker to still exist, stat err=%v", err) + } + }) + + t.Run("commit prompt abort returns abort error", func(t *testing.T) { + markerPath := "/tmp/proxsave/ac.marker" + accessControlArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*accessControlRollbackHandle, error) { + handle := &accessControlRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/ac.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + accessControlApplyFromStage = func(ctx context.Context, logger *logging.Logger, stageRoot string) error { return nil } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {err: input.ErrInputAborted}, // Abort at commit prompt + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/access-control.tgz"} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, rollback, stageWithAC, false) + if err == nil || !errors.Is(err, input.ErrInputAborted) { + t.Fatalf("expected abort error, got %v", err) + } + }) + + t.Run("commit prompt failure returns typed error", func(t *testing.T) { + markerPath := "/tmp/proxsave/ac.marker" + accessControlArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*accessControlRollbackHandle, error) { + handle := &accessControlRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/ac.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + accessControlApplyFromStage = func(ctx context.Context, logger *logging.Logger, stageRoot string) error { return nil } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {err: fmt.Errorf("boom")}, // Commit prompt fails + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/access-control.tgz"} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, rollback, stageWithAC, false) + if err == nil { + t.Fatalf("expected error") + } + if !errors.Is(err, ErrAccessControlApplyNotCommitted) { + t.Fatalf("expected ErrAccessControlApplyNotCommitted, got %v", err) + } + if _, err := fakeFS.Stat(markerPath); err != nil { + t.Fatalf("expected marker to still exist, stat err=%v", err) + } + }) + + t.Run("remaining timeout returns typed error without commit prompt", func(t *testing.T) { + markerPath := "/tmp/proxsave/ac.marker" + accessControlArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*accessControlRollbackHandle, error) { + handle := &accessControlRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/ac.log", + armedAt: nowRestore().Add(-timeout - time.Second), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + accessControlApplyFromStage = func(ctx context.Context, logger *logging.Logger, stageRoot string) error { return nil } + + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{{ok: true}}, // Apply now only + } + rollback := &SafetyBackupResult{BackupPath: "/backups/access-control.tgz"} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, rollback, stageWithAC, false) + if err == nil || !errors.Is(err, ErrAccessControlApplyNotCommitted) { + t.Fatalf("expected ErrAccessControlApplyNotCommitted, got %v", err) + } + }) + + t.Run("dry run skips apply", func(t *testing.T) { + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, nil, stageWithAC, true) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("non-root skips apply", func(t *testing.T) { + accessControlApplyGeteuid = func() int { return 1000 } + t.Cleanup(func() { accessControlApplyGeteuid = func() int { return 0 } }) + + ui := &scriptedRestoreWorkflowUI{fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: nil} + err := maybeApplyPVEAccessControlFromClusterBackupWithUI(context.Background(), ui, logger, plan, nil, nil, stageWithAC, false) + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) +} + +func TestMaybeApplyAccessControlWithUI_BranchCoverage(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + origGeteuid := accessControlApplyGeteuid + origMounted := accessControlIsMounted + origRealFS := accessControlIsRealRestoreFS + origArm := accessControlArmRollback + origApply := accessControlApplyFromStage + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + accessControlApplyGeteuid = origGeteuid + accessControlIsMounted = origMounted + accessControlIsRealRestoreFS = origRealFS + accessControlArmRollback = origArm + accessControlApplyFromStage = origApply + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Unix(100, 0)} + restoreCmd = &FakeCommandRunner{} + + accessControlIsRealRestoreFS = func(fs FS) bool { return true } + accessControlApplyGeteuid = func() int { return 0 } + accessControlIsMounted = func(path string) (bool, error) { return true, nil } + + stageRoot := "/stage" + if err := fakeFS.AddFile(stageRoot+"/etc/pve/user.cfg", []byte("x")); err != nil { + t.Fatalf("add staged user.cfg: %v", err) + } + logger := newTestLogger() + + t.Run("nil plan returns nil", func(t *testing.T) { + if err := maybeApplyAccessControlWithUI(context.Background(), &fakeRestoreWorkflowUI{}, logger, nil, nil, nil, stageRoot, false); err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("empty stageRoot skips", func(t *testing.T) { + plan := &RestorePlan{NormalCategories: []Category{{ID: "pve_access_control"}}} + if err := maybeApplyAccessControlWithUI(context.Background(), &fakeRestoreWorkflowUI{}, logger, plan, nil, nil, " ", false); err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("no relevant categories returns nil", func(t *testing.T) { + plan := &RestorePlan{NormalCategories: []Category{{ID: "pve_firewall"}}} + if err := maybeApplyAccessControlWithUI(context.Background(), &fakeRestoreWorkflowUI{}, logger, plan, nil, nil, stageRoot, false); err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) + + t.Run("errors when ui missing", func(t *testing.T) { + plan := &RestorePlan{ + SystemType: SystemTypePVE, + ClusterBackup: true, + NormalCategories: []Category{{ID: "pve_access_control"}}, + } + if err := maybeApplyAccessControlWithUI(context.Background(), nil, logger, plan, nil, nil, stageRoot, false); err == nil { + t.Fatalf("expected error") + } + }) + + t.Run("cluster backup path is used", func(t *testing.T) { + markerPath := "/tmp/proxsave/ac.marker" + accessControlArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*accessControlRollbackHandle, error) { + handle := &accessControlRollbackHandle{ + markerPath: markerPath, + logPath: "/tmp/proxsave/ac.log", + armedAt: nowRestore(), + timeout: timeout, + } + _ = restoreFS.WriteFile(markerPath, []byte("pending\n"), 0o640) + return handle, nil + } + accessControlApplyFromStage = func(ctx context.Context, logger *logging.Logger, stageRoot string) error { return nil } + + plan := &RestorePlan{ + SystemType: SystemTypePVE, + ClusterBackup: true, + NormalCategories: []Category{{ID: "pve_access_control"}}, + } + ui := &scriptedRestoreWorkflowUI{ + fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, + script: []scriptedConfirmAction{ + {ok: true}, // Apply now + {ok: false}, // Rollback + }, + } + rollback := &SafetyBackupResult{BackupPath: "/backups/access-control.tgz"} + err := maybeApplyAccessControlWithUI(context.Background(), ui, logger, plan, nil, rollback, stageRoot, false) + if err == nil || !errors.Is(err, ErrAccessControlApplyNotCommitted) { + t.Fatalf("expected ErrAccessControlApplyNotCommitted, got %v", err) + } + }) + + t.Run("default path uses stage apply", func(t *testing.T) { + plan := &RestorePlan{ + SystemType: SystemTypePBS, + NormalCategories: []Category{{ID: "pbs_access_control"}}, + } + if err := maybeApplyAccessControlWithUI(context.Background(), &fakeRestoreWorkflowUI{}, logger, plan, nil, nil, stageRoot, false); err != nil { + t.Fatalf("expected nil, got %v", err) + } + }) +} + From e4ae7ed9f34d9c378143ba3fa4f296548a1d594a Mon Sep 17 00:00:00 2001 From: Damiano <71268257+tis24dev@users.noreply.github.com> Date: Tue, 24 Feb 2026 14:07:31 +0100 Subject: [PATCH 07/12] Make sys/class/net path configurable Introduce sysClassNetPath variable and use it in collectCurrentNetworkInventory instead of a hard-coded "/sys/class/net" path, allowing tests to override the sysfs location. Also add a comprehensive test file (internal/orchestrator/nic_mapping_additional_test.go) covering archive reading, inventory loading, udev/permanent MAC parsing, NIC mapping computation, planning and applying NIC name repairs, and many edge/error cases using fake FS and command runners. --- internal/orchestrator/nic_mapping.go | 3 +- .../nic_mapping_additional_test.go | 1262 +++++++++++++++++ 2 files changed, 1264 insertions(+), 1 deletion(-) create mode 100644 internal/orchestrator/nic_mapping_additional_test.go diff --git a/internal/orchestrator/nic_mapping.go b/internal/orchestrator/nic_mapping.go index f4fa1d1..f2a0372c 100644 --- a/internal/orchestrator/nic_mapping.go +++ b/internal/orchestrator/nic_mapping.go @@ -20,6 +20,7 @@ import ( const maxArchiveInventoryBytes = 10 << 20 // 10 MiB var nicRepairSequence uint64 +var sysClassNetPath = "/sys/class/net" type archivedNetworkInventory struct { GeneratedAt string `json:"generated_at,omitempty"` @@ -371,7 +372,7 @@ func readArchiveEntry(ctx context.Context, archivePath string, candidates []stri } func collectCurrentNetworkInventory(ctx context.Context) (*archivedNetworkInventory, error) { - sysNet := "/sys/class/net" + sysNet := sysClassNetPath entries, err := os.ReadDir(sysNet) if err != nil { return nil, err diff --git a/internal/orchestrator/nic_mapping_additional_test.go b/internal/orchestrator/nic_mapping_additional_test.go new file mode 100644 index 0000000..c4cb681 --- /dev/null +++ b/internal/orchestrator/nic_mapping_additional_test.go @@ -0,0 +1,1262 @@ +package orchestrator + +import ( + "archive/tar" + "bytes" + "context" + "encoding/json" + "errors" + "os" + "path/filepath" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/types" +) + +type tarEntry struct { + Name string + Typeflag byte + Mode int64 + Data []byte +} + +type mkdirAllFailFS struct { + FS + failPath string + err error +} + +func (f mkdirAllFailFS) MkdirAll(path string, perm os.FileMode) error { + if filepath.Clean(path) == filepath.Clean(f.failPath) { + return f.err + } + return f.FS.MkdirAll(path, perm) +} + +func writeTarToFakeFS(t *testing.T, fs *FakeFS, archivePath string, entries []tarEntry) { + t.Helper() + + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + for _, entry := range entries { + hdr := &tar.Header{ + Name: entry.Name, + Typeflag: entry.Typeflag, + Mode: entry.Mode, + Size: int64(len(entry.Data)), + } + if entry.Typeflag == tar.TypeDir { + hdr.Size = 0 + } + if err := tw.WriteHeader(hdr); err != nil { + t.Fatalf("WriteHeader %s: %v", entry.Name, err) + } + if hdr.Size > 0 { + if _, err := tw.Write(entry.Data); err != nil { + t.Fatalf("Write %s: %v", entry.Name, err) + } + } + } + if err := tw.Close(); err != nil { + t.Fatalf("tar close: %v", err) + } + if err := fs.WriteFile(archivePath, buf.Bytes(), 0o644); err != nil { + t.Fatalf("write archive: %v", err) + } +} + +func TestNICMappingResult_RenameMapAndDetails(t *testing.T) { + r := nicMappingResult{ + Entries: []nicMappingEntry{ + {OldName: "eno1", NewName: "enp3s0"}, + {OldName: "", NewName: "enp4s0"}, + {OldName: "ens2", NewName: ""}, + {OldName: "eno1", NewName: "enp3s1"}, + }, + } + m := r.RenameMap() + if len(m) != 1 || m["eno1"] != "enp3s1" { + t.Fatalf("RenameMap=%v; want {eno1:enp3s1}", m) + } + + if got := (nicMappingResult{}).Details(); got != "NIC mapping: none" { + t.Fatalf("empty Details=%q", got) + } + + details := nicMappingResult{ + Entries: []nicMappingEntry{ + {OldName: "b", NewName: "B", Method: nicMatchMAC, Identifier: "m2"}, + {OldName: "a", NewName: "A", Method: nicMatchPermanentMAC, Identifier: "m1"}, + }, + }.Details() + want := strings.Join([]string{ + "NIC mapping (backup -> current):", + "- a -> A (permanent_mac=m1)", + "- b -> B (mac=m2)", + }, "\n") + if details != want { + t.Fatalf("Details=%q; want %q", details, want) + } +} + +func TestNICNameConflict_Details(t *testing.T) { + c := nicNameConflict{ + Mapping: nicMappingEntry{ + OldName: "eno1", + NewName: "eth0", + Method: nicMatchMAC, + Identifier: "aa:bb", + }, + Existing: archivedNetworkInterface{ + Name: "eno1", + PermanentMAC: "AA:BB:CC:DD:EE:FF", + MAC: "mac:11:22:33:44:55:66 ", + PCIPath: "/pci/0000:00:1f.6", + }, + } + got := c.Details() + if !strings.Contains(got, "permMAC=aa:bb:cc:dd:ee:ff") || !strings.Contains(got, "mac=11:22:33:44:55:66") || !strings.Contains(got, "pci=/pci/0000:00:1f.6") { + t.Fatalf("Details=%q; want identifiers included", got) + } + if !strings.Contains(got, "but current eno1 exists") { + t.Fatalf("Details=%q; want conflict message", got) + } + + none := nicNameConflict{ + Mapping: nicMappingEntry{OldName: "eno1", NewName: "eth0", Method: nicMatchPCIPath, Identifier: "pci0"}, + Existing: archivedNetworkInterface{ + Name: "eno1", + }, + }.Details() + if !strings.Contains(none, "no identifiers") { + t.Fatalf("Details=%q; want no identifiers", none) + } +} + +func TestNICRepairPlan_HasWork(t *testing.T) { + if (nicRepairPlan{}).HasWork() { + t.Fatalf("expected HasWork=false") + } + if !(nicRepairPlan{SafeMappings: []nicMappingEntry{{OldName: "a", NewName: "b"}}}.HasWork()) { + t.Fatalf("expected HasWork=true with safe mappings") + } + if !(nicRepairPlan{Conflicts: []nicNameConflict{{Mapping: nicMappingEntry{OldName: "a", NewName: "b"}}}}.HasWork()) { + t.Fatalf("expected HasWork=true with conflicts") + } +} + +func TestNICRepairResult_SummaryAndDetails(t *testing.T) { + r := nicRepairResult{SkippedReason: "test"} + if got := r.Summary(); got != "NIC name repair skipped: test" { + t.Fatalf("Summary=%q", got) + } + if r.Applied() { + t.Fatalf("Applied=true; want false") + } + + r = nicRepairResult{} + if got := r.Summary(); got != "NIC name repair: no changes needed" { + t.Fatalf("Summary=%q", got) + } + + r = nicRepairResult{ + ChangedFiles: []string{"/etc/network/interfaces"}, + BackupDir: "/tmp/proxsave/nic_repair_test", + AppliedNICMap: []nicMappingEntry{ + {OldName: "eno1", NewName: "eth0", Method: nicMatchMAC, Identifier: "aa"}, + }, + } + if got := r.Summary(); got != "NIC name repair applied: 1 file(s) updated" { + t.Fatalf("Summary=%q", got) + } + if !r.Applied() { + t.Fatalf("Applied=false; want true") + } + details := r.Details() + if !strings.Contains(details, "Backup of pre-repair files: /tmp/proxsave/nic_repair_test") || !strings.Contains(details, "Updated files:\n- /etc/network/interfaces") { + t.Fatalf("Details=%q; want backup and updated files", details) + } + if !strings.Contains(details, "NIC mapping (backup -> current):") || !strings.Contains(details, "- eno1 -> eth0 (mac=aa)") { + t.Fatalf("Details=%q; want mapping details", details) + } +} + +func TestReadArchiveEntry_ErrorsAndNotFound(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + archivePath := "/backup.tar" + writeTarToFakeFS(t, fakeFS, archivePath, []tarEntry{ + {Name: "./other.txt", Typeflag: tar.TypeReg, Mode: 0o644, Data: []byte("ok")}, + }) + + _, _, err := readArchiveEntry(context.Background(), archivePath, []string{"./missing.txt"}, 16) + if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("err=%v; want os.ErrNotExist", err) + } + + _, _, err = readArchiveEntry(context.Background(), "/does-not-exist.tar", []string{"./x"}, 16) + if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("err=%v; want open os.ErrNotExist", err) + } + + cctx, cancel := context.WithCancel(context.Background()) + cancel() + _, _, err = readArchiveEntry(cctx, archivePath, []string{"./other.txt"}, 16) + if !errors.Is(err, context.Canceled) { + t.Fatalf("err=%v; want context.Canceled", err) + } + + if err := fakeFS.WriteFile("/backup.zip", []byte("not a tar"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + _, _, err = readArchiveEntry(context.Background(), "/backup.zip", []string{"./other.txt"}, 16) + if err == nil || !strings.Contains(err.Error(), "unsupported archive format") { + t.Fatalf("err=%v; want unsupported archive format", err) + } + + writeTarToFakeFS(t, fakeFS, "/nonregular.tar", []tarEntry{ + {Name: "./entry", Typeflag: tar.TypeDir, Mode: 0o755}, + }) + _, _, err = readArchiveEntry(context.Background(), "/nonregular.tar", []string{"./entry"}, 16) + if err == nil || !strings.Contains(err.Error(), "not a regular file") { + t.Fatalf("err=%v; want not a regular file", err) + } + + writeTarToFakeFS(t, fakeFS, "/toolarge.tar", []tarEntry{ + {Name: "./big", Typeflag: tar.TypeReg, Mode: 0o644, Data: []byte("0123456789")}, + }) + _, _, err = readArchiveEntry(context.Background(), "/toolarge.tar", []string{"./big"}, 4) + if err == nil || !strings.Contains(err.Error(), "too large") { + t.Fatalf("err=%v; want too large", err) + } +} + +func TestLoadBackupNetworkInventoryFromArchive_SuccessAndBadJSON(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + type invJSON struct { + Interfaces []archivedNetworkInterface `json:"interfaces"` + } + payload, err := json.Marshal(invJSON{ + Interfaces: []archivedNetworkInterface{{Name: "eth0", MAC: "aa:bb"}}, + }) + if err != nil { + t.Fatalf("marshal: %v", err) + } + writeTarToFakeFS(t, fakeFS, "/inv.tar", []tarEntry{ + {Name: "./commands/network_inventory.json", Typeflag: tar.TypeReg, Mode: 0o644, Data: payload}, + }) + + inv, used, err := loadBackupNetworkInventoryFromArchive(context.Background(), "/inv.tar") + if err != nil { + t.Fatalf("load: %v", err) + } + if used != "./commands/network_inventory.json" { + t.Fatalf("used=%q", used) + } + if inv == nil || len(inv.Interfaces) != 1 || inv.Interfaces[0].Name != "eth0" { + t.Fatalf("inv=%+v", inv) + } + + writeTarToFakeFS(t, fakeFS, "/bad.tar", []tarEntry{ + {Name: "./commands/network_inventory.json", Typeflag: tar.TypeReg, Mode: 0o644, Data: []byte("{")}, + }) + _, _, err = loadBackupNetworkInventoryFromArchive(context.Background(), "/bad.tar") + if err == nil || !strings.Contains(err.Error(), "parse network inventory json") { + t.Fatalf("err=%v; want parse network inventory json", err) + } +} + +func TestParseAndReadPermanentMAC(t *testing.T) { + output := "some header\nPermanent Address: AA:BB:CC:DD:EE:FF \n" + if got := parsePermanentMAC(output); got != "aa:bb:cc:dd:ee:ff" { + t.Fatalf("parsePermanentMAC=%q", got) + } + if got := parsePermanentMAC("nope"); got != "" { + t.Fatalf("parsePermanentMAC=%q; want empty", got) + } + + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + cmd := &FakeCommandRunner{ + Outputs: map[string][]byte{ + "ethtool -P eth0": []byte(output), + }, + } + restoreCmd = cmd + + got, err := readPermanentMAC(context.Background(), "eth0") + if err != nil { + t.Fatalf("readPermanentMAC: %v", err) + } + if got != "aa:bb:cc:dd:ee:ff" { + t.Fatalf("readPermanentMAC=%q", got) + } + + cmd.Errors = map[string]error{"ethtool -P eth0": errors.New("boom")} + _, err = readPermanentMAC(context.Background(), "eth0") + if err == nil { + t.Fatalf("expected error") + } +} + +func TestReadUdevProperties(t *testing.T) { + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + + cmd := &FakeCommandRunner{ + Outputs: map[string][]byte{ + "udevadm info -q property -p /sys/class/net/eth0": []byte(strings.Join([]string{ + "ID_SERIAL=abc", + "ID_PATH= pci-0000:00:1f.6 ", + "BADLINE", + "FOO=", + "=bar", + "", + }, "\n")), + }, + } + restoreCmd = cmd + + props, err := readUdevProperties(context.Background(), "/sys/class/net/eth0") + if err != nil { + t.Fatalf("readUdevProperties: %v", err) + } + if props["ID_SERIAL"] != "abc" || props["ID_PATH"] != "pci-0000:00:1f.6" { + t.Fatalf("props=%v", props) + } + if _, ok := props["FOO"]; ok { + t.Fatalf("expected empty-value key to be skipped: %v", props) + } + + cmd.Errors = map[string]error{"udevadm info -q property -p /sys/class/net/eth0": errors.New("boom")} + _, err = readUdevProperties(context.Background(), "/sys/class/net/eth0") + if err == nil { + t.Fatalf("expected error") + } +} + +func TestReadTrimmedLine(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "line") + if err := os.WriteFile(path, []byte(" HELLO WORLD \n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + if got := readTrimmedLine(path, 0); got != "HELLO WORLD" { + t.Fatalf("readTrimmedLine=%q", got) + } + if got := readTrimmedLine(path, 5); got != "HELLO" { + t.Fatalf("readTrimmedLine=%q", got) + } + if got := readTrimmedLine(filepath.Join(dir, "missing"), 5); got != "" { + t.Fatalf("readTrimmedLine=%q; want empty", got) + } +} + +func TestCollectCurrentNetworkInventory_WithFakeSysfs(t *testing.T) { + origSys := sysClassNetPath + t.Cleanup(func() { sysClassNetPath = origSys }) + t.Setenv("PATH", "") + + root := t.TempDir() + netDir := filepath.Join(root, "sys/class/net") + if err := os.MkdirAll(netDir, 0o755); err != nil { + t.Fatalf("mkdir netDir: %v", err) + } + sysClassNetPath = netDir + if err := os.MkdirAll(filepath.Join(netDir, " "), 0o755); err != nil { + t.Fatalf("mkdir blank-name entry: %v", err) + } + + writePhysical := func(name, mac, pciDevice, driver string) { + t.Helper() + ifaceDir := filepath.Join(netDir, name) + if err := os.MkdirAll(ifaceDir, 0o755); err != nil { + t.Fatalf("mkdir ifaceDir: %v", err) + } + if err := os.WriteFile(filepath.Join(ifaceDir, "address"), []byte(mac+"\n"), 0o644); err != nil { + t.Fatalf("write address: %v", err) + } + + devDir := filepath.Join(root, "devices", pciDevice) + driverDir := filepath.Join(root, "drivers", driver) + if err := os.MkdirAll(devDir, 0o755); err != nil { + t.Fatalf("mkdir devDir: %v", err) + } + if err := os.MkdirAll(driverDir, 0o755); err != nil { + t.Fatalf("mkdir driverDir: %v", err) + } + if err := os.Symlink(devDir, filepath.Join(ifaceDir, "device")); err != nil { + t.Fatalf("symlink device: %v", err) + } + if err := os.Symlink(driverDir, filepath.Join(devDir, "driver")); err != nil { + t.Fatalf("symlink driver: %v", err) + } + } + writePhysical("eth0", "MAC:AA:BB:CC:DD:EE:01", "pci0000:00/0000:00:1f.6", "e1000") + writePhysical("eno1", "aa:bb:cc:dd:ee:02", "pci0000:00/0000:00:1c.0", "igb") + + ifaceLo := filepath.Join(netDir, "lo") + if err := os.MkdirAll(ifaceLo, 0o755); err != nil { + t.Fatalf("mkdir lo: %v", err) + } + if err := os.WriteFile(filepath.Join(ifaceLo, "address"), []byte("00:00:00:00:00:00\n"), 0o644); err != nil { + t.Fatalf("write lo address: %v", err) + } + + virtualTarget := filepath.Join(root, "devices/virtual/net/vmbr0") + if err := os.MkdirAll(virtualTarget, 0o755); err != nil { + t.Fatalf("mkdir virtualTarget: %v", err) + } + if err := os.WriteFile(filepath.Join(virtualTarget, "address"), []byte("aa:aa:aa:aa:aa:aa\n"), 0o644); err != nil { + t.Fatalf("write vmbr0 address: %v", err) + } + if err := os.Symlink(virtualTarget, filepath.Join(netDir, "vmbr0")); err != nil { + t.Fatalf("symlink vmbr0: %v", err) + } + + inv, err := collectCurrentNetworkInventory(context.Background()) + if err != nil { + t.Fatalf("collect: %v", err) + } + if inv == nil || len(inv.Interfaces) != 4 { + t.Fatalf("inv=%+v", inv) + } + if inv.Interfaces[0].Name != "eno1" || inv.Interfaces[1].Name != "eth0" || inv.Interfaces[2].Name != "lo" || inv.Interfaces[3].Name != "vmbr0" { + t.Fatalf("sorted names=%v", []string{inv.Interfaces[0].Name, inv.Interfaces[1].Name, inv.Interfaces[2].Name, inv.Interfaces[3].Name}) + } + var gotEth0, gotVmbr0 archivedNetworkInterface + for _, iface := range inv.Interfaces { + switch iface.Name { + case "eth0": + gotEth0 = iface + case "vmbr0": + gotVmbr0 = iface + } + } + if gotEth0.MAC != "aa:bb:cc:dd:ee:01" { + t.Fatalf("eth0 MAC=%q", gotEth0.MAC) + } + if gotEth0.Driver != "e1000" { + t.Fatalf("eth0 Driver=%q", gotEth0.Driver) + } + if !strings.Contains(gotEth0.PCIPath, "devices/pci0000:00/0000:00:1f.6") { + t.Fatalf("eth0 PCIPath=%q", gotEth0.PCIPath) + } + if !gotVmbr0.IsVirtual { + t.Fatalf("vmbr0 IsVirtual=false") + } +} + +func TestPlanAndApplyNICNameRepair_WithFakeInventory(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSys := sysClassNetPath + origSeq := atomic.LoadUint64(&nicRepairSequence) + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSys + atomic.StoreUint64(&nicRepairSequence, origSeq) + }) + t.Setenv("PATH", "") + + // Fake current inventory via fake sysfs. + root := t.TempDir() + netDir := filepath.Join(root, "sys/class/net") + if err := os.MkdirAll(netDir, 0o755); err != nil { + t.Fatalf("mkdir netDir: %v", err) + } + sysClassNetPath = netDir + mustWriteFile := func(path, content string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + } + mustWriteFile(filepath.Join(netDir, "eth0/address"), "aa:bb:cc:dd:ee:01\n") + mustWriteFile(filepath.Join(netDir, "eno1/address"), "aa:bb:cc:dd:ee:02\n") + mustWriteFile(filepath.Join(netDir, "lo/address"), "00:00:00:00:00:00\n") + + // Fake backup archive. + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2025, 1, 1, 1, 2, 3, 0, time.UTC)} + + backupInv := archivedNetworkInventory{ + GeneratedAt: time.Now().Format(time.RFC3339), + Hostname: "backup-host", + Interfaces: []archivedNetworkInterface{ + {Name: "eno1", MAC: "aa:bb:cc:dd:ee:01"}, // maps to current eth0 -> conflict (current eno1 exists) + {Name: "ens20", MAC: "aa:bb:cc:dd:ee:02"}, // maps to current eno1 -> safe + }, + } + payload, err := json.Marshal(backupInv) + if err != nil { + t.Fatalf("marshal: %v", err) + } + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + {Name: "./commands/network_inventory.json", Typeflag: tar.TypeReg, Mode: 0o644, Data: payload}, + }) + + plan, err := planNICNameRepair(context.Background(), "/backup.tar") + if err != nil { + t.Fatalf("plan: %v", err) + } + if plan == nil { + t.Fatalf("plan=nil") + } + if plan.SkippedReason != "" { + t.Fatalf("SkippedReason=%q", plan.SkippedReason) + } + if plan.Mapping.BackupSourcePath != "./commands/network_inventory.json" { + t.Fatalf("BackupSourcePath=%q", plan.Mapping.BackupSourcePath) + } + if len(plan.SafeMappings) != 1 || plan.SafeMappings[0].OldName != "ens20" || plan.SafeMappings[0].NewName != "eno1" { + t.Fatalf("SafeMappings=%+v", plan.SafeMappings) + } + if len(plan.Conflicts) != 1 || plan.Conflicts[0].Mapping.OldName != "eno1" || plan.Conflicts[0].Mapping.NewName != "eth0" { + t.Fatalf("Conflicts=%+v", plan.Conflicts) + } + + logger := logging.New(types.LogLevelDebug, false) + logger.SetOutput(bytes.NewBuffer(nil)) + + // Prepare config to exercise both safe mapping and conflict mapping. + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte(strings.Join([]string{ + "auto ens20", + "iface ens20 inet manual", + "auto eno1", + "iface eno1 inet manual", + "", + }, "\n")), 0o644); err != nil { + t.Fatalf("write interfaces: %v", err) + } + + // includeConflicts=false: applies only safe mapping (ens2 -> eno1). + res, err := applyNICNameRepair(logger, plan, false) + if err != nil { + t.Fatalf("apply: %v", err) + } + if res == nil || res.SkippedReason != "" { + t.Fatalf("result=%+v", res) + } + if len(res.ChangedFiles) != 1 || res.ChangedFiles[0] != "/etc/network/interfaces" { + t.Fatalf("ChangedFiles=%v", res.ChangedFiles) + } + data, err := fakeFS.ReadFile("/etc/network/interfaces") + if err != nil { + t.Fatalf("read: %v", err) + } + if strings.Contains(string(data), "ens20") { + t.Fatalf("expected ens20 to be replaced:\n%s", string(data)) + } + if !strings.Contains(string(data), "auto eno1") { + t.Fatalf("expected eno1 to remain:\n%s", string(data)) + } + + // includeConflicts=true: also applies eno1 -> eth0. + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte(strings.Join([]string{ + "auto ens20", + "iface ens20 inet manual", + "auto eno1", + "iface eno1 inet manual", + "", + }, "\n")), 0o644); err != nil { + t.Fatalf("rewrite interfaces: %v", err) + } + res, err = applyNICNameRepair(logger, plan, true) + if err != nil { + t.Fatalf("apply conflicts: %v", err) + } + data, err = fakeFS.ReadFile("/etc/network/interfaces") + if err != nil { + t.Fatalf("read: %v", err) + } + if strings.Contains(string(data), "ens20") || strings.Contains(string(data), "auto eno1\n") { + t.Fatalf("expected ens20 and eno1 to be replaced:\n%s", string(data)) + } + if !strings.Contains(string(data), "auto eth0") { + t.Fatalf("expected eth0:\n%s", string(data)) + } +} + +func TestPlanNICNameRepair_SkipAndErrorBranches(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + plan, err := planNICNameRepair(context.Background(), " ") + if err != nil || plan == nil || plan.SkippedReason == "" { + t.Fatalf("plan=%+v err=%v", plan, err) + } + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + writeTarToFakeFS(t, fakeFS, "/missing_inv.tar", []tarEntry{ + {Name: "./unrelated", Typeflag: tar.TypeReg, Mode: 0o644, Data: []byte("x")}, + }) + plan, err = planNICNameRepair(context.Background(), "/missing_inv.tar") + if err != nil || plan == nil || !strings.Contains(plan.SkippedReason, "backup does not include network inventory") { + t.Fatalf("plan=%+v err=%v", plan, err) + } + + if err := fakeFS.WriteFile("/bad.zip", []byte("x"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + _, err = planNICNameRepair(context.Background(), "/bad.zip") + if err == nil || !strings.Contains(err.Error(), "unsupported archive format") { + t.Fatalf("err=%v; want unsupported archive format", err) + } +} + +func TestApplyNICNameRepair_SkipBranches(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + logger.SetOutput(bytes.NewBuffer(nil)) + + res, err := applyNICNameRepair(logger, nil, false) + if err != nil || res == nil || res.SkippedReason == "" { + t.Fatalf("res=%+v err=%v", res, err) + } + + res, err = applyNICNameRepair(logger, &nicRepairPlan{SkippedReason: "nope"}, false) + if err != nil || res == nil || !strings.Contains(res.SkippedReason, "nope") { + t.Fatalf("res=%+v err=%v", res, err) + } + + res, err = applyNICNameRepair(logger, &nicRepairPlan{ + Conflicts: []nicNameConflict{{Mapping: nicMappingEntry{OldName: "a", NewName: "b"}}}, + }, false) + if err != nil || res == nil || !strings.Contains(res.SkippedReason, "conflicting NIC mappings") { + t.Fatalf("res=%+v err=%v", res, err) + } + + res, err = applyNICNameRepair(logger, &nicRepairPlan{ + SafeMappings: []nicMappingEntry{ + {OldName: "eno1", NewName: "eno1"}, + }, + Conflicts: []nicNameConflict{ + {Mapping: nicMappingEntry{OldName: "eno2", NewName: "eth0"}}, + }, + }, false) + if err != nil || res == nil || res.SkippedReason != "conflicting NIC mappings detected; skipped by user" { + t.Fatalf("res=%+v err=%v", res, err) + } +} + +func TestMappingHelpersAndEdgeCases(t *testing.T) { + if !computeNICMapping(nil, nil).IsEmpty() { + t.Fatalf("expected empty mapping for nil inventories") + } + + if !hasStableUdevIdentifiers(map[string]string{"ID_SERIAL": " abc "}) { + t.Fatalf("expected stable udev identifiers") + } + if hasStableUdevIdentifiers(map[string]string{"ID_SERIAL": " "}) { + t.Fatalf("expected false for blank udev values") + } + + if shouldAddMapping("a", "a", map[string]struct{}{}) { + t.Fatalf("expected false for old==new") + } + if !shouldAddMapping("a", "b", nil) { + t.Fatalf("expected true when usedCurrent nil") + } + if shouldAddMapping("a", "b", map[string]struct{}{"b": {}}) { + t.Fatalf("expected false when usedCurrent already contains new") + } + + if isCandidatePhysicalNIC(archivedNetworkInterface{Name: "lo", MAC: "aa"}) { + t.Fatalf("expected lo to be non-candidate") + } + if isCandidatePhysicalNIC(archivedNetworkInterface{Name: "eth0", IsVirtual: true, MAC: "aa"}) { + t.Fatalf("expected virtual to be non-candidate") + } + if isCandidatePhysicalNIC(archivedNetworkInterface{Name: "eth0"}) { + t.Fatalf("expected no-identifiers to be non-candidate") + } + if !isCandidatePhysicalNIC(archivedNetworkInterface{Name: "eth0", UdevProps: map[string]string{"ID_PATH": "pci-1"}}) { + t.Fatalf("expected udev identifiers to make candidate") + } + + if out, changed := applyInterfaceRenameMap("", map[string]string{"a": "b"}); out != "" || changed { + t.Fatalf("applyInterfaceRenameMap unexpected result: out=%q changed=%v", out, changed) + } + if out, changed := applyInterfaceRenameMap("auto a\n", map[string]string{}); out != "auto a\n" || changed { + t.Fatalf("applyInterfaceRenameMap unexpected result: out=%q changed=%v", out, changed) + } + + if out, changed := replaceInterfaceToken("", "a", "b"); out != "" || changed { + t.Fatalf("replaceInterfaceToken unexpected: out=%q changed=%v", out, changed) + } + if _, changed := replaceInterfaceToken("auto a\n", "a", "a"); changed { + t.Fatalf("replaceInterfaceToken should not change when old==new") + } + + cases := map[byte]bool{ + 'a': true, + 'Z': true, + '0': true, + '_': true, + '-': true, + '.': false, + ' ': false, + } + for ch, want := range cases { + if got := isIfaceNameChar(ch); got != want { + t.Fatalf("isIfaceNameChar(%q)=%v want %v", string(ch), got, want) + } + } + + backup := &archivedNetworkInventory{ + Interfaces: []archivedNetworkInterface{ + {Name: "eno1", MAC: "aa:aa", PCIPath: "/pci/1"}, + }, + } + current := &archivedNetworkInventory{ + Interfaces: []archivedNetworkInterface{ + {Name: "eth0", MAC: "aa:aa", PCIPath: "/pci/1"}, + {Name: "eth1", MAC: "aa:aa", PCIPath: "/pci/2"}, + }, + } + got := computeNICMapping(backup, current) + if got.IsEmpty() || got.Entries[0].Method != nicMatchPCIPath { + t.Fatalf("got=%+v; want pci_path match due to MAC dupes", got) + } + + backup = &archivedNetworkInventory{ + Interfaces: []archivedNetworkInterface{ + {Name: "eno1", MAC: "bb:bb"}, + {Name: "ens2", MAC: "bb:bb"}, + }, + } + current = &archivedNetworkInventory{ + Interfaces: []archivedNetworkInterface{ + {Name: "eth0", MAC: "bb:bb"}, + }, + } + got = computeNICMapping(backup, current) + if len(got.Entries) != 1 { + t.Fatalf("Entries=%+v; want 1 due to usedCurrent", got.Entries) + } +} + +func TestPlanNICNameRepair_NoMappingAndCurrentInventoryError(t *testing.T) { + origFS := restoreFS + origSys := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + sysClassNetPath = origSys + }) + t.Setenv("PATH", "") + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + t.Run("no mapping", func(t *testing.T) { + root := t.TempDir() + netDir := filepath.Join(root, "sys/class/net") + if err := os.MkdirAll(netDir, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + sysClassNetPath = netDir + if err := os.MkdirAll(filepath.Join(netDir, "eth0"), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(netDir, "eth0/address"), []byte("aa:bb:cc:dd:ee:01\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + backupInv := archivedNetworkInventory{ + Interfaces: []archivedNetworkInterface{ + {Name: "eth0", MAC: "aa:bb:cc:dd:ee:01"}, + }, + } + payload, err := json.Marshal(backupInv) + if err != nil { + t.Fatalf("marshal: %v", err) + } + writeTarToFakeFS(t, fakeFS, "/nomap.tar", []tarEntry{ + {Name: "./commands/network_inventory.json", Typeflag: tar.TypeReg, Mode: 0o644, Data: payload}, + }) + + plan, err := planNICNameRepair(context.Background(), "/nomap.tar") + if err != nil { + t.Fatalf("plan: %v", err) + } + if plan == nil || !strings.Contains(plan.SkippedReason, "no NIC rename mapping found") { + t.Fatalf("plan=%+v", plan) + } + }) + + t.Run("current inventory error", func(t *testing.T) { + sysClassNetPath = filepath.Join(t.TempDir(), "does-not-exist") + + backupInv := archivedNetworkInventory{ + Interfaces: []archivedNetworkInterface{ + {Name: "eno1", MAC: "aa:bb"}, + }, + } + payload, err := json.Marshal(backupInv) + if err != nil { + t.Fatalf("marshal: %v", err) + } + writeTarToFakeFS(t, fakeFS, "/inv.tar", []tarEntry{ + {Name: "./commands/network_inventory.json", Typeflag: tar.TypeReg, Mode: 0o644, Data: payload}, + }) + + _, err = planNICNameRepair(context.Background(), "/inv.tar") + if err == nil || !strings.Contains(err.Error(), "collect current network inventory") { + t.Fatalf("err=%v; want collect current network inventory", err) + } + }) +} + +func TestApplyNICNameRepair_NoMatchesAndNoRenamesSelected(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)} + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto lo\niface lo inet loopback\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + res, err := applyNICNameRepair(nil, &nicRepairPlan{ + SafeMappings: []nicMappingEntry{ + {OldName: "eno1", NewName: "eth0", Method: nicMatchMAC, Identifier: "aa"}, + }, + }, false) + if err != nil { + t.Fatalf("apply: %v", err) + } + if res == nil || !strings.Contains(res.SkippedReason, "no matching interface names found") { + t.Fatalf("res=%+v", res) + } + + res, err = applyNICNameRepair(nil, &nicRepairPlan{ + SafeMappings: []nicMappingEntry{ + {OldName: "eno1", NewName: "eno1"}, + }, + }, false) + if err != nil { + t.Fatalf("apply: %v", err) + } + if res == nil || res.SkippedReason != "no NIC renames selected" { + t.Fatalf("res=%+v", res) + } +} + +func TestApplyNICNameRepair_PropagatesRewriteErrors(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)} + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + fakeFS.MkdirAllErr = errors.New("disk full") + + res, err := applyNICNameRepair(nil, &nicRepairPlan{ + SafeMappings: []nicMappingEntry{{OldName: "eno1", NewName: "eth0"}}, + }, false) + if err == nil || res != nil { + t.Fatalf("res=%+v err=%v; want nil result and error", res, err) + } +} + +func TestReadArchiveEntry_CorruptTar(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.WriteFile("/corrupt.tar", []byte("not a tar"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + _, _, err := readArchiveEntry(context.Background(), "/corrupt.tar", []string{"./x"}, 16) + if err == nil || errors.Is(err, os.ErrNotExist) { + t.Fatalf("err=%v; want tar parse error", err) + } +} + +func TestReadArchiveEntry_TruncatedEntryReadError(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + data := []byte("0123456789") + if err := tw.WriteHeader(&tar.Header{Name: "./x", Mode: 0o644, Size: int64(len(data))}); err != nil { + t.Fatalf("WriteHeader: %v", err) + } + if _, err := tw.Write(data); err != nil { + t.Fatalf("Write: %v", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + full := buf.Bytes() + if len(full) <= 512+5 { + t.Fatalf("unexpected tar size: %d", len(full)) + } + truncated := full[:512+5] + if err := fakeFS.WriteFile("/trunc.tar", truncated, 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + _, _, err := readArchiveEntry(context.Background(), "/trunc.tar", []string{"./x"}, 16) + if err == nil || errors.Is(err, os.ErrNotExist) { + t.Fatalf("err=%v; want read error", err) + } +} + +func TestCollectCurrentNetworkInventory_UdevAndPermanentMAC(t *testing.T) { + origSys := sysClassNetPath + origCmd := restoreCmd + t.Cleanup(func() { + sysClassNetPath = origSys + restoreCmd = origCmd + }) + + // Make commandAvailable() succeed. + binDir := t.TempDir() + for _, name := range []string{"udevadm", "ethtool"} { + path := filepath.Join(binDir, name) + if err := os.WriteFile(path, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { + t.Fatalf("write %s: %v", name, err) + } + } + t.Setenv("PATH", binDir) + + root := t.TempDir() + netDir := filepath.Join(root, "sys/class/net") + if err := os.MkdirAll(netDir, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + sysClassNetPath = netDir + + // eth0 directory. + if err := os.MkdirAll(filepath.Join(netDir, "eth0"), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(netDir, "eth0/address"), []byte("aa:bb:cc:dd:ee:01\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + // eth1 symlink (non-virtual). + eth1Target := filepath.Join(root, "devices/pci0000:00/eth1") + if err := os.MkdirAll(eth1Target, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(eth1Target, "address"), []byte("aa:bb:cc:dd:ee:02\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + if err := os.Symlink(eth1Target, filepath.Join(netDir, "eth1")); err != nil { + t.Fatalf("symlink: %v", err) + } + + cmd := &FakeCommandRunner{ + Outputs: map[string][]byte{ + "udevadm info -q property -p " + filepath.Join(netDir, "eth0"): []byte("ID_SERIAL=abc\n"), + "ethtool -P eth0": []byte("permanent address: AA:BB:CC:DD:EE:FF\n"), + "udevadm info -q property -p " + filepath.Join(netDir, "eth1"): []byte("ID_SERIAL=ignored\n"), + "ethtool -P eth1": []byte("permanent address: 11:22:33:44:55:66\n"), + }, + Errors: map[string]error{ + "udevadm info -q property -p " + filepath.Join(netDir, "eth1"): errors.New("boom"), + "ethtool -P eth1": errors.New("boom"), + }, + } + restoreCmd = cmd + + inv, err := collectCurrentNetworkInventory(context.Background()) + if err != nil { + t.Fatalf("collect: %v", err) + } + var eth0, eth1 archivedNetworkInterface + for _, iface := range inv.Interfaces { + switch iface.Name { + case "eth0": + eth0 = iface + case "eth1": + eth1 = iface + } + } + if eth0.PermanentMAC != "aa:bb:cc:dd:ee:ff" { + t.Fatalf("eth0 PermanentMAC=%q", eth0.PermanentMAC) + } + if eth0.UdevProps == nil || eth0.UdevProps["ID_SERIAL"] != "abc" { + t.Fatalf("eth0 UdevProps=%v", eth0.UdevProps) + } + if eth1.IsVirtual { + t.Fatalf("eth1 IsVirtual=true; want false") + } + if eth1.PermanentMAC != "" || eth1.UdevProps != nil { + t.Fatalf("eth1 should ignore cmd errors: %+v", eth1) + } +} + +func TestRewriteIfupdownConfigFiles_EdgeAndErrorPaths(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + restoreTime = &FakeTime{Current: time.Date(2025, 1, 1, 1, 2, 3, 0, time.UTC)} + + t.Run("empty rename map", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + paths, dir, err := rewriteIfupdownConfigFiles(nil, map[string]string{}) + if err != nil || len(paths) != 0 || dir != "" { + t.Fatalf("paths=%v dir=%q err=%v", paths, dir, err) + } + }) + + t.Run("interfaces.d missing but update succeeds", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + paths, _, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err != nil || len(paths) != 1 { + t.Fatalf("paths=%v err=%v", paths, err) + } + }) + + t.Run("interfaces.d entries skipped", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + if err := fakeFS.MkdirAll("/etc/network/interfaces.d/subdir", 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := fakeFS.WriteFile("/etc/network/interfaces.d/ ", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + if err := fakeFS.WriteFile("/etc/network/interfaces.d/extra", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + paths, _, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err != nil || len(paths) != 2 { + t.Fatalf("paths=%v err=%v", paths, err) + } + }) + + t.Run("stat failure skips", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + base := statFailFS{FS: fakeFS, failPath: "/etc/network/interfaces", err: errors.New("boom")} + restoreFS = base + + paths, dir, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err != nil || len(paths) != 0 || dir != "" { + t.Fatalf("paths=%v dir=%q err=%v", paths, dir, err) + } + }) + + t.Run("not regular file skipped", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.MkdirAll("/etc/network/interfaces", 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + paths, dir, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err != nil || len(paths) != 0 || dir != "" { + t.Fatalf("paths=%v dir=%q err=%v", paths, dir, err) + } + }) + + t.Run("read failure skipped", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + restoreFS = readFileFailFS{FS: fakeFS, failPath: "/etc/network/interfaces", err: errors.New("boom")} + + paths, dir, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err != nil || len(paths) != 0 || dir != "" { + t.Fatalf("paths=%v dir=%q err=%v", paths, dir, err) + } + }) + + t.Run("base dir create fails", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + fakeFS.MkdirAllErr = errors.New("disk full") + + _, _, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err == nil || !strings.Contains(err.Error(), "create nic repair base directory") { + t.Fatalf("err=%v; want create nic repair base directory", err) + } + }) + + t.Run("write updated fails", func(t *testing.T) { + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + restoreFS = writeFileFailFS{FS: fakeFS, failPath: "/etc/network/interfaces", err: errors.New("disk full")} + + _, _, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err == nil || !strings.Contains(err.Error(), "write updated file") { + t.Fatalf("err=%v; want write updated file", err) + } + }) +} + +func TestRewriteIfupdownConfigFiles_BackupStageErrors(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSeq := atomic.LoadUint64(&nicRepairSequence) + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + atomic.StoreUint64(&nicRepairSequence, origSeq) + }) + + restoreTime = &FakeTime{Current: time.Date(2025, 1, 1, 1, 2, 3, 0, time.UTC)} + expectedBackupDir := "/tmp/proxsave/nic_repair_20250101_010203_1" + expectedBackupPath := filepath.Join(expectedBackupDir, "etc/network/interfaces") + expectedBackupPathDir := filepath.Dir(expectedBackupPath) + + t.Run("create backup dir fails", func(t *testing.T) { + atomic.StoreUint64(&nicRepairSequence, 0) + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + restoreFS = mkdirAllFailFS{FS: fakeFS, failPath: expectedBackupDir, err: errors.New("boom")} + + _, _, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err == nil || !strings.Contains(err.Error(), "create nic repair backup directory") { + t.Fatalf("err=%v; want create nic repair backup directory", err) + } + }) + + t.Run("create backup path dir fails", func(t *testing.T) { + atomic.StoreUint64(&nicRepairSequence, 0) + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + restoreFS = mkdirAllFailFS{FS: fakeFS, failPath: expectedBackupPathDir, err: errors.New("boom")} + + _, _, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err == nil || !strings.Contains(err.Error(), "create backup directory for") { + t.Fatalf("err=%v; want create backup directory for", err) + } + }) + + t.Run("write backup file fails", func(t *testing.T) { + atomic.StoreUint64(&nicRepairSequence, 0) + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\n"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + restoreFS = writeFileFailFS{FS: fakeFS, failPath: expectedBackupPath, err: errors.New("disk full")} + + _, _, err := rewriteIfupdownConfigFiles(nil, map[string]string{"eno1": "eth0"}) + if err == nil || !strings.Contains(err.Error(), "write backup file") { + t.Fatalf("err=%v; want write backup file", err) + } + }) +} + +func TestMapToEntriesAndTokenBoundary(t *testing.T) { + if got := mapToEntries(map[string]string{}); got != nil { + t.Fatalf("mapToEntries=%v; want nil", got) + } + got := mapToEntries(map[string]string{"b": "B", "a": "A"}) + if len(got) != 2 || got[0].OldName != "a" || got[1].OldName != "b" { + t.Fatalf("entries=%+v", got) + } + + if isTokenBoundary("abc", -1, "a") { + t.Fatalf("expected false for negative idx") + } + if isTokenBoundary("abc", 2, "zz") { + t.Fatalf("expected false for token overflow") + } + if isTokenBoundary("xeno1", 1, "eno1") { + t.Fatalf("expected false for iface-char prefix") + } + if isTokenBoundary("eno10", 0, "eno1") { + t.Fatalf("expected false for iface-char suffix") + } + if !isTokenBoundary("eno1", 0, "eno1") { + t.Fatalf("expected true for token covering full string") + } + + if out, changed := applyInterfaceRenameMap("auto a\n", map[string]string{"a": "a"}); out != "auto a\n" || changed { + t.Fatalf("applyInterfaceRenameMap unexpected: out=%q changed=%v", out, changed) + } +} From 37c3f0133f080c450c7c948e7bb1878ff1e440c6 Mon Sep 17 00:00:00 2001 From: Damiano <71268257+tis24dev@users.noreply.github.com> Date: Tue, 24 Feb 2026 15:01:10 +0100 Subject: [PATCH 08/12] Add hooks for PBS staged apply and tests Introduce hookable function variables in pbs_staged_apply.go (for euid, isRealRestoreFS and all PBS API/apply functions) and switch maybeApplyPBSConfigsFromStage to use them so the staged-apply logic can be tested without touching the real system/API. Add two comprehensive test files (pbs_staged_apply_additional_test.go and pbs_staged_apply_maybeapply_test.go) that exercise parsing/validation, file-based fallbacks, datastore deferral logic, atomic write error handling, job/tape config application, permission checks, and various edge cases. These changes enable robust unit testing of PBS staged config application while keeping runtime behavior unchanged. --- internal/orchestrator/pbs_staged_apply.go | 37 +- .../pbs_staged_apply_additional_test.go | 1031 +++++++++++++++++ .../pbs_staged_apply_maybeapply_test.go | 407 +++++++ 3 files changed, 1464 insertions(+), 11 deletions(-) create mode 100644 internal/orchestrator/pbs_staged_apply_additional_test.go create mode 100644 internal/orchestrator/pbs_staged_apply_maybeapply_test.go diff --git a/internal/orchestrator/pbs_staged_apply.go b/internal/orchestrator/pbs_staged_apply.go index 0b84e8f..265910c 100644 --- a/internal/orchestrator/pbs_staged_apply.go +++ b/internal/orchestrator/pbs_staged_apply.go @@ -12,6 +12,21 @@ import ( "github.com/tis24dev/proxsave/internal/logging" ) +// Hookable functions for testing staged PBS apply logic without touching the real system/API. +var ( + pbsStagedApplyIsRealRestoreFSFn = isRealRestoreFS + pbsStagedApplyGeteuidFn = os.Geteuid + pbsStagedApplyEnsurePBSServicesForAPIFn = ensurePBSServicesForAPI + pbsStagedApplyTrafficControlCfgViaAPIFn = applyPBSTrafficControlCfgViaAPI + pbsStagedApplyNodeCfgViaAPIFn = applyPBSNodeCfgViaAPI + pbsStagedApplyS3CfgViaAPIFn = applyPBSS3CfgViaAPI + pbsStagedApplyDatastoreCfgViaAPIFn = applyPBSDatastoreCfgViaAPI + pbsStagedApplyRemoteCfgViaAPIFn = applyPBSRemoteCfgViaAPI + pbsStagedApplySyncCfgViaAPIFn = applyPBSSyncCfgViaAPI + pbsStagedApplyVerificationCfgViaAPIFn = applyPBSVerificationCfgViaAPI + pbsStagedApplyPruneCfgViaAPIFn = applyPBSPruneCfgViaAPI +) + func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, plan *RestorePlan, stageRoot string, dryRun bool) (err error) { if plan == nil || plan.SystemType != SystemTypePBS { return nil @@ -31,11 +46,11 @@ func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, logger.Info("Dry run enabled: skipping staged PBS config apply") return nil } - if !isRealRestoreFS(restoreFS) { + if !pbsStagedApplyIsRealRestoreFSFn(restoreFS) { logger.Debug("Skipping staged PBS config apply: non-system filesystem in use") return nil } - if os.Geteuid() != 0 { + if pbsStagedApplyGeteuidFn() != 0 { logger.Warning("Skipping staged PBS config apply: requires root privileges") return nil } @@ -47,7 +62,7 @@ func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, needsAPI := plan.HasCategoryID("pbs_host") || plan.HasCategoryID("datastore_pbs") || plan.HasCategoryID("pbs_remotes") || plan.HasCategoryID("pbs_jobs") apiAvailable := false if needsAPI { - if err := ensurePBSServicesForAPI(ctx, logger); err != nil { + if err := pbsStagedApplyEnsurePBSServicesForAPIFn(ctx, logger); err != nil { if allowFileFallback { logger.Warning("PBS API apply unavailable; falling back to file-based staged apply where possible: %v", err) } else { @@ -73,14 +88,14 @@ func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, } if apiAvailable { - if err := applyPBSTrafficControlCfgViaAPI(ctx, logger, stageRoot, strict); err != nil { + if err := pbsStagedApplyTrafficControlCfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: traffic-control failed: %v", err) if allowFileFallback { logger.Warning("PBS staged apply: falling back to file-based traffic-control.cfg") _ = applyPBSConfigFileFromStage(ctx, logger, stageRoot, "etc/proxmox-backup/traffic-control.cfg") } } - if err := applyPBSNodeCfgViaAPI(ctx, stageRoot); err != nil { + if err := pbsStagedApplyNodeCfgViaAPIFn(ctx, stageRoot); err != nil { logger.Warning("PBS API apply: node config failed: %v", err) if allowFileFallback { logger.Warning("PBS staged apply: falling back to file-based node.cfg") @@ -103,14 +118,14 @@ func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, if plan.HasCategoryID("datastore_pbs") { if apiAvailable { - if err := applyPBSS3CfgViaAPI(ctx, logger, stageRoot, strict); err != nil { + if err := pbsStagedApplyS3CfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: s3.cfg failed: %v", err) if allowFileFallback { logger.Warning("PBS staged apply: falling back to file-based s3.cfg") _ = applyPBSS3CfgFromStage(ctx, logger, stageRoot) } } - if err := applyPBSDatastoreCfgViaAPI(ctx, logger, stageRoot, strict); err != nil { + if err := pbsStagedApplyDatastoreCfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: datastore.cfg failed: %v", err) if allowFileFallback { logger.Warning("PBS staged apply: falling back to file-based datastore.cfg") @@ -131,7 +146,7 @@ func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, if plan.HasCategoryID("pbs_remotes") { if apiAvailable { - if err := applyPBSRemoteCfgViaAPI(ctx, logger, stageRoot, strict); err != nil { + if err := pbsStagedApplyRemoteCfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: remote.cfg failed: %v", err) if allowFileFallback { logger.Warning("PBS staged apply: falling back to file-based remote.cfg") @@ -149,21 +164,21 @@ func maybeApplyPBSConfigsFromStage(ctx context.Context, logger *logging.Logger, if plan.HasCategoryID("pbs_jobs") { if apiAvailable { - if err := applyPBSSyncCfgViaAPI(ctx, logger, stageRoot, strict); err != nil { + if err := pbsStagedApplySyncCfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: sync jobs failed: %v", err) if allowFileFallback { logger.Warning("PBS staged apply: falling back to file-based job configs") _ = applyPBSJobConfigsFromStage(ctx, logger, stageRoot) } } - if err := applyPBSVerificationCfgViaAPI(ctx, logger, stageRoot, strict); err != nil { + if err := pbsStagedApplyVerificationCfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: verification jobs failed: %v", err) if allowFileFallback { logger.Warning("PBS staged apply: falling back to file-based job configs") _ = applyPBSJobConfigsFromStage(ctx, logger, stageRoot) } } - if err := applyPBSPruneCfgViaAPI(ctx, logger, stageRoot, strict); err != nil { + if err := pbsStagedApplyPruneCfgViaAPIFn(ctx, logger, stageRoot, strict); err != nil { logger.Warning("PBS API apply: prune jobs failed: %v", err) if allowFileFallback { logger.Warning("PBS staged apply: falling back to file-based job configs") diff --git a/internal/orchestrator/pbs_staged_apply_additional_test.go b/internal/orchestrator/pbs_staged_apply_additional_test.go new file mode 100644 index 0000000..80cb253 --- /dev/null +++ b/internal/orchestrator/pbs_staged_apply_additional_test.go @@ -0,0 +1,1031 @@ +package orchestrator + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestPBSConfigHasHeader_AcceptsAndRejectsExpectedForms(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + content string + want bool + }{ + { + name: "HeaderWithSpaceSeparatedName", + content: strings.Join([]string{ + "# comment", + "", + "remote: pbs1", + " host 10.0.0.10", + }, "\n"), + want: true, + }, + { + name: "HeaderWithInlineName", + content: strings.Join([]string{ + "remote:pbs1", + " host 10.0.0.10", + }, "\n"), + want: true, + }, + { + name: "RejectsHeaderWithoutName", + content: strings.Join([]string{ + "datastore:", + " path /mnt/datastore", + }, "\n"), + want: false, + }, + { + name: "RejectsInvalidKeyCharacters", + content: "foo.bar: baz\n", + want: false, + }, + { + name: "RejectsEmptyKey", + content: ": x\n", + want: false, + }, + { + name: "RejectsOnlyComments", + content: "# comment\n# still comment\n", + want: false, + }, + { + name: "AcceptsDashAndUnderscore", + content: "key-with-dash: v\nkey_with_underscore: v\n", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := pbsConfigHasHeader(tt.content); got != tt.want { + t.Fatalf("pbsConfigHasHeader()=%v want %v", got, tt.want) + } + }) + } +} + +func TestMaybeApplyPBSConfigsFromStage_EarlyReturns(t *testing.T) { + ctx := context.Background() + logger := newTestLogger() + + if err := maybeApplyPBSConfigsFromStage(ctx, logger, nil, "/stage", false); err != nil { + t.Fatalf("nil plan: expected nil error, got %v", err) + } + + planWrongSystem := &RestorePlan{SystemType: SystemTypePVE, NormalCategories: []Category{{ID: "pbs_host"}}} + if err := maybeApplyPBSConfigsFromStage(ctx, logger, planWrongSystem, "/stage", false); err != nil { + t.Fatalf("wrong system type: expected nil error, got %v", err) + } + + planNoCategories := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "unrelated"}}} + if err := maybeApplyPBSConfigsFromStage(ctx, logger, planNoCategories, "/stage", false); err != nil { + t.Fatalf("no pbs categories: expected nil error, got %v", err) + } + + plan := &RestorePlan{SystemType: SystemTypePBS, NormalCategories: []Category{{ID: "pbs_host"}}} + if err := maybeApplyPBSConfigsFromStage(ctx, logger, plan, " ", false); err != nil { + t.Fatalf("blank stageRoot: expected nil error, got %v", err) + } + if err := maybeApplyPBSConfigsFromStage(ctx, logger, plan, "/stage", true); err != nil { + t.Fatalf("dryRun: expected nil error, got %v", err) + } + + origFS := restoreFS + fakeFS := NewFakeFS() + t.Cleanup(func() { + restoreFS = origFS + _ = os.RemoveAll(fakeFS.Root) + }) + restoreFS = fakeFS + if err := maybeApplyPBSConfigsFromStage(ctx, logger, plan, "/stage", false); err != nil { + t.Fatalf("non-real FS: expected nil error, got %v", err) + } +} + +func TestApplyPBSConfigFileFromStage_SkipsMissingFile(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := applyPBSConfigFileFromStage(context.Background(), newTestLogger(), "/stage", "etc/proxmox-backup/s3.cfg"); err != nil { + t.Fatalf("applyPBSConfigFileFromStage: %v", err) + } + if _, err := fakeFS.Stat("/etc/proxmox-backup/s3.cfg"); err == nil { + t.Fatalf("expected s3.cfg to not be created") + } else if !os.IsNotExist(err) { + t.Fatalf("stat s3.cfg: %v", err) + } +} + +func TestApplyPBSConfigFileFromStage_SkipsInvalidHeader_LeavesTargetUnchanged(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + const destPath = "/etc/proxmox-backup/remote.cfg" + existing := "remote: old\n host 1.2.3.4\n" + if err := fakeFS.WriteFile(destPath, []byte(existing), 0o640); err != nil { + t.Fatalf("write existing remote.cfg: %v", err) + } + + stageRoot := "/stage" + staged := "this is not a PBS config file\n(no section header)\n" + if err := fakeFS.WriteFile(filepath.Join(stageRoot, "etc/proxmox-backup/remote.cfg"), []byte(staged), 0o640); err != nil { + t.Fatalf("write staged remote.cfg: %v", err) + } + + if err := applyPBSConfigFileFromStage(context.Background(), newTestLogger(), stageRoot, "etc/proxmox-backup/remote.cfg"); err != nil { + t.Fatalf("applyPBSConfigFileFromStage: %v", err) + } + + got, err := fakeFS.ReadFile(destPath) + if err != nil { + t.Fatalf("read dest remote.cfg: %v", err) + } + if string(got) != existing { + t.Fatalf("dest remote.cfg changed unexpectedly: got=%q want=%q", string(got), existing) + } +} + +func TestApplyPBSConfigFileFromStage_ReturnsErrorOnAtomicWriteFailure(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Unix(123, 456)} + + stageRoot := "/stage" + rel := "etc/proxmox-backup/s3.cfg" + staged := "s3: r1\n bucket test\n" + if err := fakeFS.WriteFile(filepath.Join(stageRoot, rel), []byte(staged), 0o640); err != nil { + t.Fatalf("write staged s3.cfg: %v", err) + } + + destPath := filepath.Join(string(os.PathSeparator), filepath.FromSlash(rel)) + tmpPath := fmt.Sprintf("%s.proxsave.tmp.%d", destPath, nowRestore().UnixNano()) + fakeFS.OpenFileErr[filepath.Clean(tmpPath)] = errors.New("forced OpenFile error") + + if err := applyPBSConfigFileFromStage(context.Background(), newTestLogger(), stageRoot, rel); err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestApplyPBSS3CfgFromStage_WritesS3Cfg(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + content := "s3: r1\n bucket test\n" + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/s3.cfg", []byte(content), 0o640); err != nil { + t.Fatalf("write staged s3.cfg: %v", err) + } + + if err := applyPBSS3CfgFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSS3CfgFromStage: %v", err) + } + + if _, err := fakeFS.Stat("/etc/proxmox-backup/s3.cfg"); err != nil { + t.Fatalf("expected s3.cfg to exist: %v", err) + } +} + +func TestApplyPBSConfigFileFromStage_PropagatesReadError(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + rel := "etc/proxmox-backup/remote.cfg" + if err := fakeFS.MkdirAll(filepath.Join(stageRoot, rel), 0o755); err != nil { + t.Fatalf("mkdir staged remote.cfg dir: %v", err) + } + + if err := applyPBSConfigFileFromStage(context.Background(), newTestLogger(), stageRoot, rel); err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestLoadPBSDatastoreCfgFromInventory_FallsBackToDatastoreList(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + inventory := `{"datastores":[{"name":"DS1","path":"/mnt/ds1","comment":"primary"},{"name":"DS2","path":"/mnt/ds2","comment":""}]}` + if err := fakeFS.WriteFile(stageRoot+"/var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json", []byte(inventory), 0o640); err != nil { + t.Fatalf("write inventory: %v", err) + } + + content, src, err := loadPBSDatastoreCfgFromInventory(stageRoot) + if err != nil { + t.Fatalf("loadPBSDatastoreCfgFromInventory: %v", err) + } + if src != "pbs_datastore_inventory.json.datastores" { + t.Fatalf("src=%q", src) + } + + blocks, err := parsePBSDatastoreCfgBlocks(content) + if err != nil { + t.Fatalf("parsePBSDatastoreCfgBlocks: %v", err) + } + if len(blocks) != 2 { + t.Fatalf("expected 2 blocks, got %d", len(blocks)) + } + paths := map[string]string{} + for _, b := range blocks { + paths[b.Name] = b.Path + } + if paths["DS1"] != "/mnt/ds1" { + t.Fatalf("DS1 path=%q", paths["DS1"]) + } + if paths["DS2"] != "/mnt/ds2" { + t.Fatalf("DS2 path=%q", paths["DS2"]) + } + if !strings.Contains(content, "comment primary") { + t.Fatalf("expected DS1 comment in generated content") + } +} + +func TestLoadPBSDatastoreCfgFromInventory_PropagatesErrors(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + inventoryPath := stageRoot + "/var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json" + + if err := fakeFS.WriteFile(inventoryPath, []byte(" \n"), 0o640); err != nil { + t.Fatalf("write empty inventory: %v", err) + } + if _, _, err := loadPBSDatastoreCfgFromInventory(stageRoot); err == nil { + t.Fatalf("expected error for empty inventory") + } + + if err := fakeFS.WriteFile(inventoryPath, []byte("not-json"), 0o640); err != nil { + t.Fatalf("write invalid inventory: %v", err) + } + if _, _, err := loadPBSDatastoreCfgFromInventory(stageRoot); err == nil { + t.Fatalf("expected error for invalid JSON") + } + + if err := fakeFS.WriteFile(inventoryPath, []byte(`{"datastores":[{"name":"","path":"","comment":""}]}`), 0o640); err != nil { + t.Fatalf("write unusable inventory: %v", err) + } + if _, _, err := loadPBSDatastoreCfgFromInventory(stageRoot); err == nil { + t.Fatalf("expected error for unusable inventory") + } +} + +func TestDetectPBSDatastoreCfgDuplicateKeys_DetectsDuplicateKeys(t *testing.T) { + t.Parallel() + + blocks := []pbsDatastoreBlock{{ + Name: "DS1", + Lines: []string{ + "datastore: DS1", + "# comment", + "", + " path /mnt/a", + " path /mnt/b", + }, + }} + if reason := detectPBSDatastoreCfgDuplicateKeys(blocks); reason == "" { + t.Fatalf("expected duplicate key detection") + } +} + +func TestDetectPBSDatastoreCfgDuplicateKeys_AllowsUniqueKeys(t *testing.T) { + t.Parallel() + + blocks := []pbsDatastoreBlock{{ + Name: "DS1", + Lines: []string{ + "datastore: DS1", + " comment one", + " path /mnt/a", + }, + }} + if reason := detectPBSDatastoreCfgDuplicateKeys(blocks); reason != "" { + t.Fatalf("expected no duplicates, got %q", reason) + } +} + +func TestParsePBSDatastoreCfgBlocks_IgnoresGarbageAndHandlesMissingNames(t *testing.T) { + t.Parallel() + + content := strings.Join([]string{ + "path /should/be/ignored", + "datastore:", + " path /also/ignored", + "datastore: DS1", + "# keep comment", + "path /mnt/ds1", + "", + "datastore: DS2:", + " path /mnt/ds2", + "", + }, "\n") + + blocks, err := parsePBSDatastoreCfgBlocks(content) + if err != nil { + t.Fatalf("parsePBSDatastoreCfgBlocks: %v", err) + } + if len(blocks) != 2 { + t.Fatalf("expected 2 blocks, got %d", len(blocks)) + } + if blocks[0].Name != "DS1" || blocks[0].Path != "/mnt/ds1" { + t.Fatalf("block[0]=%+v", blocks[0]) + } + if blocks[1].Name != "DS2" || blocks[1].Path != "/mnt/ds2" { + t.Fatalf("block[1]=%+v", blocks[1]) + } + if gotLines := strings.Join(blocks[0].Lines, "\n"); !strings.Contains(gotLines, "# keep comment") { + t.Fatalf("expected DS1 block to retain comment line; got=%q", gotLines) + } +} + +func TestParsePBSDatastoreCfgBlocks_DropsEmptyNamedBlocks(t *testing.T) { + t.Parallel() + + content := strings.Join([]string{ + "datastore: :", + " path /mnt/ignored", + "", + "datastore: DS1", + " path /mnt/ds1", + "", + }, "\n") + + blocks, err := parsePBSDatastoreCfgBlocks(content) + if err != nil { + t.Fatalf("parsePBSDatastoreCfgBlocks: %v", err) + } + if len(blocks) != 1 { + t.Fatalf("expected 1 block, got %d", len(blocks)) + } + if blocks[0].Name != "DS1" { + t.Fatalf("block[0].Name=%q", blocks[0].Name) + } +} + +func TestShouldApplyPBSDatastoreBlock_CoversCommonBranches(t *testing.T) { + t.Parallel() + + if ok, reason := shouldApplyPBSDatastoreBlock(pbsDatastoreBlock{Name: "ds", Path: "/"}, newTestLogger()); ok || !strings.Contains(reason, "invalid") { + t.Fatalf("expected invalid path rejection, got ok=%v reason=%q", ok, reason) + } + + dsDir := t.TempDir() + if err := os.MkdirAll(filepath.Join(dsDir, ".chunks"), 0o755); err != nil { + t.Fatalf("mkdir .chunks: %v", err) + } + if err := os.WriteFile(filepath.Join(dsDir, ".chunks", "c1"), []byte("x"), 0o644); err != nil { + t.Fatalf("write chunk: %v", err) + } + if ok, reason := shouldApplyPBSDatastoreBlock(pbsDatastoreBlock{Name: "ds", Path: dsDir}, newTestLogger()); !ok { + t.Fatalf("expected hasData datastore to be applied, got ok=false reason=%q", reason) + } + + tooLong := "/" + strings.Repeat("a", 5000) + if ok, reason := shouldApplyPBSDatastoreBlock(pbsDatastoreBlock{Name: "ds", Path: tooLong}, newTestLogger()); ok || !strings.Contains(reason, "inspection failed") { + t.Fatalf("expected inspection failure, got ok=%v reason=%q", ok, reason) + } +} + +func TestWriteDeferredPBSDatastoreCfg_EmptyInputIsNoop(t *testing.T) { + t.Parallel() + + if path, err := writeDeferredPBSDatastoreCfg(nil); err != nil { + t.Fatalf("err=%v", err) + } else if path != "" { + t.Fatalf("expected empty path, got %q", path) + } +} + +func TestWriteDeferredPBSDatastoreCfg_WritesFile(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 1, 2, 3, 4, 5, 0, time.UTC)} + + blocks := []pbsDatastoreBlock{{ + Name: "DS1", + Path: "/mnt/ds1", + Lines: []string{"datastore: DS1", " path /mnt/ds1"}, + }} + + path, err := writeDeferredPBSDatastoreCfg(blocks) + if err != nil { + t.Fatalf("writeDeferredPBSDatastoreCfg: %v", err) + } + if path == "" { + t.Fatalf("expected non-empty path") + } + + raw, err := fakeFS.ReadFile(path) + if err != nil { + t.Fatalf("read deferred file: %v", err) + } + if !strings.Contains(string(raw), "datastore: DS1") { + t.Fatalf("unexpected deferred content: %q", string(raw)) + } +} + +func TestWriteDeferredPBSDatastoreCfg_PropagatesMkdirError(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + fakeFS.MkdirAllErr = errors.New("forced mkdir error") + + _, err := writeDeferredPBSDatastoreCfg([]pbsDatastoreBlock{{Name: "DS1", Path: "/mnt/ds1", Lines: []string{"datastore: DS1", " path /mnt/ds1"}}}) + if err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestWriteDeferredPBSDatastoreCfg_PropagatesWriteError(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + fakeFS.WriteErr = errors.New("forced write error") + + _, err := writeDeferredPBSDatastoreCfg([]pbsDatastoreBlock{{Name: "DS1", Path: "/mnt/ds1", Lines: []string{"datastore: DS1", " path /mnt/ds1"}}}) + if err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestWriteDeferredPBSDatastoreCfg_MultipleBlocksAddsSeparator(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 1, 2, 3, 4, 6, 0, time.UTC)} + + blocks := []pbsDatastoreBlock{ + {Name: "DS1", Lines: []string{"datastore: DS1", " path /mnt/ds1"}}, + {Name: "DS2", Lines: []string{"datastore: DS2", " path /mnt/ds2"}}, + } + + path, err := writeDeferredPBSDatastoreCfg(blocks) + if err != nil { + t.Fatalf("writeDeferredPBSDatastoreCfg: %v", err) + } + + raw, err := fakeFS.ReadFile(path) + if err != nil { + t.Fatalf("read deferred file: %v", err) + } + if !strings.Contains(string(raw), "datastore: DS1\n path /mnt/ds1\n\ndatastore: DS2\n path /mnt/ds2") { + t.Fatalf("expected blank line separator between blocks; got=%q", string(raw)) + } +} + +func TestApplyPBSDatastoreCfgFromStage_SkipsMissingStagedFile(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := applyPBSDatastoreCfgFromStage(context.Background(), newTestLogger(), "/stage"); err != nil { + t.Fatalf("applyPBSDatastoreCfgFromStage: %v", err) + } + if _, err := fakeFS.Stat("/etc/proxmox-backup/datastore.cfg"); err == nil { + t.Fatalf("expected datastore.cfg not to be created") + } +} + +func TestApplyPBSDatastoreCfgFromStage_RemovesTargetWhenStagedEmpty(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.WriteFile("/etc/proxmox-backup/datastore.cfg", []byte("datastore: old\n path /mnt/old\n"), 0o640); err != nil { + t.Fatalf("write existing datastore.cfg: %v", err) + } + if err := fakeFS.WriteFile("/stage/etc/proxmox-backup/datastore.cfg", []byte(" \n"), 0o640); err != nil { + t.Fatalf("write staged datastore.cfg: %v", err) + } + + if err := applyPBSDatastoreCfgFromStage(context.Background(), newTestLogger(), "/stage"); err != nil { + t.Fatalf("applyPBSDatastoreCfgFromStage: %v", err) + } + if _, err := fakeFS.Stat("/etc/proxmox-backup/datastore.cfg"); err == nil { + t.Fatalf("expected datastore.cfg removed") + } +} + +func TestApplyPBSDatastoreCfgFromStage_DefersUnsafeAndAppliesSafe(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 3, 4, 5, 6, 0, time.UTC)} + + safeDir := t.TempDir() + unsafeDir := t.TempDir() + if err := os.WriteFile(filepath.Join(unsafeDir, "unexpected"), []byte("x"), 0o644); err != nil { + t.Fatalf("write unexpected file: %v", err) + } + + stageRoot := "/stage" + staged := strings.Join([]string{ + "datastore: Safe", + fmt.Sprintf("path %s", safeDir), + "", + "datastore: Unsafe", + fmt.Sprintf("path %s", unsafeDir), + "", + }, "\n") + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/datastore.cfg", []byte(staged), 0o640); err != nil { + t.Fatalf("write staged datastore.cfg: %v", err) + } + + if err := applyPBSDatastoreCfgFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSDatastoreCfgFromStage: %v", err) + } + + out, err := fakeFS.ReadFile("/etc/proxmox-backup/datastore.cfg") + if err != nil { + t.Fatalf("read applied datastore.cfg: %v", err) + } + if !strings.Contains(string(out), "datastore: Safe") { + t.Fatalf("expected Safe datastore in output: %q", string(out)) + } + if strings.Contains(string(out), "datastore: Unsafe") { + t.Fatalf("did not expect Unsafe datastore in output: %q", string(out)) + } + + entries, err := fakeFS.ReadDir("/tmp/proxsave") + if err != nil { + t.Fatalf("readdir deferred dir: %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 deferred file, got %d", len(entries)) + } + deferredPath := filepath.Join("/tmp/proxsave", entries[0].Name()) + deferred, err := fakeFS.ReadFile(deferredPath) + if err != nil { + t.Fatalf("read deferred file: %v", err) + } + if !strings.Contains(string(deferred), "datastore: Unsafe") { + t.Fatalf("expected Unsafe datastore deferred: %q", string(deferred)) + } +} + +func TestApplyPBSDatastoreCfgFromStage_AllDeferredLeavesTargetUnchanged(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 3, 4, 5, 7, 0, time.UTC)} + + existing := "datastore: Existing\n path /mnt/existing\n" + if err := fakeFS.WriteFile("/etc/proxmox-backup/datastore.cfg", []byte(existing), 0o640); err != nil { + t.Fatalf("write existing datastore.cfg: %v", err) + } + + unsafeDir := t.TempDir() + if err := os.WriteFile(filepath.Join(unsafeDir, "unexpected"), []byte("x"), 0o644); err != nil { + t.Fatalf("write unexpected file: %v", err) + } + + stageRoot := "/stage" + staged := strings.Join([]string{ + "datastore: UnsafeOnly", + fmt.Sprintf("path %s", unsafeDir), + "", + }, "\n") + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/datastore.cfg", []byte(staged), 0o640); err != nil { + t.Fatalf("write staged datastore.cfg: %v", err) + } + + if err := applyPBSDatastoreCfgFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSDatastoreCfgFromStage: %v", err) + } + + got, err := fakeFS.ReadFile("/etc/proxmox-backup/datastore.cfg") + if err != nil { + t.Fatalf("read datastore.cfg: %v", err) + } + if string(got) != existing { + t.Fatalf("datastore.cfg changed unexpectedly: got=%q want=%q", string(got), existing) + } +} + +func TestApplyPBSDatastoreCfgFromStage_DuplicateKeysWithoutInventoryLeavesTargetUnchanged(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + existing := "datastore: Existing\n path /mnt/existing\n" + if err := fakeFS.WriteFile("/etc/proxmox-backup/datastore.cfg", []byte(existing), 0o640); err != nil { + t.Fatalf("write existing datastore.cfg: %v", err) + } + + stageRoot := "/stage" + staged := strings.Join([]string{ + "datastore: Broken", + " path /mnt/a", + " path /mnt/b", + "", + }, "\n") + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/datastore.cfg", []byte(staged), 0o640); err != nil { + t.Fatalf("write staged datastore.cfg: %v", err) + } + + if err := applyPBSDatastoreCfgFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSDatastoreCfgFromStage: %v", err) + } + + got, err := fakeFS.ReadFile("/etc/proxmox-backup/datastore.cfg") + if err != nil { + t.Fatalf("read datastore.cfg: %v", err) + } + if string(got) != existing { + t.Fatalf("datastore.cfg changed unexpectedly: got=%q want=%q", string(got), existing) + } +} + +func TestApplyPBSDatastoreCfgFromStage_SkipsWhenNoBlocksDetected(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/datastore.cfg", []byte("# only comments\n\n# nothing else\n"), 0o640); err != nil { + t.Fatalf("write staged datastore.cfg: %v", err) + } + + if err := applyPBSDatastoreCfgFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSDatastoreCfgFromStage: %v", err) + } + if _, err := fakeFS.Stat("/etc/proxmox-backup/datastore.cfg"); err == nil { + t.Fatalf("expected datastore.cfg not to be created") + } +} + +func TestApplyPBSDatastoreCfgFromStage_PropagatesReadError(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + if err := fakeFS.MkdirAll(stageRoot+"/etc/proxmox-backup/datastore.cfg", 0o755); err != nil { + t.Fatalf("mkdir staged datastore.cfg dir: %v", err) + } + + if err := applyPBSDatastoreCfgFromStage(context.Background(), newTestLogger(), stageRoot); err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestApplyPBSDatastoreCfgFromStage_ContinuesWhenDeferredWriteFails(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Unix(10, 0)} + + safeDir := t.TempDir() + unsafeDir := t.TempDir() + if err := os.WriteFile(filepath.Join(unsafeDir, "unexpected"), []byte("x"), 0o644); err != nil { + t.Fatalf("write unexpected file: %v", err) + } + + stageRoot := "/stage" + staged := strings.Join([]string{ + "datastore: Safe", + fmt.Sprintf("path %s", safeDir), + "", + "datastore: Unsafe", + fmt.Sprintf("path %s", unsafeDir), + "", + }, "\n") + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/datastore.cfg", []byte(staged), 0o640); err != nil { + t.Fatalf("write staged datastore.cfg: %v", err) + } + + // Fail deferred file writes but still allow atomic apply. + fakeFS.WriteErr = errors.New("forced write error") + + if err := applyPBSDatastoreCfgFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSDatastoreCfgFromStage: %v", err) + } + + out, err := fakeFS.ReadFile("/etc/proxmox-backup/datastore.cfg") + if err != nil { + t.Fatalf("read applied datastore.cfg: %v", err) + } + if !strings.Contains(string(out), "datastore: Safe") || strings.Contains(string(out), "datastore: Unsafe") { + t.Fatalf("unexpected apply result: %q", string(out)) + } + + if entries, err := fakeFS.ReadDir("/tmp/proxsave"); err != nil { + t.Fatalf("readdir /tmp/proxsave: %v", err) + } else if len(entries) != 0 { + t.Fatalf("expected no deferred files due to forced write error, got %d", len(entries)) + } +} + +func TestApplyPBSDatastoreCfgFromStage_ReturnsErrorOnAtomicWriteFailure(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Unix(10, 0)} + + safeDir := t.TempDir() + stageRoot := "/stage" + staged := strings.Join([]string{ + "datastore: DS1", + fmt.Sprintf("path %s", safeDir), + "", + }, "\n") + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/datastore.cfg", []byte(staged), 0o640); err != nil { + t.Fatalf("write staged datastore.cfg: %v", err) + } + + dest := "/etc/proxmox-backup/datastore.cfg" + tmp := fmt.Sprintf("%s.proxsave.tmp.%d", dest, nowRestore().UnixNano()) + fakeFS.OpenFileErr[filepath.Clean(tmp)] = errors.New("forced OpenFile error") + + if err := applyPBSDatastoreCfgFromStage(context.Background(), newTestLogger(), stageRoot); err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestApplyPBSJobConfigsFromStage_WritesAllJobConfigs(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + for name, content := range map[string]string{ + "sync.cfg": "sync: job1\n remote r1\n", + "verification.cfg": "verification: v1\n datastore ds1\n", + "prune.cfg": "prune: p1\n keep-last 1\n", + } { + if err := fakeFS.WriteFile(filepath.Join(stageRoot, "etc/proxmox-backup", name), []byte(content), 0o640); err != nil { + t.Fatalf("write staged %s: %v", name, err) + } + } + + if err := applyPBSJobConfigsFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSJobConfigsFromStage: %v", err) + } + + for _, name := range []string{"sync.cfg", "verification.cfg", "prune.cfg"} { + dest := filepath.Join("/etc/proxmox-backup", name) + if _, err := fakeFS.Stat(dest); err != nil { + t.Fatalf("expected %s to exist: %v", dest, err) + } + } +} + +func TestApplyPBSJobConfigsFromStage_ContinuesOnApplyErrors(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Unix(10, 0)} + + stageRoot := "/stage" + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/sync.cfg", []byte("sync: job1\n remote r1\n"), 0o640); err != nil { + t.Fatalf("write staged sync.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/verification.cfg", []byte("verification: v1\n datastore ds1\n"), 0o640); err != nil { + t.Fatalf("write staged verification.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/prune.cfg", []byte("prune: p1\n keep-last 1\n"), 0o640); err != nil { + t.Fatalf("write staged prune.cfg: %v", err) + } + + relFail := "etc/proxmox-backup/verification.cfg" + destFail := filepath.Join(string(os.PathSeparator), filepath.FromSlash(relFail)) + tmpFail := fmt.Sprintf("%s.proxsave.tmp.%d", destFail, nowRestore().UnixNano()) + fakeFS.OpenFileErr[filepath.Clean(tmpFail)] = errors.New("forced OpenFile error") + + if err := applyPBSJobConfigsFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSJobConfigsFromStage: %v", err) + } + + if _, err := fakeFS.Stat("/etc/proxmox-backup/sync.cfg"); err != nil { + t.Fatalf("expected sync.cfg to exist: %v", err) + } + if _, err := fakeFS.Stat("/etc/proxmox-backup/prune.cfg"); err != nil { + t.Fatalf("expected prune.cfg to exist: %v", err) + } + if _, err := fakeFS.Stat("/etc/proxmox-backup/verification.cfg"); err == nil { + t.Fatalf("expected verification.cfg not to be created due to forced error") + } +} + +func TestApplyPBSTapeConfigsFromStage_WritesConfigsAndSensitiveKeys(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + for name, content := range map[string]string{ + "tape.cfg": "drive: d1\n path /dev/nst0\n", + "tape-job.cfg": "tape-job: job1\n drive d1\n", + "media-pool.cfg": "media-pool: pool1\n retention 30\n", + } { + if err := fakeFS.WriteFile(filepath.Join(stageRoot, "etc/proxmox-backup", name), []byte(content), 0o640); err != nil { + t.Fatalf("write staged %s: %v", name, err) + } + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/tape-encryption-keys.json", []byte(`{"keys":[{"fingerprint":"abc"}]}`), 0o640); err != nil { + t.Fatalf("write staged tape-encryption-keys.json: %v", err) + } + + if err := applyPBSTapeConfigsFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSTapeConfigsFromStage: %v", err) + } + + for _, name := range []string{"tape.cfg", "tape-job.cfg", "media-pool.cfg"} { + dest := filepath.Join("/etc/proxmox-backup", name) + if _, err := fakeFS.Stat(dest); err != nil { + t.Fatalf("expected %s to exist: %v", dest, err) + } + } + + if info, err := fakeFS.Stat("/etc/proxmox-backup/tape-encryption-keys.json"); err != nil { + t.Fatalf("stat tape-encryption-keys.json: %v", err) + } else if info.Mode().Perm() != 0o600 { + t.Fatalf("tape-encryption-keys.json mode=%#o want %#o", info.Mode().Perm(), 0o600) + } +} + +func TestApplyPBSTapeConfigsFromStage_ContinuesOnErrors(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + // Force applyPBSConfigFileFromStage and applySensitiveFileFromStage errors (ReadFile on directory). + if err := fakeFS.MkdirAll(stageRoot+"/etc/proxmox-backup/tape.cfg", 0o755); err != nil { + t.Fatalf("mkdir staged tape.cfg dir: %v", err) + } + if err := fakeFS.MkdirAll(stageRoot+"/etc/proxmox-backup/tape-encryption-keys.json", 0o755); err != nil { + t.Fatalf("mkdir staged tape-encryption-keys.json dir: %v", err) + } + + if err := applyPBSTapeConfigsFromStage(context.Background(), newTestLogger(), stageRoot); err != nil { + t.Fatalf("applyPBSTapeConfigsFromStage: %v", err) + } +} + +func TestRemoveIfExists_IgnoresMissingAndRemovesExisting(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := removeIfExists("/etc/proxmox-backup/missing.cfg"); err != nil { + t.Fatalf("removeIfExists missing: %v", err) + } + + if err := fakeFS.WriteFile("/etc/proxmox-backup/existing.cfg", []byte("x"), 0o640); err != nil { + t.Fatalf("write existing.cfg: %v", err) + } + if err := removeIfExists("/etc/proxmox-backup/existing.cfg"); err != nil { + t.Fatalf("removeIfExists existing: %v", err) + } + if _, err := fakeFS.Stat("/etc/proxmox-backup/existing.cfg"); err == nil { + t.Fatalf("expected existing.cfg removed") + } +} + +func TestRemoveIfExists_PropagatesNonExistErrors(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + // Make Remove fail with a non-ENOENT error. + if err := fakeFS.MkdirAll("/etc/proxmox-backup/dir", 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := fakeFS.WriteFile("/etc/proxmox-backup/dir/file", []byte("x"), 0o640); err != nil { + t.Fatalf("write: %v", err) + } + + if err := removeIfExists("/etc/proxmox-backup/dir"); err == nil { + t.Fatalf("expected error, got nil") + } +} diff --git a/internal/orchestrator/pbs_staged_apply_maybeapply_test.go b/internal/orchestrator/pbs_staged_apply_maybeapply_test.go new file mode 100644 index 0000000..3c4b5d1 --- /dev/null +++ b/internal/orchestrator/pbs_staged_apply_maybeapply_test.go @@ -0,0 +1,407 @@ +package orchestrator + +import ( + "context" + "errors" + "os" + "testing" + + "github.com/tis24dev/proxsave/internal/logging" +) + +func TestMaybeApplyPBSConfigsFromStage_SkipsWhenNonRoot(t *testing.T) { + origFS := restoreFS + origIsReal := pbsStagedApplyIsRealRestoreFSFn + origGeteuid := pbsStagedApplyGeteuidFn + t.Cleanup(func() { + restoreFS = origFS + pbsStagedApplyIsRealRestoreFSFn = origIsReal + pbsStagedApplyGeteuidFn = origGeteuid + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + pbsStagedApplyIsRealRestoreFSFn = func(FS) bool { return true } + pbsStagedApplyGeteuidFn = func() int { return 1000 } + + stageRoot := "/stage" + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/acme/accounts.cfg", []byte("account: a1\n foo bar\n"), 0o640); err != nil { + t.Fatalf("write staged accounts.cfg: %v", err) + } + + plan := &RestorePlan{ + SystemType: SystemTypePBS, + PBSRestoreBehavior: PBSRestoreBehaviorClean, + NormalCategories: []Category{{ID: "pbs_host"}}, + } + if err := maybeApplyPBSConfigsFromStage(context.Background(), newTestLogger(), plan, stageRoot, false); err != nil { + t.Fatalf("maybeApplyPBSConfigsFromStage: %v", err) + } + + if _, err := fakeFS.Stat("/etc/proxmox-backup/acme/accounts.cfg"); err == nil { + t.Fatalf("expected no writes when non-root") + } +} + +func TestMaybeApplyPBSConfigsFromStage_CleanMode_ApiUnavailableFallsBackToFiles(t *testing.T) { + origFS := restoreFS + origIsReal := pbsStagedApplyIsRealRestoreFSFn + origGeteuid := pbsStagedApplyGeteuidFn + origEnsure := pbsStagedApplyEnsurePBSServicesForAPIFn + t.Cleanup(func() { + restoreFS = origFS + pbsStagedApplyIsRealRestoreFSFn = origIsReal + pbsStagedApplyGeteuidFn = origGeteuid + pbsStagedApplyEnsurePBSServicesForAPIFn = origEnsure + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + pbsStagedApplyIsRealRestoreFSFn = func(FS) bool { return true } + pbsStagedApplyGeteuidFn = func() int { return 0 } + pbsStagedApplyEnsurePBSServicesForAPIFn = func(context.Context, *logging.Logger) error { + return errors.New("forced API unavailable") + } + + stageRoot := "/stage" + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/acme/accounts.cfg", []byte("account: a1\n foo bar\n"), 0o640); err != nil { + t.Fatalf("write staged accounts.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/traffic-control.cfg", []byte("traffic-control: tc1\n rate 10mbit\n"), 0o640); err != nil { + t.Fatalf("write staged traffic-control.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/node.cfg", []byte("node: n1\n description test\n"), 0o640); err != nil { + t.Fatalf("write staged node.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/s3.cfg", []byte("s3: r1\n bucket test\n"), 0o640); err != nil { + t.Fatalf("write staged s3.cfg: %v", err) + } + + safeDir := t.TempDir() + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/datastore.cfg", []byte("datastore: DS1\npath "+safeDir+"\n"), 0o640); err != nil { + t.Fatalf("write staged datastore.cfg: %v", err) + } + + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/remote.cfg", []byte("remote: r1\n host 10.0.0.10\n"), 0o640); err != nil { + t.Fatalf("write staged remote.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/sync.cfg", []byte("sync: job1\n remote r1\n"), 0o640); err != nil { + t.Fatalf("write staged sync.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/verification.cfg", []byte("verification: v1\n datastore DS1\n"), 0o640); err != nil { + t.Fatalf("write staged verification.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/prune.cfg", []byte("prune: p1\n keep-last 1\n"), 0o640); err != nil { + t.Fatalf("write staged prune.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/tape.cfg", []byte("drive: d1\n path /dev/nst0\n"), 0o640); err != nil { + t.Fatalf("write staged tape.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/tape-job.cfg", []byte("tape-job: job1\n drive d1\n"), 0o640); err != nil { + t.Fatalf("write staged tape-job.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/media-pool.cfg", []byte("media-pool: pool1\n retention 30\n"), 0o640); err != nil { + t.Fatalf("write staged media-pool.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/tape-encryption-keys.json", []byte(`{"keys":[{"fingerprint":"abc"}]}`), 0o640); err != nil { + t.Fatalf("write staged tape-encryption-keys.json: %v", err) + } + + plan := &RestorePlan{ + SystemType: SystemTypePBS, + PBSRestoreBehavior: PBSRestoreBehaviorClean, + NormalCategories: []Category{ + {ID: "pbs_host"}, + {ID: "datastore_pbs"}, + {ID: "pbs_remotes"}, + {ID: "pbs_jobs"}, + {ID: "pbs_tape"}, + }, + } + + if err := maybeApplyPBSConfigsFromStage(context.Background(), newTestLogger(), plan, stageRoot, false); err != nil { + t.Fatalf("maybeApplyPBSConfigsFromStage: %v", err) + } + + for _, path := range []string{ + "/etc/proxmox-backup/acme/accounts.cfg", + "/etc/proxmox-backup/traffic-control.cfg", + "/etc/proxmox-backup/node.cfg", + "/etc/proxmox-backup/s3.cfg", + "/etc/proxmox-backup/datastore.cfg", + "/etc/proxmox-backup/remote.cfg", + "/etc/proxmox-backup/sync.cfg", + "/etc/proxmox-backup/verification.cfg", + "/etc/proxmox-backup/prune.cfg", + "/etc/proxmox-backup/tape-encryption-keys.json", + } { + if _, err := fakeFS.Stat(path); err != nil { + t.Fatalf("expected %s to exist: %v", path, err) + } + } +} + +func TestMaybeApplyPBSConfigsFromStage_MergeMode_ApiUnavailableSkipsApiCategories(t *testing.T) { + origFS := restoreFS + origIsReal := pbsStagedApplyIsRealRestoreFSFn + origGeteuid := pbsStagedApplyGeteuidFn + origEnsure := pbsStagedApplyEnsurePBSServicesForAPIFn + t.Cleanup(func() { + restoreFS = origFS + pbsStagedApplyIsRealRestoreFSFn = origIsReal + pbsStagedApplyGeteuidFn = origGeteuid + pbsStagedApplyEnsurePBSServicesForAPIFn = origEnsure + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + pbsStagedApplyIsRealRestoreFSFn = func(FS) bool { return true } + pbsStagedApplyGeteuidFn = func() int { return 0 } + pbsStagedApplyEnsurePBSServicesForAPIFn = func(context.Context, *logging.Logger) error { + return errors.New("forced API unavailable") + } + + stageRoot := "/stage" + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/acme/accounts.cfg", []byte("account: a1\n foo bar\n"), 0o640); err != nil { + t.Fatalf("write staged accounts.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/node.cfg", []byte("node: n1\n description test\n"), 0o640); err != nil { + t.Fatalf("write staged node.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/s3.cfg", []byte("s3: r1\n bucket test\n"), 0o640); err != nil { + t.Fatalf("write staged s3.cfg: %v", err) + } + + plan := &RestorePlan{ + SystemType: SystemTypePBS, + PBSRestoreBehavior: PBSRestoreBehaviorMerge, + NormalCategories: []Category{ + {ID: "pbs_host"}, + {ID: "datastore_pbs"}, + {ID: "pbs_remotes"}, + {ID: "pbs_jobs"}, + }, + } + + if err := maybeApplyPBSConfigsFromStage(context.Background(), newTestLogger(), plan, stageRoot, false); err != nil { + t.Fatalf("maybeApplyPBSConfigsFromStage: %v", err) + } + + if _, err := fakeFS.Stat("/etc/proxmox-backup/acme/accounts.cfg"); err != nil { + t.Fatalf("expected accounts.cfg to exist: %v", err) + } + + // Merge mode requires API for these categories, so they must not be file-applied. + for _, path := range []string{ + "/etc/proxmox-backup/node.cfg", + "/etc/proxmox-backup/s3.cfg", + "/etc/proxmox-backup/datastore.cfg", + "/etc/proxmox-backup/remote.cfg", + "/etc/proxmox-backup/sync.cfg", + } { + if _, err := fakeFS.Stat(path); err == nil { + t.Fatalf("did not expect %s to be applied in merge mode without API", path) + } + } +} + +func TestMaybeApplyPBSConfigsFromStage_ApiErrorsTriggerFallbackOnlyInCleanMode(t *testing.T) { + origFS := restoreFS + origIsReal := pbsStagedApplyIsRealRestoreFSFn + origGeteuid := pbsStagedApplyGeteuidFn + origEnsure := pbsStagedApplyEnsurePBSServicesForAPIFn + origTraffic := pbsStagedApplyTrafficControlCfgViaAPIFn + origNode := pbsStagedApplyNodeCfgViaAPIFn + origS3 := pbsStagedApplyS3CfgViaAPIFn + origDS := pbsStagedApplyDatastoreCfgViaAPIFn + origRemote := pbsStagedApplyRemoteCfgViaAPIFn + origSync := pbsStagedApplySyncCfgViaAPIFn + origVerify := pbsStagedApplyVerificationCfgViaAPIFn + origPrune := pbsStagedApplyPruneCfgViaAPIFn + t.Cleanup(func() { + restoreFS = origFS + pbsStagedApplyIsRealRestoreFSFn = origIsReal + pbsStagedApplyGeteuidFn = origGeteuid + pbsStagedApplyEnsurePBSServicesForAPIFn = origEnsure + pbsStagedApplyTrafficControlCfgViaAPIFn = origTraffic + pbsStagedApplyNodeCfgViaAPIFn = origNode + pbsStagedApplyS3CfgViaAPIFn = origS3 + pbsStagedApplyDatastoreCfgViaAPIFn = origDS + pbsStagedApplyRemoteCfgViaAPIFn = origRemote + pbsStagedApplySyncCfgViaAPIFn = origSync + pbsStagedApplyVerificationCfgViaAPIFn = origVerify + pbsStagedApplyPruneCfgViaAPIFn = origPrune + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + pbsStagedApplyIsRealRestoreFSFn = func(FS) bool { return true } + pbsStagedApplyGeteuidFn = func() int { return 0 } + pbsStagedApplyEnsurePBSServicesForAPIFn = func(context.Context, *logging.Logger) error { return nil } + + var strictArgsClean []bool + var strictArgsMerge []bool + strictSink := func(strict bool) { + strictArgsClean = append(strictArgsClean, strict) + } + pbsStagedApplyTrafficControlCfgViaAPIFn = func(_ context.Context, _ *logging.Logger, _ string, strict bool) error { + strictSink(strict) + return errors.New("forced API error") + } + pbsStagedApplyNodeCfgViaAPIFn = func(context.Context, string) error { return errors.New("forced API error") } + pbsStagedApplyS3CfgViaAPIFn = func(_ context.Context, _ *logging.Logger, _ string, strict bool) error { + strictSink(strict) + return errors.New("forced API error") + } + pbsStagedApplyDatastoreCfgViaAPIFn = func(_ context.Context, _ *logging.Logger, _ string, strict bool) error { + strictSink(strict) + return errors.New("forced API error") + } + pbsStagedApplyRemoteCfgViaAPIFn = func(_ context.Context, _ *logging.Logger, _ string, strict bool) error { + strictSink(strict) + return errors.New("forced API error") + } + pbsStagedApplySyncCfgViaAPIFn = func(_ context.Context, _ *logging.Logger, _ string, strict bool) error { + strictSink(strict) + return errors.New("forced API error") + } + pbsStagedApplyVerificationCfgViaAPIFn = func(_ context.Context, _ *logging.Logger, _ string, strict bool) error { + strictSink(strict) + return errors.New("forced API error") + } + pbsStagedApplyPruneCfgViaAPIFn = func(_ context.Context, _ *logging.Logger, _ string, strict bool) error { + strictSink(strict) + return errors.New("forced API error") + } + + stageRoot := "/stage" + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/acme/accounts.cfg", []byte("account: a1\n foo bar\n"), 0o640); err != nil { + t.Fatalf("write staged accounts.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/traffic-control.cfg", []byte("traffic-control: tc1\n rate 10mbit\n"), 0o640); err != nil { + t.Fatalf("write staged traffic-control.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/node.cfg", []byte("node: n1\n description test\n"), 0o640); err != nil { + t.Fatalf("write staged node.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/s3.cfg", []byte("s3: r1\n bucket test\n"), 0o640); err != nil { + t.Fatalf("write staged s3.cfg: %v", err) + } + + safeDir := t.TempDir() + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/datastore.cfg", []byte("datastore: DS1\npath "+safeDir+"\n"), 0o640); err != nil { + t.Fatalf("write staged datastore.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/remote.cfg", []byte("remote: r1\n host 10.0.0.10\n"), 0o640); err != nil { + t.Fatalf("write staged remote.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/sync.cfg", []byte("sync: job1\n remote r1\n"), 0o640); err != nil { + t.Fatalf("write staged sync.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/verification.cfg", []byte("verification: v1\n datastore DS1\n"), 0o640); err != nil { + t.Fatalf("write staged verification.cfg: %v", err) + } + if err := fakeFS.WriteFile(stageRoot+"/etc/proxmox-backup/prune.cfg", []byte("prune: p1\n keep-last 1\n"), 0o640); err != nil { + t.Fatalf("write staged prune.cfg: %v", err) + } + + planClean := &RestorePlan{ + SystemType: SystemTypePBS, + PBSRestoreBehavior: PBSRestoreBehaviorClean, + NormalCategories: []Category{ + {ID: "pbs_host"}, + {ID: "datastore_pbs"}, + {ID: "pbs_remotes"}, + {ID: "pbs_jobs"}, + }, + } + if err := maybeApplyPBSConfigsFromStage(context.Background(), newTestLogger(), planClean, stageRoot, false); err != nil { + t.Fatalf("maybeApplyPBSConfigsFromStage clean: %v", err) + } + + for _, path := range []string{ + "/etc/proxmox-backup/traffic-control.cfg", + "/etc/proxmox-backup/node.cfg", + "/etc/proxmox-backup/s3.cfg", + "/etc/proxmox-backup/datastore.cfg", + "/etc/proxmox-backup/remote.cfg", + "/etc/proxmox-backup/sync.cfg", + } { + if _, err := fakeFS.Stat(path); err != nil { + t.Fatalf("expected %s to exist in clean fallback mode: %v", path, err) + } + } + + if len(strictArgsClean) == 0 { + t.Fatalf("expected strict API calls in clean mode") + } + for _, strict := range strictArgsClean { + if !strict { + t.Fatalf("expected strict=true in clean mode, got false") + } + } + + // In merge mode, the same API failures must not trigger file-based fallbacks. + strictSink = func(strict bool) { + strictArgsMerge = append(strictArgsMerge, strict) + } + fakeFS2 := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS2.Root) }) + restoreFS = fakeFS2 + if err := fakeFS2.WriteFile(stageRoot+"/etc/proxmox-backup/acme/accounts.cfg", []byte("account: a1\n foo bar\n"), 0o640); err != nil { + t.Fatalf("write staged accounts.cfg (merge): %v", err) + } + if err := fakeFS2.WriteFile(stageRoot+"/etc/proxmox-backup/node.cfg", []byte("node: n1\n description test\n"), 0o640); err != nil { + t.Fatalf("write staged node.cfg (merge): %v", err) + } + if err := fakeFS2.WriteFile(stageRoot+"/etc/proxmox-backup/s3.cfg", []byte("s3: r1\n bucket test\n"), 0o640); err != nil { + t.Fatalf("write staged s3.cfg (merge): %v", err) + } + if err := fakeFS2.WriteFile(stageRoot+"/etc/proxmox-backup/remote.cfg", []byte("remote: r1\n host 10.0.0.10\n"), 0o640); err != nil { + t.Fatalf("write staged remote.cfg (merge): %v", err) + } + if err := fakeFS2.WriteFile(stageRoot+"/etc/proxmox-backup/sync.cfg", []byte("sync: job1\n remote r1\n"), 0o640); err != nil { + t.Fatalf("write staged sync.cfg (merge): %v", err) + } + + planMerge := &RestorePlan{ + SystemType: SystemTypePBS, + PBSRestoreBehavior: PBSRestoreBehaviorMerge, + NormalCategories: []Category{ + {ID: "pbs_host"}, + {ID: "datastore_pbs"}, + {ID: "pbs_remotes"}, + {ID: "pbs_jobs"}, + }, + } + if err := maybeApplyPBSConfigsFromStage(context.Background(), newTestLogger(), planMerge, stageRoot, false); err != nil { + t.Fatalf("maybeApplyPBSConfigsFromStage merge: %v", err) + } + for _, path := range []string{ + "/etc/proxmox-backup/node.cfg", + "/etc/proxmox-backup/s3.cfg", + "/etc/proxmox-backup/remote.cfg", + "/etc/proxmox-backup/sync.cfg", + } { + if _, err := fakeFS2.Stat(path); err == nil { + t.Fatalf("did not expect %s to be file-applied in merge mode", path) + } + } + + if len(strictArgsMerge) == 0 { + t.Fatalf("expected strict API calls in merge mode") + } + for _, strict := range strictArgsMerge { + if strict { + t.Fatalf("expected strict=false in merge mode, got true") + } + } +} From eb2f074faa831c8c1d85cc19a4c914b4656f2a29 Mon Sep 17 00:00:00 2001 From: Damiano <71268257+tis24dev@users.noreply.github.com> Date: Tue, 24 Feb 2026 18:20:30 +0100 Subject: [PATCH 09/12] Prefill TUI install wizard from template Pass baseTemplate into the TUI install flow and prefill form fields when editing an existing env template. Add installWizardPrefill and deriveInstallWizardPrefill to parse env templates (parseEnvTemplate, readTemplateString, readTemplateBool) and set initial values for secondary/cloud/firewall/notifications/encryption. Set dropdown defaults via boolToOptionIndex, trim input values, and tighten rclone validation to reject empty backup/log entries. Preserve existing TELEGRAM and EMAIL delivery preferences when applying install data to an existing template. Misc: adjust some field labels, add bufio import, and update the cmd/proxsave call to forward baseTemplate. --- cmd/proxsave/install_tui.go | 2 +- internal/tui/wizard/install.go | 177 ++++++++++++++++++++++++++++----- 2 files changed, 151 insertions(+), 28 deletions(-) diff --git a/cmd/proxsave/install_tui.go b/cmd/proxsave/install_tui.go index 7d6d8ee..d61f4b0 100644 --- a/cmd/proxsave/install_tui.go +++ b/cmd/proxsave/install_tui.go @@ -93,7 +93,7 @@ func runInstallTUI(ctx context.Context, configPath string, bootstrap *logging.Bo if !skipConfigWizard { // Run the wizard logging.DebugStepBootstrap(bootstrap, "install workflow (tui)", "running install wizard") - wizardData, err = wizard.RunInstallWizard(ctx, configPath, baseDir, buildSig) + wizardData, err = wizard.RunInstallWizard(ctx, configPath, baseDir, buildSig, baseTemplate) if err != nil { if errors.Is(err, wizard.ErrInstallCancelled) { return wrapInstallError(errInteractiveAborted) diff --git a/internal/tui/wizard/install.go b/internal/tui/wizard/install.go index f1999ca..66484b9 100644 --- a/internal/tui/wizard/install.go +++ b/internal/tui/wizard/install.go @@ -1,6 +1,7 @@ package wizard import ( + "bufio" "context" "errors" "fmt" @@ -18,6 +19,19 @@ import ( "github.com/tis24dev/proxsave/pkg/utils" ) +type installWizardPrefill struct { + SecondaryEnabled bool + SecondaryPath string + SecondaryLogPath string + CloudEnabled bool + CloudRemote string + CloudLogPath string + FirewallEnabled bool + TelegramEnabled bool + EmailEnabled bool + EncryptionEnabled bool +} + // InstallWizardData holds the collected installation data type InstallWizardData struct { BaseDir string @@ -52,7 +66,7 @@ var ( ) // RunInstallWizard runs the TUI-based installation wizard -func RunInstallWizard(ctx context.Context, configPath string, baseDir string, buildSig string) (*InstallWizardData, error) { +func RunInstallWizard(ctx context.Context, configPath string, baseDir string, buildSig string, baseTemplate string) (*InstallWizardData, error) { defaultFirewallRules := false data := &InstallWizardData{ BaseDir: baseDir, @@ -64,6 +78,8 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu app := tui.NewApp() + prefill := deriveInstallWizardPrefill(baseTemplate) + // Build the form form := components.NewForm(app) @@ -94,7 +110,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu var dropdownOpen bool // Secondary Storage section - var secondaryEnabled bool + secondaryEnabled := prefill.SecondaryEnabled var secondaryPathField, secondaryLogField *tview.InputField secondaryDropdown := tview.NewDropDown(). @@ -110,7 +126,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu } dropdownOpen = false }). - SetCurrentOption(0) + SetCurrentOption(boolToOptionIndex(secondaryEnabled)) secondaryDropdown.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyEnter { @@ -134,18 +150,24 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu SetLabel(" └─ Secondary Backup Path"). SetText("/mnt/secondary-backup"). SetFieldWidth(40) - secondaryPathField.SetDisabled(true) + if prefill.SecondaryPath != "" { + secondaryPathField.SetText(prefill.SecondaryPath) + } + secondaryPathField.SetDisabled(!secondaryEnabled) form.Form.AddFormItem(secondaryPathField) secondaryLogField = tview.NewInputField(). SetLabel(" └─ Secondary Log Path"). SetText("/mnt/secondary-backup/logs"). SetFieldWidth(40) - secondaryLogField.SetDisabled(true) + if prefill.SecondaryLogPath != "" { + secondaryLogField.SetText(prefill.SecondaryLogPath) + } + secondaryLogField.SetDisabled(!secondaryEnabled) form.Form.AddFormItem(secondaryLogField) // Cloud Storage section - var cloudEnabled bool + cloudEnabled := prefill.CloudEnabled var rcloneBackupField, rcloneLogField *tview.InputField cloudDropdown := tview.NewDropDown(). @@ -160,7 +182,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu } dropdownOpen = false }). - SetCurrentOption(0) + SetCurrentOption(boolToOptionIndex(cloudEnabled)) cloudDropdown.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyEnter { @@ -174,7 +196,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu form.Form.AddFormItem(cloudDropdown) cloudHint := tview.NewInputField(). - SetLabel(" Tip: remotename:path (via 'rclone config'), e.g. myremote:pbs-backups"). + SetLabel(" Tip: remote name (via 'rclone config'), e.g. myremote (or myremote:path)"). SetFieldWidth(0). SetText("") cloudHint.SetDisabled(true) @@ -184,25 +206,31 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu SetLabel(" └─ Rclone Backup Remote"). SetText("myremote:pbs-backups"). SetFieldWidth(40) - rcloneBackupField.SetDisabled(true) + if prefill.CloudRemote != "" { + rcloneBackupField.SetText(prefill.CloudRemote) + } + rcloneBackupField.SetDisabled(!cloudEnabled) form.Form.AddFormItem(rcloneBackupField) rcloneLogField = tview.NewInputField(). - SetLabel(" └─ Rclone Log Remote"). + SetLabel(" └─ Rclone Log Path"). SetText("myremote:pbs-logs"). SetFieldWidth(40) - rcloneLogField.SetDisabled(true) + if prefill.CloudLogPath != "" { + rcloneLogField.SetText(prefill.CloudLogPath) + } + rcloneLogField.SetDisabled(!cloudEnabled) form.Form.AddFormItem(rcloneLogField) // Firewall rules backup (system collection) - firewallEnabled := false + firewallEnabled := prefill.FirewallEnabled firewallDropdown := tview.NewDropDown(). SetLabel("Backup Firewall Rules (iptables/nftables)"). SetOptions([]string{"No", "Yes"}, func(option string, index int) { firewallEnabled = (option == "Yes") dropdownOpen = false }). - SetCurrentOption(0) + SetCurrentOption(boolToOptionIndex(firewallEnabled)) firewallDropdown.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyEnter { @@ -216,7 +244,8 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu form.Form.AddFormItem(firewallDropdown) // Notifications (header + two toggles) - var telegramEnabled, emailEnabled bool + telegramEnabled := prefill.TelegramEnabled + emailEnabled := prefill.EmailEnabled notificationHeader := tview.NewInputField(). SetLabel("Notifications"). SetFieldWidth(0). @@ -230,7 +259,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu telegramEnabled = (option == "Yes") dropdownOpen = false }). - SetCurrentOption(0) + SetCurrentOption(boolToOptionIndex(telegramEnabled)) telegramDropdown.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyEnter { dropdownOpen = !dropdownOpen @@ -247,7 +276,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu emailEnabled = (option == "Yes") dropdownOpen = false }). - SetCurrentOption(0) + SetCurrentOption(boolToOptionIndex(emailEnabled)) emailDropdown.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyEnter { dropdownOpen = !dropdownOpen @@ -264,7 +293,7 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu SetOptions([]string{"No", "Yes"}, func(option string, index int) { dropdownOpen = false }). - SetCurrentOption(0) + SetCurrentOption(boolToOptionIndex(prefill.EncryptionEnabled)) encryptionDropdown.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyEnter { @@ -312,15 +341,15 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu data.EnableCloudStorage = cloudEnabled if cloudEnabled { - data.RcloneBackupRemote = rcloneBackupField.GetText() - data.RcloneLogRemote = rcloneLogField.GetText() + data.RcloneBackupRemote = strings.TrimSpace(rcloneBackupField.GetText()) + data.RcloneLogRemote = strings.TrimSpace(rcloneLogField.GetText()) - // Validate rclone remotes - if !strings.Contains(data.RcloneBackupRemote, ":") { - return fmt.Errorf("rclone backup remote must be in format 'remote:path'") + // Validate rclone inputs (allow both "remote" and "remote:path", logs can also be path-only) + if data.RcloneBackupRemote == "" { + return fmt.Errorf("rclone backup remote cannot be empty") } - if !strings.Contains(data.RcloneLogRemote, ":") { - return fmt.Errorf("rclone log remote must be in format 'remote:path'") + if data.RcloneLogRemote == "" { + return fmt.Errorf("rclone log path cannot be empty") } } @@ -454,6 +483,11 @@ func RunInstallWizard(ctx context.Context, configPath string, baseDir string, bu // If baseTemplate is empty, the embedded default template is used. func ApplyInstallData(baseTemplate string, data *InstallWizardData) (string, error) { template := baseTemplate + editingExisting := strings.TrimSpace(baseTemplate) != "" + existingValues := map[string]string{} + if editingExisting { + existingValues = parseEnvTemplate(baseTemplate) + } if strings.TrimSpace(template) == "" { template = config.DefaultEnvTemplate() } @@ -497,15 +531,23 @@ func ApplyInstallData(baseTemplate string, data *InstallWizardData) (string, err // Apply notifications if data.NotificationMode == "telegram" || data.NotificationMode == "both" { template = setEnvValue(template, "TELEGRAM_ENABLED", "true") - template = setEnvValue(template, "BOT_TELEGRAM_TYPE", "centralized") + // Preserve existing telegram mode when editing an existing config. + if !editingExisting || strings.TrimSpace(existingValues["BOT_TELEGRAM_TYPE"]) == "" { + template = setEnvValue(template, "BOT_TELEGRAM_TYPE", "centralized") + } } else { template = setEnvValue(template, "TELEGRAM_ENABLED", "false") } if data.NotificationMode == "email" || data.NotificationMode == "both" { template = setEnvValue(template, "EMAIL_ENABLED", "true") - template = setEnvValue(template, "EMAIL_DELIVERY_METHOD", "relay") - template = setEnvValue(template, "EMAIL_FALLBACK_SENDMAIL", "true") + // Preserve existing delivery preferences when editing an existing config. + if !editingExisting || strings.TrimSpace(existingValues["EMAIL_DELIVERY_METHOD"]) == "" { + template = setEnvValue(template, "EMAIL_DELIVERY_METHOD", "relay") + } + if !editingExisting || strings.TrimSpace(existingValues["EMAIL_FALLBACK_SENDMAIL"]) == "" { + template = setEnvValue(template, "EMAIL_FALLBACK_SENDMAIL", "true") + } } else { template = setEnvValue(template, "EMAIL_ENABLED", "false") } @@ -560,6 +602,87 @@ func unsetEnvValue(template, key string) string { return strings.Join(out, "\n") } +func boolToOptionIndex(value bool) int { + if value { + return 1 + } + return 0 +} + +func deriveInstallWizardPrefill(baseTemplate string) installWizardPrefill { + out := installWizardPrefill{} + if strings.TrimSpace(baseTemplate) == "" { + return out + } + values := parseEnvTemplate(baseTemplate) + + out.SecondaryEnabled = readTemplateBool(values, "SECONDARY_ENABLED", "ENABLE_SECONDARY_BACKUP") + out.SecondaryPath = readTemplateString(values, "SECONDARY_PATH", "SECONDARY_BACKUP_PATH") + out.SecondaryLogPath = readTemplateString(values, "SECONDARY_LOG_PATH") + + out.CloudEnabled = readTemplateBool(values, "CLOUD_ENABLED", "ENABLE_CLOUD_BACKUP") + out.CloudRemote = readTemplateString(values, "CLOUD_REMOTE", "RCLONE_REMOTE") + out.CloudLogPath = readTemplateString(values, "CLOUD_LOG_PATH") + + out.FirewallEnabled = readTemplateBool(values, "BACKUP_FIREWALL_RULES") + + out.TelegramEnabled = readTemplateBool(values, "TELEGRAM_ENABLED") + out.EmailEnabled = readTemplateBool(values, "EMAIL_ENABLED") + + out.EncryptionEnabled = readTemplateBool(values, "ENCRYPT_ARCHIVE") + + return out +} + +func parseEnvTemplate(template string) map[string]string { + values := make(map[string]string) + + scanner := bufio.NewScanner(strings.NewReader(template)) + for scanner.Scan() { + line := strings.TrimRight(scanner.Text(), "\r") + trimmed := strings.TrimSpace(line) + if utils.IsComment(trimmed) { + continue + } + + key, value, ok := utils.SplitKeyValue(line) + if !ok { + continue + } + if fields := strings.Fields(key); len(fields) >= 2 && fields[0] == "export" { + key = fields[1] + } + key = strings.ToUpper(strings.TrimSpace(key)) + if key == "" { + continue + } + values[key] = strings.TrimSpace(value) + } + + return values +} + +func readTemplateString(values map[string]string, keys ...string) string { + for _, key := range keys { + key = strings.ToUpper(strings.TrimSpace(key)) + if key == "" { + continue + } + if val, ok := values[key]; ok { + return strings.TrimSpace(val) + } + } + return "" +} + +func readTemplateBool(values map[string]string, keys ...string) bool { + raw := readTemplateString(values, keys...) + if strings.TrimSpace(raw) == "" { + return false + } + return utils.ParseBool(raw) +} + // CheckExistingConfig checks if config file exists and asks how to proceed func CheckExistingConfig(configPath string, buildSig string) (ExistingConfigAction, error) { if _, err := os.Stat(configPath); err == nil { From 41f280f5a70147bf70b03cc4e8bbd0061e283ffb Mon Sep 17 00:00:00 2001 From: Damiano <71268257+tis24dev@users.noreply.github.com> Date: Tue, 24 Feb 2026 19:14:21 +0100 Subject: [PATCH 10/12] Docs: installer prompts for existing config Clarify installer behavior when a configuration file already exists and refine wizard prompts. Adds TUI options (Overwrite / Edit existing / Keep & exit), documents CLI mode overwrite prompt (choosing No keeps the file and skips the wizard), and notes that cron schedule selection (HH:MM) is TUI-only. Also clarifies cloud storage/rclone guidance and adjusts wizard step numbering in CLI_REFERENCE.md and INSTALL.md. --- docs/CLI_REFERENCE.md | 7 ++++++- docs/INSTALL.md | 8 +++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/docs/CLI_REFERENCE.md b/docs/CLI_REFERENCE.md index d69cdc1..456db0b 100644 --- a/docs/CLI_REFERENCE.md +++ b/docs/CLI_REFERENCE.md @@ -130,6 +130,10 @@ Some interactive commands support two interface modes: **Use `--cli` when**: TUI rendering issues occur or advanced debugging is needed. +**Existing configuration**: +- If the configuration file already exists, the **TUI wizard** prompts you to **Overwrite**, **Edit existing** (uses the current file as base and pre-fills the wizard fields), or **Keep & exit**. +- In **CLI mode** (`--cli`), you will be prompted to overwrite; choosing "No" keeps the file and skips the configuration wizard. + **Wizard workflow**: 1. Generates/updates the configuration file (`configs/backup.env` by default) 2. Optionally configures secondary storage @@ -137,7 +141,8 @@ Some interactive commands support two interface modes: 4. Optionally enables firewall rules collection (`BACKUP_FIREWALL_RULES=false` by default) 5. Optionally sets up notifications (Telegram, Email; Email defaults to `EMAIL_DELIVERY_METHOD=relay`) 6. Optionally configures encryption (AGE setup) -7. Finalizes installation (symlinks, cron migration, permission checks) +7. (TUI) Optionally selects a cron time (HH:MM) for the `proxsave` cron entry +8. Finalizes installation (symlinks, cron migration, permission checks) ### Configuration Upgrade diff --git a/docs/INSTALL.md b/docs/INSTALL.md index fc45a23..fe996cd 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -204,14 +204,20 @@ The installation wizard creates your configuration file interactively: ./build/proxsave --new-install ``` +If the configuration file already exists, the **TUI wizard** will ask whether to: +- **Overwrite** (start from the embedded template) +- **Edit existing** (use the current file as base and pre-fill the wizard fields) +- **Keep & exit** (leave the file untouched and exit) + **Wizard prompts:** 1. **Configuration file path**: Default `configs/backup.env` (accepts absolute or relative paths within repo) 2. **Secondary storage**: Optional path for backup/log copies -3. **Cloud storage**: Optional rclone remote configuration +3. **Cloud storage (rclone)**: Optional rclone configuration (supports `CLOUD_REMOTE` as a remote name (recommended) or legacy `remote:path`; `CLOUD_LOG_PATH` supports path-only (recommended) or `otherremote:/path`) 4. **Firewall rules**: Optional firewall rules collection toggle (`BACKUP_FIREWALL_RULES=false` by default; supports iptables/nftables) 5. **Notifications**: Enable Telegram (centralized) and Email notifications (wizard defaults to `EMAIL_DELIVERY_METHOD=relay`; you can switch to `sendmail` or `pmf` later) 6. **Encryption**: AGE encryption setup (runs sub-wizard immediately if enabled) +7. **Cron schedule**: Choose cron time (HH:MM) for the `proxsave` cron entry (TUI mode only) **Features:** From 9d6d87d8280c51d67f65f409411ca782f7491595 Mon Sep 17 00:00:00 2001 From: Damiano <71268257+tis24dev@users.noreply.github.com> Date: Tue, 24 Feb 2026 22:22:00 +0100 Subject: [PATCH 11/12] Add optional post-install audit (dry-run) Introduce an optional post-install audit that runs a proxsave --dry-run to detect unused BACKUP_* collectors and offer to disable them. Changes: - CLI: runPostInstallAuditCLI prompts the user, runs the dry-run, parses actionable "set KEY=false" hints, and can update the config (atomic tmp file, keys sorted). - TUI: integrate RunPostInstallAuditWizard into the install TUI flow and provide an interactive review UI to disable selected suggestions. - Core: add internal/tui/wizard/post_install_audit_core.go to run the dry-run and extract/normalize actionable warning lines; implement suggestion collection and filtering only for allowed BACKUP_* keys that are currently enabled. - TUI wizard: add internal/tui/wizard/post_install_audit_tui.go implementing the interactive review, applyAuditDisables, and atomic write helper. - Tests: add unit tests for parsing/collection logic in post_install_audit_core_test.go. - Docs: update CLI_REFERENCE.md and INSTALL.md to document the optional post-install dry-run audit. The audit is non-blocking on failures (warnings/errors are logged but do not fail installation) and keeps changes explicit by requiring user confirmation before modifying backup.env. --- cmd/proxsave/install.go | 83 ++++ cmd/proxsave/install_tui.go | 8 + docs/CLI_REFERENCE.md | 3 +- docs/INSTALL.md | 2 + .../tui/wizard/post_install_audit_core.go | 265 +++++++++++++ .../wizard/post_install_audit_core_test.go | 119 ++++++ internal/tui/wizard/post_install_audit_tui.go | 364 ++++++++++++++++++ 7 files changed, 843 insertions(+), 1 deletion(-) create mode 100644 internal/tui/wizard/post_install_audit_core.go create mode 100644 internal/tui/wizard/post_install_audit_core_test.go create mode 100644 internal/tui/wizard/post_install_audit_tui.go diff --git a/cmd/proxsave/install.go b/cmd/proxsave/install.go index d95e881..f3c72f9 100644 --- a/cmd/proxsave/install.go +++ b/cmd/proxsave/install.go @@ -9,6 +9,7 @@ import ( "os" "os/exec" "path/filepath" + "sort" "strings" "github.com/tis24dev/proxsave/internal/config" @@ -88,6 +89,14 @@ func runInstall(ctx context.Context, configPath string, bootstrap *logging.Boots return err } + // Optional post-install audit: run a dry-run and offer to disable unused collectors. + if !skipConfigWizard { + logging.DebugStepBootstrap(bootstrap, "install workflow (cli)", "post-install audit") + if err := runPostInstallAuditCLI(ctx, reader, execInfo.ExecPath, configPath, bootstrap); err != nil { + return err + } + } + logging.DebugStepBootstrap(bootstrap, "install workflow (cli)", "finalizing symlinks and cron") runPostInstallSymlinksAndCron(ctx, baseDir, execInfo, bootstrap) @@ -108,6 +117,80 @@ func runInstall(ctx context.Context, configPath string, bootstrap *logging.Boots return nil } +func runPostInstallAuditCLI(ctx context.Context, reader *bufio.Reader, execPath, configPath string, bootstrap *logging.BootstrapLogger) error { + fmt.Println("\n--- Post-install check (optional) ---") + run, err := promptYesNo(ctx, reader, "Run a dry-run to detect unused components and reduce warnings? [Y/n]: ", true) + if err != nil { + return wrapInstallError(err) + } + if !run { + return nil + } + + if bootstrap != nil { + bootstrap.Info("Running post-install dry-run audit (this may take a minute)") + } + + suggestions, err := wizard.CollectPostInstallDisableSuggestions(ctx, execPath, configPath) + if err != nil { + fmt.Printf("WARNING: Post-install check failed (non-blocking): %v\n", err) + return nil + } + if len(suggestions) == 0 { + fmt.Println("No unused components detected. No changes required.") + return nil + } + + fmt.Printf("Detected %d unused/optional component(s) that may cause WARNINGs.\n", len(suggestions)) + for _, s := range suggestions { + reason := "" + if len(s.Messages) > 0 { + reason = strings.TrimSpace(s.Messages[0]) + } + if reason != "" { + fmt.Printf(" - %s: %s\n", s.Key, reason) + } else { + fmt.Printf(" - %s\n", s.Key) + } + } + fmt.Println() + + disableAll, err := promptYesNo(ctx, reader, "Disable all suggested components now (set KEY=false)? [y/N]: ", false) + if err != nil { + return wrapInstallError(err) + } + if !disableAll { + fmt.Println("No changes applied. You can disable unused components later by editing backup.env.") + return nil + } + + contentBytes, err := os.ReadFile(configPath) + if err != nil { + fmt.Printf("ERROR: Unable to update configuration (read failed): %v\n", err) + return nil + } + content := string(contentBytes) + + keys := make([]string, 0, len(suggestions)) + for _, s := range suggestions { + keys = append(keys, s.Key) + } + sort.Strings(keys) + for _, key := range keys { + content = setEnvValue(content, key, "false") + } + + tmpAuditPath := configPath + ".tmp.audit" + defer cleanupTempConfig(tmpAuditPath) + if err := writeConfigFile(configPath, tmpAuditPath, content); err != nil { + fmt.Printf("ERROR: Unable to update configuration (write failed): %v\n", err) + return nil + } + + fmt.Printf("✓ Updated %s: disabled %d component(s).\n", configPath, len(keys)) + return nil +} + func runNewInstall(ctx context.Context, configPath string, bootstrap *logging.BootstrapLogger, useCLI bool) (err error) { done := logging.DebugStartBootstrap(bootstrap, "new-install workflow", "config=%s", configPath) defer func() { done(err) }() diff --git a/cmd/proxsave/install_tui.go b/cmd/proxsave/install_tui.go index d61f4b0..3567e06 100644 --- a/cmd/proxsave/install_tui.go +++ b/cmd/proxsave/install_tui.go @@ -179,6 +179,14 @@ func runInstallTUI(ctx context.Context, configPath string, bootstrap *logging.Bo bootstrap.Info("IMPORTANT: Keep your passphrase/private key offline and secure!") } + // Optional post-install audit: run a dry-run and offer to disable unused collectors + // based on actionable warning hints like "set BACKUP_*=false to disable". + if _, auditErr := wizard.RunPostInstallAuditWizard(ctx, execInfo.ExecPath, configPath, buildSig); auditErr != nil { + if bootstrap != nil { + bootstrap.Warning("Post-install check failed (non-blocking): %v", auditErr) + } + } + // Clean up legacy bash-based symlinks if bootstrap != nil { bootstrap.Info("Cleaning up legacy bash-based symlinks (if present)") diff --git a/docs/CLI_REFERENCE.md b/docs/CLI_REFERENCE.md index 456db0b..46d19c2 100644 --- a/docs/CLI_REFERENCE.md +++ b/docs/CLI_REFERENCE.md @@ -142,7 +142,8 @@ Some interactive commands support two interface modes: 5. Optionally sets up notifications (Telegram, Email; Email defaults to `EMAIL_DELIVERY_METHOD=relay`) 6. Optionally configures encryption (AGE setup) 7. (TUI) Optionally selects a cron time (HH:MM) for the `proxsave` cron entry -8. Finalizes installation (symlinks, cron migration, permission checks) +8. Optionally runs a post-install dry-run audit and offers to disable unused collectors (actionable hints like `set BACKUP_*=false to disable`) +9. Finalizes installation (symlinks, cron migration, permission checks) ### Configuration Upgrade diff --git a/docs/INSTALL.md b/docs/INSTALL.md index fe996cd..a9144e6 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -218,6 +218,7 @@ If the configuration file already exists, the **TUI wizard** will ask whether to 5. **Notifications**: Enable Telegram (centralized) and Email notifications (wizard defaults to `EMAIL_DELIVERY_METHOD=relay`; you can switch to `sendmail` or `pmf` later) 6. **Encryption**: AGE encryption setup (runs sub-wizard immediately if enabled) 7. **Cron schedule**: Choose cron time (HH:MM) for the `proxsave` cron entry (TUI mode only) +8. **Post-install check (optional)**: Runs `proxsave --dry-run` and shows actionable warnings like `set BACKUP_*=false to disable`, allowing you to disable unused collectors and reduce WARNING noise **Features:** @@ -225,6 +226,7 @@ If the configuration file already exists, the **TUI wizard** will ask whether to - Template comment preservation - Creates all necessary directories with proper permissions (0700) - Immediate AGE key generation if encryption is enabled +- Optional post-install audit to disable unused collectors (keeps changes explicit; nothing is disabled silently) After completion, edit `configs/backup.env` manually for advanced options. diff --git a/internal/tui/wizard/post_install_audit_core.go b/internal/tui/wizard/post_install_audit_core.go new file mode 100644 index 0000000..5d9a76a --- /dev/null +++ b/internal/tui/wizard/post_install_audit_core.go @@ -0,0 +1,265 @@ +package wizard + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "regexp" + "sort" + "strings" + "sync" + "time" + + "github.com/tis24dev/proxsave/internal/config" +) + +// PostInstallAuditSuggestion represents an optional feature that appears to be enabled +// but not configured/detected on this system. Users can disable the feature by setting +// the corresponding KEY=false in backup.env. +type PostInstallAuditSuggestion struct { + Key string + Messages []string +} + +var ( + postInstallAuditDisableHintRe = regexp.MustCompile(`(?i)\bset\s+([A-Z0-9_]+)=false\b`) + postInstallAuditANSISGRRe = regexp.MustCompile(`\x1b\[[0-9;]*m`) + + postInstallAuditAllowedKeysOnce sync.Once + postInstallAuditAllowedKeys map[string]struct{} + + postInstallAuditRunner = runPostInstallAuditDryRun +) + +func postInstallAuditAllowedKeysSet() map[string]struct{} { + postInstallAuditAllowedKeysOnce.Do(func() { + allowed := make(map[string]struct{}) + values := parseEnvTemplate(config.DefaultEnvTemplate()) + for key := range values { + key = strings.ToUpper(strings.TrimSpace(key)) + if strings.HasPrefix(key, "BACKUP_") { + allowed[key] = struct{}{} + } + } + postInstallAuditAllowedKeys = allowed + }) + return postInstallAuditAllowedKeys +} + +func runPostInstallAuditDryRun(ctx context.Context, execPath, configPath string) (output string, exitCode int, err error) { + // Run a dry-run with warning-level logs to keep output minimal while still capturing + // all actionable "set KEY=false" hints. + cmd := exec.CommandContext(ctx, execPath, + "--dry-run", + "--log-level", "warning", + "--config", configPath, + ) + out, runErr := cmd.CombinedOutput() + if runErr == nil { + return string(out), 0, nil + } + + var exitErr *exec.ExitError + if errors.As(runErr, &exitErr) { + // Non-zero exit codes are expected when warnings are emitted (exit code 1). + return string(out), exitErr.ExitCode(), nil + } + return string(out), -1, fmt.Errorf("post-install audit dry-run failed: %w", runErr) +} + +// CollectPostInstallDisableSuggestions runs a proxsave dry-run and extracts actionable +// "set KEY=false" hints from the resulting warnings/errors. It only returns keys that: +// - exist in the embedded template, and +// - start with "BACKUP_", and +// - are currently enabled (true) in the provided config file. +func CollectPostInstallDisableSuggestions(ctx context.Context, execPath, configPath string) ([]PostInstallAuditSuggestion, error) { + if strings.TrimSpace(execPath) == "" { + return nil, fmt.Errorf("exec path cannot be empty") + } + if strings.TrimSpace(configPath) == "" { + return nil, fmt.Errorf("config path cannot be empty") + } + + configBytes, err := os.ReadFile(configPath) + if err != nil { + return nil, fmt.Errorf("read configuration for audit: %w", err) + } + configValues := parseEnvTemplate(string(configBytes)) + allowed := postInstallAuditAllowedKeysSet() + + auditCtx, cancel := context.WithTimeout(ctx, 90*time.Second) + defer cancel() + + output, _, err := postInstallAuditRunner(auditCtx, execPath, configPath) + if err != nil { + return nil, err + } + + issueLines := extractIssueLinesFromProxsaveOutput(output) + return extractDisableSuggestionsFromIssueLines(issueLines, allowed, configValues), nil +} + +func extractIssueLinesFromProxsaveOutput(output string) []string { + lines := splitNormalizedLines(output) + + // Prefer the end-of-run summary, which is clean (no ANSI) and deduplicated. + const header = "WARNINGS/ERRORS DURING RUN" + inSummary := false + issues := make([]string, 0, 16) + + for _, line := range lines { + if strings.Contains(line, header) { + inSummary = true + continue + } + if !inSummary { + continue + } + trimmed := strings.TrimSpace(line) + if trimmed == "" { + continue + } + if strings.HasPrefix(trimmed, "===========================================") { + break + } + issues = append(issues, trimmed) + } + + if len(issues) > 0 { + return issues + } + + // Fallback: scan the entire output and keep only actionable lines. This is less + // robust because the live log output may contain ANSI codes. + for _, line := range lines { + clean := stripANSI(strings.TrimSpace(line)) + if clean == "" { + continue + } + if postInstallAuditDisableHintRe.MatchString(clean) { + issues = append(issues, clean) + } + } + return issues +} + +func extractDisableSuggestionsFromIssueLines(issueLines []string, allowed map[string]struct{}, configValues map[string]string) []PostInstallAuditSuggestion { + if len(issueLines) == 0 { + return nil + } + if allowed == nil { + allowed = map[string]struct{}{} + } + if configValues == nil { + configValues = map[string]string{} + } + + type msgSet map[string]struct{} + byKey := make(map[string]msgSet) + + for _, raw := range issueLines { + line := strings.TrimSpace(stripANSI(raw)) + if line == "" { + continue + } + + matches := postInstallAuditDisableHintRe.FindAllStringSubmatch(line, -1) + if len(matches) == 0 { + continue + } + + msg := normalizeIssueMessage(line) + if msg == "" { + msg = line + } + + for _, m := range matches { + if len(m) < 2 { + continue + } + key := strings.ToUpper(strings.TrimSpace(m[1])) + if !strings.HasPrefix(key, "BACKUP_") { + continue + } + if _, ok := allowed[key]; !ok { + continue + } + // Only suggest disabling keys that are currently enabled in the config. + if !readTemplateBool(configValues, key) { + continue + } + if _, ok := byKey[key]; !ok { + byKey[key] = make(msgSet) + } + byKey[key][msg] = struct{}{} + } + } + + if len(byKey) == 0 { + return nil + } + + keys := make([]string, 0, len(byKey)) + for key := range byKey { + keys = append(keys, key) + } + sort.Strings(keys) + + out := make([]PostInstallAuditSuggestion, 0, len(keys)) + for _, key := range keys { + msgs := make([]string, 0, len(byKey[key])) + for msg := range byKey[key] { + msgs = append(msgs, msg) + } + sort.Strings(msgs) + out = append(out, PostInstallAuditSuggestion{ + Key: key, + Messages: msgs, + }) + } + return out +} + +func normalizeIssueMessage(line string) string { + line = strings.TrimSpace(stripANSI(line)) + if line == "" { + return "" + } + // Prefer to remove "[timestamp] LEVEL" prefix when present. + if strings.HasPrefix(line, "[") { + if idx := strings.Index(line, "]"); idx >= 0 { + rest := strings.TrimSpace(line[idx+1:]) + fields := strings.Fields(rest) + if len(fields) >= 2 { + level := fields[0] + rest = strings.TrimSpace(rest[len(level):]) + if rest != "" { + return rest + } + } + } + } + return line +} + +func splitNormalizedLines(s string) []string { + if strings.TrimSpace(s) == "" { + return nil + } + // Normalize CRLF and split. + s = strings.ReplaceAll(s, "\r\n", "\n") + s = strings.ReplaceAll(s, "\r", "\n") + return strings.Split(s, "\n") +} + +func stripANSI(s string) string { + // Best-effort removal of common ANSI SGR sequences. + // Example: "\x1b[33mWARNING\x1b[0m" + const esc = "\x1b[" + if !strings.Contains(s, esc) { + return s + } + return postInstallAuditANSISGRRe.ReplaceAllString(s, "") +} diff --git a/internal/tui/wizard/post_install_audit_core_test.go b/internal/tui/wizard/post_install_audit_core_test.go new file mode 100644 index 0000000..471efac --- /dev/null +++ b/internal/tui/wizard/post_install_audit_core_test.go @@ -0,0 +1,119 @@ +package wizard + +import ( + "context" + "os" + "path/filepath" + "reflect" + "strings" + "testing" +) + +func TestExtractIssueLinesFromProxsaveOutput_UsesSummaryBlock(t *testing.T) { + output := strings.Join([]string{ + "[2026-02-24 09:12:55] \x1b[33mWARNING\x1b[0m Live warning (colored)", + "===========================================", + "WARNINGS/ERRORS DURING RUN (warnings=1 errors=0)", + "", + "[2026-02-24 09:12:55] WARNING Corosync authkey: not configured. If unused, set BACKUP_CLUSTER_CONFIG=false to disable.", + "===========================================", + "", + }, "\n") + + issues := extractIssueLinesFromProxsaveOutput(output) + if len(issues) != 1 { + t.Fatalf("issues len=%d, want 1: %#v", len(issues), issues) + } + if !strings.Contains(issues[0], "set BACKUP_CLUSTER_CONFIG=false") { + t.Fatalf("unexpected issue line: %q", issues[0]) + } +} + +func TestExtractDisableSuggestionsFromIssueLines_FiltersAllowedAndEnabled(t *testing.T) { + allowed := map[string]struct{}{ + "BACKUP_CLUSTER_CONFIG": {}, + "BACKUP_ZFS_CONFIG": {}, + } + configValues := map[string]string{ + "BACKUP_CLUSTER_CONFIG": "true", + "BACKUP_ZFS_CONFIG": "true", + "BACKUP_CEPH_CONFIG": "false", + } + lines := []string{ + "[2026-02-24 09:12:55] WARNING Corosync authkey: not configured. If unused, set BACKUP_CLUSTER_CONFIG=false to disable.", + "Skipping ZFS collection: not detected. Set BACKUP_ZFS_CONFIG=false to disable.", + "[2026-02-24 09:12:55] WARNING Something else. Set NOT_BACKUP_VAR=false to disable.", + "[2026-02-24 09:12:55] WARNING Ceph not detected. If unused, set BACKUP_CEPH_CONFIG=false to disable.", + } + + got := extractDisableSuggestionsFromIssueLines(lines, allowed, configValues) + wantKeys := []string{"BACKUP_CLUSTER_CONFIG", "BACKUP_ZFS_CONFIG"} + if len(got) != len(wantKeys) { + t.Fatalf("got %d suggestions, want %d: %#v", len(got), len(wantKeys), got) + } + for i, key := range wantKeys { + if got[i].Key != key { + t.Fatalf("suggestion[%d].Key=%q, want %q", i, got[i].Key, key) + } + } +} + +func TestCollectPostInstallDisableSuggestions_UsesRunnerAndConfigFilter(t *testing.T) { + tmp := t.TempDir() + configPath := filepath.Join(tmp, "backup.env") + if err := os.WriteFile(configPath, []byte("BACKUP_CLUSTER_CONFIG=true\nBACKUP_ZFS_CONFIG=false\n"), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + origRunner := postInstallAuditRunner + t.Cleanup(func() { postInstallAuditRunner = origRunner }) + + postInstallAuditRunner = func(ctx context.Context, execPath, cfgPath string) (string, int, error) { + return strings.Join([]string{ + "===========================================", + "WARNINGS/ERRORS DURING RUN (warnings=1 errors=0)", + "", + "[2026-02-24 09:12:55] WARNING Corosync authkey: not configured. If unused, set BACKUP_CLUSTER_CONFIG=false to disable.", + "===========================================", + }, "\n"), 1, nil + } + + suggestions, err := CollectPostInstallDisableSuggestions(context.Background(), "/fake/proxsave", configPath) + if err != nil { + t.Fatalf("CollectPostInstallDisableSuggestions error: %v", err) + } + if len(suggestions) != 1 { + t.Fatalf("got %d suggestions, want 1: %#v", len(suggestions), suggestions) + } + if suggestions[0].Key != "BACKUP_CLUSTER_CONFIG" { + t.Fatalf("key=%q, want BACKUP_CLUSTER_CONFIG", suggestions[0].Key) + } + if len(suggestions[0].Messages) != 1 || !strings.Contains(suggestions[0].Messages[0], "Corosync authkey") { + t.Fatalf("unexpected messages: %#v", suggestions[0].Messages) + } +} + +func TestNormalizeIssueMessage_RemovesTimestampAndLevel(t *testing.T) { + got := normalizeIssueMessage("[2026-02-24 09:12:55] WARNING hello world") + want := "hello world" + if got != want { + t.Fatalf("got %q, want %q", got, want) + } +} + +func TestStripANSI_RemovesSGRCodes(t *testing.T) { + in := "\x1b[33mWARNING\x1b[0m hello" + got := stripANSI(in) + want := "WARNING hello" + if got != want { + t.Fatalf("got %q, want %q", got, want) + } +} + +func TestSplitNormalizedLines_NormalizesCRLF(t *testing.T) { + got := splitNormalizedLines("a\r\nb\r\n") + want := []string{"a", "b", ""} + if !reflect.DeepEqual(got, want) { + t.Fatalf("got %#v, want %#v", got, want) + } +} diff --git a/internal/tui/wizard/post_install_audit_tui.go b/internal/tui/wizard/post_install_audit_tui.go new file mode 100644 index 0000000..2b1cd54 --- /dev/null +++ b/internal/tui/wizard/post_install_audit_tui.go @@ -0,0 +1,364 @@ +package wizard + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/gdamore/tcell/v2" + "github.com/rivo/tview" + + "github.com/tis24dev/proxsave/internal/tui" +) + +var ( + postInstallAuditWizardRunner = func(app *tui.App, root, focus tview.Primitive) error { + return app.SetRoot(root, true).SetFocus(focus).Run() + } +) + +// RunPostInstallAuditWizard runs an optional post-installation check that: +// 1. runs proxsave --dry-run +// 2. extracts actionable "set KEY=false" hints from warnings +// 3. lets the user disable unused BACKUP_* collectors in backup.env +// +// It returns the list of applied keys (may be empty). Errors are returned only for +// unexpected failures (e.g., UI setup issues). +func RunPostInstallAuditWizard(ctx context.Context, execPath, configPath, buildSig string) (applied []string, err error) { + app := tui.NewApp() + + titleText := tview.NewTextView(). + SetText("ProxSave - Post-install Check\n\n" + + "Detect optional components that are enabled but not configured on this node.\n" + + "This helps reduce WARNING noise and exit code 1 runs when features are unused.\n"). + SetTextColor(tui.ProxmoxLight). + SetDynamicColors(true) + titleText.SetBorder(false) + + nav := tview.NewTextView(). + SetText("[yellow]Navigation:[white] ↑↓ to move | ENTER/SPACE to toggle | ←→ on buttons | ENTER to select"). + SetTextColor(tcell.ColorWhite). + SetDynamicColors(true). + SetTextAlign(tview.AlignCenter) + nav.SetBorder(false) + + separator := tview.NewTextView(). + SetText(strings.Repeat("─", 80)). + SetTextColor(tui.ProxmoxOrange) + separator.SetBorder(false) + + configPathText := tview.NewTextView(). + SetText(fmt.Sprintf("[yellow]Configuration file:[white] %s", configPath)). + SetTextColor(tcell.ColorWhite). + SetDynamicColors(true). + SetTextAlign(tview.AlignCenter) + configPathText.SetBorder(false) + + buildSigText := tview.NewTextView(). + SetText(fmt.Sprintf("[yellow]Build Signature:[white] %s", buildSig)). + SetTextColor(tcell.ColorWhite). + SetDynamicColors(true). + SetTextAlign(tview.AlignCenter) + buildSigText.SetBorder(false) + + pages := tview.NewPages() + + confirmRun := false + confirm := tview.NewModal(). + SetText("Run the post-install check now?\n\n" + + "ProxSave will execute a dry-run and collect WARNING messages that include a hint like:\n" + + " set BACKUP_CLUSTER_CONFIG=false to disable\n\n" + + "You can then choose which optional components to disable.\n"). + AddButtons([]string{"Run check", "Skip"}). + SetDoneFunc(func(_ int, label string) { + confirmRun = (label == "Run check") + if !confirmRun { + app.Stop() + return + } + pages.SwitchToPage("running") + go func() { + suggestions, collectErr := CollectPostInstallDisableSuggestions(ctx, execPath, configPath) + app.QueueUpdateDraw(func() { + if collectErr != nil { + showAuditDoneModal(app, pages, "Post-install check failed:\n\n"+collectErr.Error()) + return + } + if len(suggestions) == 0 { + showAuditDoneModal(app, pages, "No unused components detected.\n\nNo changes are required.") + return + } + showAuditReview(app, pages, configPath, suggestions, &applied) + }) + }() + }) + + confirm.SetBorder(true). + SetTitle(" Post-install Check "). + SetTitleAlign(tview.AlignCenter). + SetTitleColor(tui.ProxmoxOrange). + SetBorderColor(tui.ProxmoxOrange). + SetBackgroundColor(tcell.ColorBlack) + + running := tview.NewTextView(). + SetText("Running dry-run...\n\nPlease wait. This may take a minute."). + SetTextColor(tcell.ColorWhite). + SetTextAlign(tview.AlignCenter) + running.SetBorder(true). + SetTitle(" Running "). + SetTitleAlign(tview.AlignCenter). + SetTitleColor(tui.WarningYellow). + SetBorderColor(tui.WarningYellow). + SetBackgroundColor(tcell.ColorBlack) + + pages.AddPage("confirm", confirm, true, true) + pages.AddPage("running", running, true, false) + + layout := tview.NewFlex(). + SetDirection(tview.FlexRow). + AddItem(titleText, 5, 0, false). + AddItem(nav, 2, 0, false). + AddItem(separator, 1, 0, false). + AddItem(pages, 0, 1, true). + AddItem(configPathText, 1, 0, false). + AddItem(buildSigText, 1, 0, false) + + layout.SetBorder(true). + SetTitle(" ProxSave "). + SetTitleAlign(tview.AlignCenter). + SetTitleColor(tui.ProxmoxOrange). + SetBorderColor(tui.ProxmoxOrange). + SetBackgroundColor(tcell.ColorBlack) + + if runErr := postInstallAuditWizardRunner(app, layout, confirm); runErr != nil { + return nil, runErr + } + // If user skipped, confirmRun is false; applied will be nil/empty. + _ = confirmRun + return applied, nil +} + +func showAuditDoneModal(app *tui.App, pages *tview.Pages, message string) { + done := tview.NewModal(). + SetText(message). + AddButtons([]string{"Continue"}). + SetDoneFunc(func(_ int, _ string) { + app.Stop() + }) + done.SetBorder(true). + SetTitle(" Post-install Check "). + SetTitleAlign(tview.AlignCenter). + SetTitleColor(tui.ProxmoxOrange). + SetBorderColor(tui.ProxmoxOrange). + SetBackgroundColor(tcell.ColorBlack) + + pages.AddAndSwitchToPage("done", done, true) +} + +func showAuditReview(app *tui.App, pages *tview.Pages, configPath string, suggestions []PostInstallAuditSuggestion, applied *[]string) { + if applied == nil { + tmp := []string{} + applied = &tmp + } + + selected := make(map[string]bool, len(suggestions)) + list := tview.NewList(). + ShowSecondaryText(false) + // We render checkbox markers like "[x]" which would otherwise be interpreted + // as style tags by tview and get stripped. + list.SetUseStyleTags(false, false) + + details := tview.NewTextView(). + SetDynamicColors(true). + SetWrap(true). + SetTextAlign(tview.AlignLeft) + details.SetBorder(true). + SetTitle(" Details "). + SetTitleAlign(tview.AlignCenter). + SetTitleColor(tui.ProxmoxOrange). + SetBorderColor(tui.ProxmoxOrange) + + updateListItem := func(index int) { + if index < 0 || index >= len(suggestions) { + return + } + key := suggestions[index].Key + marker := "[ ]" + if selected[key] { + marker = "[x]" + } + list.SetItemText(index, fmt.Sprintf("%s %s", marker, key), "") + } + + updateDetails := func(index int) { + if index < 0 || index >= len(suggestions) { + details.SetText("") + return + } + s := suggestions[index] + var b strings.Builder + b.WriteString("[yellow]Detected warnings:[white]\n\n") + for _, msg := range s.Messages { + b.WriteString("- ") + b.WriteString(msg) + b.WriteString("\n") + } + b.WriteString("\n") + b.WriteString(fmt.Sprintf("If you don’t use this feature, set [yellow]%s=false[white] to disable.\n", s.Key)) + details.SetText(b.String()) + } + + toggle := func(index int) { + if index < 0 || index >= len(suggestions) { + return + } + key := suggestions[index].Key + selected[key] = !selected[key] + updateListItem(index) + updateDetails(index) + } + + for i, s := range suggestions { + selected[s.Key] = false + list.AddItem("", "", 0, nil) + updateListItem(i) + } + + list.SetChangedFunc(func(index int, _ string, _ string, _ rune) { + updateDetails(index) + }) + list.SetSelectedFunc(func(index int, _ string, _ string, _ rune) { + toggle(index) + }) + list.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + switch event.Key() { + case tcell.KeyEnter: + // Let SetSelectedFunc handle it. + return event + } + if event.Rune() == ' ' { + toggle(list.GetCurrentItem()) + return nil + } + return event + }) + + if len(suggestions) > 0 { + updateDetails(0) + } + + buttons := tview.NewForm(). + AddButton("Disable selected", func() { + keys := make([]string, 0, len(suggestions)) + for _, s := range suggestions { + if selected[s.Key] { + keys = append(keys, s.Key) + } + } + sort.Strings(keys) + if len(keys) == 0 { + showAuditDoneModal(app, pages, "No changes selected.\n\nNothing was modified.") + return + } + if err := applyAuditDisables(configPath, keys); err != nil { + showAuditDoneModal(app, pages, "Failed to update configuration:\n\n"+err.Error()) + return + } + *applied = keys + showAuditDoneModal(app, pages, fmt.Sprintf("Configuration updated successfully.\n\nDisabled %d feature(s).", len(keys))) + }). + AddButton("Disable all", func() { + for i := range suggestions { + selected[suggestions[i].Key] = true + updateListItem(i) + } + updateDetails(list.GetCurrentItem()) + }). + AddButton("Skip", func() { + app.Stop() + }) + + buttons.SetBorder(true). + SetTitle(" Actions "). + SetTitleAlign(tview.AlignCenter). + SetTitleColor(tui.ProxmoxOrange). + SetBorderColor(tui.ProxmoxOrange). + SetBackgroundColor(tcell.ColorBlack) + + list.SetBorder(true). + SetTitle(" Suggestions "). + SetTitleAlign(tview.AlignCenter). + SetTitleColor(tui.ProxmoxOrange). + SetBorderColor(tui.ProxmoxOrange) + + mid := tview.NewFlex(). + AddItem(list, 0, 1, true). + AddItem(details, 0, 2, false) + + review := tview.NewFlex(). + SetDirection(tview.FlexRow). + AddItem(tview.NewTextView(). + SetText("Select which features to disable. This only changes backup.env flags.\n"). + SetTextColor(tcell.ColorWhite), 2, 0, false). + AddItem(mid, 0, 1, true). + AddItem(buttons, 7, 0, false) + + review.SetBorder(true). + SetTitle(" Review & Disable "). + SetTitleAlign(tview.AlignCenter). + SetTitleColor(tui.ProxmoxOrange). + SetBorderColor(tui.ProxmoxOrange). + SetBackgroundColor(tcell.ColorBlack) + + pages.AddAndSwitchToPage("review", review, true) + app.SetFocus(list) +} + +func applyAuditDisables(configPath string, keys []string) error { + if strings.TrimSpace(configPath) == "" { + return fmt.Errorf("config path cannot be empty") + } + if len(keys) == 0 { + return nil + } + + contentBytes, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("read configuration: %w", err) + } + content := string(contentBytes) + for _, key := range keys { + key = strings.ToUpper(strings.TrimSpace(key)) + if key == "" { + continue + } + content = setEnvValue(content, key, "false") + } + + tmpPath := configPath + ".tmp.audit" + if err := writeConfigFileAtomic(configPath, tmpPath, content); err != nil { + return err + } + return nil +} + +func writeConfigFileAtomic(configPath, tmpPath, content string) error { + dir := filepath.Dir(strings.TrimSpace(configPath)) + if dir == "" || dir == "." { + return fmt.Errorf("invalid configuration path: %q", configPath) + } + if err := os.MkdirAll(dir, 0o700); err != nil { + return fmt.Errorf("failed to create configuration directory: %w", err) + } + if err := os.WriteFile(tmpPath, []byte(content), 0o600); err != nil { + return fmt.Errorf("failed to write configuration file: %w", err) + } + if err := os.Rename(tmpPath, configPath); err != nil { + _ = os.Remove(tmpPath) + return fmt.Errorf("failed to finalize configuration file: %w", err) + } + return nil +} From 0b30441654dffee1053a607416eab0f81fd0f938 Mon Sep 17 00:00:00 2001 From: Damiano <71268257+tis24dev@users.noreply.github.com> Date: Wed, 25 Feb 2026 00:09:27 +0100 Subject: [PATCH 12/12] Enhance post-install audit logging and UI Improve the post-install audit flow and session logging across CLI and TUI. - cmd/proxsave/install.go: Add bootstrap logging for audit steps and failures, switch CLI prompt from a single "disable all" to per-key prompts, include suggested/disabled keys in messages and in the final config update output, and add warnings on read/write failures. - cmd/proxsave/install_tui.go: Log detailed audit result summary via bootstrap (skipped, failures, suggestions, applied disables). - internal/tui/wizard/post_install_audit_tui.go: Introduce PostInstallAuditResult (Ran, Suggestions, AppliedKeys, CollectErr) and return it from RunPostInstallAuditWizard; collect suggestions asynchronously with synchronization and propagate collection errors to the UI. - docs/CLI_REFERENCE.md & docs/INSTALL.md: Document TUI vs CLI audit behavior and add note about install session logs under /tmp/proxsave/install-*.log. These changes provide better observability of the post-install audit, make the CLI interaction less all-or-nothing by allowing per-key selection, and surface suggestions and applied disables in installer logs. --- cmd/proxsave/install.go | 58 ++++++++++++++++--- cmd/proxsave/install_tui.go | 25 +++++++- docs/CLI_REFERENCE.md | 4 +- docs/INSTALL.md | 1 + internal/tui/wizard/post_install_audit_tui.go | 45 ++++++++++---- 5 files changed, 112 insertions(+), 21 deletions(-) diff --git a/cmd/proxsave/install.go b/cmd/proxsave/install.go index f3c72f9..b781773 100644 --- a/cmd/proxsave/install.go +++ b/cmd/proxsave/install.go @@ -124,24 +124,40 @@ func runPostInstallAuditCLI(ctx context.Context, reader *bufio.Reader, execPath, return wrapInstallError(err) } if !run { + if bootstrap != nil { + bootstrap.Info("Post-install audit: skipped by user") + } return nil } if bootstrap != nil { - bootstrap.Info("Running post-install dry-run audit (this may take a minute)") + bootstrap.Info("Post-install audit: running dry-run (this may take a minute)") } suggestions, err := wizard.CollectPostInstallDisableSuggestions(ctx, execPath, configPath) if err != nil { fmt.Printf("WARNING: Post-install check failed (non-blocking): %v\n", err) + if bootstrap != nil { + bootstrap.Warning("Post-install audit failed (non-blocking): %v", err) + } return nil } if len(suggestions) == 0 { fmt.Println("No unused components detected. No changes required.") + if bootstrap != nil { + bootstrap.Info("Post-install audit: no unused components detected") + } return nil } fmt.Printf("Detected %d unused/optional component(s) that may cause WARNINGs.\n", len(suggestions)) + if bootstrap != nil { + keys := make([]string, 0, len(suggestions)) + for _, s := range suggestions { + keys = append(keys, s.Key) + } + bootstrap.Info("Post-install audit: suggested disables (%d): %s", len(keys), strings.Join(keys, ", ")) + } for _, s := range suggestions { reason := "" if len(s.Messages) > 0 { @@ -155,26 +171,46 @@ func runPostInstallAuditCLI(ctx context.Context, reader *bufio.Reader, execPath, } fmt.Println() - disableAll, err := promptYesNo(ctx, reader, "Disable all suggested components now (set KEY=false)? [y/N]: ", false) + disableAny, err := promptYesNo(ctx, reader, "Disable any of the suggested components now (set KEY=false)? [y/N]: ", false) if err != nil { return wrapInstallError(err) } - if !disableAll { + if !disableAny { fmt.Println("No changes applied. You can disable unused components later by editing backup.env.") + if bootstrap != nil { + bootstrap.Info("Post-install audit: no disables applied") + } + return nil + } + + keys := make([]string, 0, len(suggestions)) + for _, s := range suggestions { + disable, err := promptYesNo(ctx, reader, fmt.Sprintf("Disable %s? [y/N]: ", s.Key), false) + if err != nil { + return wrapInstallError(err) + } + if disable { + keys = append(keys, s.Key) + } + } + if len(keys) == 0 { + fmt.Println("No changes selected. Nothing was modified.") + if bootstrap != nil { + bootstrap.Info("Post-install audit: no disables selected") + } return nil } contentBytes, err := os.ReadFile(configPath) if err != nil { fmt.Printf("ERROR: Unable to update configuration (read failed): %v\n", err) + if bootstrap != nil { + bootstrap.Warning("Post-install audit: unable to update configuration (read failed): %v", err) + } return nil } content := string(contentBytes) - keys := make([]string, 0, len(suggestions)) - for _, s := range suggestions { - keys = append(keys, s.Key) - } sort.Strings(keys) for _, key := range keys { content = setEnvValue(content, key, "false") @@ -184,10 +220,16 @@ func runPostInstallAuditCLI(ctx context.Context, reader *bufio.Reader, execPath, defer cleanupTempConfig(tmpAuditPath) if err := writeConfigFile(configPath, tmpAuditPath, content); err != nil { fmt.Printf("ERROR: Unable to update configuration (write failed): %v\n", err) + if bootstrap != nil { + bootstrap.Warning("Post-install audit: unable to update configuration (write failed): %v", err) + } return nil } - fmt.Printf("✓ Updated %s: disabled %d component(s).\n", configPath, len(keys)) + fmt.Printf("✓ Updated %s: disabled %d component(s): %s\n", configPath, len(keys), strings.Join(keys, ", ")) + if bootstrap != nil { + bootstrap.Info("Post-install audit: disabled (%d): %s", len(keys), strings.Join(keys, ", ")) + } return nil } diff --git a/cmd/proxsave/install_tui.go b/cmd/proxsave/install_tui.go index 3567e06..9f4c1cd 100644 --- a/cmd/proxsave/install_tui.go +++ b/cmd/proxsave/install_tui.go @@ -181,9 +181,30 @@ func runInstallTUI(ctx context.Context, configPath string, bootstrap *logging.Bo // Optional post-install audit: run a dry-run and offer to disable unused collectors // based on actionable warning hints like "set BACKUP_*=false to disable". - if _, auditErr := wizard.RunPostInstallAuditWizard(ctx, execInfo.ExecPath, configPath, buildSig); auditErr != nil { - if bootstrap != nil { + auditRes, auditErr := wizard.RunPostInstallAuditWizard(ctx, execInfo.ExecPath, configPath, buildSig) + if bootstrap != nil { + if auditErr != nil { bootstrap.Warning("Post-install check failed (non-blocking): %v", auditErr) + } else { + switch { + case !auditRes.Ran: + bootstrap.Info("Post-install audit: skipped by user") + case auditRes.CollectErr != nil: + bootstrap.Warning("Post-install audit failed (non-blocking): %v", auditRes.CollectErr) + case len(auditRes.Suggestions) == 0: + bootstrap.Info("Post-install audit: no unused components detected") + default: + keys := make([]string, 0, len(auditRes.Suggestions)) + for _, s := range auditRes.Suggestions { + keys = append(keys, s.Key) + } + bootstrap.Info("Post-install audit: suggested disables (%d): %s", len(keys), strings.Join(keys, ", ")) + if len(auditRes.AppliedKeys) > 0 { + bootstrap.Info("Post-install audit: disabled (%d): %s", len(auditRes.AppliedKeys), strings.Join(auditRes.AppliedKeys, ", ")) + } else { + bootstrap.Info("Post-install audit: no disables applied") + } + } } } diff --git a/docs/CLI_REFERENCE.md b/docs/CLI_REFERENCE.md index 46d19c2..a5c2602 100644 --- a/docs/CLI_REFERENCE.md +++ b/docs/CLI_REFERENCE.md @@ -142,9 +142,11 @@ Some interactive commands support two interface modes: 5. Optionally sets up notifications (Telegram, Email; Email defaults to `EMAIL_DELIVERY_METHOD=relay`) 6. Optionally configures encryption (AGE setup) 7. (TUI) Optionally selects a cron time (HH:MM) for the `proxsave` cron entry -8. Optionally runs a post-install dry-run audit and offers to disable unused collectors (actionable hints like `set BACKUP_*=false to disable`) +8. Optionally runs a post-install dry-run audit and offers to disable unused collectors (TUI: checklist; CLI: per-key prompts; actionable hints like `set BACKUP_*=false to disable`) 9. Finalizes installation (symlinks, cron migration, permission checks) +**Install log**: The installer writes a session log under `/tmp/proxsave/install-*.log` (includes post-install audit suggestions and any accepted disables). + ### Configuration Upgrade ```bash diff --git a/docs/INSTALL.md b/docs/INSTALL.md index a9144e6..eae383d 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -227,6 +227,7 @@ If the configuration file already exists, the **TUI wizard** will ask whether to - Creates all necessary directories with proper permissions (0700) - Immediate AGE key generation if encryption is enabled - Optional post-install audit to disable unused collectors (keeps changes explicit; nothing is disabled silently) +- Install session log under `/tmp/proxsave/install-*.log` (includes post-install audit suggestions and any accepted disables) After completion, edit `configs/backup.env` manually for advanced options. diff --git a/internal/tui/wizard/post_install_audit_tui.go b/internal/tui/wizard/post_install_audit_tui.go index 2b1cd54..b46587d 100644 --- a/internal/tui/wizard/post_install_audit_tui.go +++ b/internal/tui/wizard/post_install_audit_tui.go @@ -7,6 +7,7 @@ import ( "path/filepath" "sort" "strings" + "sync" "github.com/gdamore/tcell/v2" "github.com/rivo/tview" @@ -20,14 +21,25 @@ var ( } ) +type PostInstallAuditResult struct { + // Ran indicates whether the user chose to run the post-install check. + Ran bool + // Suggestions contains the disable suggestions extracted from the dry-run output. + Suggestions []PostInstallAuditSuggestion + // AppliedKeys contains the keys written as KEY=false into backup.env. + AppliedKeys []string + // CollectErr is set when the dry-run/suggestion collection failed. + CollectErr error +} + // RunPostInstallAuditWizard runs an optional post-installation check that: // 1. runs proxsave --dry-run // 2. extracts actionable "set KEY=false" hints from warnings // 3. lets the user disable unused BACKUP_* collectors in backup.env // -// It returns the list of applied keys (may be empty). Errors are returned only for -// unexpected failures (e.g., UI setup issues). -func RunPostInstallAuditWizard(ctx context.Context, execPath, configPath, buildSig string) (applied []string, err error) { +// It returns the audit result. Errors are returned only for unexpected failures +// (e.g., UI setup issues). +func RunPostInstallAuditWizard(ctx context.Context, execPath, configPath, buildSig string) (result PostInstallAuditResult, err error) { app := tui.NewApp() titleText := tview.NewTextView(). @@ -67,6 +79,10 @@ func RunPostInstallAuditWizard(ctx context.Context, execPath, configPath, buildS pages := tview.NewPages() confirmRun := false + var mu sync.Mutex + var collectedSuggestions []PostInstallAuditSuggestion + var collectErr error + applied := []string{} confirm := tview.NewModal(). SetText("Run the post-install check now?\n\n" + "ProxSave will execute a dry-run and collect WARNING messages that include a hint like:\n" + @@ -81,10 +97,14 @@ func RunPostInstallAuditWizard(ctx context.Context, execPath, configPath, buildS } pages.SwitchToPage("running") go func() { - suggestions, collectErr := CollectPostInstallDisableSuggestions(ctx, execPath, configPath) + suggestions, suggestionErr := CollectPostInstallDisableSuggestions(ctx, execPath, configPath) app.QueueUpdateDraw(func() { - if collectErr != nil { - showAuditDoneModal(app, pages, "Post-install check failed:\n\n"+collectErr.Error()) + mu.Lock() + collectedSuggestions = suggestions + collectErr = suggestionErr + mu.Unlock() + if suggestionErr != nil { + showAuditDoneModal(app, pages, "Post-install check failed:\n\n"+suggestionErr.Error()) return } if len(suggestions) == 0 { @@ -134,11 +154,16 @@ func RunPostInstallAuditWizard(ctx context.Context, execPath, configPath, buildS SetBackgroundColor(tcell.ColorBlack) if runErr := postInstallAuditWizardRunner(app, layout, confirm); runErr != nil { - return nil, runErr + return PostInstallAuditResult{}, runErr } - // If user skipped, confirmRun is false; applied will be nil/empty. - _ = confirmRun - return applied, nil + + result.Ran = confirmRun + mu.Lock() + result.Suggestions = collectedSuggestions + result.CollectErr = collectErr + mu.Unlock() + result.AppliedKeys = applied + return result, nil } func showAuditDoneModal(app *tui.App, pages *tview.Pages, message string) {