Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 89 additions & 16 deletions pkg/cli/update_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ import (

"github.com/cli/go-gh/v2/pkg/api"
"github.com/github/gh-aw/pkg/console"
"github.com/github/gh-aw/pkg/constants"
"github.com/github/gh-aw/pkg/logger"
"github.com/github/gh-aw/pkg/workflow"
"golang.org/x/mod/semver"
)

var updateCheckLog = logger.New("cli:update_check")
Expand Down Expand Up @@ -99,6 +101,10 @@ func isRunningAsMCPServer() bool {
var (
// getLastCheckFilePathFunc allows overriding in tests
getLastCheckFilePathFunc = getLastCheckFilePathImpl
// getLatestReleaseFunc allows overriding in tests
getLatestReleaseFunc = getLatestRelease
// getLatestAWFReleaseFunc allows overriding in tests
getLatestAWFReleaseFunc = getLatestAWFRelease
)

// getLastCheckFilePath returns the path to the last check timestamp file
Expand Down Expand Up @@ -138,7 +144,7 @@ func updateLastCheckTime() {
}
}

// checkForUpdates checks if a newer version of gh-aw is available
// checkForUpdates checks if a newer version of gh-aw or gh-aw-firewall is available.
// This function is non-blocking and ignores all errors (connectivity, API, etc.)
func checkForUpdates(noCheckUpdate bool, verbose bool) {
// Quick check if we should even attempt the update check
Expand All @@ -158,8 +164,15 @@ func checkForUpdates(noCheckUpdate bool, verbose bool) {
return
}

// Check gh-aw and gh-aw-firewall for updates concurrently
checkForGhAwUpdates(currentVersion, verbose)
checkForAWFUpdates()
}

// checkForGhAwUpdates checks if a newer version of gh-aw is available and notifies the user.
func checkForGhAwUpdates(currentVersion string, verbose bool) {
// Query GitHub API for latest release
latestVersion, err := getLatestRelease()
latestVersion, err := getLatestReleaseFunc()
if err != nil {
// Silently ignore errors - update check should never fail the command
updateCheckLog.Printf("Error checking for updates (ignoring): %v", err)
Expand All @@ -171,28 +184,21 @@ func checkForUpdates(noCheckUpdate bool, verbose bool) {
return
}

// Compare versions
if latestVersion == currentVersion {
if verbose {
updateCheckLog.Print("gh-aw is up to date")
}
return
}
// Ensure versions have 'v' prefix for semver comparison
current := ensureVPrefix(currentVersion)
latest := ensureVPrefix(latestVersion)

// Normalize versions for comparison (remove 'v' prefix)
currentVersionNormalized := strings.TrimPrefix(currentVersion, "v")
latestVersionNormalized := strings.TrimPrefix(latestVersion, "v")
cmp := semver.Compare(current, latest)

if currentVersionNormalized == latestVersionNormalized {
if cmp == 0 {
if verbose {
updateCheckLog.Print("gh-aw is up to date (version format differs)")
updateCheckLog.Print("gh-aw is up to date")
}
return
}

// Check if we're on a newer version (development/prerelease)
// Simple heuristic: if current version sorts after latest, we might be on a dev version
if currentVersionNormalized > latestVersionNormalized {
if cmp > 0 {
updateCheckLog.Printf("Current version (%s) appears newer than latest release (%s), skipping notification", currentVersion, latestVersion)
return
}
Expand Down Expand Up @@ -226,6 +232,73 @@ func getLatestRelease() (string, error) {
return release.TagName, nil
}

// getLatestAWFRelease queries GitHub API for the latest release of gh-aw-firewall
func getLatestAWFRelease() (string, error) {
updateCheckLog.Print("Querying GitHub API for latest gh-aw-firewall release...")

client, err := api.NewRESTClient(api.ClientOptions{})
if err != nil {
return "", fmt.Errorf("failed to create GitHub client: %w", err)
}

var release Release
err = client.Get("repos/github/gh-aw-firewall/releases/latest", &release)
if err != nil {
return "", fmt.Errorf("failed to query latest gh-aw-firewall release: %w", err)
}

updateCheckLog.Printf("Latest gh-aw-firewall release: %s", release.TagName)
return release.TagName, nil
}

// checkForAWFUpdates checks if a newer version of gh-aw-firewall is available
// compared to the bundled default version. Errors are silently ignored.
func checkForAWFUpdates() {
bundledVersion := string(constants.DefaultFirewallVersion)

latestVersion, err := getLatestAWFReleaseFunc()
if err != nil {
updateCheckLog.Printf("Error checking for gh-aw-firewall updates (ignoring): %v", err)
return
}

if latestVersion == "" {
updateCheckLog.Print("Could not determine latest gh-aw-firewall version")
return
}

// Ensure versions have 'v' prefix for semver comparison
bundled := ensureVPrefix(bundledVersion)
latest := ensureVPrefix(latestVersion)

cmp := semver.Compare(bundled, latest)

if cmp == 0 {
updateCheckLog.Print("gh-aw-firewall is up to date")
return
}

// If bundled version is already newer, skip
if cmp > 0 {
updateCheckLog.Printf("Bundled gh-aw-firewall (%s) appears newer than latest release (%s), skipping notification", bundledVersion, latestVersion)
return
}

// A newer AWF version is available – updating gh-aw will pick it up
updateCheckLog.Printf("Newer gh-aw-firewall available: %s (bundled: %s)", latestVersion, bundledVersion)
fmt.Fprintln(os.Stderr, console.FormatInfoMessage(fmt.Sprintf("A new version of gh-aw-firewall is available: %s (bundled: %s)", latestVersion, bundledVersion)))
fmt.Fprintln(os.Stderr, console.FormatInfoMessage("Update with: gh extension upgrade github/gh-aw"))
fmt.Fprintln(os.Stderr, "")
}

// ensureVPrefix ensures a version string starts with 'v' as required by golang.org/x/mod/semver
func ensureVPrefix(version string) string {
if !strings.HasPrefix(version, "v") {
return "v" + version
}
return version
}

// CheckForUpdatesAsync performs update check in background (best effort)
// This is called from compile command and should never block or fail the compilation
// The context can be used to cancel the update check if the program is shutting down
Expand Down
176 changes: 176 additions & 0 deletions pkg/cli/update_check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
package cli

import (
"bytes"
"context"
"os"
"path/filepath"
"strings"
"testing"
"time"

"github.com/github/gh-aw/pkg/constants"
)

func TestShouldCheckForUpdate(t *testing.T) {
Expand Down Expand Up @@ -338,3 +342,175 @@ func TestCheckForUpdatesAsync_ContextCancellation(t *testing.T) {
// Note: The check might still run if it started before cancellation,
// so we just verify no panics occurred
}

func TestCheckForGhAwUpdates_NewVersionAvailable(t *testing.T) {
origGetLatestRelease := getLatestReleaseFunc
defer func() { getLatestReleaseFunc = origGetLatestRelease }()

getLatestReleaseFunc = func() (string, error) {
return "v9.9.9", nil
}

// Capture stderr output
oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w

checkForGhAwUpdates("v1.0.0", false)

w.Close()
os.Stderr = oldStderr

var buf bytes.Buffer
_, _ = buf.ReadFrom(r)
output := buf.String()

if !strings.Contains(output, "v9.9.9") {
t.Errorf("expected output to mention v9.9.9, got: %s", output)
}
if !strings.Contains(output, "v1.0.0") {
t.Errorf("expected output to mention v1.0.0, got: %s", output)
}
if !strings.Contains(output, "gh extension upgrade github/gh-aw") {
t.Errorf("expected update command in output, got: %s", output)
}
}

func TestCheckForGhAwUpdates_AlreadyUpToDate(t *testing.T) {
origGetLatestRelease := getLatestReleaseFunc
defer func() { getLatestReleaseFunc = origGetLatestRelease }()

getLatestReleaseFunc = func() (string, error) {
return "v1.0.0", nil
}

oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w

checkForGhAwUpdates("v1.0.0", false)

w.Close()
os.Stderr = oldStderr

var buf bytes.Buffer
_, _ = buf.ReadFrom(r)
output := buf.String()

if strings.Contains(output, "gh extension upgrade") {
t.Errorf("expected no update message when already up to date, got: %s", output)
}
}

func TestCheckForGhAwUpdates_CurrentNewerThanLatest(t *testing.T) {
origGetLatestRelease := getLatestReleaseFunc
defer func() { getLatestReleaseFunc = origGetLatestRelease }()

getLatestReleaseFunc = func() (string, error) {
return "v1.0.0", nil
}

oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w

checkForGhAwUpdates("v9.9.9", false)

w.Close()
os.Stderr = oldStderr

var buf bytes.Buffer
_, _ = buf.ReadFrom(r)
output := buf.String()

if strings.Contains(output, "gh extension upgrade") {
t.Errorf("expected no update message when current is newer, got: %s", output)
}
}

func TestCheckForAWFUpdates_NewVersionAvailable(t *testing.T) {
origGetLatestAWFRelease := getLatestAWFReleaseFunc
defer func() { getLatestAWFReleaseFunc = origGetLatestAWFRelease }()

getLatestAWFReleaseFunc = func() (string, error) {
return "v9.9.9", nil
}

oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w

checkForAWFUpdates()

w.Close()
os.Stderr = oldStderr

var buf bytes.Buffer
_, _ = buf.ReadFrom(r)
output := buf.String()

bundled := string(constants.DefaultFirewallVersion)

if !strings.Contains(output, "v9.9.9") {
t.Errorf("expected output to mention v9.9.9, got: %s", output)
}
if !strings.Contains(output, bundled) {
t.Errorf("expected output to mention bundled version %s, got: %s", bundled, output)
}
if !strings.Contains(output, "gh extension upgrade github/gh-aw") {
t.Errorf("expected update command in output, got: %s", output)
}
}

func TestCheckForAWFUpdates_AlreadyUpToDate(t *testing.T) {
origGetLatestAWFRelease := getLatestAWFReleaseFunc
defer func() { getLatestAWFReleaseFunc = origGetLatestAWFRelease }()

bundled := string(constants.DefaultFirewallVersion)
getLatestAWFReleaseFunc = func() (string, error) {
return bundled, nil
}

oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w

checkForAWFUpdates()

w.Close()
os.Stderr = oldStderr

var buf bytes.Buffer
_, _ = buf.ReadFrom(r)
output := buf.String()

if strings.Contains(output, "gh extension upgrade") {
t.Errorf("expected no update message when AWF is up to date, got: %s", output)
}
}

func TestCheckForAWFUpdates_BundledNewerThanLatest(t *testing.T) {
origGetLatestAWFRelease := getLatestAWFReleaseFunc
defer func() { getLatestAWFReleaseFunc = origGetLatestAWFRelease }()

getLatestAWFReleaseFunc = func() (string, error) {
return "v0.0.1", nil
}

oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w

checkForAWFUpdates()

w.Close()
os.Stderr = oldStderr

var buf bytes.Buffer
_, _ = buf.ReadFrom(r)
output := buf.String()

if strings.Contains(output, "gh extension upgrade") {
t.Errorf("expected no update message when bundled is newer, got: %s", output)
}
}