@@ -5,8 +5,11 @@ import (
55 "fmt"
66 "io"
77 "log"
8+ "net/http"
9+ "net/url"
810 "os"
911 "os/signal"
12+ "strings"
1013 "syscall"
1114
1215 "github.com/github/github-mcp-server/pkg/github"
@@ -15,6 +18,7 @@ import (
1518 gogithub "github.com/google/go-github/v69/github"
1619 "github.com/mark3labs/mcp-go/mcp"
1720 "github.com/mark3labs/mcp-go/server"
21+ "github.com/shurcooL/githubv4"
1822 "github.com/sirupsen/logrus"
1923)
2024
@@ -44,25 +48,43 @@ type MCPServerConfig struct {
4448}
4549
4650func NewMCPServer (cfg MCPServerConfig ) (* server.MCPServer , error ) {
47- ghClient := gogithub .NewClient (nil ).WithAuthToken (cfg .Token )
48- ghClient .UserAgent = fmt .Sprintf ("github-mcp-server/%s" , cfg .Version )
49-
50- if cfg .Host != "" {
51- var err error
52- ghClient , err = ghClient .WithEnterpriseURLs (cfg .Host , cfg .Host )
53- if err != nil {
54- return nil , fmt .Errorf ("failed to create GitHub client with host: %w" , err )
55- }
51+ apiHost , err := parseAPIHost (cfg .Host )
52+ if err != nil {
53+ return nil , fmt .Errorf ("failed to parse API host: %w" , err )
5654 }
5755
56+ // Construct our REST client
57+ restClient := gogithub .NewClient (nil ).WithAuthToken (cfg .Token )
58+ restClient .UserAgent = fmt .Sprintf ("github-mcp-server/%s" , cfg .Version )
59+ restClient .BaseURL = apiHost .baseRESTURL
60+ restClient .UploadURL = apiHost .uploadURL
61+
62+ // Construct our GraphQL client
63+ // We're using NewEnterpriseClient here unconditionally as opposed to NewClient because we already
64+ // did the necessary API host parsing so that github.com will return the correct URL anyway.
65+ gqlHTTPClient := & http.Client {
66+ Transport : & bearerAuthTransport {
67+ transport : http .DefaultTransport ,
68+ token : cfg .Token ,
69+ },
70+ } // We're going to wrap the Transport later in beforeInit
71+ gqlClient := githubv4 .NewEnterpriseClient (apiHost .graphqlURL .String (), gqlHTTPClient )
72+
5873 // When a client send an initialize request, update the user agent to include the client info.
5974 beforeInit := func (_ context.Context , _ any , message * mcp.InitializeRequest ) {
60- ghClient . UserAgent = fmt .Sprintf (
75+ userAgent : = fmt .Sprintf (
6176 "github-mcp-server/%s (%s/%s)" ,
6277 cfg .Version ,
6378 message .Params .ClientInfo .Name ,
6479 message .Params .ClientInfo .Version ,
6580 )
81+
82+ restClient .UserAgent = userAgent
83+
84+ gqlHTTPClient .Transport = & userAgentTransport {
85+ transport : gqlHTTPClient .Transport ,
86+ agent : userAgent ,
87+ }
6688 }
6789
6890 hooks := & server.Hooks {
@@ -83,14 +105,19 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) {
83105 }
84106
85107 getClient := func (_ context.Context ) (* gogithub.Client , error ) {
86- return ghClient , nil // closing over client
108+ return restClient , nil // closing over client
109+ }
110+
111+ getGQLClient := func (_ context.Context ) (* githubv4.Client , error ) {
112+ return gqlClient , nil // closing over client
87113 }
88114
89115 // Create default toolsets
90116 toolsets , err := github .InitToolsets (
91117 enabledToolsets ,
92118 cfg .ReadOnly ,
93119 getClient ,
120+ getGQLClient ,
94121 cfg .Translator ,
95122 )
96123 if err != nil {
@@ -213,3 +240,141 @@ func RunStdioServer(cfg StdioServerConfig) error {
213240
214241 return nil
215242}
243+
244+ type apiHost struct {
245+ baseRESTURL * url.URL
246+ graphqlURL * url.URL
247+ uploadURL * url.URL
248+ }
249+
250+ func newDotcomHost () (apiHost , error ) {
251+ baseRestURL , err := url .Parse ("https://api.github.com/" )
252+ if err != nil {
253+ return apiHost {}, fmt .Errorf ("failed to parse dotcom REST URL: %w" , err )
254+ }
255+
256+ gqlURL , err := url .Parse ("https://api.github.com/graphql" )
257+ if err != nil {
258+ return apiHost {}, fmt .Errorf ("failed to parse dotcom GraphQL URL: %w" , err )
259+ }
260+
261+ uploadURL , err := url .Parse ("https://uploads.github.com" )
262+ if err != nil {
263+ return apiHost {}, fmt .Errorf ("failed to parse dotcom Upload URL: %w" , err )
264+ }
265+
266+ return apiHost {
267+ baseRESTURL : baseRestURL ,
268+ graphqlURL : gqlURL ,
269+ uploadURL : uploadURL ,
270+ }, nil
271+ }
272+
273+ func newGHECHost (hostname string ) (apiHost , error ) {
274+ u , err := url .Parse (hostname )
275+ if err != nil {
276+ return apiHost {}, fmt .Errorf ("failed to parse GHEC URL: %w" , err )
277+ }
278+
279+ // Unsecured GHEC would be an error
280+ if u .Scheme == "http" {
281+ return apiHost {}, fmt .Errorf ("GHEC URL must be HTTPS" )
282+ }
283+
284+ restURL , err := url .Parse (fmt .Sprintf ("https://api.%s/" , u .Hostname ()))
285+ if err != nil {
286+ return apiHost {}, fmt .Errorf ("failed to parse GHEC REST URL: %w" , err )
287+ }
288+
289+ gqlURL , err := url .Parse (fmt .Sprintf ("https://api.%s/graphql" , u .Hostname ()))
290+ if err != nil {
291+ return apiHost {}, fmt .Errorf ("failed to parse GHEC GraphQL URL: %w" , err )
292+ }
293+
294+ uploadURL , err := url .Parse (fmt .Sprintf ("https://uploads.%s" , u .Hostname ()))
295+ if err != nil {
296+ return apiHost {}, fmt .Errorf ("failed to parse GHEC Upload URL: %w" , err )
297+ }
298+
299+ return apiHost {
300+ baseRESTURL : restURL ,
301+ graphqlURL : gqlURL ,
302+ uploadURL : uploadURL ,
303+ }, nil
304+ }
305+
306+ func newGHESHost (hostname string ) (apiHost , error ) {
307+ u , err := url .Parse (hostname )
308+ if err != nil {
309+ return apiHost {}, fmt .Errorf ("failed to parse GHES URL: %w" , err )
310+ }
311+
312+ restURL , err := url .Parse (fmt .Sprintf ("%s://%s/api/v3/" , u .Scheme , u .Hostname ()))
313+ if err != nil {
314+ return apiHost {}, fmt .Errorf ("failed to parse GHES REST URL: %w" , err )
315+ }
316+
317+ gqlURL , err := url .Parse (fmt .Sprintf ("%s://%s/api/graphql" , u .Scheme , u .Hostname ()))
318+ if err != nil {
319+ return apiHost {}, fmt .Errorf ("failed to parse GHES GraphQL URL: %w" , err )
320+ }
321+
322+ uploadURL , err := url .Parse (fmt .Sprintf ("%s://%s/api/uploads/" , u .Scheme , u .Hostname ()))
323+ if err != nil {
324+ return apiHost {}, fmt .Errorf ("failed to parse GHES Upload URL: %w" , err )
325+ }
326+
327+ return apiHost {
328+ baseRESTURL : restURL ,
329+ graphqlURL : gqlURL ,
330+ uploadURL : uploadURL ,
331+ }, nil
332+ }
333+
334+ // Note that this does not handle ports yet, so development environments are out.
335+ func parseAPIHost (s string ) (apiHost , error ) {
336+ if s == "" {
337+ return newDotcomHost ()
338+ }
339+
340+ url , err := url .Parse (s )
341+ if err != nil {
342+ return apiHost {}, fmt .Errorf ("could not parse host as URL: %s" , s )
343+ }
344+
345+ if url .Scheme == "" {
346+ return apiHost {}, fmt .Errorf ("host must have a scheme (http or https): %s" , s )
347+ }
348+
349+ if url .Hostname () == "github.com" {
350+ return newDotcomHost ()
351+ }
352+
353+ if strings .HasSuffix (url .Hostname (), "ghe.com" ) {
354+ return newGHECHost (s )
355+ }
356+
357+ return newGHESHost (s )
358+ }
359+
360+ type userAgentTransport struct {
361+ transport http.RoundTripper
362+ agent string
363+ }
364+
365+ func (t * userAgentTransport ) RoundTrip (req * http.Request ) (* http.Response , error ) {
366+ req = req .Clone (req .Context ())
367+ req .Header .Set ("User-Agent" , t .agent )
368+ return t .transport .RoundTrip (req )
369+ }
370+
371+ type bearerAuthTransport struct {
372+ transport http.RoundTripper
373+ token string
374+ }
375+
376+ func (t * bearerAuthTransport ) RoundTrip (req * http.Request ) (* http.Response , error ) {
377+ req = req .Clone (req .Context ())
378+ req .Header .Set ("Authorization" , "Bearer " + t .token )
379+ return t .transport .RoundTrip (req )
380+ }
0 commit comments