diff --git a/cmd/proxsave/install.go b/cmd/proxsave/install.go index 0128864..d368262 100644 --- a/cmd/proxsave/install.go +++ b/cmd/proxsave/install.go @@ -838,12 +838,6 @@ func isInstallAbortedError(err error) bool { return false } -// clearImmutableAttributes attempts to remove immutable flags (chattr -i) so deletion can proceed. -// It logs warnings on failure but does not return an error, since removal will report issues later. -func clearImmutableAttributes(target string, bootstrap *logging.BootstrapLogger) { - _ = clearImmutableAttributesWithContext(context.Background(), target, bootstrap) -} - func clearImmutableAttributesWithContext(ctx context.Context, target string, bootstrap *logging.BootstrapLogger) error { if ctx == nil { ctx = context.Background() diff --git a/cmd/proxsave/install_existing_config.go b/cmd/proxsave/install_existing_config.go index a243d1f..3bc16c5 100644 --- a/cmd/proxsave/install_existing_config.go +++ b/cmd/proxsave/install_existing_config.go @@ -26,9 +26,16 @@ type existingConfigDecision struct { } func promptExistingConfigModeCLI(ctx context.Context, reader *bufio.Reader, configPath string) (existingConfigMode, error) { + if ctx == nil { + ctx = context.Background() + } + info, err := os.Stat(configPath) if err != nil { if os.IsNotExist(err) { + if err := ctx.Err(); err != nil { + return existingConfigCancel, err + } return existingConfigOverwrite, nil } return existingConfigCancel, fmt.Errorf("failed to access configuration file: %w", err) @@ -53,12 +60,24 @@ func promptExistingConfigModeCLI(ctx context.Context, reader *bufio.Reader, conf case "": fallthrough case "3": + if err := ctx.Err(); err != nil { + return existingConfigCancel, err + } return existingConfigKeepContinue, nil case "1": + if err := ctx.Err(); err != nil { + return existingConfigCancel, err + } return existingConfigOverwrite, nil case "2": + if err := ctx.Err(); err != nil { + return existingConfigCancel, err + } return existingConfigEdit, nil case "0": + if err := ctx.Err(); err != nil { + return existingConfigCancel, err + } return existingConfigCancel, nil default: fmt.Println("Please enter 1, 2, 3 or 0.") diff --git a/cmd/proxsave/install_existing_config_test.go b/cmd/proxsave/install_existing_config_test.go index 8de7a58..dd5f358 100644 --- a/cmd/proxsave/install_existing_config_test.go +++ b/cmd/proxsave/install_existing_config_test.go @@ -21,6 +21,20 @@ func TestPromptExistingConfigModeCLIMissingFileDefaultsToOverwrite(t *testing.T) } } +func TestPromptExistingConfigModeCLIMissingFileRespectsCanceledContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + missing := filepath.Join(t.TempDir(), "missing.env") + mode, err := promptExistingConfigModeCLI(ctx, bufio.NewReader(strings.NewReader("")), missing) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled error, got %v", err) + } + if mode != existingConfigCancel { + t.Fatalf("expected cancel mode, got %v", mode) + } +} + func TestPromptExistingConfigModeCLIOptions(t *testing.T) { cfgFile := createTempFile(t, "EXISTING=1\n") tests := []struct { diff --git a/cmd/proxsave/new_install.go b/cmd/proxsave/new_install.go index 7e8ae2c..357c21c 100644 --- a/cmd/proxsave/new_install.go +++ b/cmd/proxsave/new_install.go @@ -56,6 +56,7 @@ func formatNewInstallPreservedEntries(entries []string) string { formatted := make([]string, 0, len(entries)) for _, entry := range entries { trimmed := strings.TrimSpace(entry) + trimmed = strings.TrimRight(trimmed, "/") if trimmed == "" { continue } diff --git a/cmd/proxsave/new_install_test.go b/cmd/proxsave/new_install_test.go index de43452..2daaa26 100644 --- a/cmd/proxsave/new_install_test.go +++ b/cmd/proxsave/new_install_test.go @@ -120,6 +120,11 @@ func TestFormatNewInstallPreservedEntries(t *testing.T) { entries: []string{"", " ", "\t"}, want: "(none)", }, + { + name: "normalizes trailing slashes", + entries: []string{"env/", "build//", " identity/// ", "/"}, + want: "env/ build/ identity/", + }, } for _, tt := range tests { diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index cba8da5..d9997b2 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -846,6 +846,7 @@ If `EMAIL_ENABLED` is omitted, the default remains `false`. The legacy alias `EM - Allowed values for `EMAIL_DELIVERY_METHOD` are: `relay`, `sendmail`, `pmf` (invalid values will skip Email with a warning). - `EMAIL_FALLBACK_SENDMAIL` is a historical name (kept for compatibility). When `EMAIL_DELIVERY_METHOD=relay`, it enables fallback to **pmf** (it will not fall back to `/usr/sbin/sendmail`). - `relay` requires a real mailbox recipient and blocks `root@…` recipients; set `EMAIL_RECIPIENT` to a non-root mailbox if needed. +- When logs say the relay "accepted request", it means the worker and upstream email API accepted the submission. It does **not** guarantee final inbox delivery (the message may still bounce, be deferred, or land in spam later). - If `EMAIL_RECIPIENT` is empty, ProxSave auto-detects the recipient from the `root@pam` user: - **PVE**: Proxmox API via `pvesh get /access/users/root@pam` → fallback to `pveum user list` → fallback to `/etc/pve/user.cfg` - **PBS**: `proxmox-backup-manager user list` → fallback to `/etc/proxmox-backup/user.cfg` diff --git a/docs/TROUBLESHOOTING.md b/docs/TROUBLESHOOTING.md index dda54ba..4290e98 100644 --- a/docs/TROUBLESHOOTING.md +++ b/docs/TROUBLESHOOTING.md @@ -571,6 +571,22 @@ If Email is enabled but you don't see it being dispatched, ensure `EMAIL_DELIVER - Relay blocks `root@…` recipients; use a real non-root mailbox for `EMAIL_RECIPIENT`. - If `EMAIL_FALLBACK_SENDMAIL=true`, ProxSave will fall back to `EMAIL_DELIVERY_METHOD=pmf` when the relay fails. - Check the proxsave logs for `email-relay` warnings/errors. +- `Email relay accepted request ...` means the relay accepted the submission. It does **not** guarantee final inbox delivery; later provider-side failures/bounces are outside the ProxSave process. + +Common relay auth/forbidden errors: + +- `authentication failed (HTTP 401): missing bearer token`: the relay did not receive the `Authorization: Bearer ...` header. +- `authentication failed (HTTP 401): missing signature`: the relay did not receive the `X-Signature` header. +- `forbidden (HTTP 403): invalid token`: the bearer token is wrong or not allowed by the worker. +- `forbidden (HTTP 403): HMAC signature validation failed`: the request body and `X-Signature` do not match the worker's `HMAC_SECRET`. +- `forbidden (HTTP 403): missing or invalid script version`: the relay rejected `X-Script-Version` (it must be semantic-version-like, e.g. `1.2.3`). +- `forbidden (HTTP 403): from address override not allowed`: the client attempted to override the worker-managed sender address. + +If you operate your own relay worker: + +- The worker-side env var `MAC_LIMIT_IP_WHITELIST` can bypass the per-server daily MAC quota for trusted source IPs. +- Example: `MAC_LIMIT_IP_WHITELIST=86.56.17.99` +- This bypass affects only the daily MAC quota. It does **not** disable bearer-token checks, HMAC validation, IP burst limits, or token limits. Quick checks for auto-detect: diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 0f2ec7c..97eef63 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -11,6 +11,24 @@ import ( func setBaseDirEnv(t *testing.T, value string) { t.Helper() + if value == "" { + original, hadOriginal := os.LookupEnv("BASE_DIR") + if err := os.Unsetenv("BASE_DIR"); err != nil { + t.Fatalf("Unsetenv(BASE_DIR) failed: %v", err) + } + t.Cleanup(func() { + if hadOriginal { + if err := os.Setenv("BASE_DIR", original); err != nil { + t.Fatalf("restore BASE_DIR failed: %v", err) + } + return + } + if err := os.Unsetenv("BASE_DIR"); err != nil { + t.Fatalf("cleanup Unsetenv(BASE_DIR) failed: %v", err) + } + }) + return + } t.Setenv("BASE_DIR", value) } diff --git a/internal/config/env_mutation.go b/internal/config/env_mutation.go index 5ade607..1fca9a8 100644 --- a/internal/config/env_mutation.go +++ b/internal/config/env_mutation.go @@ -1,19 +1,38 @@ package config import ( + "strconv" "strings" "github.com/tis24dev/proxsave/pkg/utils" ) +func encodeEnvValue(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + + if !strings.ContainsAny(value, "# \t\r\n\"'") { + return value + } + if !strings.Contains(value, "'") { + return "'" + value + "'" + } + if !strings.Contains(value, `"`) { + return `"` + value + `"` + } + return strconv.Quote(value) +} + // ApplySecondaryStorageSettings writes the canonical secondary-storage state // into an env template. Disabled secondary storage always clears both // SECONDARY_PATH and SECONDARY_LOG_PATH so the saved config matches user intent. func ApplySecondaryStorageSettings(template string, enabled bool, secondaryPath string, secondaryLogPath string) string { if enabled { template = utils.SetEnvValue(template, "SECONDARY_ENABLED", "true") - template = utils.SetEnvValue(template, "SECONDARY_PATH", strings.TrimSpace(secondaryPath)) - template = utils.SetEnvValue(template, "SECONDARY_LOG_PATH", strings.TrimSpace(secondaryLogPath)) + template = utils.SetEnvValue(template, "SECONDARY_PATH", encodeEnvValue(secondaryPath)) + template = utils.SetEnvValue(template, "SECONDARY_LOG_PATH", encodeEnvValue(secondaryLogPath)) return template } diff --git a/internal/config/env_mutation_test.go b/internal/config/env_mutation_test.go index 6b434b6..abd5ffd 100644 --- a/internal/config/env_mutation_test.go +++ b/internal/config/env_mutation_test.go @@ -1,74 +1,112 @@ package config import ( + "os" + "path/filepath" "strings" "testing" ) -func TestApplySecondaryStorageSettingsEnabled(t *testing.T) { - template := "SECONDARY_ENABLED=false\nSECONDARY_PATH=\nSECONDARY_LOG_PATH=\n" +func parseMutatedEnvTemplate(t *testing.T, template string) (map[string]string, map[string]int) { + t.Helper() - got := ApplySecondaryStorageSettings(template, true, " /mnt/secondary ", " /mnt/secondary/log ") + values := make(map[string]string) + counts := make(map[string]int) + + for _, line := range strings.Split(template, "\n") { + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } - for _, needle := range []string{ - "SECONDARY_ENABLED=true", - "SECONDARY_PATH=/mnt/secondary", - "SECONDARY_LOG_PATH=/mnt/secondary/log", - } { - if !strings.Contains(got, needle) { - t.Fatalf("expected %q in template:\n%s", needle, got) + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + t.Fatalf("invalid env line %q in template:\n%s", line, template) } + + key := strings.TrimSpace(parts[0]) + value := parts[1] + counts[key]++ + values[key] = value } + + return values, counts +} + +func assertMutatedEnvValue(t *testing.T, values map[string]string, counts map[string]int, key, want string) { + t.Helper() + + if got := counts[key]; got != 1 { + t.Fatalf("%s occurrences = %d; want 1", key, got) + } + if got := values[key]; got != want { + t.Fatalf("%s = %q; want %q", key, got, want) + } +} + +func TestApplySecondaryStorageSettingsEnabled(t *testing.T) { + template := "SECONDARY_ENABLED=false\nSECONDARY_PATH=\nSECONDARY_LOG_PATH=\n" + + got := ApplySecondaryStorageSettings(template, true, " /mnt/secondary ", " /mnt/secondary/log ") + values, counts := parseMutatedEnvTemplate(t, got) + assertMutatedEnvValue(t, values, counts, "SECONDARY_ENABLED", "true") + assertMutatedEnvValue(t, values, counts, "SECONDARY_PATH", "/mnt/secondary") + assertMutatedEnvValue(t, values, counts, "SECONDARY_LOG_PATH", "/mnt/secondary/log") } func TestApplySecondaryStorageSettingsEnabledWithEmptyLogPath(t *testing.T) { template := "SECONDARY_ENABLED=false\nSECONDARY_PATH=\nSECONDARY_LOG_PATH=/old/log\n" got := ApplySecondaryStorageSettings(template, true, "/mnt/secondary", "") - - for _, needle := range []string{ - "SECONDARY_ENABLED=true", - "SECONDARY_PATH=/mnt/secondary", - "SECONDARY_LOG_PATH=", - } { - if !strings.Contains(got, needle) { - t.Fatalf("expected %q in template:\n%s", needle, got) - } - } + values, counts := parseMutatedEnvTemplate(t, got) + assertMutatedEnvValue(t, values, counts, "SECONDARY_ENABLED", "true") + assertMutatedEnvValue(t, values, counts, "SECONDARY_PATH", "/mnt/secondary") + assertMutatedEnvValue(t, values, counts, "SECONDARY_LOG_PATH", "") } func TestApplySecondaryStorageSettingsDisabledClearsValues(t *testing.T) { template := "SECONDARY_ENABLED=true\nSECONDARY_PATH=/mnt/old-secondary\nSECONDARY_LOG_PATH=/mnt/old-secondary/logs\n" got := ApplySecondaryStorageSettings(template, false, "/ignored", "/ignored/logs") - - for _, needle := range []string{ - "SECONDARY_ENABLED=false", - "SECONDARY_PATH=", - "SECONDARY_LOG_PATH=", - } { - if !strings.Contains(got, needle) { - t.Fatalf("expected %q in template:\n%s", needle, got) - } - } - if strings.Contains(got, "/mnt/old-secondary") { - t.Fatalf("expected old secondary values to be cleared:\n%s", got) - } + values, counts := parseMutatedEnvTemplate(t, got) + assertMutatedEnvValue(t, values, counts, "SECONDARY_ENABLED", "false") + assertMutatedEnvValue(t, values, counts, "SECONDARY_PATH", "") + assertMutatedEnvValue(t, values, counts, "SECONDARY_LOG_PATH", "") } func TestApplySecondaryStorageSettingsDisabledAppendsCanonicalState(t *testing.T) { template := "BACKUP_ENABLED=true\n" got := ApplySecondaryStorageSettings(template, false, "", "") + values, counts := parseMutatedEnvTemplate(t, got) + assertMutatedEnvValue(t, values, counts, "BACKUP_ENABLED", "true") + assertMutatedEnvValue(t, values, counts, "SECONDARY_ENABLED", "false") + assertMutatedEnvValue(t, values, counts, "SECONDARY_PATH", "") + assertMutatedEnvValue(t, values, counts, "SECONDARY_LOG_PATH", "") +} - for _, needle := range []string{ - "BACKUP_ENABLED=true", - "SECONDARY_ENABLED=false", - "SECONDARY_PATH=", - "SECONDARY_LOG_PATH=", - } { - if !strings.Contains(got, needle) { - t.Fatalf("expected %q in template:\n%s", needle, got) - } +func TestApplySecondaryStorageSettingsQuotesUnsafePaths(t *testing.T) { + template := "SECONDARY_ENABLED=false\nSECONDARY_PATH=\nSECONDARY_LOG_PATH=\n" + + got := ApplySecondaryStorageSettings(template, true, " /mnt/secondary #1 ", " /mnt/secondary/log dir ") + values, counts := parseMutatedEnvTemplate(t, got) + assertMutatedEnvValue(t, values, counts, "SECONDARY_ENABLED", "true") + assertMutatedEnvValue(t, values, counts, "SECONDARY_PATH", "'/mnt/secondary #1'") + assertMutatedEnvValue(t, values, counts, "SECONDARY_LOG_PATH", "'/mnt/secondary/log dir'") + + configPath := filepath.Join(t.TempDir(), "backup.env") + if err := os.WriteFile(configPath, []byte(got), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + raw, err := parseEnvFile(configPath) + if err != nil { + t.Fatalf("parseEnvFile() error = %v", err) + } + if gotPath := raw["SECONDARY_PATH"]; gotPath != "/mnt/secondary #1" { + t.Fatalf("SECONDARY_PATH = %q; want %q", gotPath, "/mnt/secondary #1") + } + if gotLogPath := raw["SECONDARY_LOG_PATH"]; gotLogPath != "/mnt/secondary/log dir" { + t.Fatalf("SECONDARY_LOG_PATH = %q; want %q", gotLogPath, "/mnt/secondary/log dir") } } diff --git a/internal/config/migration_test.go b/internal/config/migration_test.go index 5457aaa..a4bd4fb 100644 --- a/internal/config/migration_test.go +++ b/internal/config/migration_test.go @@ -167,6 +167,32 @@ BACKUP_USER=backup BACKUP_GROUP=backup ` +func overrideTemplateValues(template string, overrides map[string]string) string { + lines := strings.Split(template, "\n") + applied := make(map[string]bool, len(overrides)) + + for i, line := range lines { + key, _, ok := strings.Cut(line, "=") + if !ok { + continue + } + value, ok := overrides[key] + if !ok { + continue + } + lines[i] = key + "=" + value + applied[key] = true + } + + for key := range overrides { + if !applied[key] { + panic("template key not found: " + key) + } + } + + return strings.Join(lines, "\n") +} + func TestMigrateLegacyEnvCreatesConfigAndKeepsValues(t *testing.T) { withTemplate(t, baseInstallTemplate, func() { tmpDir := t.TempDir() @@ -249,15 +275,11 @@ func TestMigrateLegacyEnvCreatesBackupWhenOverwriting(t *testing.T) { }) } -const invalidPermissionsTemplate = `BACKUP_ENABLED=true -BACKUP_PATH=/default/backup -LOG_PATH=/default/log -SECONDARY_ENABLED=false -CLOUD_ENABLED=false -SET_BACKUP_PERMISSIONS=true -BACKUP_USER= -BACKUP_GROUP= -` +var invalidPermissionsTemplate = overrideTemplateValues(baseInstallTemplate, map[string]string{ + "SET_BACKUP_PERMISSIONS": "true", + "BACKUP_USER": "", + "BACKUP_GROUP": "", +}) func TestMigrateLegacyEnvRollsBackOnValidationFailure(t *testing.T) { withTemplate(t, invalidPermissionsTemplate, func() { diff --git a/internal/identity/identity.go b/internal/identity/identity.go index 2e2621e..db1f742 100644 --- a/internal/identity/identity.go +++ b/internal/identity/identity.go @@ -39,8 +39,11 @@ type Info struct { } var ( - hostnameFunc = os.Hostname - readFirstLineFunc = readFirstLine + hostnameFunc = os.Hostname + readFirstLineFunc = readFirstLine + writeIdentityFileWithContextWriteFile = os.WriteFile + writeIdentityFileWithContextChmod = os.Chmod + writeIdentityFileWithContextSetImmutable = setImmutableAttributeWithContext ) // Detect resolves the server identity (ID + MAC address) and ensures persistence. @@ -810,28 +813,39 @@ func writeIdentityFile(path, content string, logger *logging.Logger) error { return writeIdentityFileWithContext(context.Background(), path, content, logger) } -func writeIdentityFileWithContext(ctx context.Context, path, content string, logger *logging.Logger) error { +func writeIdentityFileWithContext(ctx context.Context, path, content string, logger *logging.Logger) (err error) { if ctx == nil { ctx = context.Background() } logDebug(logger, "Identity: writeIdentityFile: start path=%s contentBytes=%d", path, len(content)) // Ensure file is writable even if immutable was previously set - if err := setImmutableAttributeWithContext(ctx, path, false, logger); err != nil { + if err := writeIdentityFileWithContextSetImmutable(ctx, path, false, logger); err != nil { return err } + defer func() { + lockErr := writeIdentityFileWithContextSetImmutable(context.Background(), path, true, logger) + if lockErr == nil { + return + } + logDebug(logger, "Identity: writeIdentityFile: failed to restore immutable attribute: %v", lockErr) + if err == nil { + err = lockErr + } + }() - if err := os.WriteFile(path, []byte(content), 0o600); err != nil { - logDebug(logger, "Identity: writeIdentityFile: os.WriteFile failed: %v", err) + if err := ctx.Err(); err != nil { + logDebug(logger, "Identity: writeIdentityFile: context canceled before write for %s: %v", path, err) return err } - if err := os.Chmod(path, 0o600); err != nil { - logDebug(logger, "Identity: writeIdentityFile: os.Chmod failed: %v", err) + if err := writeIdentityFileWithContextWriteFile(path, []byte(content), 0o600); err != nil { + logDebug(logger, "Identity: writeIdentityFile: os.WriteFile failed: %v", err) return err } - if err := setImmutableAttributeWithContext(ctx, path, true, logger); err != nil { + if err := writeIdentityFileWithContextChmod(path, 0o600); err != nil { + logDebug(logger, "Identity: writeIdentityFile: os.Chmod failed: %v", err) return err } diff --git a/internal/identity/identity_test.go b/internal/identity/identity_test.go index 2171156..9ceac33 100644 --- a/internal/identity/identity_test.go +++ b/internal/identity/identity_test.go @@ -424,6 +424,135 @@ func TestWriteIdentityFileCreatesFileWith0600(t *testing.T) { } } +func TestWriteIdentityFileWithContext_RelocksOnCanceledContextBeforeWrite(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "id.conf") + const initialContent = "initial" + if err := os.WriteFile(path, []byte(initialContent), 0o600); err != nil { + t.Fatalf("seed identity file: %v", err) + } + + origSetImmutable := writeIdentityFileWithContextSetImmutable + origWriteFile := writeIdentityFileWithContextWriteFile + origChmod := writeIdentityFileWithContextChmod + t.Cleanup(func() { + writeIdentityFileWithContextSetImmutable = origSetImmutable + writeIdentityFileWithContextWriteFile = origWriteFile + writeIdentityFileWithContextChmod = origChmod + }) + + ctx, cancel := context.WithCancel(context.Background()) + type immutableCall struct { + ctx context.Context + enable bool + } + var calls []immutableCall + writeIdentityFileWithContextSetImmutable = func(callCtx context.Context, path string, enable bool, logger *logging.Logger) error { + calls = append(calls, immutableCall{ctx: callCtx, enable: enable}) + if !enable { + cancel() + } + return nil + } + writeIdentityFileWithContextWriteFile = func(path string, data []byte, perm os.FileMode) error { + t.Fatal("writeIdentityFileWithContext should not write after context cancellation") + return nil + } + writeIdentityFileWithContextChmod = func(path string, mode os.FileMode) error { + t.Fatal("writeIdentityFileWithContext should not chmod after context cancellation") + return nil + } + + err := writeIdentityFileWithContext(ctx, path, "updated", nil) + if !errors.Is(err, context.Canceled) { + t.Fatalf("err=%v; want %v", err, context.Canceled) + } + if len(calls) != 2 { + t.Fatalf("immutable call count = %d, want 2", len(calls)) + } + if calls[0].ctx != ctx || calls[0].enable { + t.Fatalf("first immutable call = %+v, want unlock with original ctx", calls[0]) + } + if !calls[1].enable { + t.Fatalf("second immutable call = %+v, want relock", calls[1]) + } + if calls[1].ctx == ctx { + t.Fatalf("expected relock to use non-cancelable context") + } + if calls[1].ctx.Err() != nil { + t.Fatalf("expected relock context to be active, got %v", calls[1].ctx.Err()) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read identity file: %v", err) + } + if string(data) != initialContent { + t.Fatalf("file content = %q, want %q", string(data), initialContent) + } +} + +func TestWriteIdentityFileWithContext_RelocksOnWriteError(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "id.conf") + const initialContent = "initial" + if err := os.WriteFile(path, []byte(initialContent), 0o600); err != nil { + t.Fatalf("seed identity file: %v", err) + } + + origSetImmutable := writeIdentityFileWithContextSetImmutable + origWriteFile := writeIdentityFileWithContextWriteFile + origChmod := writeIdentityFileWithContextChmod + t.Cleanup(func() { + writeIdentityFileWithContextSetImmutable = origSetImmutable + writeIdentityFileWithContextWriteFile = origWriteFile + writeIdentityFileWithContextChmod = origChmod + }) + + type immutableCall struct { + ctx context.Context + enable bool + } + var calls []immutableCall + writeIdentityFileWithContextSetImmutable = func(callCtx context.Context, path string, enable bool, logger *logging.Logger) error { + calls = append(calls, immutableCall{ctx: callCtx, enable: enable}) + return nil + } + writeErr := errors.New("write failed") + writeIdentityFileWithContextWriteFile = func(path string, data []byte, perm os.FileMode) error { + return writeErr + } + writeIdentityFileWithContextChmod = func(path string, mode os.FileMode) error { + t.Fatal("writeIdentityFileWithContext should not chmod after write failure") + return nil + } + + err := writeIdentityFileWithContext(context.Background(), path, "updated", nil) + if !errors.Is(err, writeErr) { + t.Fatalf("err=%v; want %v", err, writeErr) + } + if len(calls) != 2 { + t.Fatalf("immutable call count = %d, want 2", len(calls)) + } + if calls[0].enable { + t.Fatalf("first immutable call = %+v, want unlock", calls[0]) + } + if !calls[1].enable { + t.Fatalf("second immutable call = %+v, want relock", calls[1]) + } + if calls[1].ctx.Err() != nil { + t.Fatalf("expected relock context to be active, got %v", calls[1].ctx.Err()) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read identity file: %v", err) + } + if string(data) != initialContent { + t.Fatalf("file content = %q, want %q", string(data), initialContent) + } +} + func TestHexToDecimalValidAndInvalid(t *testing.T) { if got := hexToDecimal("ff"); got != "255" { t.Fatalf("hexToDecimal(\"ff\") = %q, want %q", got, "255") diff --git a/internal/notify/email_relay.go b/internal/notify/email_relay.go index 8d99bd9..dd17a19 100644 --- a/internal/notify/email_relay.go +++ b/internal/notify/email_relay.go @@ -52,6 +52,7 @@ type EmailRelayPayload struct { // EmailRelayResponse represents the response from the cloud worker type EmailRelayResponse struct { Success bool `json:"success"` + Code string `json:"code,omitempty"` Message string `json:"message,omitempty"` Error string `json:"error,omitempty"` } @@ -139,26 +140,44 @@ func sendViaCloudRelay( continue } + // Log raw response body for all status codes (aids future diagnosis) + logger.Debug("Cloud relay: HTTP %d response (%d bytes): %s", resp.StatusCode, len(body), string(body)) + // Handle HTTP status codes switch resp.StatusCode { case 200: - // Success + // Parse response body to verify actual delivery success. + // Empty body or non-JSON is treated as success for backward compatibility. + if len(body) > 0 { + var apiResp EmailRelayResponse + if err := json.Unmarshal(body, &apiResp); err == nil && !apiResp.Success { + errMsg := apiResp.Error + if errMsg == "" { + errMsg = apiResp.Message + } + if errMsg == "" { + errMsg = "relay returned success=false with no details" + } + return fmt.Errorf("cloud relay rejected email (HTTP 200 but success=false): %s", errMsg) + } + } logger.Debug("Cloud relay: email sent successfully") return nil case 400: // Bad request - don't retry - var apiResp EmailRelayResponse - _ = json.Unmarshal(body, &apiResp) - return fmt.Errorf("bad request (HTTP 400): %s", apiResp.Error) + apiResp := parseRelayResponse(body) + return fmt.Errorf("bad request (HTTP 400): %s", relayResponseDetail(apiResp, "bad request")) case 401: // Authentication failed - don't retry - return fmt.Errorf("authentication failed (HTTP 401): invalid token") + apiResp := parseRelayResponse(body) + return fmt.Errorf("authentication failed (HTTP 401): %s", classifyUnauthorizedRelayError(apiResp)) case 403: - // Forbidden - HMAC validation failed - don't retry - return fmt.Errorf("forbidden (HTTP 403): HMAC signature validation failed") + // Forbidden - classify specific cause when available. + apiResp := parseRelayResponse(body) + return fmt.Errorf("forbidden (HTTP 403): %s", classifyForbiddenRelayError(apiResp)) case 429: // Rate limit exceeded - show detailed message @@ -218,6 +237,75 @@ func generateHMACSignature(payload []byte, secret string) string { return hex.EncodeToString(h.Sum(nil)) } +func parseRelayResponse(body []byte) EmailRelayResponse { + var apiResp EmailRelayResponse + _ = json.Unmarshal(body, &apiResp) + return apiResp +} + +func relayResponseDetail(apiResp EmailRelayResponse, fallback string) string { + detail := strings.TrimSpace(apiResp.Error) + if detail == "" { + detail = strings.TrimSpace(apiResp.Message) + } + if detail == "" { + detail = fallback + } + return detail +} + +func classifyUnauthorizedRelayError(apiResp EmailRelayResponse) string { + switch strings.TrimSpace(apiResp.Code) { + case "MISSING_SIGNATURE": + return "missing signature" + case "MISSING_TOKEN": + return "missing bearer token" + } + + detail := relayResponseDetail(apiResp, "unauthorized") + lower := strings.ToLower(detail) + + switch { + case strings.Contains(lower, "missing signature"): + return "missing signature" + case strings.Contains(lower, "missing bearer token"): + return "missing bearer token" + default: + return detail + } +} + +func classifyForbiddenRelayError(apiResp EmailRelayResponse) string { + switch strings.TrimSpace(apiResp.Code) { + case "INVALID_SIGNATURE": + return "HMAC signature validation failed" + case "INVALID_TOKEN": + return "invalid token" + case "INVALID_SCRIPT_VERSION": + return "missing or invalid script version" + case "FROM_OVERRIDE_ATTEMPT": + return "from address override not allowed" + } + + detail := relayResponseDetail(apiResp, "forbidden") + lower := strings.ToLower(detail) + + switch { + case lower == "forbidden": + return "invalid token" + case strings.Contains(lower, "signature"): + return "HMAC signature validation failed" + case strings.Contains(lower, "script version"): + return "missing or invalid script version" + case strings.Contains(lower, "from address override"): + return "from address override not allowed" + case strings.Contains(lower, "invalid token"): + return "invalid token" + default: + return detail + } +} + // isQuotaLimit returns true if the rate-limit detail clearly indicates a quota cap // (e.g., daily per-server quota) that won't succeed with retries. func isQuotaLimit(detail string) bool { diff --git a/internal/notify/email_relay_test.go b/internal/notify/email_relay_test.go index 97468d4..94a9b14 100644 --- a/internal/notify/email_relay_test.go +++ b/internal/notify/email_relay_test.go @@ -8,6 +8,7 @@ import ( "errors" "net/http" "net/http/httptest" + "strings" "sync/atomic" "testing" "time" @@ -112,12 +113,13 @@ func TestFormatCloudPathDisplay(t *testing.T) { func TestSendViaCloudRelay_StatusHandling(t *testing.T) { type testCase struct { - name string - statusCode int - body string - config CloudRelayConfig - expectErr bool - expectCalls int + name string + statusCode int + body string + config CloudRelayConfig + expectErr bool + expectErrContains string + expectCalls int } baseConfig := CloudRelayConfig{ @@ -133,10 +135,17 @@ func TestSendViaCloudRelay_StatusHandling(t *testing.T) { tests := []testCase{ {name: "200-success", statusCode: 200, body: `{"success":true}`, config: baseConfig, expectErr: false, expectCalls: 1}, - {name: "400-bad-request", statusCode: 400, body: `{"error":"bad payload"}`, config: baseConfig, expectErr: true, expectCalls: 1}, - {name: "401-auth", statusCode: 401, body: ``, config: baseConfig, expectErr: true, expectCalls: 1}, - {name: "403-forbidden", statusCode: 403, body: ``, config: baseConfig, expectErr: true, expectCalls: 1}, - {name: "429-quota", statusCode: 429, body: `{"message":"quota per server exceeded"}`, config: baseConfig, expectErr: true, expectCalls: 1}, + {name: "200-success-false", statusCode: 200, body: `{"success":false,"error":"downstream email service failed"}`, config: baseConfig, expectErr: true, expectErrContains: "downstream email service failed", expectCalls: 1}, + {name: "200-empty-body", statusCode: 200, body: ``, config: baseConfig, expectErr: false, expectCalls: 1}, + {name: "200-invalid-json", statusCode: 200, body: `OK`, config: baseConfig, expectErr: false, expectCalls: 1}, + {name: "400-bad-request", statusCode: 400, body: `{"error":"bad payload"}`, config: baseConfig, expectErr: true, expectErrContains: "bad payload", expectCalls: 1}, + {name: "401-missing-signature-code", statusCode: 401, body: `{"code":"MISSING_SIGNATURE","error":"Missing signature"}`, config: baseConfig, expectErr: true, expectErrContains: "missing signature", expectCalls: 1}, + {name: "403-invalid-signature-code", statusCode: 403, body: `{"code":"INVALID_SIGNATURE","error":"Invalid signature"}`, config: baseConfig, expectErr: true, expectErrContains: "HMAC signature validation failed", expectCalls: 1}, + {name: "403-invalid-token-code", statusCode: 403, body: `{"code":"INVALID_TOKEN","error":"Invalid token"}`, config: baseConfig, expectErr: true, expectErrContains: "invalid token", expectCalls: 1}, + {name: "403-invalid-script-version-code", statusCode: 403, body: `{"code":"INVALID_SCRIPT_VERSION","error":"Missing or invalid script version"}`, config: baseConfig, expectErr: true, expectErrContains: "script version", expectCalls: 1}, + {name: "403-from-override-code", statusCode: 403, body: `{"code":"FROM_OVERRIDE_ATTEMPT","error":"From address override not allowed"}`, config: baseConfig, expectErr: true, expectErrContains: "from address override not allowed", expectCalls: 1}, + {name: "403-legacy-forbidden", statusCode: 403, body: `{"error":"Forbidden"}`, config: baseConfig, expectErr: true, expectErrContains: "invalid token", expectCalls: 1}, + {name: "429-quota", statusCode: 429, body: `{"message":"quota per server exceeded"}`, config: baseConfig, expectErr: true, expectErrContains: "rate limit exceeded", expectCalls: 1}, } for _, tt := range tests { @@ -169,6 +178,11 @@ func TestSendViaCloudRelay_StatusHandling(t *testing.T) { if !tt.expectErr && err != nil { t.Fatalf("unexpected error: %v", err) } + if tt.expectErrContains != "" { + if err == nil || !strings.Contains(err.Error(), tt.expectErrContains) { + t.Fatalf("expected error containing %q, got %v", tt.expectErrContains, err) + } + } if callCount != tt.expectCalls { t.Fatalf("expected %d calls, got %d", tt.expectCalls, callCount) } diff --git a/internal/notify/webhook_test.go b/internal/notify/webhook_test.go index 841c828..8550333 100644 --- a/internal/notify/webhook_test.go +++ b/internal/notify/webhook_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync/atomic" "testing" "time" @@ -280,15 +281,18 @@ func TestWebhookNotifier_Send_Success(t *testing.T) { func TestWebhookNotifier_SendToEndpoint_StopsRetryingWhenContextCanceled(t *testing.T) { logger := logging.New(types.LogLevelDebug, false) - attempts := 0 ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + var attempts atomic.Int32 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - attempts++ - if attempts == 1 { - cancel() + if attempts.Add(1) == 1 { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"temporary"}`)) + time.AfterFunc(10*time.Millisecond, cancel) + return } w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(`{"error":"temporary"}`)) + _, _ = w.Write([]byte(`{"error":"temporary"}`)) })) defer server.Close() @@ -297,7 +301,7 @@ func TestWebhookNotifier_SendToEndpoint_StopsRetryingWhenContextCanceled(t *test DefaultFormat: "generic", Timeout: 30, MaxRetries: 3, - RetryDelay: 0, + RetryDelay: 1, Endpoints: []config.WebhookEndpoint{ { Name: "test-webhook", @@ -318,8 +322,8 @@ func TestWebhookNotifier_SendToEndpoint_StopsRetryingWhenContextCanceled(t *test if !errors.Is(err, context.Canceled) { t.Fatalf("expected context cancellation error, got %v", err) } - if attempts != 1 { - t.Fatalf("expected 1 attempt after cancellation, got %d", attempts) + if got := attempts.Load(); got != 1 { + t.Fatalf("expected 1 attempt after cancellation, got %d", got) } } diff --git a/internal/orchestrator/decrypt_tui_e2e_helpers_test.go b/internal/orchestrator/decrypt_tui_e2e_helpers_test.go index e7b78c1..84cf85d 100644 --- a/internal/orchestrator/decrypt_tui_e2e_helpers_test.go +++ b/internal/orchestrator/decrypt_tui_e2e_helpers_test.go @@ -10,6 +10,7 @@ import ( "io" "os" "path/filepath" + "strings" "sync" "testing" "time" @@ -26,11 +27,31 @@ import ( var decryptTUIE2EMu sync.Mutex +type notifyingSimulationScreen struct { + tcell.SimulationScreen + notify func() +} + +func (s *notifyingSimulationScreen) Show() { + s.SimulationScreen.Show() + if s.notify != nil { + s.notify() + } +} + +func (s *notifyingSimulationScreen) Sync() { + s.SimulationScreen.Sync() + if s.notify != nil { + s.notify() + } +} + type timedSimKey struct { - Key tcell.Key - R rune - Mod tcell.ModMask - Wait time.Duration + Key tcell.Key + R rune + Mod tcell.ModMask + Wait time.Duration + WaitForText string } type decryptTUIFixture struct { @@ -61,34 +82,86 @@ func withTimedSimAppSequence(t *testing.T, keys []timedSimKey) { decryptTUIE2EMu.Unlock() }) - screen := tcell.NewSimulationScreen("UTF-8") - if err := screen.Init(); err != nil { + baseScreen := tcell.NewSimulationScreen("UTF-8") + if err := baseScreen.Init(); err != nil { t.Fatalf("screen.Init: %v", err) } - screen.SetSize(120, 40) + baseScreen.SetSize(120, 40) + + type timedSimScreenState struct { + signature string + text string + } + + screenStateCh := make(chan struct{}, 1) + var appMu sync.RWMutex + var currentApp *tui.App + screen := ¬ifyingSimulationScreen{ + SimulationScreen: baseScreen, + notify: func() { + select { + case screenStateCh <- struct{}{}: + default: + } + }, + } var once sync.Once newTUIApp = func() *tui.App { app := tui.NewApp() + appMu.Lock() + currentApp = app + appMu.Unlock() app.SetScreen(screen) once.Do(func() { injectWG.Add(1) go func() { defer injectWG.Done() + var lastInjectedState string + + currentScreenState := func() timedSimScreenState { + appMu.RLock() + app := currentApp + appMu.RUnlock() + + var focus any + if app != nil { + focus = app.GetFocus() + } + + return timedSimScreenState{ + signature: timedSimScreenStateSignature(screen, focus), + text: timedSimScreenText(screen), + } + } + + waitForScreenText := func(expected string) bool { + expected = strings.TrimSpace(expected) + for { + current := currentScreenState() + if current.signature != "" { + if (expected == "" || strings.Contains(current.text, expected)) && + (lastInjectedState == "" || current.signature != lastInjectedState) { + return true + } + } - for _, k := range keys { - if k.Wait > 0 { - timer := time.NewTimer(k.Wait) select { case <-done: - if !timer.Stop() { - <-timer.C - } + return false + case <-screenStateCh: + } + } + } + + for _, k := range keys { + if k.Wait > 0 { + if !waitForScreenText(k.WaitForText) { return - case <-timer.C: } } + current := currentScreenState() mod := k.Mod if mod == 0 { mod = tcell.ModNone @@ -99,6 +172,7 @@ func withTimedSimAppSequence(t *testing.T, keys []timedSimKey) { default: } screen.InjectKey(k.Key, k.R, mod) + lastInjectedState = current.signature } }() }) @@ -107,6 +181,42 @@ func withTimedSimAppSequence(t *testing.T, keys []timedSimKey) { } } +func timedSimScreenStateSignature(screen tcell.SimulationScreen, focus any) string { + cells, width, height := screen.GetContents() + cursorX, cursorY, cursorVisible := screen.GetCursor() + + sum := sha256.New() + fmt.Fprintf(sum, "size:%d:%d cursor:%d:%d:%t focus:%T:%p\n", width, height, cursorX, cursorY, cursorVisible, focus, focus) + for _, cell := range cells { + fg, bg, attr := cell.Style.Decompose() + fmt.Fprintf(sum, "%x/%d/%d/%d;", cell.Bytes, fg, bg, attr) + } + return hex.EncodeToString(sum.Sum(nil)) +} + +func timedSimScreenText(screen tcell.SimulationScreen) string { + cells, width, height := screen.GetContents() + if width <= 0 || height <= 0 || len(cells) < width*height { + return "" + } + + var b strings.Builder + for y := 0; y < height; y++ { + row := make([]byte, 0, width) + for x := 0; x < width; x++ { + cell := cells[y*width+x] + if len(cell.Bytes) == 0 { + row = append(row, ' ') + continue + } + row = append(row, cell.Bytes...) + } + b.WriteString(strings.TrimRight(string(row), " ")) + b.WriteByte('\n') + } + return b.String() +} + func createDecryptTUIEncryptedFixture(t *testing.T) *decryptTUIFixture { t.Helper() @@ -202,23 +312,24 @@ func createDecryptTUIEncryptedFixture(t *testing.T) *decryptTUIFixture { func successDecryptTUISequence(secret string) []timedSimKey { keys := []timedSimKey{ - {Key: tcell.KeyEnter, Wait: 1 * time.Second}, - {Key: tcell.KeyEnter, Wait: 750 * time.Millisecond}, + {Key: tcell.KeyEnter, Wait: 1 * time.Second, WaitForText: "Select backup source"}, + {Key: tcell.KeyEnter, Wait: 750 * time.Millisecond, WaitForText: "Select backup"}, } for _, r := range secret { keys = append(keys, timedSimKey{ - Key: tcell.KeyRune, - R: r, - Wait: 35 * time.Millisecond, + Key: tcell.KeyRune, + R: r, + Wait: 35 * time.Millisecond, + WaitForText: "Decrypt key", }) } keys = append(keys, - timedSimKey{Key: tcell.KeyTab, Wait: 150 * time.Millisecond}, - timedSimKey{Key: tcell.KeyEnter, Wait: 100 * time.Millisecond}, - timedSimKey{Key: tcell.KeyTab, Wait: 500 * time.Millisecond}, - timedSimKey{Key: tcell.KeyEnter, Wait: 100 * time.Millisecond}, + timedSimKey{Key: tcell.KeyTab, Wait: 150 * time.Millisecond, WaitForText: "Decrypt key"}, + timedSimKey{Key: tcell.KeyEnter, Wait: 100 * time.Millisecond, WaitForText: "Decrypt key"}, + timedSimKey{Key: tcell.KeyTab, Wait: 500 * time.Millisecond, WaitForText: "Destination directory"}, + timedSimKey{Key: tcell.KeyEnter, Wait: 100 * time.Millisecond, WaitForText: "Destination directory"}, ) return keys @@ -226,11 +337,11 @@ func successDecryptTUISequence(secret string) []timedSimKey { func abortDecryptTUISequence() []timedSimKey { return []timedSimKey{ - {Key: tcell.KeyEnter, Wait: 1 * time.Second}, - {Key: tcell.KeyEnter, Wait: 750 * time.Millisecond}, - {Key: tcell.KeyRune, R: '0', Wait: 500 * time.Millisecond}, - {Key: tcell.KeyTab, Wait: 150 * time.Millisecond}, - {Key: tcell.KeyEnter, Wait: 100 * time.Millisecond}, + {Key: tcell.KeyEnter, Wait: 1 * time.Second, WaitForText: "Select backup source"}, + {Key: tcell.KeyEnter, Wait: 750 * time.Millisecond, WaitForText: "Select backup"}, + {Key: tcell.KeyRune, R: '0', Wait: 500 * time.Millisecond, WaitForText: "Decrypt key"}, + {Key: tcell.KeyTab, Wait: 150 * time.Millisecond, WaitForText: "Decrypt key"}, + {Key: tcell.KeyEnter, Wait: 100 * time.Millisecond, WaitForText: "Decrypt key"}, } } diff --git a/internal/orchestrator/decrypt_workflow_ui.go b/internal/orchestrator/decrypt_workflow_ui.go index 7adc905..2cbd988 100644 --- a/internal/orchestrator/decrypt_workflow_ui.go +++ b/internal/orchestrator/decrypt_workflow_ui.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "path/filepath" + "reflect" "strings" "filippo.io/age" @@ -15,11 +16,25 @@ import ( "github.com/tis24dev/proxsave/internal/logging" ) +func isNilInterface(v any) bool { + if v == nil { + return true + } + + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice: + return rv.IsNil() + default: + return false + } +} + func selectBackupCandidateWithUI(ctx context.Context, ui BackupSelectionUI, cfg *config.Config, logger *logging.Logger, requireEncrypted bool) (candidate *decryptCandidate, err error) { done := logging.DebugStart(logger, "select backup candidate (ui)", "requireEncrypted=%v", requireEncrypted) defer func() { done(err) }() - if ui == nil { + if isNilInterface(ui) { return nil, fmt.Errorf("backup selection UI not available") } @@ -105,6 +120,10 @@ func selectBackupCandidateWithUI(ctx context.Context, ui BackupSelectionUI, cfg } func ensureWritablePathWithUI(ctx context.Context, ui DecryptWorkflowUI, targetPath, description string) (string, error) { + if isNilInterface(ui) { + return "", fmt.Errorf("decrypt workflow UI not available") + } + current := filepath.Clean(targetPath) failure := "" @@ -184,7 +203,7 @@ func preparePlainBundleWithUI(ctx context.Context, cand *decryptCandidate, versi if cand == nil || cand.Manifest == nil { return nil, fmt.Errorf("invalid backup candidate") } - if ui == nil { + if isNilInterface(ui) { return nil, fmt.Errorf("decrypt workflow UI not available") } @@ -202,7 +221,7 @@ func runDecryptWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l if logger == nil { logger = logging.GetDefaultLogger() } - if ui == nil { + if isNilInterface(ui) { return fmt.Errorf("decrypt workflow UI not available") } done := logging.DebugStart(logger, "decrypt workflow (ui)", "version=%s", version) diff --git a/internal/orchestrator/decrypt_workflow_ui_test.go b/internal/orchestrator/decrypt_workflow_ui_test.go index 27a07c7..bea77a4 100644 --- a/internal/orchestrator/decrypt_workflow_ui_test.go +++ b/internal/orchestrator/decrypt_workflow_ui_test.go @@ -273,6 +273,73 @@ func TestPreparePlainBundleWithUIRejectsMissingUI(t *testing.T) { } } +func TestSelectBackupCandidateWithUIRejectsTypedNilUI(t *testing.T) { + logger := logging.New(types.LogLevelError, false) + cfg := &config.Config{} + + var typedNil *fakeDecryptWorkflowUI + var ui BackupSelectionUI = typedNil + + _, err := selectBackupCandidateWithUI(context.Background(), ui, cfg, logger, false) + if err == nil { + t.Fatal("expected error for typed-nil UI") + } + if got, want := err.Error(), "backup selection UI not available"; got != want { + t.Fatalf("error=%q, want %q", got, want) + } +} + +func TestEnsureWritablePathWithUIRejectsTypedNilUI(t *testing.T) { + var typedNil *fakeDecryptWorkflowUI + var ui DecryptWorkflowUI = typedNil + + _, err := ensureWritablePathWithUI(context.Background(), ui, mustCreateExistingFile(t), "bundle") + if err == nil { + t.Fatal("expected error for typed-nil UI") + } + if got, want := err.Error(), "decrypt workflow UI not available"; got != want { + t.Fatalf("error=%q, want %q", got, want) + } +} + +func TestPreparePlainBundleWithUIRejectsTypedNilUI(t *testing.T) { + logger := logging.New(types.LogLevelError, false) + tmp := t.TempDir() + rawArchive := filepath.Join(tmp, "backup.tar") + rawMetadata := rawArchive + ".metadata" + rawChecksum := rawArchive + ".sha256" + + if err := os.WriteFile(rawArchive, []byte("payload-data"), 0o640); err != nil { + t.Fatalf("write archive: %v", err) + } + if err := os.WriteFile(rawMetadata, []byte(`{"manifest":true}`), 0o640); err != nil { + t.Fatalf("write metadata: %v", err) + } + if err := os.WriteFile(rawChecksum, checksumLineForBytes("backup.tar", []byte("payload-data")), 0o640); err != nil { + t.Fatalf("write checksum: %v", err) + } + + cand := &decryptCandidate{ + Manifest: &backup.Manifest{ + ArchivePath: rawArchive, + EncryptionMode: "none", + CreatedAt: time.Now(), + Hostname: "node1", + }, + Source: sourceRaw, + RawArchivePath: rawArchive, + RawMetadataPath: rawMetadata, + RawChecksumPath: rawChecksum, + DisplayBase: "test-backup", + } + + var typedNil *countingSecretPrompter + + if _, err := preparePlainBundleWithUI(context.Background(), cand, "1.0.0", logger, typedNil); err == nil { + t.Fatal("expected error for typed-nil UI") + } +} + func TestRunDecryptWorkflowWithUIRejectsMissingUI(t *testing.T) { logger := logging.New(types.LogLevelError, false) cfg := &config.Config{} @@ -286,6 +353,22 @@ func TestRunDecryptWorkflowWithUIRejectsMissingUI(t *testing.T) { } } +func TestRunDecryptWorkflowWithUIRejectsTypedNilUI(t *testing.T) { + logger := logging.New(types.LogLevelError, false) + cfg := &config.Config{} + + var typedNil *fakeDecryptWorkflowUI + var ui DecryptWorkflowUI = typedNil + + err := runDecryptWorkflowWithUI(context.Background(), cfg, logger, "1.0.0", ui) + if err == nil { + t.Fatal("expected error for typed-nil UI") + } + if got, want := err.Error(), "decrypt workflow UI not available"; got != want { + t.Fatalf("error=%q, want %q", got, want) + } +} + func mustCreateExistingFile(t *testing.T) string { t.Helper() diff --git a/internal/orchestrator/restore_ha_additional_test.go b/internal/orchestrator/restore_ha_additional_test.go index bca75e9..4bee40e 100644 --- a/internal/orchestrator/restore_ha_additional_test.go +++ b/internal/orchestrator/restore_ha_additional_test.go @@ -384,25 +384,32 @@ func TestDisarmHARollback_RemovesMarkerAndStopsTimer(t *testing.T) { } func TestMaybeApplyPVEHAWithUI_BranchCoverage(t *testing.T) { - env := setupHATestEnv(t) - stageWithHA := env.stageRoot + "/etc/pve/ha/resources.cfg" - if err := env.fs.AddFile(stageWithHA, []byte("res\n")); err != nil { - t.Fatalf("add staged HA config: %v", err) + newEnv := func(t *testing.T) *haTestEnv { + t.Helper() + env := setupHATestEnv(t) + stageWithHA := env.stageRoot + "/etc/pve/ha/resources.cfg" + if err := env.fs.AddFile(stageWithHA, []byte("res\n")); err != nil { + t.Fatalf("add staged HA config: %v", err) + } + return env } t.Run("nil plan returns nil", func(t *testing.T) { + env := newEnv(t) if err := maybeApplyPVEHAWithUI(context.Background(), &fakeRestoreWorkflowUI{}, env.logger, nil, nil, nil, env.stageRoot, false); err != nil { t.Fatalf("expected nil, got %v", err) } }) t.Run("errors when ui missing", func(t *testing.T) { + env := newEnv(t) if err := maybeApplyPVEHAWithUI(context.Background(), nil, env.logger, env.plan, nil, nil, env.stageRoot, false); err == nil { t.Fatalf("expected error") } }) t.Run("skips on non-system restore fs", func(t *testing.T) { + env := newEnv(t) haIsRealRestoreFS = func(fs FS) bool { return false } t.Cleanup(func() { haIsRealRestoreFS = func(fs FS) bool { return true } }) @@ -412,6 +419,7 @@ func TestMaybeApplyPVEHAWithUI_BranchCoverage(t *testing.T) { }) t.Run("dry run, non-root, empty stage and cluster restore all skip", func(t *testing.T) { + env := newEnv(t) if err := maybeApplyPVEHAWithUI(context.Background(), &fakeRestoreWorkflowUI{}, env.logger, env.plan, nil, nil, env.stageRoot, true); err != nil { t.Fatalf("expected nil on dry run, got %v", err) } @@ -434,6 +442,7 @@ func TestMaybeApplyPVEHAWithUI_BranchCoverage(t *testing.T) { }) t.Run("skips when stage has no HA config or mount unavailable", func(t *testing.T) { + env := newEnv(t) if err := maybeApplyPVEHAWithUI(context.Background(), &fakeRestoreWorkflowUI{}, env.logger, env.plan, nil, nil, "/empty", false); err != nil { t.Fatalf("expected nil when stage has no HA config, got %v", err) } @@ -446,6 +455,7 @@ func TestMaybeApplyPVEHAWithUI_BranchCoverage(t *testing.T) { }) t.Run("stage detection and initial prompt errors are propagated", func(t *testing.T) { + env := newEnv(t) restoreFS = statFailFS{ FS: env.fs, failPath: env.stageRoot + "/etc/pve/ha/resources.cfg", @@ -472,6 +482,7 @@ func TestMaybeApplyPVEHAWithUI_BranchCoverage(t *testing.T) { }) t.Run("user skips apply", func(t *testing.T) { + env := newEnv(t) ui := &scriptedRestoreWorkflowUI{ fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: []scriptedConfirmAction{{ok: false}}, @@ -482,6 +493,7 @@ func TestMaybeApplyPVEHAWithUI_BranchCoverage(t *testing.T) { }) t.Run("proceed without rollback applies and returns", func(t *testing.T) { + env := newEnv(t) haArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*haRollbackHandle, error) { t.Fatalf("unexpected rollback arm") return nil, nil @@ -508,6 +520,7 @@ func TestMaybeApplyPVEHAWithUI_BranchCoverage(t *testing.T) { }) t.Run("full rollback and no rollback prompts can be declined or fail", func(t *testing.T) { + env := newEnv(t) ui := &scriptedRestoreWorkflowUI{ fakeRestoreWorkflowUI: &fakeRestoreWorkflowUI{}, script: []scriptedConfirmAction{ @@ -555,6 +568,7 @@ func TestMaybeApplyPVEHAWithUI_BranchCoverage(t *testing.T) { }) t.Run("full rollback backup is used when HA rollback backup missing", func(t *testing.T) { + env := newEnv(t) markerPath := "/tmp/proxsave/ha-full.marker" disarmed := false haArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*haRollbackHandle, error) { @@ -596,6 +610,7 @@ func TestMaybeApplyPVEHAWithUI_BranchCoverage(t *testing.T) { }) t.Run("no changes applied disarms rollback", func(t *testing.T) { + env := newEnv(t) markerPath := "/tmp/proxsave/ha-empty.marker" disarmed := false haArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*haRollbackHandle, error) { @@ -630,6 +645,7 @@ func TestMaybeApplyPVEHAWithUI_BranchCoverage(t *testing.T) { }) t.Run("apply errors are propagated", func(t *testing.T) { + env := newEnv(t) haApplyFromStage = func(logger *logging.Logger, stageRoot string) ([]string, error) { return nil, fmt.Errorf("boom") } @@ -647,9 +663,13 @@ func TestMaybeApplyPVEHAWithUI_BranchCoverage(t *testing.T) { } func TestMaybeApplyPVEHAWithUI_CommitOutcomes(t *testing.T) { - env := setupHATestEnv(t) - if err := env.fs.AddFile(env.stageRoot+"/etc/pve/ha/resources.cfg", []byte("res\n")); err != nil { - t.Fatalf("add staged HA config: %v", err) + newEnv := func(t *testing.T) *haTestEnv { + t.Helper() + env := setupHATestEnv(t) + if err := env.fs.AddFile(env.stageRoot+"/etc/pve/ha/resources.cfg", []byte("res\n")); err != nil { + t.Fatalf("add staged HA config: %v", err) + } + return env } baseRollback := &SafetyBackupResult{BackupPath: "/backups/ha.tgz"} @@ -667,6 +687,7 @@ func TestMaybeApplyPVEHAWithUI_CommitOutcomes(t *testing.T) { } t.Run("rollback choice returns typed error", func(t *testing.T) { + env := newEnv(t) markerPath := "/tmp/proxsave/ha-rollback.marker" haArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*haRollbackHandle, error) { return makeHandle(markerPath, nowRestore()), nil @@ -696,6 +717,7 @@ func TestMaybeApplyPVEHAWithUI_CommitOutcomes(t *testing.T) { }) t.Run("commit prompt abort returns abort error", func(t *testing.T) { + env := newEnv(t) haArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*haRollbackHandle, error) { return makeHandle("/tmp/proxsave/ha-abort.marker", nowRestore()), nil } @@ -717,6 +739,7 @@ func TestMaybeApplyPVEHAWithUI_CommitOutcomes(t *testing.T) { }) t.Run("commit prompt failure returns typed error", func(t *testing.T) { + env := newEnv(t) haArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*haRollbackHandle, error) { return makeHandle("/tmp/proxsave/ha-fail.marker", nowRestore()), nil } @@ -738,6 +761,7 @@ func TestMaybeApplyPVEHAWithUI_CommitOutcomes(t *testing.T) { }) t.Run("expired rollback handle returns typed error without commit prompt", func(t *testing.T) { + env := newEnv(t) haArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*haRollbackHandle, error) { return makeHandle("/tmp/proxsave/ha-expired.marker", nowRestore().Add(-defaultHARollbackTimeout-time.Second)), nil } @@ -759,6 +783,7 @@ func TestMaybeApplyPVEHAWithUI_CommitOutcomes(t *testing.T) { }) t.Run("arm rollback failure is wrapped", func(t *testing.T) { + env := newEnv(t) haArmRollback = func(ctx context.Context, logger *logging.Logger, backupPath string, timeout time.Duration, workDir string) (*haRollbackHandle, error) { return nil, fmt.Errorf("boom") } diff --git a/internal/orchestrator/restore_tui.go b/internal/orchestrator/restore_tui.go index 48618c2..a33ee85 100644 --- a/internal/orchestrator/restore_tui.go +++ b/internal/orchestrator/restore_tui.go @@ -813,11 +813,13 @@ func promptYesNoTUIWithCountdown(ctx context.Context, logger *logging.Logger, ti page := buildRestoreWizardPage(title, configPath, buildSig, content) form.SetParentView(page) - stopCh := make(chan struct{}) - defer close(stopCh) + stopTicker := func() {} if timeout > 0 { + stopCh := make(chan struct{}) + done := make(chan struct{}) go func() { + defer close(done) ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() for { @@ -832,14 +834,25 @@ func promptYesNoTUIWithCountdown(ctx context.Context, logger *logging.Logger, ti app.Stop() return } + select { + case <-stopCh: + return + default: + } app.QueueUpdateDraw(func() { updateCountdown() }) } } }() + stopTicker = func() { + close(stopCh) + <-done + } } app.SetRoot(page, true).SetFocus(form.Form) - if err := app.RunWithContext(ctx); err != nil { + err := app.RunWithContext(ctx) + stopTicker() + if err != nil { return false, err } if timedOut { diff --git a/internal/orchestrator/restore_tui_simulation_test.go b/internal/orchestrator/restore_tui_simulation_test.go index 519dd64..bb25536 100644 --- a/internal/orchestrator/restore_tui_simulation_test.go +++ b/internal/orchestrator/restore_tui_simulation_test.go @@ -63,14 +63,18 @@ func TestShowRestorePlanTUI_CancelReturnsAborted(t *testing.T) { } func TestConfirmRestoreTUI_ConfirmedAndOverwriteReturnsTrue(t *testing.T) { + expectedCtx := context.WithValue(context.Background(), struct{}{}, "confirm-restore") restore := stubPromptYesNo(func(ctx context.Context, title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { + if ctx != expectedCtx { + t.Fatalf("stub received unexpected context: got %v want %v", ctx, expectedCtx) + } return true, nil }) defer restore() withSimApp(t, []tcell.Key{tcell.KeyEnter}) - ok, err := confirmRestoreTUI(context.Background(), "/tmp/config.env", "sig") + ok, err := confirmRestoreTUI(expectedCtx, "/tmp/config.env", "sig") if err != nil { t.Fatalf("confirmRestoreTUI error: %v", err) } @@ -80,14 +84,18 @@ func TestConfirmRestoreTUI_ConfirmedAndOverwriteReturnsTrue(t *testing.T) { } func TestConfirmRestoreTUI_OverwriteDeclinedReturnsFalse(t *testing.T) { + expectedCtx := context.WithValue(context.Background(), struct{}{}, "overwrite-declined") restore := stubPromptYesNo(func(ctx context.Context, title, configPath, buildSig, message, yesLabel, noLabel string) (bool, error) { + if ctx != expectedCtx { + t.Fatalf("stub received unexpected context: got %v want %v", ctx, expectedCtx) + } return false, nil }) defer restore() withSimApp(t, []tcell.Key{tcell.KeyEnter}) - ok, err := confirmRestoreTUI(context.Background(), "/tmp/config.env", "sig") + ok, err := confirmRestoreTUI(expectedCtx, "/tmp/config.env", "sig") if err != nil { t.Fatalf("confirmRestoreTUI error: %v", err) } diff --git a/internal/orchestrator/workflow_ui_cli.go b/internal/orchestrator/workflow_ui_cli.go index ec73455..9c069ec 100644 --- a/internal/orchestrator/workflow_ui_cli.go +++ b/internal/orchestrator/workflow_ui_cli.go @@ -135,7 +135,7 @@ func (u *cliWorkflowUI) ResolveExistingPath(ctx context.Context, path, descripti } trimmed, err := validateDistinctNewPathInput(newPath, current) if err != nil { - fmt.Println(err.Error()) + fmt.Fprintln(os.Stderr, err.Error()) continue } return PathDecisionNewPath, filepath.Clean(trimmed), nil diff --git a/internal/orchestrator/workflow_ui_cli_test.go b/internal/orchestrator/workflow_ui_cli_test.go index bad0856..11810c0 100644 --- a/internal/orchestrator/workflow_ui_cli_test.go +++ b/internal/orchestrator/workflow_ui_cli_test.go @@ -19,9 +19,6 @@ func captureCLIStdout(t *testing.T, fn func()) (captured string) { t.Fatalf("os.Pipe: %v", err) } os.Stdout = w - t.Cleanup(func() { - os.Stdout = oldStdout - }) var buf bytes.Buffer done := make(chan struct{}) @@ -41,6 +38,34 @@ func captureCLIStdout(t *testing.T, fn func()) (captured string) { return } +func captureCLIStderr(t *testing.T, fn func()) (captured string) { + t.Helper() + + oldStderr := os.Stderr + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("os.Pipe: %v", err) + } + os.Stderr = w + + var buf bytes.Buffer + done := make(chan struct{}) + go func() { + _, _ = io.Copy(&buf, r) + close(done) + }() + defer func() { + os.Stderr = oldStderr + _ = w.Close() + <-done + _ = r.Close() + captured = buf.String() + }() + + fn() + return +} + func TestCLIWorkflowUIResolveExistingPath_RejectsEquivalentNormalizedPath(t *testing.T) { reader := bufio.NewReader(strings.NewReader("2\n/tmp/out/\n2\n /tmp/out/../alt \n")) ui := newCLIWorkflowUI(reader, nil) @@ -50,7 +75,7 @@ func TestCLIWorkflowUIResolveExistingPath_RejectsEquivalentNormalizedPath(t *tes newPath string err error ) - output := captureCLIStdout(t, func() { + stderrOutput := captureCLIStderr(t, func() { decision, newPath, err = ui.ResolveExistingPath(context.Background(), "/tmp/out", "archive", "") }) if err != nil { @@ -62,8 +87,8 @@ func TestCLIWorkflowUIResolveExistingPath_RejectsEquivalentNormalizedPath(t *tes if newPath != "/tmp/alt" { t.Fatalf("newPath=%q, want %q", newPath, "/tmp/alt") } - if !strings.Contains(output, "path must be different from existing path") { - t.Fatalf("expected validation message in output, got %q", output) + if !strings.Contains(stderrOutput, "path must be different from existing path") { + t.Fatalf("expected validation message in stderr, got %q", stderrOutput) } } diff --git a/internal/storage/backup_files.go b/internal/storage/backup_files.go index 2aaa0c4..f7e5e8b 100644 --- a/internal/storage/backup_files.go +++ b/internal/storage/backup_files.go @@ -13,17 +13,27 @@ func trimBundleSuffix(path string) (string, bool) { return path, false } +func normalizeBundleBasePath(path string) string { + for { + trimmed, ok := trimBundleSuffix(path) + if !ok { + return path + } + path = trimmed + } +} + // bundlePathFor returns the canonical bundle path for either a raw archive path // or a path that already points to a bundle. func bundlePathFor(path string) string { - base, _ := trimBundleSuffix(path) - return base + bundleSuffix + return normalizeBundleBasePath(path) + bundleSuffix } // buildBackupCandidatePaths returns the list of files that belong to a backup. // When includeBundle is true, both the bundle and the legacy single-file layout // are included so retention can clean up either form. func buildBackupCandidatePaths(base string, includeBundle bool) []string { + base = normalizeBundleBasePath(base) seen := make(map[string]struct{}) add := func(path string) bool { if path == "" { diff --git a/internal/storage/backup_files_test.go b/internal/storage/backup_files_test.go new file mode 100644 index 0000000..a307ddc --- /dev/null +++ b/internal/storage/backup_files_test.go @@ -0,0 +1,56 @@ +package storage + +import ( + "reflect" + "testing" +) + +func TestBundlePathForNormalizesRepeatedBundleSuffixes(t *testing.T) { + got := bundlePathFor("backup.tar.zst.bundle.tar.bundle.tar") + want := "backup.tar.zst.bundle.tar" + if got != want { + t.Fatalf("bundlePathFor() = %q, want %q", got, want) + } +} + +func TestBuildBackupCandidatePathsNormalizesBundleInput(t *testing.T) { + tests := []struct { + name string + base string + includeBundle bool + want []string + }{ + { + name: "bundle included", + base: "backup.tar.zst.bundle.tar.bundle.tar", + includeBundle: true, + want: []string{ + "backup.tar.zst.bundle.tar", + "backup.tar.zst", + "backup.tar.zst.sha256", + "backup.tar.zst.metadata", + "backup.tar.zst.metadata.sha256", + }, + }, + { + name: "legacy only", + base: "backup.tar.zst.bundle.tar", + includeBundle: false, + want: []string{ + "backup.tar.zst", + "backup.tar.zst.sha256", + "backup.tar.zst.metadata", + "backup.tar.zst.metadata.sha256", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := buildBackupCandidatePaths(tt.base, tt.includeBundle) + if !reflect.DeepEqual(got, tt.want) { + t.Fatalf("buildBackupCandidatePaths() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/storage/cloud.go b/internal/storage/cloud.go index 17e07ed..8d306ad 100644 --- a/internal/storage/cloud.go +++ b/internal/storage/cloud.go @@ -377,7 +377,7 @@ func (c *CloudStorage) checkRemoteAccessible(ctx context.Context) error { waitTime := cloudRetryBackoff(attempt) c.logger.Debug("Cloud remote check attempt %d/%d failed: %v (retrying in %v)", attempt, maxAttempts, err, waitTime) - if err := c.waitForRetry(timeoutCtx, waitTime); err != nil { + if err := c.callWaitForRetry(timeoutCtx, waitTime); err != nil { if parentErr := ctx.Err(); parentErr != nil { return parentErr } @@ -797,7 +797,7 @@ func (c *CloudStorage) uploadWithRetry(ctx context.Context, localFile, remoteFil if attempt < c.config.RcloneRetries { waitTime := cloudRetryBackoff(attempt) c.logger.Debug("Waiting %v before retry...", waitTime) - if err := c.waitForRetry(ctx, waitTime); err != nil { + if err := c.callWaitForRetry(ctx, waitTime); err != nil { return err } } @@ -1765,6 +1765,13 @@ func (c *CloudStorage) exec(ctx context.Context, name string, args ...string) ([ return defaultExecCommand(ctx, name, args...) } +func (c *CloudStorage) callWaitForRetry(ctx context.Context, d time.Duration) error { + if c.waitForRetry != nil { + return c.waitForRetry(ctx, d) + } + return nil +} + func defaultExecCommand(ctx context.Context, name string, args ...string) ([]byte, error) { cmd := exec.CommandContext(ctx, name, args...) return cmd.CombinedOutput() diff --git a/internal/storage/cloud_test.go b/internal/storage/cloud_test.go index c938a2b..e9fc948 100644 --- a/internal/storage/cloud_test.go +++ b/internal/storage/cloud_test.go @@ -316,6 +316,33 @@ func TestCloudStorageUploadWithRetryEventuallySucceeds(t *testing.T) { } } +func TestCloudStorageUploadWithRetryNilWaitForRetryIsNoOp(t *testing.T) { + cfg := &config.Config{ + CloudEnabled: true, + CloudRemote: "remote", + RcloneRetries: 2, + RcloneTimeoutOperation: 5, + } + cs := newCloudStorageForTest(cfg) + + queue := &commandQueue{ + t: t, + queue: []queuedResponse{ + {name: "rclone", err: errors.New("copy failed")}, + {name: "rclone", out: "ok"}, + }, + } + cs.execCommand = queue.exec + cs.waitForRetry = nil + + if err := cs.uploadWithRetry(context.Background(), "/tmp/local.tar", "remote:local.tar"); err != nil { + t.Fatalf("uploadWithRetry() error = %v", err) + } + if len(queue.calls) != 2 { + t.Fatalf("expected 2 upload attempts, got %d", len(queue.calls)) + } +} + func TestCloudStorageUploadWithRetryUsesCappedBackoff(t *testing.T) { cfg := &config.Config{ CloudEnabled: true, @@ -1251,6 +1278,33 @@ func TestCloudStorageCheckWithNetworkErrorNoFallback(t *testing.T) { } } +func TestCloudStorageCheckRemoteAccessibleNilWaitForRetryIsNoOp(t *testing.T) { + cfg := &config.Config{ + CloudEnabled: true, + CloudRemote: "remote", + CloudWriteHealthCheck: false, + RcloneTimeoutConnection: 30, + } + cs := newCloudStorageForTest(cfg) + + queue := &commandQueue{ + t: t, + queue: []queuedResponse{ + {name: "rclone", err: errors.New("exit 1"), out: "dial tcp: connection refused"}, + {name: "rclone", out: "ok"}, + }, + } + cs.execCommand = queue.exec + cs.waitForRetry = nil + + if err := cs.checkRemoteAccessible(context.Background()); err != nil { + t.Fatalf("checkRemoteAccessible() error = %v", err) + } + if len(queue.calls) != 2 { + t.Fatalf("expected 2 remote check attempts, got %d", len(queue.calls)) + } +} + func TestCloudStorageCheckRemoteAccessibleReturnsContextErrorDuringBackoff(t *testing.T) { cfg := &config.Config{ CloudEnabled: true,