diff --git a/github/gen-iterators.go b/github/gen-iterators.go index 5830f97ddae..a623cf1c151 100644 --- a/github/gen-iterators.go +++ b/github/gen-iterators.go @@ -162,6 +162,20 @@ var useCursorPagination = map[string]bool{ "RepositoriesService.ListHookDeliveries": true, } +// customNames provides custom names for iterator methods where the default methodName + "Iter" would be confusing. +var customNames = map[string]string{ + "RepositoriesService.GetCommit": "ListCommitFiles", + "RepositoriesService.CompareCommits": "ListCommitComparisonFiles", + "RepositoriesService.GetCombinedStatus": "ListCombinedStatus", +} + +// sliceToBeUsedForIteration identifies methods where the wrapper struct contains multiple []*T fields, +// and specifies which field should be used for iteration. +var sliceToBeUsedForIteration = map[string]string{ + "RepositoriesService.GetCommit": "Files", + "RepositoriesService.CompareCommits": "Files", +} + // customTestJSON maps method names to the JSON response they expect in tests. // This is needed for methods that internally unmarshal a wrapper struct // even though they return a slice. @@ -301,7 +315,8 @@ func (t *templateData) processMethods(f *ast.File) error { continue } - if !fd.Name.IsExported() || !strings.HasPrefix(fd.Name.Name, "List") { + methodKey := strings.TrimPrefix(typeToString(fd.Recv.List[0].Type), "*") + "." + fd.Name.Name + if !fd.Name.IsExported() || (!strings.HasPrefix(fd.Name.Name, "List") && customNames[methodKey] == "") { continue } @@ -448,6 +463,13 @@ func (t *templateData) collectMethodInfo(fd *ast.FuncDecl) (*methodInfo, bool) { }, true } +func getIterName(methodInfo *methodInfo, methodName string) string { + if customName, ok := customNames[methodInfo.RecvType+"."+methodName]; ok { + return customName + "Iter" + } + return methodName + "Iter" +} + func (t *templateData) processReturnArrayType(fd *ast.FuncDecl, sliceRet *ast.ArrayType, methodInfo *methodInfo) { testJSON, emptyReturnValue := "[]", "{}" if val, ok := customTestJSON[fd.Name.Name]; ok { @@ -467,7 +489,7 @@ func (t *templateData) processReturnArrayType(fd *ast.FuncDecl, sliceRet *ast.Ar RecvVar: methodInfo.RecvVar, ClientField: methodInfo.ClientField, MethodName: fd.Name.Name, - IterMethod: fd.Name.Name + "Iter", + IterMethod: getIterName(methodInfo, fd.Name.Name), Args: methodInfo.Args, CallArgs: methodInfo.CallArgs, TestCallArgs: methodInfo.TestCallArgs, @@ -496,8 +518,14 @@ func (t *templateData) processReturnStarExpr(fd *ast.FuncDecl, starRet *ast.Star return } - itemsField, itemsType, ok := findSinglePointerSliceField(wrapperDef) - if !ok { + var itemsField, itemsType string + if field, ok := sliceToBeUsedForIteration[methodInfo.RecvType+"."+fd.Name.Name]; ok { + itemsField = field + if itemsType, ok = wrapperDef.Fields[itemsField]; !ok || !strings.HasPrefix(itemsType, "[]*") { + logf("Skipping %v.%v: specified items field %v not found or not of type []*T in wrapper %v", methodInfo.RecvTypeRaw, fd.Name.Name, itemsField, wrapperType) + return + } + } else if itemsField, itemsType, ok = findSinglePointerSliceField(wrapperDef); !ok { logf("Skipping %v.%v: wrapper %v does not contain exactly one []*T field", methodInfo.RecvTypeRaw, fd.Name.Name, wrapperType) return } @@ -525,7 +553,7 @@ func (t *templateData) processReturnStarExpr(fd *ast.FuncDecl, starRet *ast.Star RecvVar: methodInfo.RecvVar, ClientField: methodInfo.ClientField, MethodName: fd.Name.Name, - IterMethod: fd.Name.Name + "Iter", + IterMethod: getIterName(methodInfo, fd.Name.Name), Args: methodInfo.Args, CallArgs: methodInfo.CallArgs, TestCallArgs: methodInfo.TestCallArgs, diff --git a/github/github-iterators.go b/github/github-iterators.go index f33c9edc47c..5568ca79236 100644 --- a/github/github-iterators.go +++ b/github/github-iterators.go @@ -5316,6 +5316,111 @@ func (s *ReactionsService) ListTeamDiscussionReactionsIter(ctx context.Context, } } +// ListCommitComparisonFilesIter returns an iterator that paginates through all results of CompareCommits. +func (s *RepositoriesService) ListCommitComparisonFilesIter(ctx context.Context, owner string, repo string, base string, head string, opts *ListOptions) iter.Seq2[*CommitFile, error] { + return func(yield func(*CommitFile, error) bool) { + // Create a copy of opts to avoid mutating the caller's struct + if opts == nil { + opts = &ListOptions{} + } else { + opts = Ptr(*opts) + } + + for { + results, resp, err := s.CompareCommits(ctx, owner, repo, base, head, opts) + if err != nil { + yield(nil, err) + return + } + + var iterItems []*CommitFile + if results != nil { + iterItems = results.Files + } + for _, item := range iterItems { + if !yield(item, nil) { + return + } + } + + if resp.NextPage == 0 { + break + } + opts.Page = resp.NextPage + } + } +} + +// ListCombinedStatusIter returns an iterator that paginates through all results of GetCombinedStatus. +func (s *RepositoriesService) ListCombinedStatusIter(ctx context.Context, owner string, repo string, ref string, opts *ListOptions) iter.Seq2[*RepoStatus, error] { + return func(yield func(*RepoStatus, error) bool) { + // Create a copy of opts to avoid mutating the caller's struct + if opts == nil { + opts = &ListOptions{} + } else { + opts = Ptr(*opts) + } + + for { + results, resp, err := s.GetCombinedStatus(ctx, owner, repo, ref, opts) + if err != nil { + yield(nil, err) + return + } + + var iterItems []*RepoStatus + if results != nil { + iterItems = results.Statuses + } + for _, item := range iterItems { + if !yield(item, nil) { + return + } + } + + if resp.NextPage == 0 { + break + } + opts.Page = resp.NextPage + } + } +} + +// ListCommitFilesIter returns an iterator that paginates through all results of GetCommit. +func (s *RepositoriesService) ListCommitFilesIter(ctx context.Context, owner string, repo string, sha string, opts *ListOptions) iter.Seq2[*CommitFile, error] { + return func(yield func(*CommitFile, error) bool) { + // Create a copy of opts to avoid mutating the caller's struct + if opts == nil { + opts = &ListOptions{} + } else { + opts = Ptr(*opts) + } + + for { + results, resp, err := s.GetCommit(ctx, owner, repo, sha, opts) + if err != nil { + yield(nil, err) + return + } + + var iterItems []*CommitFile + if results != nil { + iterItems = results.Files + } + for _, item := range iterItems { + if !yield(item, nil) { + return + } + } + + if resp.NextPage == 0 { + break + } + opts.Page = resp.NextPage + } + } +} + // ListIter returns an iterator that paginates through all results of List. func (s *RepositoriesService) ListIter(ctx context.Context, user string, opts *RepositoryListOptions) iter.Seq2[*Repository, error] { return func(yield func(*Repository, error) bool) { diff --git a/github/github-iterators_test.go b/github/github-iterators_test.go index c48f02bcefc..f64e39c9fde 100644 --- a/github/github-iterators_test.go +++ b/github/github-iterators_test.go @@ -11751,6 +11751,222 @@ func TestReactionsService_ListTeamDiscussionReactionsIter(t *testing.T) { } } +func TestRepositoriesService_ListCommitComparisonFilesIter(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + var callNum int + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + callNum++ + switch callNum { + case 1: + w.Header().Set("Link", `; rel="next"`) + fmt.Fprint(w, `{"files": [{},{},{}]}`) + case 2: + fmt.Fprint(w, `{"files": [{},{},{},{}]}`) + case 3: + fmt.Fprint(w, `{"files": [{},{}]}`) + case 4: + w.WriteHeader(http.StatusNotFound) + case 5: + fmt.Fprint(w, `{"files": [{},{}]}`) + } + }) + + iter := client.Repositories.ListCommitComparisonFilesIter(t.Context(), "", "", "", "", nil) + var gotItems int + for _, err := range iter { + gotItems++ + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + } + if want := 7; gotItems != want { + t.Errorf("client.Repositories.ListCommitComparisonFilesIter call 1 got %v items; want %v", gotItems, want) + } + + opts := &ListOptions{} + iter = client.Repositories.ListCommitComparisonFilesIter(t.Context(), "", "", "", "", opts) + gotItems = 0 + for _, err := range iter { + gotItems++ + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + } + if want := 2; gotItems != want { + t.Errorf("client.Repositories.ListCommitComparisonFilesIter call 2 got %v items; want %v", gotItems, want) + } + + iter = client.Repositories.ListCommitComparisonFilesIter(t.Context(), "", "", "", "", nil) + gotItems = 0 + for _, err := range iter { + gotItems++ + if err == nil { + t.Error("expected error; got nil") + } + } + if gotItems != 1 { + t.Errorf("client.Repositories.ListCommitComparisonFilesIter call 3 got %v items; want 1 (an error)", gotItems) + } + + iter = client.Repositories.ListCommitComparisonFilesIter(t.Context(), "", "", "", "", nil) + gotItems = 0 + iter(func(item *CommitFile, err error) bool { + gotItems++ + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + return false + }) + if gotItems != 1 { + t.Errorf("client.Repositories.ListCommitComparisonFilesIter call 4 got %v items; want 1 (an error)", gotItems) + } +} + +func TestRepositoriesService_ListCombinedStatusIter(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + var callNum int + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + callNum++ + switch callNum { + case 1: + w.Header().Set("Link", `; rel="next"`) + fmt.Fprint(w, `{"statuses": [{},{},{}]}`) + case 2: + fmt.Fprint(w, `{"statuses": [{},{},{},{}]}`) + case 3: + fmt.Fprint(w, `{"statuses": [{},{}]}`) + case 4: + w.WriteHeader(http.StatusNotFound) + case 5: + fmt.Fprint(w, `{"statuses": [{},{}]}`) + } + }) + + iter := client.Repositories.ListCombinedStatusIter(t.Context(), "", "", "", nil) + var gotItems int + for _, err := range iter { + gotItems++ + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + } + if want := 7; gotItems != want { + t.Errorf("client.Repositories.ListCombinedStatusIter call 1 got %v items; want %v", gotItems, want) + } + + opts := &ListOptions{} + iter = client.Repositories.ListCombinedStatusIter(t.Context(), "", "", "", opts) + gotItems = 0 + for _, err := range iter { + gotItems++ + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + } + if want := 2; gotItems != want { + t.Errorf("client.Repositories.ListCombinedStatusIter call 2 got %v items; want %v", gotItems, want) + } + + iter = client.Repositories.ListCombinedStatusIter(t.Context(), "", "", "", nil) + gotItems = 0 + for _, err := range iter { + gotItems++ + if err == nil { + t.Error("expected error; got nil") + } + } + if gotItems != 1 { + t.Errorf("client.Repositories.ListCombinedStatusIter call 3 got %v items; want 1 (an error)", gotItems) + } + + iter = client.Repositories.ListCombinedStatusIter(t.Context(), "", "", "", nil) + gotItems = 0 + iter(func(item *RepoStatus, err error) bool { + gotItems++ + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + return false + }) + if gotItems != 1 { + t.Errorf("client.Repositories.ListCombinedStatusIter call 4 got %v items; want 1 (an error)", gotItems) + } +} + +func TestRepositoriesService_ListCommitFilesIter(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + var callNum int + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + callNum++ + switch callNum { + case 1: + w.Header().Set("Link", `; rel="next"`) + fmt.Fprint(w, `{"files": [{},{},{}]}`) + case 2: + fmt.Fprint(w, `{"files": [{},{},{},{}]}`) + case 3: + fmt.Fprint(w, `{"files": [{},{}]}`) + case 4: + w.WriteHeader(http.StatusNotFound) + case 5: + fmt.Fprint(w, `{"files": [{},{}]}`) + } + }) + + iter := client.Repositories.ListCommitFilesIter(t.Context(), "", "", "", nil) + var gotItems int + for _, err := range iter { + gotItems++ + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + } + if want := 7; gotItems != want { + t.Errorf("client.Repositories.ListCommitFilesIter call 1 got %v items; want %v", gotItems, want) + } + + opts := &ListOptions{} + iter = client.Repositories.ListCommitFilesIter(t.Context(), "", "", "", opts) + gotItems = 0 + for _, err := range iter { + gotItems++ + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + } + if want := 2; gotItems != want { + t.Errorf("client.Repositories.ListCommitFilesIter call 2 got %v items; want %v", gotItems, want) + } + + iter = client.Repositories.ListCommitFilesIter(t.Context(), "", "", "", nil) + gotItems = 0 + for _, err := range iter { + gotItems++ + if err == nil { + t.Error("expected error; got nil") + } + } + if gotItems != 1 { + t.Errorf("client.Repositories.ListCommitFilesIter call 3 got %v items; want 1 (an error)", gotItems) + } + + iter = client.Repositories.ListCommitFilesIter(t.Context(), "", "", "", nil) + gotItems = 0 + iter(func(item *CommitFile, err error) bool { + gotItems++ + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + return false + }) + if gotItems != 1 { + t.Errorf("client.Repositories.ListCommitFilesIter call 4 got %v items; want 1 (an error)", gotItems) + } +} + func TestRepositoriesService_ListIter(t *testing.T) { t.Parallel() client, mux, _ := setup(t)