-
Notifications
You must be signed in to change notification settings - Fork 374
Expand file tree
/
Copy pathfetch.go
More file actions
225 lines (193 loc) · 7.94 KB
/
fetch.go
File metadata and controls
225 lines (193 loc) · 7.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
package cli
import (
"context"
"fmt"
"os"
"regexp"
"strings"
"time"
"github.com/github/gh-aw/pkg/console"
"github.com/github/gh-aw/pkg/logger"
"github.com/github/gh-aw/pkg/parser"
)
var remoteWorkflowLog = logger.New("cli:remote_workflow")
var resolveRefToSHAForHost = parser.ResolveRefToSHAForHost
var downloadFileFromGitHubForHost = parser.DownloadFileFromGitHubForHost
var waitBeforeSHAResolutionRetry = sleepForSHAResolutionRetry
var shaResolutionRetryDelays = []time.Duration{
1 * time.Second,
3 * time.Second,
9 * time.Second,
}
var transientHTTP5xxPattern = regexp.MustCompile(`http 5\d{2}`)
// FetchedWorkflow contains content and metadata from a directly fetched workflow file.
// This is the unified type that combines content with source information.
type FetchedWorkflow struct {
Content []byte // The raw content of the workflow file
CommitSHA string // The resolved commit SHA at the time of fetch (empty for local)
IsLocal bool // true if this is a local workflow (from filesystem)
SourcePath string // The original source path (local path or remote path)
}
// FetchWorkflowFromSourceWithContext fetches a workflow file from local disk or GitHub.
// The context is used to cancel remote ref resolution retries (for example, on Ctrl-C).
func FetchWorkflowFromSourceWithContext(ctx context.Context, spec *WorkflowSpec, verbose bool) (*FetchedWorkflow, error) {
remoteWorkflowLog.Printf("Fetching workflow from source: spec=%s", spec.String())
// Handle local workflows
if isLocalWorkflowPath(spec.WorkflowPath) {
return fetchLocalWorkflow(spec, verbose)
}
// Handle remote workflows from GitHub
return fetchRemoteWorkflow(ctx, spec, verbose)
}
// fetchLocalWorkflow reads a workflow file from the local filesystem
func fetchLocalWorkflow(spec *WorkflowSpec, verbose bool) (*FetchedWorkflow, error) {
if verbose {
fmt.Fprintln(os.Stderr, console.FormatInfoMessage("Reading local workflow: "+spec.WorkflowPath))
}
content, err := os.ReadFile(spec.WorkflowPath)
if err != nil {
return nil, fmt.Errorf("local workflow '%s' not found: %w", spec.WorkflowPath, err)
}
return &FetchedWorkflow{
Content: content,
CommitSHA: "", // Local workflows don't have a commit SHA
IsLocal: true,
SourcePath: spec.WorkflowPath,
}, nil
}
// fetchRemoteWorkflow fetches a workflow file directly from GitHub using the API
func fetchRemoteWorkflow(ctx context.Context, spec *WorkflowSpec, verbose bool) (*FetchedWorkflow, error) {
remoteWorkflowLog.Printf("Fetching remote workflow: repo=%s, path=%s, version=%s",
spec.RepoSlug, spec.WorkflowPath, spec.Version)
// Parse owner and repo from the slug
parts := strings.SplitN(spec.RepoSlug, "/", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid repository slug: %s", spec.RepoSlug)
}
owner := parts[0]
repo := parts[1]
// Determine the ref to use
ref := spec.Version
if ref == "" {
ref = "main" // Default to main branch
remoteWorkflowLog.Print("No version specified, defaulting to 'main'")
}
if verbose {
fmt.Fprintln(os.Stderr, console.FormatInfoMessage(fmt.Sprintf("Fetching %s/%s/%s@%s...", owner, repo, spec.WorkflowPath, ref)))
}
// Resolve the ref to a commit SHA for source tracking.
commitSHA, err := resolveCommitSHAWithRetries(ctx, owner, repo, ref, spec.WorkflowPath, spec.Host, verbose)
if err != nil {
return nil, err
}
if verbose {
fmt.Fprintln(os.Stderr, console.FormatInfoMessage("Resolved to commit: "+commitSHA[:7]))
}
// Download the workflow file from GitHub
content, err := downloadFileFromGitHubForHost(owner, repo, spec.WorkflowPath, ref, spec.Host)
if err != nil {
// Try with common workflow directory prefixes if the direct path fails.
// This handles short workflow names without path separators (e.g. "my-workflow.md").
if !strings.HasPrefix(spec.WorkflowPath, "workflows/") && !strings.Contains(spec.WorkflowPath, "/") {
for _, prefix := range []string{"workflows/", ".github/workflows/"} {
altPath := prefix + spec.WorkflowPath
if !strings.HasSuffix(altPath, ".md") {
altPath += ".md"
}
remoteWorkflowLog.Printf("Direct path failed, trying: %s", altPath)
if altContent, altErr := downloadFileFromGitHubForHost(owner, repo, altPath, ref, spec.Host); altErr == nil {
return &FetchedWorkflow{
Content: altContent,
CommitSHA: commitSHA,
IsLocal: false,
SourcePath: altPath,
}, nil
}
}
}
return nil, fmt.Errorf("failed to download workflow from %s/%s/%s@%s: %w", owner, repo, spec.WorkflowPath, ref, err)
}
if verbose {
fmt.Fprintln(os.Stderr, console.FormatSuccessMessage(fmt.Sprintf("Downloaded workflow (%d bytes)", len(content))))
}
return &FetchedWorkflow{
Content: content,
CommitSHA: commitSHA,
IsLocal: false,
SourcePath: spec.WorkflowPath,
}, nil
}
func resolveCommitSHAWithRetries(ctx context.Context, owner, repo, ref, workflowPath, host string, verbose bool) (string, error) {
attempts := len(shaResolutionRetryDelays) + 1
var lastErr error
for attempt := 1; attempt <= attempts; attempt++ {
commitSHA, err := resolveRefToSHAForHost(owner, repo, ref, host)
if err == nil {
remoteWorkflowLog.Printf("Resolved ref %s to SHA: %s", ref, commitSHA)
return commitSHA, nil
}
lastErr = err
remoteWorkflowLog.Printf("Failed to resolve ref %s to SHA (attempt %d/%d): %v", ref, attempt, attempts, err)
if !isTransientSHAResolutionError(err) {
retryCommand := fmt.Sprintf("gh aw add %s/%s/%s@<40-char-sha>", owner, repo, workflowPath)
return "", fmt.Errorf(
"failed to resolve '%s' to commit SHA for '%s/%s'. Expected the GitHub API to return a commit SHA for the ref. Try: %s: %w",
ref, owner, repo, retryCommand, err,
)
}
if attempt < attempts {
delay := shaResolutionRetryDelays[attempt-1]
if verbose {
fmt.Fprintln(os.Stderr, console.FormatWarningMessage(
fmt.Sprintf("Transient SHA resolution failure for '%s' (attempt %d/%d). Retrying in %s...", ref, attempt, attempts, delay),
))
}
if waitErr := waitBeforeSHAResolutionRetry(ctx, delay); waitErr != nil {
retryCommand := fmt.Sprintf("gh aw add %s/%s/%s@<40-char-sha>", owner, repo, workflowPath)
return "", fmt.Errorf(
"failed to resolve '%s' to commit SHA because retry wait was cancelled. Expected the GitHub API to return a commit SHA for the ref. Try: %s: %w",
ref, retryCommand, waitErr,
)
}
}
}
retryCommand := fmt.Sprintf("gh aw add %s/%s/%s@<40-char-sha>", owner, repo, workflowPath)
return "", fmt.Errorf(
"failed to resolve '%s' to commit SHA after %d retries for '%s/%s'. Expected the GitHub API to return a commit SHA for the ref. Check rate limits or try: %s: %w",
ref, len(shaResolutionRetryDelays), owner, repo, retryCommand, lastErr,
)
}
// sleepForSHAResolutionRetry waits for the retry delay or context cancellation.
// It returns ctx.Err() when the context is cancelled before the delay elapses,
// otherwise nil when the delay completes normally.
func sleepForSHAResolutionRetry(ctx context.Context, delay time.Duration) error {
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
// isTransientSHAResolutionError returns true when the ref-to-SHA failure appears
// transient and worth retrying (rate limits, network/timeout failures, or HTTP 5xx).
// All other errors are treated as permanent and fail immediately.
func isTransientSHAResolutionError(err error) bool {
if err == nil {
return false
}
errorText := strings.ToLower(err.Error())
if strings.Contains(errorText, "http 429") ||
strings.Contains(errorText, "rate limit") ||
strings.Contains(errorText, "timeout") ||
strings.Contains(errorText, "timed out") ||
strings.Contains(errorText, "context deadline exceeded") ||
strings.Contains(errorText, "temporary") ||
strings.Contains(errorText, "connection reset") ||
strings.Contains(errorText, "connection refused") ||
strings.Contains(errorText, "eof") {
return true
}
return transientHTTP5xxPattern.MatchString(errorText)
}