diff --git a/internal/github/github.go b/internal/github/github.go index cea2c8c..21b3b70 100644 --- a/internal/github/github.go +++ b/internal/github/github.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "math" "github.com/cli/go-gh/v2/pkg/api" graphql "github.com/cli/shurcooL-graphql" @@ -319,6 +320,11 @@ func (c *Client) FindPRDetailsForBranch(branch string) (*PRDetails, error) { // FindPRByNumber fetches a pull request by its number. func (c *Client) FindPRByNumber(number int) (*PullRequest, error) { + gqlNumber, err := toGraphQLInt(number) + if err != nil { + return nil, err + } + var query struct { Repository struct { PullRequest struct { @@ -339,7 +345,7 @@ func (c *Client) FindPRByNumber(number int) (*PullRequest, error) { variables := map[string]interface{}{ "owner": graphql.String(c.owner), "name": graphql.String(c.repo), - "number": graphql.Int(number), + "number": gqlNumber, } if err := c.gql.Query("FindPRByNumber", &query, variables); err != nil { @@ -364,6 +370,13 @@ func (c *Client) FindPRByNumber(number int) (*PullRequest, error) { }, nil } +func toGraphQLInt(n int) (graphql.Int, error) { + if n < math.MinInt32 || n > math.MaxInt32 { + return 0, fmt.Errorf("number %d is out of GraphQL Int range", n) + } + return graphql.Int(n), nil +} + type RemoteStack struct { ID int `json:"id"` PullRequests []int `json:"pull_requests"` diff --git a/internal/github/github_test.go b/internal/github/github_test.go index 4efee87..29814bc 100644 --- a/internal/github/github_test.go +++ b/internal/github/github_test.go @@ -3,6 +3,7 @@ package github import ( "testing" + graphql "github.com/cli/shurcooL-graphql" "github.com/stretchr/testify/assert" ) @@ -46,3 +47,16 @@ func TestPullRequest_IsQueued(t *testing.T) { assert.False(t, pr.IsQueued()) }) } + +func TestToGraphQLInt(t *testing.T) { + t.Run("in range", func(t *testing.T) { + got, err := toGraphQLInt(123) + assert.NoError(t, err) + assert.Equal(t, graphql.Int(123), got) + }) + + t.Run("out of range", func(t *testing.T) { + _, err := toGraphQLInt(1 << 40) + assert.Error(t, err) + }) +}