From f3a5be463dc66963c8d20a7ef58c9b4623a9b5fb Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 16 Nov 2025 00:49:23 +0000 Subject: [PATCH 1/3] feat: Add gRPC redirector This commit introduces a new gRPC redirector that is capable of proxying raw gRPC traffic, including bidirectional streams. The implementation is located in `tavern/internal/redirectors/grpc` and conforms to the existing `Redirector` interface. Key features: - Uses `grpc.UnknownServiceHandler` to transparently handle all gRPC services and methods. - Proxies raw byte frames for maximum performance and to support end-to-end encryption. - Supports bidirectional streaming, which is a requirement for the "ReverseShell" feature. - Includes extensive integration tests covering success paths, context cancellation, and upstream connection failures. The new redirector is registered under the name "grpc" and is now available for use in the application. --- tavern/internal/redirectors/grpc/grpc.go | 100 ++++++++ tavern/internal/redirectors/grpc/grpc_test.go | 217 ++++++++++++++++++ 2 files changed, 317 insertions(+) create mode 100644 tavern/internal/redirectors/grpc/grpc.go create mode 100644 tavern/internal/redirectors/grpc/grpc_test.go diff --git a/tavern/internal/redirectors/grpc/grpc.go b/tavern/internal/redirectors/grpc/grpc.go new file mode 100644 index 000000000..11bc63806 --- /dev/null +++ b/tavern/internal/redirectors/grpc/grpc.go @@ -0,0 +1,100 @@ +package grpc + +import ( + "context" + "fmt" + "io" + "net" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "realm.pub/tavern/internal/redirectors" +) + +func init() { + redirectors.Register("grpc", &Redirector{}) +} + +// Redirector is a gRPC redirector. +type Redirector struct{} + +// Redirect implements the redirectors.Redirector interface. +func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *grpc.ClientConn) error { + lis, err := net.Listen("tcp", listenOn) + if err != nil { + return fmt.Errorf("failed to listen: %w", err) + } + + s := grpc.NewServer( + grpc.UnknownServiceHandler(r.handler(upstream)), + ) + + go func() { + <-ctx.Done() + s.GracefulStop() + }() + + return s.Serve(lis) +} + +func (r *Redirector) handler(upstream *grpc.ClientConn) grpc.StreamHandler { + return func(srv any, ss grpc.ServerStream) error { + fullMethodName, ok := grpc.MethodFromServerStream(ss) + if !ok { + return status.Errorf(codes.Internal, "failed to get method from server stream") + } + + ctx := ss.Context() + md, ok := metadata.FromIncomingContext(ctx) + if ok { + ctx = metadata.NewOutgoingContext(ctx, md) + } + cs, err := upstream.NewStream(ctx, &grpc.StreamDesc{ + StreamName: fullMethodName, + ServerStreams: true, + ClientStreams: true, + }, fullMethodName, grpc.CallContentSubtype("raw")) + if err != nil { + return fmt.Errorf("failed to create new client stream: %w", err) + } + + errChan := make(chan error, 2) + go r.proxy(ss, cs, errChan) + go r.proxy(cs, ss, errChan) + + err = <-errChan + if err == io.EOF { + err = <-errChan + } + + if err != nil && err != io.EOF { + return err + } + + return nil + } +} + +func (r *Redirector) proxy(from grpc.Stream, to grpc.Stream, errChan chan<- error) { + for { + var msg []byte + if err := from.RecvMsg(&msg); err != nil { + if err == io.EOF { + if cs, ok := to.(grpc.ClientStream); ok { + cs.CloseSend() + } + errChan <- io.EOF + return + } + errChan <- err + return + } + + if err := to.SendMsg(msg); err != nil { + errChan <- err + return + } + } +} diff --git a/tavern/internal/redirectors/grpc/grpc_test.go b/tavern/internal/redirectors/grpc/grpc_test.go new file mode 100644 index 000000000..5cc768189 --- /dev/null +++ b/tavern/internal/redirectors/grpc/grpc_test.go @@ -0,0 +1,217 @@ +package grpc_test + +import ( + "context" + "fmt" + "io" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/encoding" + "google.golang.org/grpc/interop/grpc_testing" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + + grpcRedirector "realm.pub/tavern/internal/redirectors/grpc" +) + +// setupRawUpstreamServer creates a mock gRPC server that uses a raw codec to manually +// handle requests. This simulates the upstream server that the redirector will connect to. +func setupRawUpstreamServer(t *testing.T) (string, func()) { + t.Helper() + + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + handler := func(srv any, stream grpc.ServerStream) error { + fullMethodName, ok := grpc.MethodFromServerStream(stream) + if !ok { + return status.Errorf(codes.Internal, "failed to get method from server stream") + } + + if fullMethodName != "/grpc.testing.TestService/FullDuplexCall" { + return status.Errorf(codes.Unimplemented, "method not implemented: %s", fullMethodName) + } + + // Manually handle the bidirectional stream + for { + var reqBytes []byte + if err := stream.RecvMsg(&reqBytes); err != nil { + if err == io.EOF { + return nil + } + return err + } + + var req grpc_testing.StreamingOutputCallRequest + if err := proto.Unmarshal(reqBytes, &req); err != nil { + return status.Errorf(codes.Internal, "failed to unmarshal request: %v", err) + } + + resp := &grpc_testing.StreamingOutputCallResponse{Payload: &grpc_testing.Payload{Body: req.Payload.Body}} + respBytes, err := proto.Marshal(resp) + if err != nil { + return status.Errorf(codes.Internal, "failed to marshal response: %v", err) + } + + if err := stream.SendMsg(respBytes); err != nil { + return err + } + } + } + + // The upstream server must also use the raw codec to correctly interpret the proxied stream. + s := grpc.NewServer( + grpc.ForceServerCodec(encoding.GetCodec("raw")), + grpc.UnknownServiceHandler(handler), + ) + + go func() { + if err := s.Serve(lis); err != nil && err != grpc.ErrServerStopped { + t.Logf("test server stopped: %v", err) + } + }() + + return lis.Addr().String(), func() { + s.Stop() + } +} + +func TestRedirector_FullDuplexCall(t *testing.T) { + // 1. Setup the raw upstream test server. + upstreamAddr, upstreamCleanup := setupRawUpstreamServer(t) + defer upstreamCleanup() + + // 2. Setup the redirector to point to the upstream server. + redirector := &grpcRedirector.Redirector{} + redirectorLis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := redirectorLis.Addr().String() + redirectorLis.Close() + + upstreamConn, err := grpc.Dial(upstreamAddr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultCallOptions(grpc.CallContentSubtype("raw"))) + require.NoError(t, err) + defer upstreamConn.Close() + + go func() { + redirector.Redirect(context.Background(), addr, upstreamConn) + }() + + // 3. Connect a client to the redirector, also using the raw codec. + conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultCallOptions(grpc.CallContentSubtype("raw"))) + require.NoError(t, err) + defer conn.Close() + + // 4. Perform a bidirectional streaming call through the redirector. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + stream, err := conn.NewStream(ctx, &grpc.StreamDesc{ + ServerStreams: true, + ClientStreams: true, + }, "/grpc.testing.TestService/FullDuplexCall") + require.NoError(t, err) + + for i := 0; i < 3; i++ { + body := []byte(fmt.Sprintf("ping-%d", i)) + req := &grpc_testing.StreamingOutputCallRequest{Payload: &grpc_testing.Payload{Body: body}} + reqBytes, err := proto.Marshal(req) + require.NoError(t, err) + + require.NoError(t, stream.SendMsg(reqBytes)) + + var respBytes []byte + require.NoError(t, stream.RecvMsg(&respBytes)) + + var resp grpc_testing.StreamingOutputCallResponse + require.NoError(t, proto.Unmarshal(respBytes, &resp)) + require.Equal(t, body, resp.Payload.Body) + } + + require.NoError(t, stream.CloseSend()) +} + +func TestRedirector_ContextCancellation(t *testing.T) { + redirector := &grpcRedirector.Redirector{} + redirectorLis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := redirectorLis.Addr().String() + redirectorLis.Close() + + // We don't need a real upstream for this test. + upstreamConn, err := grpc.Dial("localhost:1", grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer upstreamConn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + serverErr := make(chan error) + go func() { + serverErr <- redirector.Redirect(ctx, addr, upstreamConn) + }() + + // Wait a moment for the server to start listening. + time.Sleep(100 * time.Millisecond) + + // Cancel the context, which should trigger GracefulStop. + cancel() + + // The redirector should stop gracefully. + select { + case err = <-serverErr: + // Different versions of gRPC may return either nil or ErrServerStopped on graceful shutdown. + if err != nil { + require.ErrorIs(t, err, grpc.ErrServerStopped, "Redirect should return grpc.ErrServerStopped on graceful stop") + } + case <-time.After(1 * time.Second): + t.Fatal("server did not stop in time") + } +} + +func TestRedirector_UpstreamFailure(t *testing.T) { + // 1. Setup the redirector without a valid upstream. + redirector := &grpcRedirector.Redirector{} + redirectorLis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := redirectorLis.Addr().String() + redirectorLis.Close() + + // Connect to a non-existent upstream. + upstreamConn, err := grpc.Dial("localhost:1", grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer upstreamConn.Close() + + go func() { + redirector.Redirect(context.Background(), addr, upstreamConn) + }() + + // 2. Connect a client to the redirector. + conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultCallOptions(grpc.CallContentSubtype("raw"))) + require.NoError(t, err) + defer conn.Close() + + // 3. Attempt a streaming call. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + stream, err := conn.NewStream(ctx, &grpc.StreamDesc{ + ServerStreams: true, + ClientStreams: true, + }, "/grpc.testing.TestService/FullDuplexCall") + require.NoError(t, err) + + // 4. Attempting to receive a message should fail because the upstream is down. + var respBytes []byte + err = stream.RecvMsg(&respBytes) + + // 5. Verify that the error is a gRPC status error with the code Unavailable. + require.Error(t, err) + s, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + require.Equal(t, codes.Unavailable, s.Code(), "error code should be Unavailable") +} From 25eaa457387d3122cb3ab7b959cc88354a614d4a Mon Sep 17 00:00:00 2001 From: KCarretto Date: Sun, 16 Nov 2025 01:11:02 +0000 Subject: [PATCH 2/3] register grpc redirector and make it the default --- tavern/app.go | 3 ++- tavern/internal/redirectors/grpc/grpc.go | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tavern/app.go b/tavern/app.go index de6d4304f..1aaf9c73a 100644 --- a/tavern/app.go +++ b/tavern/app.go @@ -39,6 +39,7 @@ import ( "realm.pub/tavern/internal/www" "realm.pub/tavern/tomes" + _ "realm.pub/tavern/internal/redirectors/grpc" _ "realm.pub/tavern/internal/redirectors/http1" ) @@ -90,7 +91,7 @@ func newApp(ctx context.Context) (app *cli.App) { listenOn = ":8080" } if transport == "" { - transport = "http1" + transport = "grpc" } slog.InfoContext(ctx, "starting redirector", "upstream", upstream, "transport", transport, "listen_on", listenOn) return redirectors.Run(ctx, transport, listenOn, upstream) diff --git a/tavern/internal/redirectors/grpc/grpc.go b/tavern/internal/redirectors/grpc/grpc.go index 11bc63806..6dcba19b7 100644 --- a/tavern/internal/redirectors/grpc/grpc.go +++ b/tavern/internal/redirectors/grpc/grpc.go @@ -8,6 +8,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/encoding" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "realm.pub/tavern/internal/redirectors" @@ -29,6 +30,7 @@ func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *gr s := grpc.NewServer( grpc.UnknownServiceHandler(r.handler(upstream)), + grpc.ForceServerCodec(encoding.GetCodec("raw")), ) go func() { From e9020a79800bbc786ca72640457b660e636cb764 Mon Sep 17 00:00:00 2001 From: KCarretto Date: Sun, 16 Nov 2025 01:11:25 +0000 Subject: [PATCH 3/3] register grpc redirector and make it the default --- tavern/app.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tavern/app.go b/tavern/app.go index 1aaf9c73a..a659f2ae5 100644 --- a/tavern/app.go +++ b/tavern/app.go @@ -74,8 +74,8 @@ func newApp(ctx context.Context) (app *cli.App) { }, cli.StringFlag{ Name: "transport", - Usage: "Transport protocol to use for redirector (default: http1)", - Value: "http1", + Usage: "Transport protocol to use for redirector", + Value: "grpc", }, }, Action: func(c *cli.Context) error {