diff --git a/cmd/mnemonic/main.go b/cmd/mnemonic/main.go index 53131d91..603bd897 100644 --- a/cmd/mnemonic/main.go +++ b/cmd/mnemonic/main.go @@ -41,6 +41,7 @@ import ( "github.com/appsprout-dev/mnemonic/internal/backup" "github.com/appsprout-dev/mnemonic/internal/mcp" "github.com/appsprout-dev/mnemonic/internal/store" + "github.com/appsprout-dev/mnemonic/internal/updater" clipwatcher "github.com/appsprout-dev/mnemonic/internal/watcher/clipboard" fswatcher "github.com/appsprout-dev/mnemonic/internal/watcher/filesystem" @@ -176,6 +177,10 @@ func main() { diagnoseCommand(*configPath) case "generate-token": generateTokenCommand() + case "check-update": + checkUpdateCommand() + case "update": + updateCommand() case "version": fmt.Printf("mnemonic v%s\n", Version) default: @@ -284,6 +289,75 @@ func startCommand(configPath string) { } // generateTokenCommand generates a random API token and prints it. +// ============================================================================ +// Update Commands (check-update / update) +// ============================================================================ + +func checkUpdateCommand() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + fmt.Printf("Checking for updates...\n") + info, err := updater.CheckForUpdate(ctx, Version) + if err != nil { + die(exitNetwork, "Update check failed", err.Error()) + } + + if info.UpdateAvailable { + fmt.Printf("\n Current: v%s\n", info.CurrentVersion) + fmt.Printf(" Latest: %sv%s%s\n\n", colorGreen, info.LatestVersion, colorReset) + fmt.Printf(" Run %smnemonic update%s to install.\n", colorBold, colorReset) + fmt.Printf(" Release: %s\n", info.ReleaseURL) + } else { + fmt.Printf("\n %sYou're up to date!%s (v%s)\n", colorGreen, colorReset, info.CurrentVersion) + } +} + +func updateCommand() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + fmt.Printf("Checking for updates...\n") + info, err := updater.CheckForUpdate(ctx, Version) + if err != nil { + die(exitNetwork, "Update check failed", err.Error()) + } + + if !info.UpdateAvailable { + fmt.Printf("%sAlready up to date%s (v%s)\n", colorGreen, colorReset, info.CurrentVersion) + return + } + + fmt.Printf("Downloading v%s...\n", info.LatestVersion) + result, err := updater.PerformUpdate(ctx, info) + if err != nil { + die(exitGeneral, "Update failed", err.Error()) + } + + fmt.Printf("%sUpdated: v%s → v%s%s\n", colorGreen, result.PreviousVersion, result.NewVersion, colorReset) + + // Restart daemon if it's running + svc := daemon.NewServiceManager() + if svc.IsInstalled() { + running, _ := svc.IsRunning() + if running { + fmt.Printf("Restarting daemon...\n") + if err := svc.Stop(); err != nil { + fmt.Fprintf(os.Stderr, "%sWarning:%s failed to stop daemon: %v\n", colorYellow, colorReset, err) + fmt.Printf("Restart manually: mnemonic restart\n") + return + } + time.Sleep(1 * time.Second) + if err := svc.Start(); err != nil { + fmt.Fprintf(os.Stderr, "%sWarning:%s failed to start daemon: %v\n", colorYellow, colorReset, err) + fmt.Printf("Start manually: mnemonic start\n") + return + } + fmt.Printf("%sDaemon restarted with v%s%s\n", colorGreen, result.NewVersion, colorReset) + } + } +} + func generateTokenCommand() { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { @@ -1553,6 +1627,7 @@ func serveCommand(configPath string) { IngestExcludePatterns: cfg.Perception.Filesystem.ExcludePatterns, IngestMaxContentBytes: cfg.Perception.Filesystem.MaxContentBytes, Version: Version, + ServiceRestarter: daemon.NewServiceManager(), Log: log, } // Only set Consolidator if it's non-nil (avoids Go nil-interface trap) @@ -2529,6 +2604,10 @@ MONITORING COMMANDS: diagnose Run health checks (config, DB, LLM, disk) watch Live stream of daemon events +UPDATE COMMANDS: + check-update Check if a newer version is available + update Download and install the latest version + SETUP COMMANDS: install Install as system service (auto-start on login) uninstall Remove system service diff --git a/internal/api/routes/update.go b/internal/api/routes/update.go new file mode 100644 index 00000000..e28e869d --- /dev/null +++ b/internal/api/routes/update.go @@ -0,0 +1,134 @@ +package routes + +import ( + "context" + "log/slog" + "net/http" + "time" + + "github.com/appsprout-dev/mnemonic/internal/updater" +) + +// UpdateCheckResponse is the JSON response for the update check endpoint. +type UpdateCheckResponse struct { + CurrentVersion string `json:"current_version"` + LatestVersion string `json:"latest_version"` + UpdateAvailable bool `json:"update_available"` + ReleaseURL string `json:"release_url"` +} + +// UpdateResponse is the JSON response for the update endpoint. +type UpdateResponse struct { + Status string `json:"status"` + PreviousVersion string `json:"previous_version,omitempty"` + NewVersion string `json:"new_version,omitempty"` + RestartPending bool `json:"restart_pending"` + Message string `json:"message,omitempty"` +} + +// ServiceRestarter can stop and start the daemon service. +// If nil is passed to HandleUpdate, the handler will still perform the update +// but cannot restart the daemon automatically. +type ServiceRestarter interface { + IsInstalled() bool + Stop() error + Start() error +} + +// HandleUpdateCheck returns an HTTP handler that checks for available updates +// by querying the GitHub Releases API. No authentication required. +func HandleUpdateCheck(version string, log *slog.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + log.Debug("update check requested") + + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + + info, err := updater.CheckForUpdate(ctx, version) + if err != nil { + log.Error("update check failed", "error", err) + writeError(w, http.StatusBadGateway, "failed to check for updates: "+err.Error(), "UPDATE_CHECK_ERROR") + return + } + + resp := UpdateCheckResponse{ + CurrentVersion: info.CurrentVersion, + LatestVersion: info.LatestVersion, + UpdateAvailable: info.UpdateAvailable, + ReleaseURL: info.ReleaseURL, + } + + log.Info("update check completed", "current", info.CurrentVersion, "latest", info.LatestVersion, "available", info.UpdateAvailable) + writeJSON(w, http.StatusOK, resp) + } +} + +// HandleUpdate returns an HTTP handler that downloads and installs an available update. +// If svc is non-nil and installed, the daemon will be restarted after the update. +func HandleUpdate(version string, svc ServiceRestarter, log *slog.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + log.Info("update requested via API") + + // Use a generous timeout for download (5 minutes) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + info, err := updater.CheckForUpdate(ctx, version) + if err != nil { + log.Error("update check failed", "error", err) + writeError(w, http.StatusBadGateway, "failed to check for updates: "+err.Error(), "UPDATE_CHECK_ERROR") + return + } + + if !info.UpdateAvailable { + resp := UpdateResponse{ + Status: "up_to_date", + Message: "already running the latest version", + } + writeJSON(w, http.StatusOK, resp) + return + } + + result, err := updater.PerformUpdate(ctx, info) + if err != nil { + log.Error("update failed", "error", err) + writeError(w, http.StatusInternalServerError, "update failed: "+err.Error(), "UPDATE_ERROR") + return + } + + log.Info("update installed", "previous", result.PreviousVersion, "new", result.NewVersion, "binary", result.BinaryPath) + + // Determine if we can restart + canRestart := svc != nil && svc.IsInstalled() + + resp := UpdateResponse{ + Status: "updated", + PreviousVersion: result.PreviousVersion, + NewVersion: result.NewVersion, + RestartPending: canRestart, + } + + if !canRestart { + resp.Message = "update installed — restart the daemon manually to use the new version" + } + + // Send response before restarting + writeJSON(w, http.StatusOK, resp) + + // Restart the daemon in the background if possible + if canRestart { + go func() { + time.Sleep(500 * time.Millisecond) + log.Info("restarting daemon after update") + if err := svc.Stop(); err != nil { + log.Error("failed to stop daemon for restart", "error", err) + return + } + time.Sleep(1 * time.Second) + if err := svc.Start(); err != nil { + log.Error("failed to start daemon after update", "error", err) + } + }() + } + } +} diff --git a/internal/api/server.go b/internal/api/server.go index 6e8f871e..37f3e13e 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -37,6 +37,7 @@ type ServerDeps struct { IngestExcludePatterns []string IngestMaxContentBytes int Version string + ServiceRestarter routes.ServiceRestarter // can be nil if not installed as service Log *slog.Logger } @@ -78,6 +79,10 @@ func (s *Server) registerRoutes() { s.mux.HandleFunc("GET /api/v1/health", routes.HandleHealth(s.deps.Store, s.deps.LLM, s.deps.Version, s.deps.Log)) s.mux.HandleFunc("GET /api/v1/stats", routes.HandleStats(s.deps.Store, s.deps.Log)) + // Self-update + s.mux.HandleFunc("GET /api/v1/system/update-check", routes.HandleUpdateCheck(s.deps.Version, s.deps.Log)) + s.mux.HandleFunc("POST /api/v1/system/update", routes.HandleUpdate(s.deps.Version, s.deps.ServiceRestarter, s.deps.Log)) + // Memory CRUD s.mux.HandleFunc("POST /api/v1/memories", routes.HandleCreateMemory(s.deps.Store, s.deps.Bus, s.deps.Log)) s.mux.HandleFunc("GET /api/v1/memories", routes.HandleListMemories(s.deps.Store, s.deps.Log)) diff --git a/internal/updater/updater.go b/internal/updater/updater.go new file mode 100644 index 00000000..ad458143 --- /dev/null +++ b/internal/updater/updater.go @@ -0,0 +1,345 @@ +package updater + +import ( + "archive/tar" + "compress/gzip" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" +) + +const ( + githubOwner = "appsprout-dev" + githubRepo = "mnemonic" + githubAPI = "https://api.github.com" +) + +// UpdateInfo holds the result of a version check against GitHub Releases. +type UpdateInfo struct { + CurrentVersion string `json:"current_version"` + LatestVersion string `json:"latest_version"` + UpdateAvailable bool `json:"update_available"` + ReleaseURL string `json:"release_url"` + AssetURL string `json:"-"` + ChecksumsURL string `json:"-"` +} + +// UpdateResult holds the result of a completed update. +type UpdateResult struct { + PreviousVersion string `json:"previous_version"` + NewVersion string `json:"new_version"` + BinaryPath string `json:"binary_path"` +} + +// githubRelease is the subset of the GitHub release API response we need. +type githubRelease struct { + TagName string `json:"tag_name"` + HTMLURL string `json:"html_url"` + Assets []githubAsset `json:"assets"` +} + +// githubAsset is the subset of the GitHub release asset API response we need. +type githubAsset struct { + Name string `json:"name"` + BrowserDownloadURL string `json:"browser_download_url"` +} + +// CheckForUpdate checks the GitHub Releases API for a newer version. +// No authentication is required for public repositories. +func CheckForUpdate(ctx context.Context, currentVersion string) (*UpdateInfo, error) { + url := fmt.Sprintf("%s/repos/%s/%s/releases/latest", githubAPI, githubOwner, githubRepo) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Accept", "application/vnd.github+json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("fetching latest release: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusForbidden { + return nil, fmt.Errorf("GitHub API rate limit exceeded — try again later") + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("GitHub API returned status %d", resp.StatusCode) + } + + var release githubRelease + if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { + return nil, fmt.Errorf("decoding release response: %w", err) + } + + latestVersion := strings.TrimPrefix(release.TagName, "v") + + // Find the asset for this platform + assetName := fmt.Sprintf("mnemonic_%s_%s_%s.tar.gz", latestVersion, runtime.GOOS, runtime.GOARCH) + var assetURL, checksumsURL string + for _, a := range release.Assets { + switch a.Name { + case assetName: + assetURL = a.BrowserDownloadURL + case "checksums.txt": + checksumsURL = a.BrowserDownloadURL + } + } + + info := &UpdateInfo{ + CurrentVersion: currentVersion, + LatestVersion: latestVersion, + UpdateAvailable: compareVersions(latestVersion, currentVersion) > 0, + ReleaseURL: release.HTMLURL, + AssetURL: assetURL, + ChecksumsURL: checksumsURL, + } + + return info, nil +} + +// PerformUpdate downloads and installs the update described by info. +// It downloads the archive, verifies its checksum, extracts the binary, +// and atomically replaces the current binary. +func PerformUpdate(ctx context.Context, info *UpdateInfo) (*UpdateResult, error) { + if !info.UpdateAvailable { + return nil, fmt.Errorf("no update available") + } + if info.AssetURL == "" { + return nil, fmt.Errorf("no release asset found for %s/%s — download manually from %s", runtime.GOOS, runtime.GOARCH, info.ReleaseURL) + } + + // Resolve the current binary path + execPath, err := os.Executable() + if err != nil { + return nil, fmt.Errorf("resolving executable path: %w", err) + } + execPath, err = filepath.EvalSymlinks(execPath) + if err != nil { + return nil, fmt.Errorf("resolving symlinks: %w", err) + } + + execDir := filepath.Dir(execPath) + archivePath := filepath.Join(execDir, ".mnemonic.update.tar.gz") + newBinaryPath := filepath.Join(execDir, ".mnemonic.update.tmp") + + // Clean up temp files on failure + defer func() { + _ = os.Remove(archivePath) + _ = os.Remove(newBinaryPath) + }() + + // Download the archive + if err := downloadFile(ctx, info.AssetURL, archivePath); err != nil { + return nil, fmt.Errorf("downloading update: %w", err) + } + + // Verify checksum if available + if info.ChecksumsURL != "" { + assetName := fmt.Sprintf("mnemonic_%s_%s_%s.tar.gz", info.LatestVersion, runtime.GOOS, runtime.GOARCH) + if err := verifyChecksum(ctx, archivePath, info.ChecksumsURL, assetName); err != nil { + return nil, fmt.Errorf("checksum verification failed: %w", err) + } + } + + // Extract the binary from the archive + if err := extractBinary(archivePath, newBinaryPath); err != nil { + return nil, fmt.Errorf("extracting binary: %w", err) + } + + // Make the new binary executable + if err := os.Chmod(newBinaryPath, 0755); err != nil { + return nil, fmt.Errorf("setting permissions: %w", err) + } + + // Atomic replace: rename over the current binary + if err := os.Rename(newBinaryPath, execPath); err != nil { + // On permission error, give the user a helpful hint + if os.IsPermission(err) { + return nil, fmt.Errorf("permission denied replacing %s — try running with sudo, or if installed via Homebrew use: brew upgrade appsprout-dev/tap/mnemonic", execPath) + } + return nil, fmt.Errorf("replacing binary: %w", err) + } + + return &UpdateResult{ + PreviousVersion: info.CurrentVersion, + NewVersion: info.LatestVersion, + BinaryPath: execPath, + }, nil +} + +// compareVersions compares two semver version strings (MAJOR.MINOR.PATCH). +// Returns -1 if a < b, 0 if a == b, 1 if a > b. +func compareVersions(a, b string) int { + aParts := parseVersion(a) + bParts := parseVersion(b) + + for i := range 3 { + if aParts[i] < bParts[i] { + return -1 + } + if aParts[i] > bParts[i] { + return 1 + } + } + return 0 +} + +// parseVersion splits a version string into [major, minor, patch]. +// Invalid parts default to 0. +func parseVersion(v string) [3]int { + v = strings.TrimPrefix(v, "v") + parts := strings.SplitN(v, ".", 3) + var result [3]int + for i := range min(len(parts), 3) { + n, err := strconv.Atoi(parts[i]) + if err == nil { + result[i] = n + } + } + return result +} + +// downloadFile downloads a URL to a local file path. +func downloadFile(ctx context.Context, url, dest string) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("creating request: %w", err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("downloading %s: %w", url, err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("download returned status %d", resp.StatusCode) + } + + f, err := os.Create(dest) + if err != nil { + return fmt.Errorf("creating file %s: %w", dest, err) + } + + if _, err := io.Copy(f, resp.Body); err != nil { + _ = f.Close() + return fmt.Errorf("writing file: %w", err) + } + + return f.Close() +} + +// verifyChecksum downloads checksums.txt and verifies the archive's SHA256. +func verifyChecksum(ctx context.Context, archivePath, checksumsURL, expectedName string) error { + // Download checksums.txt + req, err := http.NewRequestWithContext(ctx, http.MethodGet, checksumsURL, nil) + if err != nil { + return fmt.Errorf("creating checksums request: %w", err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("downloading checksums: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("checksums download returned status %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("reading checksums: %w", err) + } + + // Find the line matching our asset + var expectedHash string + for line := range strings.SplitSeq(string(body), "\n") { + fields := strings.Fields(line) + if len(fields) == 2 && fields[1] == expectedName { + expectedHash = fields[0] + break + } + } + if expectedHash == "" { + return fmt.Errorf("no checksum found for %s in checksums.txt", expectedName) + } + + // Compute SHA256 of the downloaded archive + f, err := os.Open(archivePath) + if err != nil { + return fmt.Errorf("opening archive for checksum: %w", err) + } + defer func() { _ = f.Close() }() + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return fmt.Errorf("computing checksum: %w", err) + } + actualHash := hex.EncodeToString(h.Sum(nil)) + + if actualHash != expectedHash { + return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedHash, actualHash) + } + + return nil +} + +// extractBinary extracts the "mnemonic" binary from a tar.gz archive. +func extractBinary(archivePath, destPath string) error { + f, err := os.Open(archivePath) + if err != nil { + return fmt.Errorf("opening archive: %w", err) + } + defer func() { _ = f.Close() }() + + gz, err := gzip.NewReader(f) + if err != nil { + return fmt.Errorf("creating gzip reader: %w", err) + } + defer func() { _ = gz.Close() }() + + binaryName := "mnemonic" + if runtime.GOOS == "windows" { + binaryName = "mnemonic.exe" + } + + tr := tar.NewReader(gz) + for { + header, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("reading tar: %w", err) + } + + // The binary may be at the root or in a subdirectory + name := filepath.Base(header.Name) + if name == binaryName && header.Typeflag == tar.TypeReg { + out, err := os.Create(destPath) + if err != nil { + return fmt.Errorf("creating output file: %w", err) + } + // Limit copy to 500MB to prevent zip bomb attacks + if _, err := io.Copy(out, io.LimitReader(tr, 500*1024*1024)); err != nil { + _ = out.Close() + return fmt.Errorf("extracting binary: %w", err) + } + return out.Close() + } + } + + return fmt.Errorf("binary %q not found in archive", binaryName) +} diff --git a/internal/updater/updater_test.go b/internal/updater/updater_test.go new file mode 100644 index 00000000..1d866963 --- /dev/null +++ b/internal/updater/updater_test.go @@ -0,0 +1,346 @@ +package updater + +import ( + "archive/tar" + "compress/gzip" + "context" + "crypto/sha256" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "runtime" + "testing" +) + +func TestCompareVersions(t *testing.T) { + tests := []struct { + a, b string + want int + }{ + {"1.0.0", "1.0.0", 0}, + {"1.0.1", "1.0.0", 1}, + {"1.0.0", "1.0.1", -1}, + {"1.1.0", "1.0.0", 1}, + {"2.0.0", "1.9.9", 1}, + {"0.13.0", "0.12.0", 1}, + {"0.13.0", "0.13.0", 0}, + {"0.13.0", "0.14.0", -1}, + {"1.0.0", "0.99.99", 1}, + // With v prefix + {"v1.0.0", "1.0.0", 0}, + {"1.0.0", "v1.0.0", 0}, + // Partial versions + {"1.0", "1.0.0", 0}, + {"1", "1.0.0", 0}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%s_vs_%s", tt.a, tt.b), func(t *testing.T) { + got := compareVersions(tt.a, tt.b) + if got != tt.want { + t.Errorf("compareVersions(%q, %q) = %d, want %d", tt.a, tt.b, got, tt.want) + } + }) + } +} + +func TestParseVersion(t *testing.T) { + tests := []struct { + input string + want [3]int + }{ + {"1.2.3", [3]int{1, 2, 3}}, + {"v1.2.3", [3]int{1, 2, 3}}, + {"0.13.0", [3]int{0, 13, 0}}, + {"1.0", [3]int{1, 0, 0}}, + {"1", [3]int{1, 0, 0}}, + {"dev", [3]int{0, 0, 0}}, + {"", [3]int{0, 0, 0}}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := parseVersion(tt.input) + if got != tt.want { + t.Errorf("parseVersion(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestCheckForUpdate(t *testing.T) { + // Create a mock GitHub API server + release := githubRelease{ + TagName: "v0.14.0", + HTMLURL: "https://github.com/appsprout-dev/mnemonic/releases/tag/v0.14.0", + Assets: []githubAsset{ + { + Name: fmt.Sprintf("mnemonic_0.14.0_%s_%s.tar.gz", runtime.GOOS, runtime.GOARCH), + BrowserDownloadURL: "https://example.com/mnemonic.tar.gz", + }, + { + Name: "checksums.txt", + BrowserDownloadURL: "https://example.com/checksums.txt", + }, + }, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(release) + })) + defer server.Close() + + // Temporarily override the GitHub API URL by testing via the exported function + // We need to test the parsing logic, so we'll use a custom approach + t.Run("update_available", func(t *testing.T) { + info, err := checkForUpdateFromURL(context.Background(), "0.13.0", server.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !info.UpdateAvailable { + t.Error("expected update to be available") + } + if info.LatestVersion != "0.14.0" { + t.Errorf("expected latest version 0.14.0, got %s", info.LatestVersion) + } + if info.AssetURL == "" { + t.Error("expected asset URL to be set") + } + if info.ChecksumsURL == "" { + t.Error("expected checksums URL to be set") + } + }) + + t.Run("already_up_to_date", func(t *testing.T) { + info, err := checkForUpdateFromURL(context.Background(), "0.14.0", server.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if info.UpdateAvailable { + t.Error("expected no update available") + } + }) + + t.Run("newer_than_latest", func(t *testing.T) { + info, err := checkForUpdateFromURL(context.Background(), "0.15.0", server.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if info.UpdateAvailable { + t.Error("expected no update available when running newer version") + } + }) +} + +func TestCheckForUpdateRateLimit(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer server.Close() + + _, err := checkForUpdateFromURL(context.Background(), "0.13.0", server.URL) + if err == nil { + t.Fatal("expected error for rate-limited response") + } + if got := err.Error(); got != "GitHub API rate limit exceeded — try again later" { + t.Errorf("unexpected error message: %s", got) + } +} + +func TestExtractBinary(t *testing.T) { + // Create a test tar.gz archive containing a fake "mnemonic" binary + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "test.tar.gz") + destPath := filepath.Join(tmpDir, "mnemonic_extracted") + binaryContent := []byte("#!/bin/sh\necho hello\n") + + // Build the archive + f, err := os.Create(archivePath) + if err != nil { + t.Fatal(err) + } + gw := gzip.NewWriter(f) + tw := tar.NewWriter(gw) + + binaryName := "mnemonic" + if runtime.GOOS == "windows" { + binaryName = "mnemonic.exe" + } + + // Add a non-binary file first (e.g. README) + if err := tw.WriteHeader(&tar.Header{Name: "README.md", Size: 5, Mode: 0644, Typeflag: tar.TypeReg}); err != nil { + t.Fatal(err) + } + if _, err := tw.Write([]byte("hello")); err != nil { + t.Fatal(err) + } + + // Add the binary + if err := tw.WriteHeader(&tar.Header{Name: binaryName, Size: int64(len(binaryContent)), Mode: 0755, Typeflag: tar.TypeReg}); err != nil { + t.Fatal(err) + } + if _, err := tw.Write(binaryContent); err != nil { + t.Fatal(err) + } + + if err := tw.Close(); err != nil { + t.Fatal(err) + } + if err := gw.Close(); err != nil { + t.Fatal(err) + } + if err := f.Close(); err != nil { + t.Fatal(err) + } + + // Extract + if err := extractBinary(archivePath, destPath); err != nil { + t.Fatalf("extractBinary failed: %v", err) + } + + // Verify content + got, err := os.ReadFile(destPath) + if err != nil { + t.Fatal(err) + } + if string(got) != string(binaryContent) { + t.Errorf("extracted content = %q, want %q", got, binaryContent) + } +} + +func TestExtractBinaryNotFound(t *testing.T) { + // Create an archive without the mnemonic binary + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "test.tar.gz") + + f, err := os.Create(archivePath) + if err != nil { + t.Fatal(err) + } + gw := gzip.NewWriter(f) + tw := tar.NewWriter(gw) + + if err := tw.WriteHeader(&tar.Header{Name: "README.md", Size: 5, Mode: 0644, Typeflag: tar.TypeReg}); err != nil { + t.Fatal(err) + } + if _, err := tw.Write([]byte("hello")); err != nil { + t.Fatal(err) + } + + if err := tw.Close(); err != nil { + t.Fatal(err) + } + if err := gw.Close(); err != nil { + t.Fatal(err) + } + if err := f.Close(); err != nil { + t.Fatal(err) + } + + err = extractBinary(archivePath, filepath.Join(tmpDir, "out")) + if err == nil { + t.Fatal("expected error when binary not in archive") + } +} + +func TestVerifyChecksum(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.tar.gz") + testContent := []byte("test archive content") + + if err := os.WriteFile(testFile, testContent, 0644); err != nil { + t.Fatal(err) + } + + // Compute expected hash + h := sha256.Sum256(testContent) + expectedHash := fmt.Sprintf("%x", h) + + // Create a mock checksums server + checksumContent := fmt.Sprintf("%s test.tar.gz\n%s other.tar.gz\n", expectedHash, "deadbeef") + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprint(w, checksumContent) + })) + defer server.Close() + + t.Run("valid_checksum", func(t *testing.T) { + err := verifyChecksum(context.Background(), testFile, server.URL, "test.tar.gz") + if err != nil { + t.Errorf("expected no error, got: %v", err) + } + }) + + t.Run("wrong_filename", func(t *testing.T) { + err := verifyChecksum(context.Background(), testFile, server.URL, "nonexistent.tar.gz") + if err == nil { + t.Error("expected error for missing checksum entry") + } + }) + + t.Run("checksum_mismatch", func(t *testing.T) { + // Write different content to the file + if err := os.WriteFile(testFile, []byte("different content"), 0644); err != nil { + t.Fatal(err) + } + err := verifyChecksum(context.Background(), testFile, server.URL, "test.tar.gz") + if err == nil { + t.Error("expected error for checksum mismatch") + } + }) +} + +// checkForUpdateFromURL is a test helper that allows overriding the GitHub API URL. +func checkForUpdateFromURL(ctx context.Context, currentVersion, apiURL string) (*UpdateInfo, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Accept", "application/vnd.github+json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("fetching latest release: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusForbidden { + return nil, fmt.Errorf("GitHub API rate limit exceeded — try again later") + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("GitHub API returned status %d", resp.StatusCode) + } + + var release githubRelease + if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { + return nil, fmt.Errorf("decoding release response: %w", err) + } + + latestVersion := release.TagName + if len(latestVersion) > 0 && latestVersion[0] == 'v' { + latestVersion = latestVersion[1:] + } + + assetName := fmt.Sprintf("mnemonic_%s_%s_%s.tar.gz", latestVersion, runtime.GOOS, runtime.GOARCH) + var assetURL, checksumsURL string + for _, a := range release.Assets { + switch a.Name { + case assetName: + assetURL = a.BrowserDownloadURL + case "checksums.txt": + checksumsURL = a.BrowserDownloadURL + } + } + + return &UpdateInfo{ + CurrentVersion: currentVersion, + LatestVersion: latestVersion, + UpdateAvailable: compareVersions(latestVersion, currentVersion) > 0, + ReleaseURL: release.HTMLURL, + AssetURL: assetURL, + ChecksumsURL: checksumsURL, + }, nil +} diff --git a/internal/web/static/index.html b/internal/web/static/index.html index 81b683bd..b168710a 100644 --- a/internal/web/static/index.html +++ b/internal/web/static/index.html @@ -1000,6 +1000,17 @@ @keyframes toastIn { from { opacity: 0; transform: translateY(10px); } to { opacity: 1; transform: translateY(0); } } @keyframes toastOut { from { opacity: 1; } to { opacity: 0; transform: translateY(10px); } } + /* ── Update Badge ── */ + .update-badge { + display: inline-block; margin-left: 8px; padding: 2px 8px; + font-size: 0.65rem; font-weight: 600; border-radius: 10px; + background: var(--accent-green); color: var(--bg-primary); + cursor: pointer; animation: badgePulse 2s ease-in-out infinite; + } + .update-badge:hover { opacity: 0.85; } + .update-badge.updating { background: var(--text-dim); cursor: wait; animation: none; } + @keyframes badgePulse { 0%, 100% { opacity: 1; } 50% { opacity: 0.7; } } + /* ── Loading ── */ .skeleton { background: linear-gradient(90deg, var(--bg-tertiary) 25%, var(--bg-card) 50%, var(--bg-tertiary) 75%); @@ -1209,7 +1220,7 @@