diff --git a/transport/http/client.go b/transport/http/client.go index 92d3292fc..eca566300 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -12,9 +12,14 @@ import ( "github.com/go-kit/kit/endpoint" ) +// HTTPClient is an interface that models *http.Client. +type HTTPClient interface { + Do(req *http.Request) (*http.Response, error) +} + // Client wraps a URL and provides a method that implements endpoint.Endpoint. type Client struct { - client *http.Client + client HTTPClient method string tgt *url.URL enc EncodeRequestFunc @@ -54,7 +59,7 @@ type ClientOption func(*Client) // SetClient sets the underlying HTTP client used for requests. // By default, http.DefaultClient is used. -func SetClient(client *http.Client) ClientOption { +func SetClient(client HTTPClient) ClientOption { return func(c *Client) { c.client = client } } diff --git a/transport/http/client_test.go b/transport/http/client_test.go index d66381000..e31d201cc 100644 --- a/transport/http/client_test.go +++ b/transport/http/client_test.go @@ -1,6 +1,7 @@ package http_test import ( + "bytes" "context" "io" "io/ioutil" @@ -252,6 +253,43 @@ func TestEncodeJSONRequest(t *testing.T) { } } +func TestSetClient(t *testing.T) { + var ( + encode = func(context.Context, *http.Request, interface{}) error { return nil } + decode = func(_ context.Context, r *http.Response) (interface{}, error) { + t, err := ioutil.ReadAll(r.Body) + if err != nil { + return nil, err + } + return string(t), nil + } + ) + + testHttpClient := httpClientFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Request: req, + Body: ioutil.NopCloser(bytes.NewBufferString("hello, world!")), + }, nil + }) + + client := httptransport.NewClient( + "GET", + &url.URL{}, + encode, + decode, + httptransport.SetClient(testHttpClient), + ).Endpoint() + + resp, err := client(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + if r, ok := resp.(string); !ok || r != "hello, world!" { + t.Fatal("Expected response to be 'hello, world!' string") + } +} + func mustParse(s string) *url.URL { u, err := url.Parse(s) if err != nil { @@ -265,3 +303,9 @@ type enhancedRequest struct { } func (e enhancedRequest) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} } + +type httpClientFunc func(req *http.Request) (*http.Response, error) + +func (f httpClientFunc) Do(req *http.Request) (*http.Response, error) { + return f(req) +} diff --git a/transport/http/jsonrpc/client.go b/transport/http/jsonrpc/client.go index de17fd5bd..4d8fe303d 100644 --- a/transport/http/jsonrpc/client.go +++ b/transport/http/jsonrpc/client.go @@ -15,7 +15,7 @@ import ( // Client wraps a JSON RPC method and provides a method that implements endpoint.Endpoint. type Client struct { - client *http.Client + client httptransport.HTTPClient // JSON RPC endpoint URL tgt *url.URL @@ -86,7 +86,7 @@ type ClientOption func(*Client) // SetClient sets the underlying HTTP client used for requests. // By default, http.DefaultClient is used. -func SetClient(client *http.Client) ClientOption { +func SetClient(client httptransport.HTTPClient) ClientOption { return func(c *Client) { c.client = client } }