Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions transport/http/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ import (
"github.com/go-kit/kit/endpoint"
)

// HTTPClient is an interface that models *http.Client.
type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
}

// Client wraps a URL and provides a method that implements endpoint.Endpoint.
type Client struct {
client *http.Client
client HTTPClient
method string
tgt *url.URL
enc EncodeRequestFunc
Expand Down Expand Up @@ -54,7 +59,7 @@ type ClientOption func(*Client)

// SetClient sets the underlying HTTP client used for requests.
// By default, http.DefaultClient is used.
func SetClient(client *http.Client) ClientOption {
func SetClient(client HTTPClient) ClientOption {
return func(c *Client) { c.client = client }
}

Expand Down
44 changes: 44 additions & 0 deletions transport/http/client_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package http_test

import (
"bytes"
"context"
"io"
"io/ioutil"
Expand Down Expand Up @@ -252,6 +253,43 @@ func TestEncodeJSONRequest(t *testing.T) {
}
}

func TestSetClient(t *testing.T) {
var (
encode = func(context.Context, *http.Request, interface{}) error { return nil }
decode = func(_ context.Context, r *http.Response) (interface{}, error) {
t, err := ioutil.ReadAll(r.Body)
if err != nil {
return nil, err
}
return string(t), nil
}
)

testHttpClient := httpClientFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Request: req,
Body: ioutil.NopCloser(bytes.NewBufferString("hello, world!")),
}, nil
})

client := httptransport.NewClient(
"GET",
&url.URL{},
encode,
decode,
httptransport.SetClient(testHttpClient),
).Endpoint()

resp, err := client(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if r, ok := resp.(string); !ok || r != "hello, world!" {
t.Fatal("Expected response to be 'hello, world!' string")
}
}

func mustParse(s string) *url.URL {
u, err := url.Parse(s)
if err != nil {
Expand All @@ -265,3 +303,9 @@ type enhancedRequest struct {
}

func (e enhancedRequest) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} }

type httpClientFunc func(req *http.Request) (*http.Response, error)

func (f httpClientFunc) Do(req *http.Request) (*http.Response, error) {
return f(req)
}
4 changes: 2 additions & 2 deletions transport/http/jsonrpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

// Client wraps a JSON RPC method and provides a method that implements endpoint.Endpoint.
type Client struct {
client *http.Client
client httptransport.HTTPClient

// JSON RPC endpoint URL
tgt *url.URL
Expand Down Expand Up @@ -86,7 +86,7 @@ type ClientOption func(*Client)

// SetClient sets the underlying HTTP client used for requests.
// By default, http.DefaultClient is used.
func SetClient(client *http.Client) ClientOption {
func SetClient(client httptransport.HTTPClient) ClientOption {
return func(c *Client) { c.client = client }
}

Expand Down