Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 1 addition & 32 deletions tavern/internal/c2/api_claim_tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@ import (
"encoding/json"
"fmt"
"log/slog"
"net"
"strings"
"time"

"github.com/prometheus/client_golang/prometheus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"realm.pub/tavern/internal/c2/c2pb"
"realm.pub/tavern/internal/c2/epb"
Expand All @@ -37,37 +34,9 @@ func init() {
prometheus.MustRegister(metricHostCallbacksTotal)
}

func getRemoteIP(ctx context.Context) string {
p, ok := peer.FromContext(ctx)
if !ok {
return "unknown"
}

host, _, err := net.SplitHostPort(p.Addr.String())
if err != nil {
return "unknown"
}

return host
}

func getClientIP(ctx context.Context) string {
md, ok := metadata.FromIncomingContext(ctx)
if ok {
if forwardedFor, exists := md["x-forwarded-for"]; exists && len(forwardedFor) > 0 {
// X-Forwarded-For is a comma-separated list, the first IP is the original client
clientIP := strings.Split(forwardedFor[0], ",")[0]
return strings.TrimSpace(clientIP)
}
}

// Fallback to peer address
return getRemoteIP(ctx)
}

func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest) (*c2pb.ClaimTasksResponse, error) {
now := time.Now()
clientIP := getClientIP(ctx)
clientIP := GetClientIP(ctx)

// Validate input
if req.Beacon == nil {
Expand Down
9 changes: 9 additions & 0 deletions tavern/internal/c2/ip.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package c2

import (
"net"
)

func validateIP(ipaddr string) bool {
return net.ParseIP(ipaddr) != nil || ipaddr == "unknown"
}
53 changes: 53 additions & 0 deletions tavern/internal/c2/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
package c2

import (
"context"
"log/slog"
"net"
"strings"

"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"realm.pub/tavern/internal/c2/c2pb"
"realm.pub/tavern/internal/ent"
"realm.pub/tavern/internal/http/stream"
Expand All @@ -20,3 +27,49 @@ func New(graph *ent.Client, mux *stream.Mux) *Server {
mux: mux,
}
}

func getRemoteIP(ctx context.Context) string {
p, ok := peer.FromContext(ctx)
if !ok {
return "unknown"
}

host, _, err := net.SplitHostPort(p.Addr.String())
if err != nil {
return "unknown"
}

return host
}

func GetClientIP(ctx context.Context) string {
md, ok := metadata.FromIncomingContext(ctx)
if ok {
if redirectedFor, exists := md["x-redirected-for"]; exists && len(redirectedFor) > 0 {
clientIP := strings.TrimSpace(redirectedFor[0])
if validateIP(clientIP) {
return clientIP
} else {
slog.Error("bad x-redirected-for ip", "ip", clientIP)
}
}
if forwardedFor, exists := md["x-forwarded-for"]; exists && len(forwardedFor) > 0 {
// X-Forwarded-For is a comma-separated list, the first IP is the original client
clientIP := strings.TrimSpace(strings.Split(forwardedFor[0], ",")[0])
if validateIP(clientIP) {
return clientIP
} else {
slog.Error("bad x-forwarded-for ip", "ip", clientIP)
}
}
}

// Fallback to peer address
remoteIp := getRemoteIP(ctx)
if validateIP(remoteIp) {
return remoteIp
} else {
slog.Error("Bad remote IP", "ip", remoteIp)
}
return "unknown"
}
160 changes: 160 additions & 0 deletions tavern/internal/c2/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package c2

import (
"context"
"net"
"testing"

"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
)

func TestGetClientIP(t *testing.T) {
tests := []struct {
name string
setupContext func() context.Context
expectedIP string
}{
{
name: "X-Forwarded-For_Only",
setupContext: func() context.Context {
ctx := context.Background()
md := metadata.New(map[string]string{
"x-forwarded-for": "203.0.113.42",
})
return metadata.NewIncomingContext(ctx, md)
},
expectedIP: "203.0.113.42",
},
{
name: "X-Redirected-For_With_X-Forwarded-For",
setupContext: func() context.Context {
ctx := context.Background()
md := metadata.New(map[string]string{
"x-forwarded-for": "203.0.113.42",
"x-redirected-for": "198.51.100.99",
})
return metadata.NewIncomingContext(ctx, md)
},
expectedIP: "198.51.100.99",
},
{
name: "Neither_Header_Set_Uses_Peer_IP",
setupContext: func() context.Context {
ctx := context.Background()
p := &peer.Peer{
Addr: &net.TCPAddr{
IP: net.ParseIP("1.1.1.1"),
Port: 12345,
},
}
return peer.NewContext(ctx, p)
},
expectedIP: "1.1.1.1",
},
{
name: "X-Forwarded-For_With_Multiple_IPs",
setupContext: func() context.Context {
ctx := context.Background()
md := metadata.New(map[string]string{
"x-forwarded-for": "203.0.113.42, 198.51.100.1, 192.0.2.5",
})
return metadata.NewIncomingContext(ctx, md)
},
expectedIP: "203.0.113.42",
},
{
name: "X-Forwarded-For_With_Whitespace",
setupContext: func() context.Context {
ctx := context.Background()
md := metadata.New(map[string]string{
"x-forwarded-for": " 203.0.113.42 ",
})
return metadata.NewIncomingContext(ctx, md)
},
expectedIP: "203.0.113.42",
},
{
name: "X-Redirected-For_Precedence_Over_Peer",
setupContext: func() context.Context {
ctx := context.Background()
p := &peer.Peer{
Addr: &net.TCPAddr{
IP: net.ParseIP("1.1.1.1"),
Port: 12345,
},
}
ctx = peer.NewContext(ctx, p)
md := metadata.New(map[string]string{
"x-redirected-for": "198.51.100.99",
})
return metadata.NewIncomingContext(ctx, md)
},
expectedIP: "198.51.100.99",
},
{
name: "Invalid_X-Forwarded-For_Fallback_To_Peer",
setupContext: func() context.Context {
ctx := context.Background()
p := &peer.Peer{
Addr: &net.TCPAddr{
IP: net.ParseIP("1.1.1.1"),
Port: 12345,
},
}
ctx = peer.NewContext(ctx, p)
md := metadata.New(map[string]string{
"x-forwarded-for": "invalid-ip-address",
})
return metadata.NewIncomingContext(ctx, md)
},
expectedIP: "1.1.1.1",
},
{
name: "No_Metadata_No_Peer_Returns_Unknown",
setupContext: func() context.Context {
return context.Background()
},
expectedIP: "unknown",
},
{
name: "Malformed_X-Redirected-For_Returns_As_Is",
setupContext: func() context.Context {
ctx := context.Background()
p := &peer.Peer{
Addr: &net.TCPAddr{
IP: net.ParseIP("1.1.1.1"),
Port: 12345,
},
}
ctx = peer.NewContext(ctx, p)
md := metadata.New(map[string]string{
"x-redirected-for": "not-an-ip",
})
return metadata.NewIncomingContext(ctx, md)
},
expectedIP: "1.1.1.1",
},
{
name: "Malformed_X-Forwarded-For_Without_Peer_Returns_Unknown",
setupContext: func() context.Context {
ctx := context.Background()
md := metadata.New(map[string]string{
"x-forwarded-for": "not-an-ip",
})
return metadata.NewIncomingContext(ctx, md)
},
expectedIP: "unknown",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := tt.setupContext()
result := GetClientIP(ctx)
if result != tt.expectedIP {
t.Errorf("GetClientIP() = %v, want %v", result, tt.expectedIP)
}
})
}
}
10 changes: 9 additions & 1 deletion tavern/internal/redirectors/grpc/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"realm.pub/tavern/internal/c2"
"realm.pub/tavern/internal/redirectors"
)

Expand Down Expand Up @@ -49,10 +50,17 @@ func (r *Redirector) handler(upstream *grpc.ClientConn) grpc.StreamHandler {
}

ctx := ss.Context()
// Get the client's remote IP address
clientIP := c2.GetClientIP(ctx)

md, ok := metadata.FromIncomingContext(ctx)
if ok {
ctx = metadata.NewOutgoingContext(ctx, md)
ctx = metadata.NewOutgoingContext(ctx, md.Copy())
}

// Set x-redirected-for header with the client IP
ctx = redirectors.SetRedirectedForHeader(ctx, clientIP)

cs, err := upstream.NewStream(ctx, &grpc.StreamDesc{
StreamName: fullMethodName,
ServerStreams: true,
Expand Down
37 changes: 37 additions & 0 deletions tavern/internal/redirectors/http1/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ import (
"fmt"
"io"
"log/slog"
"net"
"net/http"
"strings"

"google.golang.org/grpc"
"realm.pub/tavern/internal/redirectors"
)

func handleFetchAssetStreaming(w http.ResponseWriter, r *http.Request, conn *grpc.ClientConn) {
Expand Down Expand Up @@ -159,11 +162,42 @@ func handleReportFileStreaming(w http.ResponseWriter, r *http.Request, conn *grp
}
}

func getClientIP(r *http.Request) string {
if forwardedFor := r.Header.Get("x-forwarded-for"); len(forwardedFor) > 0 {
// X-Forwarded-For is a comma-separated list, the first IP is the original client
clientIp := strings.TrimSpace(strings.Split(forwardedFor, ",")[0])
if validateIP(clientIp) {
return clientIp
} else {
slog.Error("bad forwarded for ip", "ip", clientIp)
}
}

// Fallback to RemoteAddr
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
slog.Error("failed to parse remote addr", "ip", r.RemoteAddr)
}

host = strings.TrimSpace(host)
if validateIP(host) {
return host
} else {
slog.Error("bad remote ip", "ip", host)
}
return "unknown"
}

func validateIP(ipaddr string) bool {
return net.ParseIP(ipaddr) != nil || ipaddr == "unknown"
}
func handleHTTPRequest(w http.ResponseWriter, r *http.Request, conn *grpc.ClientConn) {
if !requirePOST(w, r) {
return
}

clientIp := getClientIP(r)

methodName := r.URL.Path
if methodName == "" {
http.Error(w, "Method name required in path", http.StatusBadRequest)
Expand All @@ -180,6 +214,9 @@ func handleHTTPRequest(w http.ResponseWriter, r *http.Request, conn *grpc.Client
ctx, cancel := createRequestContext(unaryTimeout)
defer cancel()

// Set x-redirected-for header with the client IP
ctx = redirectors.SetRedirectedForHeader(ctx, clientIp)

var responseBody []byte
err := conn.Invoke(
ctx,
Expand Down
Loading
Loading