diff --git a/auth/jwt/transport.go b/auth/jwt/transport.go index 6d02f7dc5..7be7db417 100644 --- a/auth/jwt/transport.go +++ b/auth/jwt/transport.go @@ -44,10 +44,10 @@ func FromHTTPContext() http.RequestFunc { // ToGRPCContext moves JWT token from grpc metadata to context. Particularly // userful for servers. -func ToGRPCContext() grpc.RequestFunc { - return func(ctx context.Context, md *metadata.MD) context.Context { +func ToGRPCContext() grpc.ServerRequestFunc { + return func(ctx context.Context, md metadata.MD) context.Context { // capital "Key" is illegal in HTTP/2. - authHeader, ok := (*md)["authorization"] + authHeader, ok := md["authorization"] if !ok { return ctx } @@ -63,7 +63,7 @@ func ToGRPCContext() grpc.RequestFunc { // FromGRPCContext moves JWT token from context to grpc metadata. Particularly // useful for clients. -func FromGRPCContext() grpc.RequestFunc { +func FromGRPCContext() grpc.ClientRequestFunc { return func(ctx context.Context, md *metadata.MD) context.Context { token, ok := ctx.Value(JWTTokenContextKey).(string) if ok { diff --git a/auth/jwt/transport_test.go b/auth/jwt/transport_test.go index 8b8922a6a..b04d76feb 100644 --- a/auth/jwt/transport_test.go +++ b/auth/jwt/transport_test.go @@ -69,7 +69,7 @@ func TestToGRPCContext(t *testing.T) { reqFunc := ToGRPCContext() // No Authorization header is passed - ctx := reqFunc(context.Background(), &md) + ctx := reqFunc(context.Background(), md) token := ctx.Value(JWTTokenContextKey) if token != nil { t.Error("Context should not contain a JWT Token") @@ -77,7 +77,7 @@ func TestToGRPCContext(t *testing.T) { // Invalid Authorization header is passed md["authorization"] = []string{fmt.Sprintf("%s", signedKey)} - ctx = reqFunc(context.Background(), &md) + ctx = reqFunc(context.Background(), md) token = ctx.Value(JWTTokenContextKey) if token != nil { t.Error("Context should not contain a JWT Token") @@ -85,7 +85,7 @@ func TestToGRPCContext(t *testing.T) { // Authorization header is correct md["authorization"] = []string{fmt.Sprintf("Bearer %s", signedKey)} - ctx = reqFunc(context.Background(), &md) + ctx = reqFunc(context.Background(), md) token, ok := ctx.Value(JWTTokenContextKey).(string) if !ok { t.Fatal("JWT Token not passed to context correctly") diff --git a/examples/addsvc/cmd/addsvc/main.go b/examples/addsvc/cmd/addsvc/main.go index 842f34bec..3fba0ada3 100644 --- a/examples/addsvc/cmd/addsvc/main.go +++ b/examples/addsvc/cmd/addsvc/main.go @@ -222,7 +222,7 @@ func main() { return } - srv := addsvc.MakeGRPCServer(ctx, endpoints, tracer, logger) + srv := addsvc.MakeGRPCServer(endpoints, tracer, logger) s := grpc.NewServer() pb.RegisterAddServer(s, srv) diff --git a/examples/addsvc/transport_grpc.go b/examples/addsvc/transport_grpc.go index 21e60bc4f..dcfc03a05 100644 --- a/examples/addsvc/transport_grpc.go +++ b/examples/addsvc/transport_grpc.go @@ -16,20 +16,18 @@ import ( ) // MakeGRPCServer makes a set of endpoints available as a gRPC AddServer. -func MakeGRPCServer(ctx context.Context, endpoints Endpoints, tracer stdopentracing.Tracer, logger log.Logger) pb.AddServer { +func MakeGRPCServer(endpoints Endpoints, tracer stdopentracing.Tracer, logger log.Logger) pb.AddServer { options := []grpctransport.ServerOption{ grpctransport.ServerErrorLogger(logger), } return &grpcServer{ sum: grpctransport.NewServer( - ctx, endpoints.SumEndpoint, DecodeGRPCSumRequest, EncodeGRPCSumResponse, append(options, grpctransport.ServerBefore(opentracing.FromGRPCRequest(tracer, "Sum", logger)))..., ), concat: grpctransport.NewServer( - ctx, endpoints.ConcatEndpoint, DecodeGRPCConcatRequest, EncodeGRPCConcatResponse, diff --git a/tracing/opentracing/grpc.go b/tracing/opentracing/grpc.go index 56eb143f5..fa4544009 100644 --- a/tracing/opentracing/grpc.go +++ b/tracing/opentracing/grpc.go @@ -32,10 +32,10 @@ func ToGRPCRequest(tracer opentracing.Tracer, logger log.Logger) func(ctx contex // `operationName` accordingly. If no trace could be found in `req`, the Span // will be a trace root. The Span is incorporated in the returned Context and // can be retrieved with opentracing.SpanFromContext(ctx). -func FromGRPCRequest(tracer opentracing.Tracer, operationName string, logger log.Logger) func(ctx context.Context, md *metadata.MD) context.Context { - return func(ctx context.Context, md *metadata.MD) context.Context { +func FromGRPCRequest(tracer opentracing.Tracer, operationName string, logger log.Logger) func(ctx context.Context, md metadata.MD) context.Context { + return func(ctx context.Context, md metadata.MD) context.Context { var span opentracing.Span - wireContext, err := tracer.Extract(opentracing.TextMap, metadataReaderWriter{md}) + wireContext, err := tracer.Extract(opentracing.TextMap, metadataReaderWriter{&md}) if err != nil && err != opentracing.ErrSpanContextNotFound { logger.Log("err", err) } diff --git a/tracing/opentracing/grpc_test.go b/tracing/opentracing/grpc_test.go index 96a834fd8..3d07a14aa 100644 --- a/tracing/opentracing/grpc_test.go +++ b/tracing/opentracing/grpc_test.go @@ -41,7 +41,7 @@ func TestTraceGRPCRequestRoundtrip(t *testing.T) { // Use FromGRPCRequest to verify that we can join with the trace given MD. fromGRPCFunc := kitot.FromGRPCRequest(tracer, "joined", logger) - joinCtx := fromGRPCFunc(afterCtx, &md) + joinCtx := fromGRPCFunc(afterCtx, md) joinedSpan := opentracing.SpanFromContext(joinCtx).(*mocktracer.MockSpan) joinedContext := joinedSpan.Context().(mocktracer.MockSpanContext) diff --git a/transport/grpc/_grpc_test/client.go b/transport/grpc/_grpc_test/client.go new file mode 100644 index 000000000..1e0c8a78e --- /dev/null +++ b/transport/grpc/_grpc_test/client.go @@ -0,0 +1,50 @@ +package test + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/go-kit/kit/endpoint" + grpctransport "github.com/go-kit/kit/transport/grpc" + "github.com/go-kit/kit/transport/grpc/_grpc_test/pb" +) + +type clientBinding struct { + test endpoint.Endpoint +} + +func (c *clientBinding) Test(ctx context.Context, a string, b int64) (context.Context, string, error) { + response, err := c.test(ctx, TestRequest{A: a, B: b}) + if err != nil { + return nil, "", err + } + r := response.(*TestResponse) + return r.Ctx, r.V, nil +} + +func NewClient(cc *grpc.ClientConn) Service { + return &clientBinding{ + test: grpctransport.NewClient( + cc, + "pb.Test", + "Test", + encodeRequest, + decodeResponse, + &pb.TestResponse{}, + grpctransport.ClientBefore( + injectCorrelationID, + ), + grpctransport.ClientBefore( + displayClientRequestHeaders, + ), + grpctransport.ClientAfter( + displayClientResponseHeaders, + displayClientResponseTrailers, + ), + grpctransport.ClientAfter( + extractConsumedCorrelationID, + ), + ).Endpoint(), + } +} diff --git a/transport/grpc/_grpc_test/context_metadata.go b/transport/grpc/_grpc_test/context_metadata.go new file mode 100644 index 000000000..0769325e2 --- /dev/null +++ b/transport/grpc/_grpc_test/context_metadata.go @@ -0,0 +1,141 @@ +package test + +import ( + "context" + "fmt" + + "google.golang.org/grpc/metadata" +) + +type metaContext string + +const ( + correlationID metaContext = "correlation-id" + responseHDR metaContext = "my-response-header" + responseTRLR metaContext = "my-response-trailer" + correlationIDTRLR metaContext = "correlation-id-consumed" +) + +/* client before functions */ + +func injectCorrelationID(ctx context.Context, md *metadata.MD) context.Context { + if hdr, ok := ctx.Value(correlationID).(string); ok { + fmt.Printf("\tClient found correlationID %q in context, set metadata header\n", hdr) + (*md)[string(correlationID)] = append((*md)[string(correlationID)], hdr) + } + return ctx +} + +func displayClientRequestHeaders(ctx context.Context, md *metadata.MD) context.Context { + if len(*md) > 0 { + fmt.Println("\tClient >> Request Headers:") + for key, val := range *md { + fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) + } + } + return ctx +} + +/* server before functions */ + +func extractCorrelationID(ctx context.Context, md metadata.MD) context.Context { + if hdr, ok := md[string(correlationID)]; ok { + cID := hdr[len(hdr)-1] + ctx = context.WithValue(ctx, correlationID, cID) + fmt.Printf("\tServer received correlationID %q in metadata header, set context\n", cID) + } + return ctx +} + +func displayServerRequestHeaders(ctx context.Context, md metadata.MD) context.Context { + if len(md) > 0 { + fmt.Println("\tServer << Request Headers:") + for key, val := range md { + fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) + } + } + return ctx +} + +/* server after functions */ + +func injectResponseHeader(ctx context.Context, md *metadata.MD, _ *metadata.MD) context.Context { + *md = metadata.Join(*md, metadata.Pairs(string(responseHDR), "has-a-value")) + return ctx +} + +func displayServerResponseHeaders(ctx context.Context, md *metadata.MD, _ *metadata.MD) context.Context { + if len(*md) > 0 { + fmt.Println("\tServer >> Response Headers:") + for key, val := range *md { + fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) + } + } + return ctx +} + +func injectResponseTrailer(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context { + *md = metadata.Join(*md, metadata.Pairs(string(responseTRLR), "has-a-value-too")) + return ctx +} + +func injectConsumedCorrelationID(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context { + if hdr, ok := ctx.Value(correlationID).(string); ok { + fmt.Printf("\tServer found correlationID %q in context, set consumed trailer\n", hdr) + *md = metadata.Join(*md, metadata.Pairs(string(correlationIDTRLR), hdr)) + } + return ctx +} + +func displayServerResponseTrailers(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context { + if len(*md) > 0 { + fmt.Println("\tServer >> Response Trailers:") + for key, val := range *md { + fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) + } + } + return ctx +} + +/* client after functions */ + +func displayClientResponseHeaders(ctx context.Context, md metadata.MD, _ metadata.MD) context.Context { + if len(md) > 0 { + fmt.Println("\tClient << Response Headers:") + for key, val := range md { + fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) + } + } + return ctx +} + +func displayClientResponseTrailers(ctx context.Context, _ metadata.MD, md metadata.MD) context.Context { + if len(md) > 0 { + fmt.Println("\tClient << Response Trailers:") + for key, val := range md { + fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) + } + } + return ctx +} + +func extractConsumedCorrelationID(ctx context.Context, _ metadata.MD, md metadata.MD) context.Context { + if hdr, ok := md[string(correlationIDTRLR)]; ok { + fmt.Printf("\tClient received consumed correlationID %q in metadata trailer, set context\n", hdr[len(hdr)-1]) + ctx = context.WithValue(ctx, correlationIDTRLR, hdr[len(hdr)-1]) + } + return ctx +} + +/* CorrelationID context handlers */ + +func SetCorrelationID(ctx context.Context, v string) context.Context { + return context.WithValue(ctx, correlationID, v) +} + +func GetConsumedCorrelationID(ctx context.Context) string { + if trlr, ok := ctx.Value(correlationIDTRLR).(string); ok { + return trlr + } + return "" +} diff --git a/transport/grpc/_grpc_test/pb/generate.go b/transport/grpc/_grpc_test/pb/generate.go new file mode 100644 index 000000000..aa20bb664 --- /dev/null +++ b/transport/grpc/_grpc_test/pb/generate.go @@ -0,0 +1,3 @@ +package pb + +//go:generate protoc test.proto --go_out=plugins=grpc:. diff --git a/transport/grpc/_grpc_test/pb/test.pb.go b/transport/grpc/_grpc_test/pb/test.pb.go new file mode 100644 index 000000000..97d29bb1e --- /dev/null +++ b/transport/grpc/_grpc_test/pb/test.pb.go @@ -0,0 +1,167 @@ +// Code generated by protoc-gen-go. +// source: test.proto +// DO NOT EDIT! + +/* +Package pb is a generated protocol buffer package. + +It is generated from these files: + test.proto + +It has these top-level messages: + TestRequest + TestResponse +*/ +package pb + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type TestRequest struct { + A string `protobuf:"bytes,1,opt,name=a" json:"a,omitempty"` + B int64 `protobuf:"varint,2,opt,name=b" json:"b,omitempty"` +} + +func (m *TestRequest) Reset() { *m = TestRequest{} } +func (m *TestRequest) String() string { return proto.CompactTextString(m) } +func (*TestRequest) ProtoMessage() {} +func (*TestRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func (m *TestRequest) GetA() string { + if m != nil { + return m.A + } + return "" +} + +func (m *TestRequest) GetB() int64 { + if m != nil { + return m.B + } + return 0 +} + +type TestResponse struct { + V string `protobuf:"bytes,1,opt,name=v" json:"v,omitempty"` +} + +func (m *TestResponse) Reset() { *m = TestResponse{} } +func (m *TestResponse) String() string { return proto.CompactTextString(m) } +func (*TestResponse) ProtoMessage() {} +func (*TestResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func (m *TestResponse) GetV() string { + if m != nil { + return m.V + } + return "" +} + +func init() { + proto.RegisterType((*TestRequest)(nil), "pb.TestRequest") + proto.RegisterType((*TestResponse)(nil), "pb.TestResponse") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// Client API for Test service + +type TestClient interface { + Test(ctx context.Context, in *TestRequest, opts ...grpc.CallOption) (*TestResponse, error) +} + +type testClient struct { + cc *grpc.ClientConn +} + +func NewTestClient(cc *grpc.ClientConn) TestClient { + return &testClient{cc} +} + +func (c *testClient) Test(ctx context.Context, in *TestRequest, opts ...grpc.CallOption) (*TestResponse, error) { + out := new(TestResponse) + err := grpc.Invoke(ctx, "/pb.Test/Test", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// Server API for Test service + +type TestServer interface { + Test(context.Context, *TestRequest) (*TestResponse, error) +} + +func RegisterTestServer(s *grpc.Server, srv TestServer) { + s.RegisterService(&_Test_serviceDesc, srv) +} + +func _Test_Test_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(TestRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(TestServer).Test(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/pb.Test/Test", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(TestServer).Test(ctx, req.(*TestRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _Test_serviceDesc = grpc.ServiceDesc{ + ServiceName: "pb.Test", + HandlerType: (*TestServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Test", + Handler: _Test_Test_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "test.proto", +} + +func init() { proto.RegisterFile("test.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 129 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x2a, 0x49, 0x2d, 0x2e, + 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2a, 0x48, 0x52, 0xd2, 0xe4, 0xe2, 0x0e, 0x49, + 0x2d, 0x2e, 0x09, 0x4a, 0x2d, 0x2c, 0x4d, 0x2d, 0x2e, 0x11, 0xe2, 0xe1, 0x62, 0x4c, 0x94, 0x60, + 0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x62, 0x4c, 0x04, 0xf1, 0x92, 0x24, 0x98, 0x14, 0x18, 0x35, 0x98, + 0x83, 0x18, 0x93, 0x94, 0x64, 0xb8, 0x78, 0x20, 0x4a, 0x8b, 0x0b, 0xf2, 0xf3, 0x8a, 0x53, 0x41, + 0xb2, 0x65, 0x30, 0xb5, 0x65, 0x46, 0xc6, 0x5c, 0x2c, 0x20, 0x59, 0x21, 0x6d, 0x28, 0xcd, 0xaf, + 0x57, 0x90, 0xa4, 0x87, 0x64, 0xb4, 0x94, 0x00, 0x42, 0x00, 0x62, 0x80, 0x12, 0x43, 0x12, 0x1b, + 0xd8, 0x21, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x49, 0xfc, 0xd8, 0xf1, 0x96, 0x00, 0x00, + 0x00, +} diff --git a/transport/grpc/_grpc_test/pb/test.proto b/transport/grpc/_grpc_test/pb/test.proto new file mode 100644 index 000000000..6a3555e3c --- /dev/null +++ b/transport/grpc/_grpc_test/pb/test.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +package pb; + +service Test { + rpc Test (TestRequest) returns (TestResponse) {} +} + +message TestRequest { + string a = 1; + int64 b = 2; +} + +message TestResponse { + string v = 1; +} diff --git a/transport/grpc/_grpc_test/request_response.go b/transport/grpc/_grpc_test/request_response.go new file mode 100644 index 000000000..269703d39 --- /dev/null +++ b/transport/grpc/_grpc_test/request_response.go @@ -0,0 +1,27 @@ +package test + +import ( + "context" + + "github.com/go-kit/kit/transport/grpc/_grpc_test/pb" +) + +func encodeRequest(ctx context.Context, req interface{}) (interface{}, error) { + r := req.(TestRequest) + return &pb.TestRequest{A: r.A, B: r.B}, nil +} + +func decodeRequest(ctx context.Context, req interface{}) (interface{}, error) { + r := req.(*pb.TestRequest) + return TestRequest{A: r.A, B: r.B}, nil +} + +func encodeResponse(ctx context.Context, resp interface{}) (interface{}, error) { + r := resp.(*TestResponse) + return &pb.TestResponse{V: r.V}, nil +} + +func decodeResponse(ctx context.Context, resp interface{}) (interface{}, error) { + r := resp.(*pb.TestResponse) + return &TestResponse{V: r.V, Ctx: ctx}, nil +} diff --git a/transport/grpc/_grpc_test/server.go b/transport/grpc/_grpc_test/server.go new file mode 100644 index 000000000..49e70a91f --- /dev/null +++ b/transport/grpc/_grpc_test/server.go @@ -0,0 +1,70 @@ +package test + +import ( + "context" + "fmt" + + oldcontext "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" + grpctransport "github.com/go-kit/kit/transport/grpc" + "github.com/go-kit/kit/transport/grpc/_grpc_test/pb" +) + +type service struct{} + +func (service) Test(ctx context.Context, a string, b int64) (context.Context, string, error) { + return nil, fmt.Sprintf("%s = %d", a, b), nil +} + +func NewService() Service { + return service{} +} + +func makeTestEndpoint(svc Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(TestRequest) + newCtx, v, err := svc.Test(ctx, req.A, req.B) + return &TestResponse{ + V: v, + Ctx: newCtx, + }, err + } +} + +type serverBinding struct { + test grpctransport.Handler +} + +func (b *serverBinding) Test(ctx oldcontext.Context, req *pb.TestRequest) (*pb.TestResponse, error) { + _, response, err := b.test.ServeGRPC(ctx, req) + if err != nil { + return nil, err + } + return response.(*pb.TestResponse), nil +} + +func NewBinding(svc Service) *serverBinding { + return &serverBinding{ + test: grpctransport.NewServer( + makeTestEndpoint(svc), + decodeRequest, + encodeResponse, + grpctransport.ServerBefore( + extractCorrelationID, + ), + grpctransport.ServerBefore( + displayServerRequestHeaders, + ), + grpctransport.ServerAfter( + injectResponseHeader, + injectResponseTrailer, + injectConsumedCorrelationID, + ), + grpctransport.ServerAfter( + displayServerResponseHeaders, + displayServerResponseTrailers, + ), + ), + } +} diff --git a/transport/grpc/_grpc_test/service.go b/transport/grpc/_grpc_test/service.go new file mode 100644 index 000000000..536b27c0b --- /dev/null +++ b/transport/grpc/_grpc_test/service.go @@ -0,0 +1,17 @@ +package test + +import "context" + +type Service interface { + Test(ctx context.Context, a string, b int64) (context.Context, string, error) +} + +type TestRequest struct { + A string + B int64 +} + +type TestResponse struct { + Ctx context.Context + V string +} diff --git a/transport/grpc/client.go b/transport/grpc/client.go index 7ab43c647..c0faa2b36 100644 --- a/transport/grpc/client.go +++ b/transport/grpc/client.go @@ -21,7 +21,8 @@ type Client struct { enc EncodeRequestFunc dec DecodeResponseFunc grpcReply reflect.Type - before []RequestFunc + before []ClientRequestFunc + after []ClientResponseFunc } // NewClient constructs a usable Client for a single remote endpoint. @@ -53,7 +54,8 @@ func NewClient( reflect.ValueOf(grpcReply), ).Interface(), ), - before: []RequestFunc{}, + before: []ClientRequestFunc{}, + after: []ClientResponseFunc{}, } for _, option := range options { option(c) @@ -66,8 +68,15 @@ type ClientOption func(*Client) // ClientBefore sets the RequestFuncs that are applied to the outgoing gRPC // request before it's invoked. -func ClientBefore(before ...RequestFunc) ClientOption { - return func(c *Client) { c.before = before } +func ClientBefore(before ...ClientRequestFunc) ClientOption { + return func(c *Client) { c.before = append(c.before, before...) } +} + +// ClientAfter sets the ClientResponseFuncs that are applied to the incoming +// gRPC response prior to it being decoded. This is useful for obtaining +// response metadata and adding onto the context prior to decoding. +func ClientAfter(after ...ClientResponseFunc) ClientOption { + return func(c *Client) { c.after = append(c.after, after...) } } // Endpoint returns a usable endpoint that will invoke the gRPC specified by the @@ -88,11 +97,19 @@ func (c Client) Endpoint() endpoint.Endpoint { } ctx = metadata.NewContext(ctx, *md) + var header, trailer metadata.MD grpcReply := reflect.New(c.grpcReply).Interface() - if err = grpc.Invoke(ctx, c.method, req, grpcReply, c.client); err != nil { + if err = grpc.Invoke( + ctx, c.method, req, grpcReply, c.client, + grpc.Header(&header), grpc.Trailer(&trailer), + ); err != nil { return nil, err } + for _, f := range c.after { + ctx = f(ctx, header, trailer) + } + response, err := c.dec(ctx, grpcReply) if err != nil { return nil, err diff --git a/transport/grpc/client_test.go b/transport/grpc/client_test.go new file mode 100644 index 000000000..e4cac1d8c --- /dev/null +++ b/transport/grpc/client_test.go @@ -0,0 +1,59 @@ +package grpc_test + +import ( + "context" + "fmt" + "net" + "testing" + + "google.golang.org/grpc" + + test "github.com/go-kit/kit/transport/grpc/_grpc_test" + "github.com/go-kit/kit/transport/grpc/_grpc_test/pb" +) + +const ( + hostPort string = "localhost:8002" +) + +func TestGRPCClient(t *testing.T) { + var ( + server = grpc.NewServer() + service = test.NewService() + ) + + sc, err := net.Listen("tcp", hostPort) + if err != nil { + t.Fatalf("unable to listen: %+v", err) + } + defer server.GracefulStop() + + go func() { + pb.RegisterTestServer(server, test.NewBinding(service)) + _ = server.Serve(sc) + }() + + cc, err := grpc.Dial(hostPort, grpc.WithInsecure()) + if err != nil { + t.Fatalf("unable to Dial: %+v", err) + } + + client := test.NewClient(cc) + + var ( + a = "the answer to life the universe and everything" + b = int64(42) + cID = "request-1" + ctx = test.SetCorrelationID(context.Background(), cID) + ) + + responseCTX, v, err := client.Test(ctx, a, b) + + if want, have := fmt.Sprintf("%s = %d", a, b), v; want != have { + t.Fatalf("want %q, have %q", want, have) + } + + if want, have := cID, test.GetConsumedCorrelationID(responseCTX); want != have { + t.Fatalf("want %q, have %q", want, have) + } +} diff --git a/transport/grpc/request_response_funcs.go b/transport/grpc/request_response_funcs.go index aa88ca65e..8d072ede7 100644 --- a/transport/grpc/request_response_funcs.go +++ b/transport/grpc/request_response_funcs.go @@ -12,30 +12,53 @@ const ( binHdrSuffix = "-bin" ) -// RequestFunc may take information from an gRPC request and put it into a -// request context. In Servers, BeforeFuncs are executed prior to invoking the -// endpoint. In Clients, BeforeFuncs are executed after creating the request -// but prior to invoking the gRPC client. -type RequestFunc func(context.Context, *metadata.MD) context.Context +// ClientRequestFunc may take information from context and use it to construct +// metadata headers to be transported to the server. ClientRequestFuncs are +// executed after creating the request but prior to sending the gRPC request to +// the server. +type ClientRequestFunc func(context.Context, *metadata.MD) context.Context -// ResponseFunc may take information from a request context and use it to -// manipulate the gRPC metadata header. ResponseFuncs are only executed in -// servers, after invoking the endpoint but prior to writing a response. -type ResponseFunc func(context.Context, *metadata.MD) +// ServerRequestFunc may take information from the received metadata header and +// use it to place items in the request scoped context. ServerRequestFuncs are +// executed prior to invoking the endpoint. +type ServerRequestFunc func(context.Context, metadata.MD) context.Context + +// ServerResponseFunc may take information from a request context and use it to +// manipulate the gRPC response metadata headers and trailers. ResponseFuncs are +// only executed in servers, after invoking the endpoint but prior to writing a +// response. +type ServerResponseFunc func(ctx context.Context, header *metadata.MD, trailer *metadata.MD) context.Context + +// ClientResponseFunc may take information from a gRPC metadata header and/or +// trailer and make the responses available for consumption. ClientResponseFuncs +// are only executed in clients, after a request has been made, but prior to it +// being decoded. +type ClientResponseFunc func(ctx context.Context, header metadata.MD, trailer metadata.MD) context.Context + +// SetRequestHeader returns a ClientRequestFunc that sets the specified metadata +// key-value pair. +func SetRequestHeader(key, val string) ClientRequestFunc { + return func(ctx context.Context, md *metadata.MD) context.Context { + key, val := EncodeKeyValue(key, val) + (*md)[key] = append((*md)[key], val) + return ctx + } +} // SetResponseHeader returns a ResponseFunc that sets the specified metadata // key-value pair. -func SetResponseHeader(key, val string) ResponseFunc { - return func(_ context.Context, md *metadata.MD) { +func SetResponseHeader(key, val string) ServerResponseFunc { + return func(ctx context.Context, md *metadata.MD, _ *metadata.MD) context.Context { key, val := EncodeKeyValue(key, val) (*md)[key] = append((*md)[key], val) + return ctx } } -// SetRequestHeader returns a RequestFunc that sets the specified metadata +// SetResponseTrailer returns a ResponseFunc that sets the specified metadata // key-value pair. -func SetRequestHeader(key, val string) RequestFunc { - return func(ctx context.Context, md *metadata.MD) context.Context { +func SetResponseTrailer(key, val string) ServerResponseFunc { + return func(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context { key, val := EncodeKeyValue(key, val) (*md)[key] = append((*md)[key], val) return ctx diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 742c1a086..b14d7d8db 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -1,9 +1,8 @@ package grpc import ( - "context" - oldcontext "golang.org/x/net/context" + "google.golang.org/grpc" "google.golang.org/grpc/metadata" "github.com/go-kit/kit/endpoint" @@ -19,12 +18,11 @@ type Handler interface { // Server wraps an endpoint and implements grpc.Handler. type Server struct { - ctx context.Context e endpoint.Endpoint dec DecodeRequestFunc enc EncodeResponseFunc - before []RequestFunc - after []ResponseFunc + before []ServerRequestFunc + after []ServerResponseFunc logger log.Logger } @@ -34,14 +32,12 @@ type Server struct { // definitions to individual handlers. Request and response objects are from the // caller business domain, not gRPC request and reply types. func NewServer( - ctx context.Context, e endpoint.Endpoint, dec DecodeRequestFunc, enc EncodeResponseFunc, options ...ServerOption, ) *Server { s := &Server{ - ctx: ctx, e: e, dec: dec, enc: enc, @@ -58,13 +54,13 @@ type ServerOption func(*Server) // ServerBefore functions are executed on the HTTP request object before the // request is decoded. -func ServerBefore(before ...RequestFunc) ServerOption { +func ServerBefore(before ...ServerRequestFunc) ServerOption { return func(s *Server) { s.before = append(s.before, before...) } } // ServerAfter functions are executed on the HTTP response writer after the // endpoint is invoked, but before anything is written to the client. -func ServerAfter(after ...ResponseFunc) ServerOption { +func ServerAfter(after ...ServerResponseFunc) ServerOption { return func(s *Server) { s.after = append(s.after, after...) } } @@ -75,46 +71,53 @@ func ServerErrorLogger(logger log.Logger) ServerOption { } // ServeGRPC implements the Handler interface. -func (s Server) ServeGRPC(grpcCtx oldcontext.Context, req interface{}) (oldcontext.Context, interface{}, error) { - ctx := s.ctx - +func (s Server) ServeGRPC(ctx oldcontext.Context, req interface{}) (oldcontext.Context, interface{}, error) { // Retrieve gRPC metadata. - md, ok := metadata.FromContext(grpcCtx) + md, ok := metadata.FromContext(ctx) if !ok { md = metadata.MD{} } for _, f := range s.before { - ctx = f(ctx, &md) + ctx = f(ctx, md) } - // Store potentially updated metadata in the gRPC context. - grpcCtx = metadata.NewContext(grpcCtx, md) - - request, err := s.dec(grpcCtx, req) + request, err := s.dec(ctx, req) if err != nil { s.logger.Log("err", err) - return grpcCtx, nil, err + return ctx, nil, err } response, err := s.e(ctx, request) if err != nil { s.logger.Log("err", err) - return grpcCtx, nil, err + return ctx, nil, err } + var mdHeader, mdTrailer metadata.MD for _, f := range s.after { - f(ctx, &md) + ctx = f(ctx, &mdHeader, &mdTrailer) } - // Store potentially updated metadata in the gRPC context. - grpcCtx = metadata.NewContext(grpcCtx, md) - - grpcResp, err := s.enc(grpcCtx, response) + grpcResp, err := s.enc(ctx, response) if err != nil { s.logger.Log("err", err) - return grpcCtx, nil, err + return ctx, nil, err + } + + if len(mdHeader) > 0 { + if err = grpc.SendHeader(ctx, mdHeader); err != nil { + s.logger.Log("err", err) + return ctx, nil, err + } + } + + if len(mdTrailer) > 0 { + if err = grpc.SetTrailer(ctx, mdTrailer); err != nil { + s.logger.Log("err", err) + return ctx, nil, err + } } - return grpcCtx, grpcResp, nil + return ctx, grpcResp, nil } diff --git a/transport/http/client.go b/transport/http/client.go index 494797583..08f1b886e 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -62,14 +62,14 @@ func SetClient(client *http.Client) ClientOption { // ClientBefore sets the RequestFuncs that are applied to the outgoing HTTP // request before it's invoked. func ClientBefore(before ...RequestFunc) ClientOption { - return func(c *Client) { c.before = before } + return func(c *Client) { c.before = append(c.before, before...) } } // ClientAfter sets the ClientResponseFuncs applied to the incoming HTTP // request prior to it being decoded. This is useful for obtaining anything off // of the response and adding onto the context prior to decoding. func ClientAfter(after ...ClientResponseFunc) ClientOption { - return func(c *Client) { c.after = after } + return func(c *Client) { c.after = append(c.after, after...) } } // BufferedStream sets whether the Response.Body is left open, allowing it