diff --git a/tavern/internal/c2/api_claim_tasks.go b/tavern/internal/c2/api_claim_tasks.go index 46a9ffc8f..6489f4368 100644 --- a/tavern/internal/c2/api_claim_tasks.go +++ b/tavern/internal/c2/api_claim_tasks.go @@ -5,14 +5,11 @@ import ( "encoding/json" "fmt" "log/slog" - "net" "strings" "time" "github.com/prometheus/client_golang/prometheus" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" - "google.golang.org/grpc/peer" "google.golang.org/grpc/status" "realm.pub/tavern/internal/c2/c2pb" "realm.pub/tavern/internal/c2/epb" @@ -37,37 +34,9 @@ func init() { prometheus.MustRegister(metricHostCallbacksTotal) } -func getRemoteIP(ctx context.Context) string { - p, ok := peer.FromContext(ctx) - if !ok { - return "unknown" - } - - host, _, err := net.SplitHostPort(p.Addr.String()) - if err != nil { - return "unknown" - } - - return host -} - -func getClientIP(ctx context.Context) string { - md, ok := metadata.FromIncomingContext(ctx) - if ok { - if forwardedFor, exists := md["x-forwarded-for"]; exists && len(forwardedFor) > 0 { - // X-Forwarded-For is a comma-separated list, the first IP is the original client - clientIP := strings.Split(forwardedFor[0], ",")[0] - return strings.TrimSpace(clientIP) - } - } - - // Fallback to peer address - return getRemoteIP(ctx) -} - func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest) (*c2pb.ClaimTasksResponse, error) { now := time.Now() - clientIP := getClientIP(ctx) + clientIP := GetClientIP(ctx) // Validate input if req.Beacon == nil { diff --git a/tavern/internal/c2/ip.go b/tavern/internal/c2/ip.go new file mode 100644 index 000000000..98271235e --- /dev/null +++ b/tavern/internal/c2/ip.go @@ -0,0 +1,9 @@ +package c2 + +import ( + "net" +) + +func validateIP(ipaddr string) bool { + return net.ParseIP(ipaddr) != nil || ipaddr == "unknown" +} diff --git a/tavern/internal/c2/server.go b/tavern/internal/c2/server.go index 42558857a..1db74d1c4 100644 --- a/tavern/internal/c2/server.go +++ b/tavern/internal/c2/server.go @@ -1,6 +1,13 @@ package c2 import ( + "context" + "log/slog" + "net" + "strings" + + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" "realm.pub/tavern/internal/c2/c2pb" "realm.pub/tavern/internal/ent" "realm.pub/tavern/internal/http/stream" @@ -20,3 +27,49 @@ func New(graph *ent.Client, mux *stream.Mux) *Server { mux: mux, } } + +func getRemoteIP(ctx context.Context) string { + p, ok := peer.FromContext(ctx) + if !ok { + return "unknown" + } + + host, _, err := net.SplitHostPort(p.Addr.String()) + if err != nil { + return "unknown" + } + + return host +} + +func GetClientIP(ctx context.Context) string { + md, ok := metadata.FromIncomingContext(ctx) + if ok { + if redirectedFor, exists := md["x-redirected-for"]; exists && len(redirectedFor) > 0 { + clientIP := strings.TrimSpace(redirectedFor[0]) + if validateIP(clientIP) { + return clientIP + } else { + slog.Error("bad x-redirected-for ip", "ip", clientIP) + } + } + if forwardedFor, exists := md["x-forwarded-for"]; exists && len(forwardedFor) > 0 { + // X-Forwarded-For is a comma-separated list, the first IP is the original client + clientIP := strings.TrimSpace(strings.Split(forwardedFor[0], ",")[0]) + if validateIP(clientIP) { + return clientIP + } else { + slog.Error("bad x-forwarded-for ip", "ip", clientIP) + } + } + } + + // Fallback to peer address + remoteIp := getRemoteIP(ctx) + if validateIP(remoteIp) { + return remoteIp + } else { + slog.Error("Bad remote IP", "ip", remoteIp) + } + return "unknown" +} diff --git a/tavern/internal/c2/server_test.go b/tavern/internal/c2/server_test.go new file mode 100644 index 000000000..78cd38f3c --- /dev/null +++ b/tavern/internal/c2/server_test.go @@ -0,0 +1,160 @@ +package c2 + +import ( + "context" + "net" + "testing" + + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" +) + +func TestGetClientIP(t *testing.T) { + tests := []struct { + name string + setupContext func() context.Context + expectedIP string + }{ + { + name: "X-Forwarded-For_Only", + setupContext: func() context.Context { + ctx := context.Background() + md := metadata.New(map[string]string{ + "x-forwarded-for": "203.0.113.42", + }) + return metadata.NewIncomingContext(ctx, md) + }, + expectedIP: "203.0.113.42", + }, + { + name: "X-Redirected-For_With_X-Forwarded-For", + setupContext: func() context.Context { + ctx := context.Background() + md := metadata.New(map[string]string{ + "x-forwarded-for": "203.0.113.42", + "x-redirected-for": "198.51.100.99", + }) + return metadata.NewIncomingContext(ctx, md) + }, + expectedIP: "198.51.100.99", + }, + { + name: "Neither_Header_Set_Uses_Peer_IP", + setupContext: func() context.Context { + ctx := context.Background() + p := &peer.Peer{ + Addr: &net.TCPAddr{ + IP: net.ParseIP("1.1.1.1"), + Port: 12345, + }, + } + return peer.NewContext(ctx, p) + }, + expectedIP: "1.1.1.1", + }, + { + name: "X-Forwarded-For_With_Multiple_IPs", + setupContext: func() context.Context { + ctx := context.Background() + md := metadata.New(map[string]string{ + "x-forwarded-for": "203.0.113.42, 198.51.100.1, 192.0.2.5", + }) + return metadata.NewIncomingContext(ctx, md) + }, + expectedIP: "203.0.113.42", + }, + { + name: "X-Forwarded-For_With_Whitespace", + setupContext: func() context.Context { + ctx := context.Background() + md := metadata.New(map[string]string{ + "x-forwarded-for": " 203.0.113.42 ", + }) + return metadata.NewIncomingContext(ctx, md) + }, + expectedIP: "203.0.113.42", + }, + { + name: "X-Redirected-For_Precedence_Over_Peer", + setupContext: func() context.Context { + ctx := context.Background() + p := &peer.Peer{ + Addr: &net.TCPAddr{ + IP: net.ParseIP("1.1.1.1"), + Port: 12345, + }, + } + ctx = peer.NewContext(ctx, p) + md := metadata.New(map[string]string{ + "x-redirected-for": "198.51.100.99", + }) + return metadata.NewIncomingContext(ctx, md) + }, + expectedIP: "198.51.100.99", + }, + { + name: "Invalid_X-Forwarded-For_Fallback_To_Peer", + setupContext: func() context.Context { + ctx := context.Background() + p := &peer.Peer{ + Addr: &net.TCPAddr{ + IP: net.ParseIP("1.1.1.1"), + Port: 12345, + }, + } + ctx = peer.NewContext(ctx, p) + md := metadata.New(map[string]string{ + "x-forwarded-for": "invalid-ip-address", + }) + return metadata.NewIncomingContext(ctx, md) + }, + expectedIP: "1.1.1.1", + }, + { + name: "No_Metadata_No_Peer_Returns_Unknown", + setupContext: func() context.Context { + return context.Background() + }, + expectedIP: "unknown", + }, + { + name: "Malformed_X-Redirected-For_Returns_As_Is", + setupContext: func() context.Context { + ctx := context.Background() + p := &peer.Peer{ + Addr: &net.TCPAddr{ + IP: net.ParseIP("1.1.1.1"), + Port: 12345, + }, + } + ctx = peer.NewContext(ctx, p) + md := metadata.New(map[string]string{ + "x-redirected-for": "not-an-ip", + }) + return metadata.NewIncomingContext(ctx, md) + }, + expectedIP: "1.1.1.1", + }, + { + name: "Malformed_X-Forwarded-For_Without_Peer_Returns_Unknown", + setupContext: func() context.Context { + ctx := context.Background() + md := metadata.New(map[string]string{ + "x-forwarded-for": "not-an-ip", + }) + return metadata.NewIncomingContext(ctx, md) + }, + expectedIP: "unknown", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := tt.setupContext() + result := GetClientIP(ctx) + if result != tt.expectedIP { + t.Errorf("GetClientIP() = %v, want %v", result, tt.expectedIP) + } + }) + } +} diff --git a/tavern/internal/redirectors/grpc/grpc.go b/tavern/internal/redirectors/grpc/grpc.go index 6dcba19b7..dfb721e3e 100644 --- a/tavern/internal/redirectors/grpc/grpc.go +++ b/tavern/internal/redirectors/grpc/grpc.go @@ -11,6 +11,7 @@ import ( "google.golang.org/grpc/encoding" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "realm.pub/tavern/internal/c2" "realm.pub/tavern/internal/redirectors" ) @@ -49,10 +50,17 @@ func (r *Redirector) handler(upstream *grpc.ClientConn) grpc.StreamHandler { } ctx := ss.Context() + // Get the client's remote IP address + clientIP := c2.GetClientIP(ctx) + md, ok := metadata.FromIncomingContext(ctx) if ok { - ctx = metadata.NewOutgoingContext(ctx, md) + ctx = metadata.NewOutgoingContext(ctx, md.Copy()) } + + // Set x-redirected-for header with the client IP + ctx = redirectors.SetRedirectedForHeader(ctx, clientIP) + cs, err := upstream.NewStream(ctx, &grpc.StreamDesc{ StreamName: fullMethodName, ServerStreams: true, diff --git a/tavern/internal/redirectors/http1/handlers.go b/tavern/internal/redirectors/http1/handlers.go index a6c54ce26..790ae9972 100644 --- a/tavern/internal/redirectors/http1/handlers.go +++ b/tavern/internal/redirectors/http1/handlers.go @@ -4,9 +4,12 @@ import ( "fmt" "io" "log/slog" + "net" "net/http" + "strings" "google.golang.org/grpc" + "realm.pub/tavern/internal/redirectors" ) func handleFetchAssetStreaming(w http.ResponseWriter, r *http.Request, conn *grpc.ClientConn) { @@ -159,11 +162,42 @@ func handleReportFileStreaming(w http.ResponseWriter, r *http.Request, conn *grp } } +func getClientIP(r *http.Request) string { + if forwardedFor := r.Header.Get("x-forwarded-for"); len(forwardedFor) > 0 { + // X-Forwarded-For is a comma-separated list, the first IP is the original client + clientIp := strings.TrimSpace(strings.Split(forwardedFor, ",")[0]) + if validateIP(clientIp) { + return clientIp + } else { + slog.Error("bad forwarded for ip", "ip", clientIp) + } + } + + // Fallback to RemoteAddr + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + slog.Error("failed to parse remote addr", "ip", r.RemoteAddr) + } + + host = strings.TrimSpace(host) + if validateIP(host) { + return host + } else { + slog.Error("bad remote ip", "ip", host) + } + return "unknown" +} + +func validateIP(ipaddr string) bool { + return net.ParseIP(ipaddr) != nil || ipaddr == "unknown" +} func handleHTTPRequest(w http.ResponseWriter, r *http.Request, conn *grpc.ClientConn) { if !requirePOST(w, r) { return } + clientIp := getClientIP(r) + methodName := r.URL.Path if methodName == "" { http.Error(w, "Method name required in path", http.StatusBadRequest) @@ -180,6 +214,9 @@ func handleHTTPRequest(w http.ResponseWriter, r *http.Request, conn *grpc.Client ctx, cancel := createRequestContext(unaryTimeout) defer cancel() + // Set x-redirected-for header with the client IP + ctx = redirectors.SetRedirectedForHeader(ctx, clientIp) + var responseBody []byte err := conn.Invoke( ctx, diff --git a/tavern/internal/redirectors/http1/handlers_test.go b/tavern/internal/redirectors/http1/handlers_test.go new file mode 100644 index 000000000..7ace4e9ab --- /dev/null +++ b/tavern/internal/redirectors/http1/handlers_test.go @@ -0,0 +1,111 @@ +package http1 + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestGetClientIP(t *testing.T) { + tests := []struct { + name string + setupRequest func() *http.Request + expectedIP string + }{ + { + name: "X-Forwarded-For_Set", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/test", nil) + req.Header.Set("X-Forwarded-For", "203.0.113.42") + req.RemoteAddr = "192.0.2.1:12345" + return req + }, + expectedIP: "203.0.113.42", + }, + { + name: "X-Forwarded-For_Not_Set_IPv4", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/test", nil) + req.RemoteAddr = "1.1.1.1:12345" + return req + }, + expectedIP: "1.1.1.1", + }, + { + name: "X-Forwarded-For_Not_Set_IPv6", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/test", nil) + req.RemoteAddr = "[2001:db8::1]:12345" + return req + }, + expectedIP: "2001:db8::1", + }, + { + name: "X-Forwarded-For_Not_Set_IPv6_Localhost", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/test", nil) + req.RemoteAddr = "[::1]:5000" + return req + }, + expectedIP: "::1", + }, + { + name: "X-Forwarded-For_Empty_Falls_Back_To_RemoteAddr", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/test", nil) + req.Header.Set("X-Forwarded-For", "") + req.RemoteAddr = "1.1.1.1:12345" + return req + }, + expectedIP: "1.1.1.1", + }, + { + name: "X-Forwarded-For_With_Multiple_IPs", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/test", nil) + req.Header.Set("X-Forwarded-For", "203.0.113.42, 198.51.100.1, 192.0.2.5") + req.RemoteAddr = "192.0.2.1:12345" + return req + }, + expectedIP: "203.0.113.42", + }, + { + name: "X-Forwarded-For_Malformed_IP", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/test", nil) + req.Header.Set("X-Forwarded-For", "not-an-ip") + req.RemoteAddr = "1.1.1.1" + return req + }, + expectedIP: "unknown", + }, + { + name: "RemoteAddr_Without_Port", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/test", nil) + req.RemoteAddr = "1.1.1.1" + return req + }, + expectedIP: "unknown", + }, + { + name: "RemoteAddr_Multiple_Colons_IPv6_With_Port", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/test", nil) + req.RemoteAddr = "[2001:db8::1]:5000" + return req + }, + expectedIP: "2001:db8::1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := tt.setupRequest() + result := getClientIP(req) + if result != tt.expectedIP { + t.Errorf("getClientIP() = %v, want %v", result, tt.expectedIP) + } + }) + } +} diff --git a/tavern/internal/redirectors/metadata.go b/tavern/internal/redirectors/metadata.go new file mode 100644 index 000000000..fcefef29b --- /dev/null +++ b/tavern/internal/redirectors/metadata.go @@ -0,0 +1,25 @@ +package redirectors + +import ( + "context" + "log/slog" + + "google.golang.org/grpc/metadata" +) + +// SetRedirectedForHeader sets the x-redirected-for header in the outgoing context metadata +// with the provided client IP address. This header is used to track the original client IP +// through the redirector chain. +func SetRedirectedForHeader(ctx context.Context, clientIP string) context.Context { + if clientIP == "" { + return ctx + } + + outMd, _ := metadata.FromOutgoingContext(ctx) + if outMd == nil { + outMd = metadata.New(nil) + } + outMd.Set("x-redirected-for", clientIP) + slog.Info("Setting redirected-for header", "clientIP", clientIP) + return metadata.NewOutgoingContext(ctx, outMd) +}