diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a95c84455..920efda25 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -2,7 +2,7 @@ name: Release env: CARGO_TERM_COLOR: always - IMIX_CALLBACK_URI: http://127.0.0.1 + IMIX_CALLBACK_URI: http://127.0.0.1:8000 on: workflow_dispatch: ~ diff --git a/AGENTS.md b/AGENTS.md index 269bb78f5..53f7e0ef7 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,15 +1,21 @@ +# Agents + Welcome to our repository! Most commands need to be run from the root of the project directory (e.g. where this `AGENT.md` file is located) -# Project Structure - * `tavern/` includes our Golang server implementation, which hosts a GraphQL API, User-Interface (typescript), and a gRPC API used by agents. - * Our user interface is located in `tavern/internal/www` and we managed dependencies within that directory using `npm` - * `implants/` contains Rust code that is deployed to target machines, such as our agent located in `implants/imix`. +## Project Structure + +* `tavern/` includes our Golang server implementation, which hosts a GraphQL API, User-Interface (typescript), and a gRPC API used by agents. + * Our user interface is located in `tavern/internal/www` and we managed dependencies within that directory using `npm` +* `implants/` contains Rust code that is deployed to target machines, such as our agent located in `implants/imix`. + +## Golang Tests -# Golang Tests To run all Golang tests in our repository, please run `go test ./...` from the project root. -# Code Generation +## Code Generation + This project heavily relies on generated code. When making changes to ent schemas, GraphQL, or user interface / front-end changes in the `tavern/internal/www/` directory, you will need to run `go generate ./...` from the project root directory to re-generate some critical files. -# Additional Documentation +## Additional Documentation + Our user-facing documentation for the project is located in `docs/_docs` and can be referenced for additional information. diff --git a/docs/_docs/admin-guide/tavern.md b/docs/_docs/admin-guide/tavern.md index e7e014527..66f4d99c9 100644 --- a/docs/_docs/admin-guide/tavern.md +++ b/docs/_docs/admin-guide/tavern.md @@ -110,7 +110,7 @@ By default Tavern only supports GRPC connections directly to the server. To Enab ### Webserver -By default, Tavern will listen on `0.0.0.0:80`. If you ever wish to change this bind address then simply supply it to the `HTTP_LISTEN_ADDR` environment variable. +By default, Tavern will listen on `0.0.0.0:8000`. If you ever wish to change this bind address then simply supply it to the `HTTP_LISTEN_ADDR` environment variable. ### Metrics @@ -119,7 +119,7 @@ By default, Tavern does not export metrics. You may use the below environment co | Env Var | Description | Default | Required | | ------- | ----------- | ------- | -------- | | ENABLE_METRICS | Set to any value to enable the "/metrics" endpoint. | Disabled | No | -| HTTP_METRICS_LISTEN_ADDR | Listen address for the metrics HTTP server, it must be different than the value of `HTTP_LISTEN_ADDR`. | `127.0.0.1:8080` | No | +| HTTP_METRICS_LISTEN_ADDR | Listen address for the metrics HTTP server, it must be different than the value of `HTTP_LISTEN_ADDR`. | `127.0.0.1:8000` | No | ### Secrets @@ -236,7 +236,7 @@ func main() { defer cancel() // Setup your Tavern URL (e.g. from env vars) - tavernURL := "http://127.0.0.1" + tavernURL := "http://127.0.0.1:8000" // Configure Browser (uses the default system browser) browser := auth.BrowserFunc(browser.OpenURL) @@ -280,7 +280,7 @@ Running Tavern with the `DISABLE_DEFAULT_TOMES` environment variable set will di ```sh DISABLE_DEFAULT_TOMES=1 go run ./tavern -2024/03/02 01:32:22 [WARN] No value for 'HTTP_LISTEN_ADDR' provided, defaulting to 0.0.0.0:80 +2024/03/02 01:32:22 [WARN] No value for 'HTTP_LISTEN_ADDR' provided, defaulting to 0.0.0.0:8000 2024/03/02 01:32:22 [WARN] MySQL is not configured, using SQLite 2024/03/02 01:32:22 [WARN] OAuth is not configured, authentication disabled 2024/03/02 01:32:22 [WARN] No value for 'DB_MAX_IDLE_CONNS' provided, defaulting to 10 diff --git a/docs/_docs/dev-guide/tavern.md b/docs/_docs/dev-guide/tavern.md index 43dc5be61..c5b6ae8e6 100644 --- a/docs/_docs/dev-guide/tavern.md +++ b/docs/_docs/dev-guide/tavern.md @@ -89,7 +89,7 @@ apt install -y graphviz ### Collect a Profile 1. Start Tavern with profiling enabled: `ENABLE_PPROF=1 go run ./tavern`. -2. Collect a Profile in desired format (e.g. png): `go tool pprof -png -seconds=10 http://127.0.0.1:80/debug/pprof/allocs?seconds=10 > .pprof/allocs.png` +2. Collect a Profile in desired format (e.g. png): `go tool pprof -png -seconds=10 http://127.0.0.1:8000/debug/pprof/allocs?seconds=10 > .pprof/allocs.png` a. Replace "allocs" with the [name of the profile](https://pkg.go.dev/runtime/pprof#Profile) to collect. b. Replace the value of seconds with the amount of time you need to reproduce performance issues. c. Read more about the available profiling URL parameters [here](https://pkg.go.dev/net/http/pprof#hdr-Parameters). @@ -113,27 +113,27 @@ The reverse shell system is designed to be highly scalable and resilient. It use The reverse shell system is composed of the following components: -* **gRPC Server**: The gRPC server is the entry point for the agent. It exposes the `ReverseShell` service, which is a bidirectional gRPC stream. The agent connects to this service to initiate a reverse shell session. -* **WebSocket Server**: The WebSocket server is the entry point for the user. It exposes a WebSocket endpoint that the user can connect to to interact with the reverse shell. -* **Pub/Sub Messaging System**: The pub/sub messaging system is the backbone of the reverse shell. It's used to decouple the gRPC server and the WebSocket server, and to provide a reliable and scalable way to transport messages between them. The system uses two topics: - * **Input Topic**: The input topic is used to send messages from the user (via the WebSocket) to the agent (via the gRPC stream). - * **Output Topic**: The output topic is used to send messages from the agent (via the gRPC stream) to the user (via the WebSocket). -* **Mux**: The `Mux` is a multiplexer that sits between the pub/sub system and the gRPC/WebSocket servers. It's responsible for routing messages between the two. There are two `Mux` instances: - * **wsMux**: The `wsMux` is used by the WebSocket server. It subscribes to the output topic and publishes to the input topic. - * **grpcMux**: The `grpcMux` is used by the gRPC server. It subscribes to the input topic and publishes to the output topic. -* **Stream**: A `Stream` represents a single reverse shell session. It's responsible for managing the connection between the `Mux` and the gRPC/WebSocket client. -* **sessionBuffer**: The `sessionBuffer` is used to order messages within a `Stream`. This is important because multiple users can be connected to the same shell session, and their messages need to be delivered in the correct order. +* **gRPC Server**: The gRPC server is the entry point for the agent. It exposes the `ReverseShell` service, which is a bidirectional gRPC stream. The agent connects to this service to initiate a reverse shell session. +* **WebSocket Server**: The WebSocket server is the entry point for the user. It exposes a WebSocket endpoint that the user can connect to to interact with the reverse shell. +* **Pub/Sub Messaging System**: The pub/sub messaging system is the backbone of the reverse shell. It's used to decouple the gRPC server and the WebSocket server, and to provide a reliable and scalable way to transport messages between them. The system uses two topics: + * **Input Topic**: The input topic is used to send messages from the user (via the WebSocket) to the agent (via the gRPC stream). + * **Output Topic**: The output topic is used to send messages from the agent (via the gRPC stream) to the user (via the WebSocket). +* **Mux**: The `Mux` is a multiplexer that sits between the pub/sub system and the gRPC/WebSocket servers. It's responsible for routing messages between the two. There are two `Mux` instances: + * **wsMux**: The `wsMux` is used by the WebSocket server. It subscribes to the output topic and publishes to the input topic. + * **grpcMux**: The `grpcMux` is used by the gRPC server. It subscribes to the input topic and publishes to the output topic. +* **Stream**: A `Stream` represents a single reverse shell session. It's responsible for managing the connection between the `Mux` and the gRPC/WebSocket client. +* **sessionBuffer**: The `sessionBuffer` is used to order messages within a `Stream`. This is important because multiple users can be connected to the same shell session, and their messages need to be delivered in the correct order. #### Communication Flow -1. The agent connects to the `ReverseShell` gRPC service. -2. The gRPC server creates a new `Shell` entity, a new `Stream`, and registers the `Stream` with the `grpcMux`. -3. The user connects to the WebSocket endpoint. -4. The WebSocket server creates a new `Stream` and registers it with the `wsMux`. -5. When the user sends a message, it's sent to the WebSocket server, which publishes it to the input topic via the `wsMux`. -6. The `grpcMux` receives the message from the input topic and sends it to the agent via the gRPC stream. -7. When the agent sends a message, it's sent to the gRPC server, which publishes it to the output topic via the `grpcMux`. -8. The `wsMux` receives the message from the output topic and sends it to the user via the WebSocket. +1. The agent connects to the `ReverseShell` gRPC service. +2. The gRPC server creates a new `Shell` entity, a new `Stream`, and registers the `Stream` with the `grpcMux`. +3. The user connects to the WebSocket endpoint. +4. The WebSocket server creates a new `Stream` and registers it with the `wsMux`. +5. When the user sends a message, it's sent to the WebSocket server, which publishes it to the input topic via the `wsMux`. +6. The `grpcMux` receives the message from the input topic and sends it to the agent via the gRPC stream. +7. When the agent sends a message, it's sent to the gRPC server, which publishes it to the output topic via the `grpcMux`. +8. The `wsMux` receives the message from the output topic and sends it to the user via the WebSocket. #### Distributed Architecture diff --git a/docs/_docs/user-guide/getting-started.md b/docs/_docs/user-guide/getting-started.md index 7e40a9c78..22b93b17a 100644 --- a/docs/_docs/user-guide/getting-started.md +++ b/docs/_docs/user-guide/getting-started.md @@ -67,7 +67,7 @@ These configurations can be controlled via Environment Variables at `imix` compi ### Quests -Now it's time to provide our [Beacon](/user-guide/terminology#beacon) it's first [Task](/user-guide/terminology#task). We do this, by creating a [Quest](/user-guide/terminology#quest) in the UI, which represents a collection of [Tasks](/user-guide/terminology#task) across one or more [Hosts](/user-guide/terminology#host). Let's open our UI, which should be available at [http://127.0.0.1:80/](http://127.0.0.1:80/). +Now it's time to provide our [Beacon](/user-guide/terminology#beacon) it's first [Task](/user-guide/terminology#task). We do this, by creating a [Quest](/user-guide/terminology#quest) in the UI, which represents a collection of [Tasks](/user-guide/terminology#task) across one or more [Hosts](/user-guide/terminology#host). Let's open our UI, which should be available at [http://127.0.0.1:8000/](http://127.0.0.1:8000/). #### Beacon Selection diff --git a/docs/_docs/user-guide/imix.md b/docs/_docs/user-guide/imix.md index 26264b5b5..b8a94f561 100644 --- a/docs/_docs/user-guide/imix.md +++ b/docs/_docs/user-guide/imix.md @@ -15,8 +15,8 @@ Imix has compile-time configuration, that may be specified using environment var | Env Var | Description | Default | Required | | ------- | ----------- | ------- | -------- | -| IMIX_CALLBACK_URI | URI for initial callbacks (must specify a scheme, e.g. `http://`) | `http://127.0.0.1:80` | No | -| IMIX_SERVER_PUBKEY | The public key for the tavern server. | - | Yes | +| IMIX_CALLBACK_URI | URI for initial callbacks (must specify a scheme, e.g. `http://`) | `http://127.0.0.1:8000` | No | +| IMIX_SERVER_PUBKEY | The public key for the tavern server (obtain from server using `curl $IMIX_CALLBACK_URI/status`). | - | Yes | | IMIX_CALLBACK_INTERVAL | Duration between callbacks, in seconds. | `5` | No | | IMIX_RETRY_INTERVAL | Duration to wait before restarting the agent loop if an error occurs, in seconds. | `5` | No | | IMIX_PROXY_URI | Overide system settings for proxy URI over HTTP(S) (must specify a scheme, e.g. `https://`) | No proxy | No | @@ -91,11 +91,12 @@ Building in the dev container limits variables that might cause issues and is th **Imix requires a server public key so it can encrypt messsages to and from the server check the server log for `level=INFO msg="public key: "`. This base64 encoded string should be passed to the agent using the environment variable `IMIX_SERVER_PUBKEY`** ## Optional build flags + These flags are passed to cargo build Eg.: `cargo build --release --bin imix --bin imix --target=x86_64-unknown-linux-musl --features foo-bar` - `--features grpc-doh` - Enable DNS over HTTP using cloudflare DNS for the grpc transport -- `--features http --no-default-features` - Changes the default grpc transport to use HTTP/1.1. Requires running the http redirector. +- `--features http1 --no-default-features` - Changes the default grpc transport to use HTTP/1.1. Requires running the http redirector. ### Linux diff --git a/implants/lib/pb/src/config.rs b/implants/lib/pb/src/config.rs index caaadd653..b1269f3d0 100644 --- a/implants/lib/pb/src/config.rs +++ b/implants/lib/pb/src/config.rs @@ -19,7 +19,7 @@ macro_rules! callback_uri { () => { match option_env!("IMIX_CALLBACK_URI") { Some(uri) => uri, - None => "http://127.0.0.1:80", + None => "http://127.0.0.1:8000", } }; } @@ -36,7 +36,7 @@ macro_rules! proxy_uri { /* * Compile-time constant for the agent callback URI, derived from the IMIX_CALLBACK_URI environment variable during compilation. - * Defaults to "http://127.0.0.1:80/grpc" if this is unset. + * Defaults to "http://127.0.0.1:8000/grpc" if this is unset. */ pub const CALLBACK_URI: &str = callback_uri!(); diff --git a/tavern/app.go b/tavern/app.go index 9619d047a..de6d4304f 100644 --- a/tavern/app.go +++ b/tavern/app.go @@ -34,46 +34,82 @@ import ( "realm.pub/tavern/internal/graphql" tavernhttp "realm.pub/tavern/internal/http" "realm.pub/tavern/internal/http/stream" - "realm.pub/tavern/internal/redirector" + "realm.pub/tavern/internal/redirectors" "realm.pub/tavern/internal/secrets" "realm.pub/tavern/internal/www" "realm.pub/tavern/tomes" + + _ "realm.pub/tavern/internal/redirectors/http1" ) func init() { configureLogging() } -func newApp(ctx context.Context, options ...func(*Config)) (app *cli.App) { +func newApp(ctx context.Context) (app *cli.App) { app = cli.NewApp() app.Name = "tavern" app.Description = "Teamserver implementation for Realm, see https://docs.realm.pub for more details" app.Usage = "Time for an Adventure!" app.Version = Version - app.Action = cli.ActionFunc(func(*cli.Context) error { - return run(ctx, options...) - }) + app.Action = func(c *cli.Context) error { + return runTavern( + ctx, + ConfigureHTTPServerFromEnv(), + ConfigureMySQLFromEnv(), + ConfigureOAuthFromEnv("/oauth/authorize"), + ) + } app.Commands = []cli.Command{ { - Name: "redirector", - Usage: "Run a redirector connecting agents using a specific transport to the server", - Subcommands: []cli.Command{ - { - Name: "http1", - Usage: "Run an HTTP/1.1 redirector", - Action: func(cCtx *cli.Context) error { - // Convert main.Config options to redirector.Config options - redirectorOptions := []func(*redirector.Config){ - func(cfg *redirector.Config) { - // Apply main Config to get server settings - mainCfg := &Config{} - for _, opt := range options { - opt(mainCfg) - } - cfg.SetServer(mainCfg.srv) - }, + Name: "redirector", + Usage: "Run a redirector connecting agents using a specific transport to the server", + ArgsUsage: "[upstream_address]", + Flags: []cli.Flag{ + cli.StringFlag{ + Name: "listen", + Usage: "Address to listen on for incoming redirector traffic (default: :8080)", + Value: ":8080", + }, + cli.StringFlag{ + Name: "transport", + Usage: "Transport protocol to use for redirector (default: http1)", + Value: "http1", + }, + }, + Action: func(c *cli.Context) error { + var ( + upstream = c.Args().First() + listenOn = c.String("listen") + transport = c.String("transport") + ) + if upstream == "" { + return fmt.Errorf("gRPC upstream address is required (first argument)") + } + if listenOn == "" { + listenOn = ":8080" + } + if transport == "" { + transport = "http1" + } + slog.InfoContext(ctx, "starting redirector", "upstream", upstream, "transport", transport, "listen_on", listenOn) + return redirectors.Run(ctx, transport, listenOn, upstream) + }, + Subcommands: cli.Commands{ + cli.Command{ + Name: "list", + Usage: "List available redirectors", + Action: func(c *cli.Context) error { + redirectorNames := redirectors.List() + if len(redirectorNames) == 0 { + fmt.Println("No redirectors registered") + return nil + } + fmt.Println("Available redirectors:") + for _, name := range redirectorNames { + fmt.Printf("- %s\n", name) } - return redirector.HTTPRedirectorRun(ctx, cCtx.Args().First(), redirectorOptions...) + return nil }, }, }, @@ -82,8 +118,7 @@ func newApp(ctx context.Context, options ...func(*Config)) (app *cli.App) { return } - -func run(ctx context.Context, options ...func(*Config)) error { +func runTavern(ctx context.Context, options ...func(*Config)) error { srv, err := NewServer(ctx, options...) if err != nil { return err @@ -475,7 +510,6 @@ func newGRPCHandler(client *ent.Client, grpcShellMux *stream.Mux) http.Handler { return } - if contentType := r.Header.Get("Content-Type"); !strings.HasPrefix(contentType, "application/grpc") { http.Error(w, "must specify Content-Type application/grpc", http.StatusBadRequest) return diff --git a/tavern/config.go b/tavern/config.go index 6626f6c88..4364e447a 100644 --- a/tavern/config.go +++ b/tavern/config.go @@ -44,7 +44,7 @@ var ( // EnvHTTPListenAddr sets the address (ip:port) for tavern's HTTP server to bind to. // EnvHTTPMetricsAddr sets the address (ip:port) for the HTTP metrics server to bind to. - EnvHTTPListenAddr = EnvString{"HTTP_LISTEN_ADDR", "0.0.0.0:80"} + EnvHTTPListenAddr = EnvString{"HTTP_LISTEN_ADDR", "0.0.0.0:8000"} EnvHTTPMetricsListenAddr = EnvString{"HTTP_METRICS_LISTEN_ADDR", "127.0.0.1:8080"} // EnvOAuthClientID set to configure OAuth Client ID. diff --git a/tavern/internal/cryptocodec/cryptocodec.go b/tavern/internal/cryptocodec/cryptocodec.go index 4591dca0d..4b3e7164b 100644 --- a/tavern/internal/cryptocodec/cryptocodec.go +++ b/tavern/internal/cryptocodec/cryptocodec.go @@ -6,7 +6,6 @@ import ( "crypto/rand" "errors" "fmt" - "log" "log/slog" "runtime/debug" "strconv" @@ -23,8 +22,9 @@ var session_pub_keys = NewSyncMap() // This size limits the number of concurrent connections each server can handle. // I can't imagine a single server handling more than 10k connections at once but just in case. const LRUCACHE_SIZE = 10480 + type SyncMap struct { - Map *lru.Cache[int, []byte] // Example data map + Map *lru.Cache[int, []byte] // Example data map } func NewSyncMap() *SyncMap { @@ -45,7 +45,6 @@ func (s *SyncMap) String() string { return res } - func (s *SyncMap) Load(key int) ([]byte, bool) { return s.Map.Get(key) } @@ -68,8 +67,8 @@ func castBytesToBufSlice(buf []byte) (mem.BufferSlice, error) { } func init() { - log.Println("[INFO] Loading xchacha20-poly1305") encoding.RegisterCodecV2(StreamDecryptCodec{}) + slog.Debug("[cryptocodec] application-layer cryptography registered xchacha20-poly1305 gRPC codec") } type StreamDecryptCodec struct { @@ -228,9 +227,9 @@ func (csvc *CryptoSvc) Encrypt(in_arr []byte) []byte { } type GoidTrace struct { - Id int + Id int ParentId int - Others []int + Others []int } func goAllIds() (GoidTrace, error) { @@ -248,9 +247,9 @@ func goAllIds() (GoidTrace, error) { } } res := GoidTrace{ - Id: ids[0], + Id: ids[0], ParentId: ids[1], - Others: ids[2:], + Others: ids[2:], } return res, nil } diff --git a/tavern/internal/redirector/grpc_frame.go b/tavern/internal/redirector/grpc_frame.go deleted file mode 100644 index 68adb2352..000000000 --- a/tavern/internal/redirector/grpc_frame.go +++ /dev/null @@ -1,60 +0,0 @@ -package redirector - -import "encoding/binary" - -// FrameHeaderSize is the size of gRPC frame header: [compression_flag(1)][length(4)] -const FrameHeaderSize = 5 - -// FrameHeader represents a gRPC wire protocol frame header -type FrameHeader struct { - CompressionFlag uint8 - MessageLength uint32 -} - -// NewFrameHeader creates a new frame header with no compression -func NewFrameHeader(messageLength uint32) FrameHeader { - return FrameHeader{ - CompressionFlag: 0x00, - MessageLength: messageLength, - } -} - -// Encode encodes the frame header to a 5-byte array -func (h FrameHeader) Encode() [5]byte { - var header [5]byte - header[0] = h.CompressionFlag - binary.BigEndian.PutUint32(header[1:], h.MessageLength) - return header -} - -// TryDecode attempts to decode a frame header from the buffer -// Returns (header, ok) where ok indicates if enough data was available -func TryDecode(buffer []byte) (FrameHeader, bool) { - if len(buffer) < FrameHeaderSize { - return FrameHeader{}, false - } - - return FrameHeader{ - CompressionFlag: buffer[0], - MessageLength: binary.BigEndian.Uint32(buffer[1:5]), - }, true -} - -// ExtractFrame extracts a complete gRPC frame from the buffer -// Returns (header, message, remainingBuffer, ok) -func ExtractFrame(buffer []byte) (FrameHeader, []byte, []byte, bool) { - header, ok := TryDecode(buffer) - if !ok { - return FrameHeader{}, nil, buffer, false - } - - totalSize := FrameHeaderSize + int(header.MessageLength) - if len(buffer) < totalSize { - return FrameHeader{}, nil, buffer, false - } - - message := buffer[FrameHeaderSize:totalSize] - remaining := buffer[totalSize:] - - return header, message, remaining, true -} diff --git a/tavern/internal/redirector/redirector_test.go b/tavern/internal/redirector/redirector_test.go deleted file mode 100644 index 59f7ca9df..000000000 --- a/tavern/internal/redirector/redirector_test.go +++ /dev/null @@ -1,897 +0,0 @@ -package redirector - -import ( - "bytes" - "fmt" - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/status" -) - -// TestRawCodecMarshal tests marshaling bytes with RawCodec -func TestRawCodecMarshal(t *testing.T) { - codec := RawCodec{} - testData := []byte("test data") - - result, err := codec.Marshal(testData) - if err != nil { - t.Fatalf("Marshal failed: %v", err) - } - if !bytes.Equal(result, testData) { - t.Errorf("Expected %v, got %v", testData, result) - } -} - -// TestRawCodecMarshalInvalidType tests marshaling with invalid type -func TestRawCodecMarshalInvalidType(t *testing.T) { - codec := RawCodec{} - invalidData := "string data" - - _, err := codec.Marshal(invalidData) - if err == nil { - t.Error("Expected error when marshaling non-bytes type") - } - if !strings.Contains(err.Error(), "failed to marshal") { - t.Errorf("Expected 'failed to marshal' in error, got: %v", err) - } -} - -// TestRawCodecUnmarshal tests unmarshaling bytes with RawCodec -func TestRawCodecUnmarshal(t *testing.T) { - codec := RawCodec{} - testData := []byte("test data") - var result []byte - - err := codec.Unmarshal(testData, &result) - if err != nil { - t.Fatalf("Unmarshal failed: %v", err) - } - if !bytes.Equal(result, testData) { - t.Errorf("Expected %v, got %v", testData, result) - } -} - -// TestRawCodecUnmarshalInvalidType tests unmarshaling with invalid type -func TestRawCodecUnmarshalInvalidType(t *testing.T) { - codec := RawCodec{} - testData := []byte("test data") - var result string - - err := codec.Unmarshal(testData, &result) - if err == nil { - t.Error("Expected error when unmarshaling to non-*[]byte type") - } - if !strings.Contains(err.Error(), "failed to unmarshal") { - t.Errorf("Expected 'failed to unmarshal' in error, got: %v", err) - } -} - -// TestRawCodecName tests the codec name -func TestRawCodecName(t *testing.T) { - codec := RawCodec{} - name := codec.Name() - if name != "raw" { - t.Errorf("Expected codec name 'raw', got '%s'", name) - } -} - -// TestConfigSetServer tests setting the HTTP server in config -func TestConfigSetServer(t *testing.T) { - config := &Config{} - server := &http.Server{Addr: ":8080"} - - config.SetServer(server) - - if config.srv != server { - t.Error("SetServer did not properly set the server") - } - if config.srv.Addr != ":8080" { - t.Errorf("Expected server address ':8080', got '%s'", config.srv.Addr) - } -} - -// TestRequirePOSTSuccess tests requirePOST with valid POST request -func TestRequirePOSTSuccess(t *testing.T) { - req := httptest.NewRequest("POST", "/test", nil) - w := httptest.NewRecorder() - - result := requirePOST(w, req) - - if !result { - t.Error("requirePOST returned false for POST request") - } - // httptest.ResponseRecorder records a 200 status by default when WriteHeader is not explicitly called - // This is expected behavior -} - -// TestRequirePOSTMethodNotAllowed tests requirePOST with GET request -func TestRequirePOSTMethodNotAllowed(t *testing.T) { - req := httptest.NewRequest("GET", "/test", nil) - w := httptest.NewRecorder() - - result := requirePOST(w, req) - - if result { - t.Error("requirePOST returned true for GET request") - } - if w.Code != http.StatusMethodNotAllowed { - t.Errorf("Expected status %d, got %d", http.StatusMethodNotAllowed, w.Code) - } -} - -// TestRequirePOSTVariousMethods tests requirePOST with various HTTP methods -func TestRequirePOSTVariousMethods(t *testing.T) { - methods := []string{"GET", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"} - - for _, method := range methods { - req := httptest.NewRequest(method, "/test", nil) - w := httptest.NewRecorder() - - result := requirePOST(w, req) - - if result { - t.Errorf("requirePOST returned true for %s request", method) - } - if w.Code != http.StatusMethodNotAllowed { - t.Errorf("Expected status %d for %s, got %d", http.StatusMethodNotAllowed, method, w.Code) - } - } -} - -// TestReadRequestBodySuccess tests reading valid request body -func TestReadRequestBodySuccess(t *testing.T) { - testData := []byte("test request body") - req := httptest.NewRequest("POST", "/test", bytes.NewReader(testData)) - w := httptest.NewRecorder() - - result, ok := readRequestBody(w, req) - - if !ok { - t.Error("readRequestBody returned false for valid body") - } - if !bytes.Equal(result, testData) { - t.Errorf("Expected body %v, got %v", testData, result) - } -} - -// TestReadRequestBodyEmpty tests reading empty request body -func TestReadRequestBodyEmpty(t *testing.T) { - req := httptest.NewRequest("POST", "/test", bytes.NewReader([]byte{})) - w := httptest.NewRecorder() - - result, ok := readRequestBody(w, req) - - if !ok { - t.Error("readRequestBody returned false for empty body") - } - if len(result) != 0 { - t.Errorf("Expected empty body, got %v", result) - } -} - -// TestReadRequestBodyLarge tests reading large request body -func TestReadRequestBodyLarge(t *testing.T) { - largeData := make([]byte, 1024*1024) // 1MB - for i := range largeData { - largeData[i] = byte(i % 256) - } - req := httptest.NewRequest("POST", "/test", bytes.NewReader(largeData)) - w := httptest.NewRecorder() - - result, ok := readRequestBody(w, req) - - if !ok { - t.Error("readRequestBody returned false for large body") - } - if !bytes.Equal(result, largeData) { - t.Errorf("Expected body size %d, got %d", len(largeData), len(result)) - } -} - -// TestSetGRPCResponseHeaders tests setting proper gRPC response headers -func TestSetGRPCResponseHeaders(t *testing.T) { - w := httptest.NewRecorder() - - setGRPCResponseHeaders(w) - - if w.Header().Get("Content-Type") != "application/grpc" { - t.Errorf("Expected Content-Type 'application/grpc', got '%s'", - w.Header().Get("Content-Type")) - } - if w.Code != http.StatusOK { - t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code) - } -} - -// TestGetFlusherSuccess tests getting a flusher from ResponseWriter -func TestGetFlusherSuccess(t *testing.T) { - w := httptest.NewRecorder() - - flusher, ok := getFlusher(w) - - if !ok { - t.Error("getFlusher returned false for httptest.ResponseRecorder") - } - if flusher == nil { - t.Error("Expected non-nil flusher") - } -} - -// TestGetFlusherWithErrorWriter tests getFlusher with writer that doesn't support flushing -func TestGetFlusherWithErrorWriter(t *testing.T) { - var buf bytes.Buffer - w := &nonFlushingWriter{buf: &buf} - - header := w.Header() - if header == nil { - t.Error("Expected non-nil header") - } -} - -// nonFlushingWriter is a mock ResponseWriter that doesn't implement Flusher -type nonFlushingWriter struct { - buf *bytes.Buffer - header http.Header -} - -func (n *nonFlushingWriter) Header() http.Header { - if n.header == nil { - n.header = make(http.Header) - } - return n.header -} - -func (n *nonFlushingWriter) Write(b []byte) (int, error) { - return n.buf.Write(b) -} - -func (n *nonFlushingWriter) WriteHeader(statusCode int) {} - -// TestHandleStreamError tests error handling for gRPC stream errors -func TestHandleStreamError(t *testing.T) { - w := httptest.NewRecorder() - testErr := status.Error(codes.Internal, "test error") - - handleStreamError(w, "Test message", testErr) - - if w.Code != http.StatusInternalServerError { - t.Errorf("Expected status code %d, got %d", http.StatusInternalServerError, w.Code) - } - if !strings.Contains(w.Body.String(), "Test message") { - t.Errorf("Expected 'Test message' in error response, got: %s", w.Body.String()) - } -} - -// TestCreateRequestContext tests context creation with timeout -func TestCreateRequestContext(t *testing.T) { - timeout := 5 * time.Second - ctx, cancel := createRequestContext(timeout) - defer cancel() - - if ctx == nil { - t.Error("Expected non-nil context") - } - - // Check that context has a deadline - deadline, ok := ctx.Deadline() - if !ok { - t.Error("Expected context to have deadline") - } - - // Check that deadline is approximately correct (within 1 second) - now := time.Now() - expectedDeadline := now.Add(timeout) - diff := expectedDeadline.Sub(deadline) - if diff < -1*time.Second || diff > 1*time.Second { - t.Errorf("Deadline not set correctly. Expected ~%v, got %v", expectedDeadline, deadline) - } -} - -// TestCreateRequestContextCancellation tests that cancel function works -func TestCreateRequestContextCancellation(t *testing.T) { - ctx, cancel := createRequestContext(30 * time.Second) - - select { - case <-ctx.Done(): - t.Error("Context should not be done before cancellation") - default: - } - - cancel() - - // Give a small window for the context to be cancelled - select { - case <-ctx.Done(): - // Expected - case <-time.After(1 * time.Second): - t.Error("Context should be cancelled") - } -} - -// TestHandleHTTPRequestInvalidMethod tests handleHTTPRequest with non-POST method -func TestHandleHTTPRequestInvalidMethod(t *testing.T) { - conn := setupTestGRPCConnection(t) - defer conn.Close() - - req := httptest.NewRequest("GET", "/test.Method", nil) - w := httptest.NewRecorder() - - handleHTTPRequest(w, req, conn) - - if w.Code != http.StatusMethodNotAllowed { - t.Errorf("Expected status %d, got %d", http.StatusMethodNotAllowed, w.Code) - } -} - -// TestHandleHTTPRequestEmptyPath tests handleHTTPRequest with empty path -func TestHandleHTTPRequestEmptyPath(t *testing.T) { - conn := setupTestGRPCConnection(t) - defer conn.Close() - - req := httptest.NewRequest("POST", "/", bytes.NewReader([]byte("data"))) - w := httptest.NewRecorder() - - // This will attempt to call an empty method on the gRPC server - // It should fail, but let's ensure it handles gracefully - handleHTTPRequest(w, req, conn) - - // Should return an error status code - if w.Code == http.StatusOK { - t.Errorf("Expected error status code, got %d", w.Code) - } -} - -// TestHandleHTTPRequestValidRequest tests handleHTTPRequest with valid request -func TestHandleHTTPRequestValidRequest(t *testing.T) { - // This test would require a running gRPC server, which is complex to set up - // For now, we'll test the structure - conn := setupTestGRPCConnection(t) - defer conn.Close() - - req := httptest.NewRequest("POST", "/test.Method", bytes.NewReader([]byte("test"))) - w := httptest.NewRecorder() - - // This test depends on having a gRPC server, so it will fail gracefully - handleHTTPRequest(w, req, conn) - - // Verify that we got an error response (since no server is running) - if w.Code != http.StatusInternalServerError { - t.Errorf("Expected status code %d for unavailable server, got %d", - http.StatusInternalServerError, w.Code) - } -} - -// TestHandleHTTPRequestReadBodyError simulates a scenario where body reading would fail -func TestHandleHTTPRequestReadBodyError(t *testing.T) { - conn := setupTestGRPCConnection(t) - defer conn.Close() - - // Create a request with an error-inducing body - req := httptest.NewRequest("POST", "/test.Method", newBrokenReader()) - w := httptest.NewRecorder() - - handleHTTPRequest(w, req, conn) - - if w.Code != http.StatusBadRequest { - t.Errorf("Expected status %d, got %d", http.StatusBadRequest, w.Code) - } -} - -// brokenReader is a reader that always returns an error -type brokenReader struct{} - -func (br *brokenReader) Read(p []byte) (n int, err error) { - return 0, io.ErrUnexpectedEOF -} - -func newBrokenReader() io.Reader { - return &brokenReader{} -} - -// TestHandleHTTPRequestConstants verifies timeout constants -func TestHandleHTTPRequestConstants(t *testing.T) { - if streamingTimeout == 0 || unaryTimeout == 0 { - t.Error("Timeout constants should be non-zero") - } - if bufferCapacity == 0 || readChunkSize == 0 { - t.Error("Buffer size constants should be non-zero") - } - if streamingTimeout <= unaryTimeout { - t.Error("Streaming timeout should be greater than unary timeout") - } - if bufferCapacity <= readChunkSize { - t.Errorf("Buffer capacity (%d) should be greater than read chunk size (%d)", - bufferCapacity, readChunkSize) - } -} - -// setupTestGRPCConnection creates a test gRPC connection to localhost -// This will fail if no server is running, which is expected for most tests -func setupTestGRPCConnection(t *testing.T) *grpc.ClientConn { - // Try to connect to a non-existent server - // This will create a connection in a failed state, which is fine for testing handlers - conn, err := grpc.NewClient( - "localhost:0", - grpc.WithTransportCredentials(insecure.NewCredentials()), - ) - if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") { - // Some errors are acceptable during setup - } - return conn -} - -// TestMultipleRawCodecOperations tests sequence of marshal/unmarshal operations -func TestMultipleRawCodecOperations(t *testing.T) { - codec := RawCodec{} - testCases := [][]byte{ - []byte(""), - []byte("a"), - []byte("hello"), - []byte("test\x00data\x01with\xffbinary"), - make([]byte, 1024), - } - - for _, testData := range testCases { - marshaled, err := codec.Marshal(testData) - if err != nil { - t.Fatalf("Marshal failed for %v: %v", testData, err) - } - - var unmarshaled []byte - err = codec.Unmarshal(marshaled, &unmarshaled) - if err != nil { - t.Fatalf("Unmarshal failed for %v: %v", marshaled, err) - } - - if !bytes.Equal(unmarshaled, testData) { - t.Errorf("Roundtrip failed: original %v, result %v", testData, unmarshaled) - } - } -} - - - -// TestRawCodecEdgeCases tests RawCodec with edge case inputs -func TestRawCodecEdgeCases(t *testing.T) { - codec := RawCodec{} - - // Test with nil - _, err := codec.Marshal(nil) - if err == nil { - t.Error("Expected error when marshaling nil") - } - - // Test with pointer to bytes - data := []byte("test") - _, err = codec.Marshal(&data) - if err == nil { - t.Error("Expected error when marshaling pointer to bytes (not bytes directly)") - } - - // Test with empty bytes - result, err := codec.Marshal([]byte{}) - if err != nil { - t.Fatalf("Marshal empty bytes failed: %v", err) - } - if len(result) != 0 { - t.Errorf("Expected empty result for empty bytes, got %d bytes", len(result)) - } -} - -// TestRequirePOSTAllowsCorrectMethod tests that only POST is allowed -func TestRequirePOSTAllowsCorrectMethod(t *testing.T) { - req := httptest.NewRequest("POST", "/", nil) - w := httptest.NewRecorder() - - if !requirePOST(w, req) { - t.Error("requirePOST should allow POST method") - } -} - -// TestReadRequestBodyClosesBody tests that request body is closed after reading -func TestReadRequestBodyClosesBody(t *testing.T) { - reader := &trackingReader{buf: bytes.NewReader([]byte("test"))} - req := httptest.NewRequest("POST", "/", reader) - w := httptest.NewRecorder() - - _, ok := readRequestBody(w, req) - if !ok { - t.Error("readRequestBody returned false") - } - if !reader.closed { - t.Error("Expected request body to be closed") - } -} - -// trackingReader wraps an io.Reader and tracks if it was closed -type trackingReader struct { - buf *bytes.Reader - closed bool -} - -func (tr *trackingReader) Read(p []byte) (n int, err error) { - return tr.buf.Read(p) -} - -func (tr *trackingReader) Close() error { - tr.closed = true - return nil -} - - - -// TestConfigurationIntegration tests Config with various server settings -func TestConfigurationIntegration(t *testing.T) { - tests := []struct { - name string - addr string - }{ - {"localhost IPv4", "127.0.0.1:8080"}, - {"localhost IPv6", "[::1]:8080"}, - {"dynamic port", "localhost:0"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - config := &Config{} - server := &http.Server{Addr: tt.addr} - - config.SetServer(server) - - if config.srv.Addr != tt.addr { - t.Errorf("Expected address '%s', got '%s'", tt.addr, config.srv.Addr) - } - }) - } -} - -// TestHTTPHandlersWithRecorder tests HTTP handlers with httptest.ResponseRecorder -func TestHTTPHandlersWithRecorder(t *testing.T) { - conn := setupTestGRPCConnection(t) - defer conn.Close() - - tests := []struct { - name string - method string - path string - expectedStatus int - }{ - {"GET not allowed", "GET", "/test.Method", http.StatusMethodNotAllowed}, - {"PUT not allowed", "PUT", "/test.Method", http.StatusMethodNotAllowed}, - {"DELETE not allowed", "DELETE", "/test.Method", http.StatusMethodNotAllowed}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest(tt.method, tt.path, bytes.NewReader([]byte("data"))) - w := httptest.NewRecorder() - - handleHTTPRequest(w, req, conn) - - if w.Code != tt.expectedStatus { - t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code) - } - }) - } -} - -// TestHandleStreamErrorWithVariousErrors tests error handling with different error types -func TestHandleStreamErrorWithVariousErrors(t *testing.T) { - tests := []struct { - name string - message string - err error - }{ - {"Internal error", "Internal error", status.Error(codes.Internal, "server error")}, - {"Unavailable", "Service unavailable", status.Error(codes.Unavailable, "server down")}, - {"Deadline exceeded", "Timeout", status.Error(codes.DeadlineExceeded, "too slow")}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := httptest.NewRecorder() - handleStreamError(w, tt.message, tt.err) - - if w.Code != http.StatusInternalServerError { - t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, w.Code) - } - if !strings.Contains(w.Body.String(), tt.message) { - t.Errorf("Expected '%s' in error message, got: %s", tt.message, w.Body.String()) - } - }) - } -} - -// TestLargeDataHandling tests handling of large data payloads -func TestLargeDataHandling(t *testing.T) { - sizes := []int{ - 1024, // 1KB - 64 * 1024, // 64KB (buffer size) - 256 * 1024, // 256KB - 1024 * 1024, // 1MB - } - - for _, size := range sizes { - t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) { - largeData := make([]byte, size) - for i := range largeData { - largeData[i] = byte(i % 256) - } - - req := httptest.NewRequest("POST", "/test", bytes.NewReader(largeData)) - w := httptest.NewRecorder() - - result, ok := readRequestBody(w, req) - if !ok { - t.Errorf("Failed to read body of size %d", size) - return - } - if len(result) != size { - t.Errorf("Expected body size %d, got %d", size, len(result)) - } - if !bytes.Equal(result, largeData) { - t.Errorf("Body content mismatch for size %d", size) - } - }) - } -} - -// TestBufferCapacityConstants validates buffer capacity configuration -func TestBufferCapacityConstants(t *testing.T) { - // Ensure buffer capacity is reasonable for streaming - if bufferCapacity < readChunkSize { - t.Errorf("Buffer capacity (%d) should be >= read chunk size (%d)", - bufferCapacity, readChunkSize) - } - - // Ensure timeouts are properly configured - if streamingTimeout <= 0 || unaryTimeout <= 0 { - t.Error("Timeout constants should be positive") - } - - // Streaming should generally allow more time than unary operations - if streamingTimeout <= unaryTimeout { - t.Logf("Warning: streamingTimeout (%v) should typically be > unaryTimeout (%v)", - streamingTimeout, unaryTimeout) - } -} - -// TestRawCodecBinaryData tests RawCodec with various binary patterns -func TestRawCodecBinaryData(t *testing.T) { - codec := RawCodec{} - binaryPatterns := [][]byte{ - {0x00, 0xFF, 0xAA, 0x55}, // Alternating patterns - {0xFF, 0xFF, 0xFF, 0xFF}, // All ones - {0x00, 0x00, 0x00, 0x00}, // All zeros - make([]byte, 256), // All zeros (larger) - } - - for _, pattern := range binaryPatterns { - marshaled, err := codec.Marshal(pattern) - if err != nil { - t.Fatalf("Marshal failed for pattern %v: %v", pattern, err) - } - - var unmarshaled []byte - err = codec.Unmarshal(marshaled, &unmarshaled) - if err != nil { - t.Fatalf("Unmarshal failed: %v", err) - } - - if !bytes.Equal(unmarshaled, pattern) { - t.Errorf("Pattern mismatch: expected %v, got %v", pattern, unmarshaled) - } - } -} - -// TestReadRequestBodyMultipleCalls tests that body can only be read once -func TestReadRequestBodyMultipleCalls(t *testing.T) { - testData := []byte("test data") - req := httptest.NewRequest("POST", "/test", bytes.NewReader(testData)) - w := httptest.NewRecorder() - - // First read should succeed - result1, ok1 := readRequestBody(w, req) - if !ok1 { - t.Fatal("First read failed") - } - if !bytes.Equal(result1, testData) { - t.Errorf("First read got unexpected data") - } - - // Second read should fail (body is closed) - req2 := httptest.NewRequest("POST", "/test", bytes.NewReader(testData)) - w2 := httptest.NewRecorder() - result2, ok2 := readRequestBody(w2, req2) - if !ok2 { - t.Error("Second read should still work with fresh request") - } - if !bytes.Equal(result2, testData) { - t.Errorf("Second read got unexpected data") - } -} - -// TestSetGRPCResponseHeadersIdempotent tests that setting headers multiple times is safe -func TestSetGRPCResponseHeadersIdempotent(t *testing.T) { - w := httptest.NewRecorder() - - // Set headers twice - setGRPCResponseHeaders(w) - setGRPCResponseHeaders(w) - - // Should still have correct headers - if w.Header().Get("Content-Type") != "application/grpc" { - t.Errorf("Content-Type should be 'application/grpc'") - } - if w.Code != http.StatusOK { - t.Errorf("Status code should be %d", http.StatusOK) - } -} - -// TestContextDeadlineAccuracy tests context deadline accuracy within tolerance -func TestContextDeadlineAccuracy(t *testing.T) { - tests := []time.Duration{ - 1 * time.Second, - 5 * time.Second, - 30 * time.Second, - 60 * time.Second, - } - - for _, timeout := range tests { - t.Run(fmt.Sprintf("timeout_%v", timeout), func(t *testing.T) { - before := time.Now() - ctx, cancel := createRequestContext(timeout) - defer cancel() - - deadline, ok := ctx.Deadline() - if !ok { - t.Fatal("Expected context to have deadline") - } - - actual := deadline.Sub(before) - // Allow 1 second tolerance for execution time - if actual < timeout-1*time.Second || actual > timeout+1*time.Second { - t.Errorf("Deadline inaccurate: expected ~%v, got %v", timeout, actual) - } - }) - } -} - -// TestReadRequestBodyMeasureSize tests body size measurement across various sizes -func TestReadRequestBodyMeasureSize(t *testing.T) { - sizes := []int{0, 1, 10, 100, 1024, 10*1024, 100*1024} - - for _, size := range sizes { - t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) { - testData := make([]byte, size) - for i := range testData { - testData[i] = byte(i % 256) - } - - req := httptest.NewRequest("POST", "/test", bytes.NewReader(testData)) - w := httptest.NewRecorder() - - result, ok := readRequestBody(w, req) - if !ok { - t.Errorf("Failed to read body of size %d", size) - return - } - - if len(result) != size { - t.Errorf("Expected size %d, got %d", size, len(result)) - } - }) - } -} - -// TestHTTPResponseWriterBehavior tests ResponseWriter behavior in different scenarios -func TestHTTPResponseWriterBehavior(t *testing.T) { - w := httptest.NewRecorder() - - // Test that WriteHeader is idempotent - w.WriteHeader(http.StatusOK) - w.WriteHeader(http.StatusInternalServerError) // Should be ignored - if w.Code != http.StatusOK { - t.Errorf("Expected first WriteHeader to take effect") - } -} - -// TestFlusherInterface tests that the Flusher interface is correctly used -func TestFlusherInterface(t *testing.T) { - w := httptest.NewRecorder() - flusher, ok := getFlusher(w) - if !ok { - t.Fatal("getFlusher should succeed with httptest.ResponseRecorder") - } - - // Verify Flusher has Flush method (doesn't panic) - flusher.Flush() - - // Write data and flush - w.WriteHeader(http.StatusOK) - w.Write([]byte("test")) - flusher.Flush() - - if w.Body.String() != "test" { - t.Errorf("Expected body 'test', got '%s'", w.Body.String()) - } -} - -// TestRequestBodyWithSpecialChars tests body reading with special characters -func TestRequestBodyWithSpecialChars(t *testing.T) { - specialData := []byte{ - 0x00, 0x01, 0x02, 0x03, // Null and control chars - 'h', 'e', 'l', 'l', 'o', // ASCII - 0xC3, 0xA9, // UTF-8 (é) - 0xFF, 0xFE, // High bytes - } - - req := httptest.NewRequest("POST", "/test", bytes.NewReader(specialData)) - w := httptest.NewRecorder() - - result, ok := readRequestBody(w, req) - if !ok { - t.Fatal("Failed to read body with special characters") - } - - if !bytes.Equal(result, specialData) { - t.Errorf("Body content mismatch: expected %v, got %v", specialData, result) - } -} - -// TestConfigurationValidation tests config field validation -func TestConfigurationValidation(t *testing.T) { - config := &Config{} - - // Initially nil - if config.srv != nil { - t.Error("Config.srv should be nil initially") - } - - // Set and verify - server := &http.Server{Addr: ":8080"} - config.SetServer(server) - - if config.srv == nil { - t.Error("Config.srv should not be nil after SetServer") - } - if config.srv != server { - t.Error("Config.srv should be the exact server object passed") - } -} - -// TestErrorMessageFormatting tests that error messages are properly formatted -func TestErrorMessageFormatting(t *testing.T) { - tests := []struct { - name string - message string - err error - }{ - {"empty message", "", fmt.Errorf("test error")}, - {"long message", strings.Repeat("a", 100), fmt.Errorf("test error")}, - {"special chars", "Test: \n\r\t", fmt.Errorf("test error")}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := httptest.NewRecorder() - handleStreamError(w, tt.message, tt.err) - - if w.Code != http.StatusInternalServerError { - t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, w.Code) - } - // Error message should contain the message - if len(tt.message) > 0 && !strings.Contains(w.Body.String(), tt.message) { - t.Errorf("Expected '%s' in error response", tt.message) - } - }) - } -} diff --git a/tavern/internal/redirectors/grpc_codec.go b/tavern/internal/redirectors/grpc_codec.go new file mode 100644 index 000000000..b0ac15ab7 --- /dev/null +++ b/tavern/internal/redirectors/grpc_codec.go @@ -0,0 +1,33 @@ +package redirectors + +import ( + "fmt" + + "google.golang.org/grpc/encoding" +) + +func init() { + encoding.RegisterCodec(gRPCRawCodec{}) +} + +// gRPCRawCodec passes through raw bytes without marshaling/unmarshaling +type gRPCRawCodec struct{} + +func (gRPCRawCodec) Marshal(v any) ([]byte, error) { + if b, ok := v.([]byte); ok { + return b, nil + } + return nil, fmt.Errorf("failed to marshal, message is %T", v) +} + +func (gRPCRawCodec) Unmarshal(data []byte, v any) error { + if b, ok := v.(*[]byte); ok { + *b = data + return nil + } + return fmt.Errorf("failed to unmarshal, message is %T", v) +} + +func (gRPCRawCodec) Name() string { + return "raw" +} diff --git a/tavern/internal/redirectors/grpc_upstream.go b/tavern/internal/redirectors/grpc_upstream.go new file mode 100644 index 000000000..98d4bd531 --- /dev/null +++ b/tavern/internal/redirectors/grpc_upstream.go @@ -0,0 +1,63 @@ +package redirectors + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/url" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" +) + +// ConnectToUpstream creates a gRPC client connection to the upstream address +func ConnectToUpstream(upstream string) (*grpc.ClientConn, error) { + // Parse host:port to determine if TLS should be used + url, err := url.Parse(upstream) + if err != nil { + return nil, fmt.Errorf("failed to parse upstream address: %v", err) + } + + // Default to TLS on 443 + var ( + tc = credentials.NewTLS(&tls.Config{}) + port = "443" + ) + + // If scheme is http, use insecure credentials and default to port 80 + if url.Scheme == "http" { + port = "80" + tc = insecure.NewCredentials() + } + + // If port is specified, use it + if url.Port() != "" { + port = url.Port() + } + + return grpc.NewClient( + url.Host, + grpc.WithTransportCredentials(tc), + grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { + // Resolve using IPv4 only (A records, not AAAA records) + ips, err := net.DefaultResolver.LookupIP(ctx, "ip4", url.Hostname()) + if err != nil { + return nil, err + } + if len(ips) == 0 { + return nil, fmt.Errorf("no IPv4 addresses found for %s", url.Hostname()) + } + + // Force IPv4 by using "tcp4" instead of "tcp" + dialer := &net.Dialer{} + tcpConn, err := dialer.DialContext(ctx, "tcp4", net.JoinHostPort(ips[0].String(), port)) + if err != nil { + return nil, err + } + + return tcpConn, nil + }), + ) +} diff --git a/tavern/internal/redirectors/http1/grpc_frame.go b/tavern/internal/redirectors/http1/grpc_frame.go new file mode 100644 index 000000000..fe4fc0542 --- /dev/null +++ b/tavern/internal/redirectors/http1/grpc_frame.go @@ -0,0 +1,60 @@ +package http1 + +import "encoding/binary" + +// frameHeaderSize is the size of gRPC frame header: [compression_flag(1)][length(4)] +const frameHeaderSize = 5 + +// frameHeader represents a gRPC wire protocol frame header +type frameHeader struct { + CompressionFlag uint8 + MessageLength uint32 +} + +// NewFrameHeader creates a new frame header with no compression +func newFrameHeader(messageLength uint32) frameHeader { + return frameHeader{ + CompressionFlag: 0x00, + MessageLength: messageLength, + } +} + +// Encode encodes the frame header to a 5-byte array +func (h frameHeader) Encode() [5]byte { + var header [5]byte + header[0] = h.CompressionFlag + binary.BigEndian.PutUint32(header[1:], h.MessageLength) + return header +} + +// tryDecode attempts to decode a frame header from the buffer +// Returns (header, ok) where ok indicates if enough data was available +func tryDecode(buffer []byte) (frameHeader, bool) { + if len(buffer) < frameHeaderSize { + return frameHeader{}, false + } + + return frameHeader{ + CompressionFlag: buffer[0], + MessageLength: binary.BigEndian.Uint32(buffer[1:5]), + }, true +} + +// extractFrame extracts a complete gRPC frame from the buffer +// Returns (header, message, remainingBuffer, ok) +func extractFrame(buffer []byte) (frameHeader, []byte, []byte, bool) { + header, ok := tryDecode(buffer) + if !ok { + return frameHeader{}, nil, buffer, false + } + + totalSize := frameHeaderSize + int(header.MessageLength) + if len(buffer) < totalSize { + return frameHeader{}, nil, buffer, false + } + + message := buffer[frameHeaderSize:totalSize] + remaining := buffer[totalSize:] + + return header, message, remaining, true +} diff --git a/tavern/internal/redirector/grpc_frame_test.go b/tavern/internal/redirectors/http1/grpc_frame_test.go similarity index 88% rename from tavern/internal/redirector/grpc_frame_test.go rename to tavern/internal/redirectors/http1/grpc_frame_test.go index f7c89ff05..7f6cc826d 100644 --- a/tavern/internal/redirector/grpc_frame_test.go +++ b/tavern/internal/redirectors/http1/grpc_frame_test.go @@ -1,11 +1,11 @@ -package redirector +package http1 import ( "testing" ) func TestFrameHeaderNew(t *testing.T) { - header := NewFrameHeader(1234) + header := newFrameHeader(1234) if header.CompressionFlag != 0x00 { t.Errorf("Expected compression flag 0x00, got 0x%02x", header.CompressionFlag) } @@ -15,7 +15,7 @@ func TestFrameHeaderNew(t *testing.T) { } func TestFrameHeaderEncode(t *testing.T) { - header := NewFrameHeader(0x12345678) + header := newFrameHeader(0x12345678) encoded := header.Encode() if len(encoded) != 5 { @@ -33,7 +33,7 @@ func TestFrameHeaderEncode(t *testing.T) { func TestTryDecodeSuccess(t *testing.T) { buffer := []byte{0x00, 0x00, 0x00, 0x01, 0x00} - header, ok := TryDecode(buffer) + header, ok := tryDecode(buffer) if !ok { t.Fatal("Expected successful decode") } @@ -48,7 +48,7 @@ func TestTryDecodeSuccess(t *testing.T) { func TestTryDecodeInsufficientData(t *testing.T) { buffer := []byte{0x00, 0x01, 0x02} // Only 3 bytes - _, ok := TryDecode(buffer) + _, ok := tryDecode(buffer) if ok { t.Error("Expected decode to fail with insufficient data") } @@ -60,7 +60,7 @@ func TestExtractFrameSuccess(t *testing.T) { buffer := []byte{0x00, 0x00, 0x00, 0x00, 0x0A} buffer = append(buffer, []byte("0123456789")...) - header, message, remaining, ok := ExtractFrame(buffer) + header, message, remaining, ok := extractFrame(buffer) if !ok { t.Fatal("Expected successful frame extraction") } @@ -84,7 +84,7 @@ func TestExtractFrameIncomplete(t *testing.T) { buffer := []byte{0x00, 0x00, 0x00, 0x00, 0x0A} buffer = append(buffer, []byte("01234")...) - _, _, remaining, ok := ExtractFrame(buffer) + _, _, remaining, ok := extractFrame(buffer) if ok { t.Error("Expected frame extraction to fail with incomplete data") } @@ -103,7 +103,7 @@ func TestExtractFrameMultiple(t *testing.T) { buffer = append(buffer, []byte("BBB")...) // Extract first frame - header1, msg1, remaining1, ok1 := ExtractFrame(buffer) + header1, msg1, remaining1, ok1 := extractFrame(buffer) if !ok1 { t.Fatal("Expected first frame extraction to succeed") } @@ -115,7 +115,7 @@ func TestExtractFrameMultiple(t *testing.T) { } // Extract second frame - header2, msg2, remaining2, ok2 := ExtractFrame(remaining1) + header2, msg2, remaining2, ok2 := extractFrame(remaining1) if !ok2 { t.Fatal("Expected second frame extraction to succeed") } @@ -127,7 +127,7 @@ func TestExtractFrameMultiple(t *testing.T) { } // No more frames - _, _, _, ok3 := ExtractFrame(remaining2) + _, _, _, ok3 := extractFrame(remaining2) if ok3 { t.Error("Expected no more frames to extract") } @@ -139,7 +139,7 @@ func TestExtractFrameMultiple(t *testing.T) { func TestExtractFrameZeroLength(t *testing.T) { buffer := []byte{0x00, 0x00, 0x00, 0x00, 0x00} - header, message, _, ok := ExtractFrame(buffer) + header, message, _, ok := extractFrame(buffer) if !ok { t.Fatal("Expected successful extraction of zero-length frame") } @@ -152,10 +152,10 @@ func TestExtractFrameZeroLength(t *testing.T) { } func TestFrameHeaderMaxLength(t *testing.T) { - header := NewFrameHeader(0xFFFFFFFF) // uint32 max + header := newFrameHeader(0xFFFFFFFF) // uint32 max encoded := header.Encode() - decoded, ok := TryDecode(encoded[:]) + decoded, ok := tryDecode(encoded[:]) if !ok { t.Fatal("Expected successful decode of max length header") } @@ -167,7 +167,7 @@ func TestFrameHeaderMaxLength(t *testing.T) { func TestFrameHeaderCompressionFlag(t *testing.T) { buffer := []byte{0x01, 0x00, 0x00, 0x00, 0x00} // compression flag = 1 - header, ok := TryDecode(buffer) + header, ok := tryDecode(buffer) if !ok { t.Fatal("Expected successful decode") } @@ -179,7 +179,7 @@ func TestFrameHeaderCompressionFlag(t *testing.T) { func TestFrameHeaderPartialFrameAcrossReads(t *testing.T) { // Simulate first chunk: partial header buffer := []byte{0x00, 0x00} - _, _, _, ok1 := ExtractFrame(buffer) + _, _, _, ok1 := extractFrame(buffer) if ok1 { t.Error("Expected extraction to fail with partial header") } @@ -187,14 +187,14 @@ func TestFrameHeaderPartialFrameAcrossReads(t *testing.T) { // Simulate second chunk: rest of header + partial data buffer = append(buffer, []byte{0x00, 0x00, 0x05}...) // Complete header now buffer = append(buffer, []byte("AB")...) // Partial data - _, _, _, ok2 := ExtractFrame(buffer) + _, _, _, ok2 := extractFrame(buffer) if ok2 { t.Error("Expected extraction to fail with partial data") } // Simulate third chunk: rest of data buffer = append(buffer, []byte("CDE")...) - header, message, _, ok3 := ExtractFrame(buffer) + header, message, _, ok3 := extractFrame(buffer) if !ok3 { t.Fatal("Expected successful extraction after receiving complete data") } @@ -207,9 +207,9 @@ func TestFrameHeaderPartialFrameAcrossReads(t *testing.T) { } func TestFrameHeaderRoundtrip(t *testing.T) { - original := NewFrameHeader(42) + original := newFrameHeader(42) encoded := original.Encode() - decoded, ok := TryDecode(encoded[:]) + decoded, ok := tryDecode(encoded[:]) if !ok { t.Fatal("Expected successful decode") @@ -230,7 +230,7 @@ func TestExtractFrameWithTrailingData(t *testing.T) { buffer = append(buffer, []byte("AAAAA")...) buffer = append(buffer, []byte("EXTRA")...) // Trailing data - header, message, remaining, ok := ExtractFrame(buffer) + header, message, remaining, ok := extractFrame(buffer) if !ok { t.Fatal("Expected successful frame extraction") } diff --git a/tavern/internal/redirector/grpc_stream.go b/tavern/internal/redirectors/http1/grpc_stream.go similarity index 83% rename from tavern/internal/redirector/grpc_stream.go rename to tavern/internal/redirectors/http1/grpc_stream.go index ce8fe2d6d..0aaf2f8b4 100644 --- a/tavern/internal/redirector/grpc_stream.go +++ b/tavern/internal/redirectors/http1/grpc_stream.go @@ -1,4 +1,4 @@ -package redirector +package http1 import ( "context" @@ -8,7 +8,7 @@ import ( // streamConfig represents gRPC stream configuration type streamConfig struct { - Desc grpc.StreamDesc + Desc grpc.StreamDesc MethodPath string } @@ -16,11 +16,11 @@ type streamConfig struct { var ( fetchAssetStream = streamConfig{ Desc: grpc.StreamDesc{ - StreamName: "FetchAsset", + StreamName: "FetchAsset", ServerStreams: true, ClientStreams: false, }, - MethodPath: "/c2.C2/FetchAsset", + MethodPath: "/c2.C2/FetchAsset", } reportFileStream = streamConfig{ @@ -29,7 +29,7 @@ var ( ServerStreams: false, ClientStreams: true, }, - MethodPath: "/c2.C2/ReportFile", + MethodPath: "/c2.C2/ReportFile", } ) diff --git a/tavern/internal/redirector/grpc_stream_test.go b/tavern/internal/redirectors/http1/grpc_stream_test.go similarity index 95% rename from tavern/internal/redirector/grpc_stream_test.go rename to tavern/internal/redirectors/http1/grpc_stream_test.go index 2a90181a4..368ce97f1 100644 --- a/tavern/internal/redirector/grpc_stream_test.go +++ b/tavern/internal/redirectors/http1/grpc_stream_test.go @@ -1,4 +1,4 @@ -package redirector +package http1 import ( "testing" @@ -63,10 +63,10 @@ func TestCreateStreamWithContext(t *testing.T) { // TestStreamConfigProperties tests properties of stream configurations func TestStreamConfigProperties(t *testing.T) { configs := []struct { - name string - cfg streamConfig - expectServer bool - expectClient bool + name string + cfg streamConfig + expectServer bool + expectClient bool }{ {"FetchAsset", fetchAssetStream, true, false}, {"ReportFile", reportFileStream, false, true}, diff --git a/tavern/internal/redirector/redirector.go b/tavern/internal/redirectors/http1/handlers.go similarity index 51% rename from tavern/internal/redirector/redirector.go rename to tavern/internal/redirectors/http1/handlers.go index 5c3de4f58..a6c54ce26 100644 --- a/tavern/internal/redirector/redirector.go +++ b/tavern/internal/redirectors/http1/handlers.go @@ -1,137 +1,14 @@ -package redirector +package http1 import ( - "context" - "crypto/tls" "fmt" "io" - "log" "log/slog" - "net" "net/http" - "net/url" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/encoding" ) -// Config holds configuration for the HTTP redirector -type Config struct { - srv *http.Server -} - -// SetServer sets the HTTP server configuration -func (c *Config) SetServer(srv *http.Server) { - c.srv = srv -} - -// RawCodec passes through raw bytes without marshaling/unmarshaling -type RawCodec struct{} - -func (RawCodec) Marshal(v any) ([]byte, error) { - if b, ok := v.([]byte); ok { - return b, nil - } - return nil, fmt.Errorf("failed to marshal, message is %T", v) -} - -func (RawCodec) Unmarshal(data []byte, v any) error { - if b, ok := v.(*[]byte); ok { - *b = data - return nil - } - return fmt.Errorf("failed to unmarshal, message is %T", v) -} - -func (RawCodec) Name() string { - return "raw" -} - -func init() { - encoding.RegisterCodec(RawCodec{}) -} - -// HTTPRedirectorRun starts an HTTP/1.1 to gRPC proxy/redirector -func HTTPRedirectorRun(ctx context.Context, upstream string, options ...func(*Config)) error { - // Initialize Config - cfg := &Config{} - for _, opt := range options { - opt(cfg) - } - - // Parse host:port to determine if TLS should be used - url, err := url.Parse(upstream) - if err != nil { - return fmt.Errorf("failed to parse upstream address: %v", err) - } - - tc := credentials.NewTLS(&tls.Config{}) - port := url.Port() - if port == "" { - port = "443" - if(url.Scheme == "http") { - port = "80" - tc = insecure.NewCredentials() - } - } - - - conn, err := grpc.NewClient( - url.Host, - grpc.WithTransportCredentials(tc), - grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { - // Resolve using IPv4 only (A records, not AAAA records) - ips, err := net.DefaultResolver.LookupIP(ctx, "ip4", url.Hostname()) - if err != nil { - return nil, err - } - if len(ips) == 0 { - return nil, fmt.Errorf("no IPv4 addresses found for %s", url.Hostname()) - } - - // Force IPv4 by using "tcp4" instead of "tcp" - dialer := &net.Dialer{} - tcpConn, err := dialer.DialContext(ctx, "tcp4", net.JoinHostPort(ips[0].String(), port)) - if err != nil { - return nil, err - } - - return tcpConn, nil - }), - - ) - if err != nil { - log.Fatalf("Failed to connect to gRPC server: %v", err) - } - defer conn.Close() - - mux := http.NewServeMux() - mux.HandleFunc("/c2.C2/FetchAsset", func(w http.ResponseWriter, r *http.Request) { - handleFetchAssetStreaming(w, r, conn) - }) - mux.HandleFunc("/c2.C2/ReportFile", func(w http.ResponseWriter, r *http.Request) { - handleReportFileStreaming(w, r, conn) - }) - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - handleHTTPRequest(w, r, conn) - }) - - server := &http.Server{ - Addr: cfg.srv.Addr, - Handler: mux, - } - - - slog.Info(fmt.Sprintf("HTTP/1.1 proxy listening on %s, forwarding to gRPC server at %s\n", server.Addr, upstream)) - if err := server.ListenAndServe(); err != nil { - log.Fatalf("Failed to start HTTP server: %v", err) - } - - return nil -} - func handleFetchAssetStreaming(w http.ResponseWriter, r *http.Request, conn *grpc.ClientConn) { if !requirePOST(w, r) { return @@ -142,7 +19,7 @@ func handleFetchAssetStreaming(w http.ResponseWriter, r *http.Request, conn *grp return } - slog.Info(fmt.Sprintf("[HTTP -> gRPC Streaming] Method: /c2.C2/FetchAsset, Body size: %d bytes\n", len(requestBody))) + slog.Debug(fmt.Sprintf("[HTTP1 -> gRPC Streaming] Method: /c2.C2/FetchAsset, Body size: %d bytes\n", len(requestBody))) ctx, cancel := createRequestContext(streamingTimeout) defer cancel() @@ -180,7 +57,7 @@ func handleFetchAssetStreaming(w http.ResponseWriter, r *http.Request, conn *grp break } if err != nil { - slog.Debug(fmt.Sprintf("[gRPC Stream Error] Failed to receive message: %v\n", err)) + slog.Error(fmt.Sprintf("[gRPC Stream Error] Failed to receive message: %v\n", err)) return } @@ -189,22 +66,22 @@ func handleFetchAssetStreaming(w http.ResponseWriter, r *http.Request, conn *grp slog.Debug(fmt.Sprintf("[gRPC Stream] Received chunk %d: %d bytes\n", chunkCount, len(responseChunk))) // Write gRPC frame header - frameHeader := NewFrameHeader(uint32(len(responseChunk))) + frameHeader := newFrameHeader(uint32(len(responseChunk))) encodedHeader := frameHeader.Encode() if _, err := w.Write(encodedHeader[:]); err != nil { - slog.Debug(fmt.Sprintf("[HTTP Write Error] Failed to write frame header: %v\n", err)) + slog.Error(fmt.Sprintf("[HTTP Write Error] Failed to write frame header: %v\n", err)) return } if _, err := w.Write(responseChunk); err != nil { - slog.Debug(fmt.Sprintf("[HTTP Write Error] Failed to write chunk: %v\n", err)) + slog.Error(fmt.Sprintf("[HTTP Write Error] Failed to write chunk: %v\n", err)) return } flusher.Flush() } - slog.Debug(fmt.Sprintf("[gRPC -> HTTP] Streamed %d chunks, total %d bytes\n", chunkCount, totalBytes)) + slog.Debug(fmt.Sprintf("[gRPC -> HTTP1] Streamed %d chunks, total %d bytes\n", chunkCount, totalBytes)) } func handleReportFileStreaming(w http.ResponseWriter, r *http.Request, conn *grpc.ClientConn) { @@ -212,7 +89,7 @@ func handleReportFileStreaming(w http.ResponseWriter, r *http.Request, conn *grp return } - slog.Info(("[HTTP -> gRPC Client Streaming] Method: /c2.C2/ReportFile\n")) + slog.Debug(("[HTTP1 -> gRPC Client Streaming] Method: /c2.C2/ReportFile\n")) ctx, cancel := createRequestContext(streamingTimeout) defer cancel() @@ -235,7 +112,7 @@ func handleReportFileStreaming(w http.ResponseWriter, r *http.Request, conn *grp // Process complete gRPC frames from buffer for { - header, message, remaining, ok := ExtractFrame(buffer) + header, message, remaining, ok := extractFrame(buffer) if !ok { break } @@ -274,7 +151,7 @@ func handleReportFileStreaming(w http.ResponseWriter, r *http.Request, conn *grp return } - slog.Debug(fmt.Sprintf("[gRPC -> HTTP] Response size: %d bytes\n", len(responseBody))) + slog.Debug(fmt.Sprintf("[gRPC -> HTTP1] Response size: %d bytes\n", len(responseBody))) setGRPCResponseHeaders(w) if _, err := w.Write(responseBody); err != nil { @@ -298,7 +175,7 @@ func handleHTTPRequest(w http.ResponseWriter, r *http.Request, conn *grpc.Client return } - slog.Info(fmt.Sprintf("[HTTP -> gRPC] Method: %s, Body size: %d bytes\n", methodName, len(requestBody))) + slog.Debug(fmt.Sprintf("[HTTP1 -> gRPC] Method: %s, Body size: %d bytes\n", methodName, len(requestBody))) ctx, cancel := createRequestContext(unaryTimeout) defer cancel() @@ -319,7 +196,7 @@ func handleHTTPRequest(w http.ResponseWriter, r *http.Request, conn *grpc.Client return } - slog.Debug(fmt.Sprintf("[gRPC -> HTTP] Response size: %d bytes\n", len(responseBody))) + slog.Debug(fmt.Sprintf("[gRPC -> HTTP1] Response size: %d bytes\n", len(responseBody))) setGRPCResponseHeaders(w) if _, err := w.Write(responseBody); err != nil { diff --git a/tavern/internal/redirector/http_helpers.go b/tavern/internal/redirectors/http1/http.go similarity index 98% rename from tavern/internal/redirector/http_helpers.go rename to tavern/internal/redirectors/http1/http.go index 6b153ddd3..b7b16fc8a 100644 --- a/tavern/internal/redirector/http_helpers.go +++ b/tavern/internal/redirectors/http1/http.go @@ -1,4 +1,4 @@ -package redirector +package http1 import ( "context" diff --git a/tavern/internal/redirectors/http1/redirector.go b/tavern/internal/redirectors/http1/redirector.go new file mode 100644 index 000000000..a6b70b64f --- /dev/null +++ b/tavern/internal/redirectors/http1/redirector.go @@ -0,0 +1,40 @@ +package http1 + +import ( + "context" + "log/slog" + "net/http" + + "google.golang.org/grpc" + "realm.pub/tavern/internal/redirectors" +) + +func init() { + redirectors.Register("http1", &Redirector{}) +} + +// A Redirector implementation which receives HTTP/1.1 traffic locally and +// sends gRPC traffic to the upstream destination. +type Redirector struct{} + +// Redirect starts the redirector, listening for traffic locally and forwarding to the upstream +func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *grpc.ClientConn) error { + mux := http.NewServeMux() + mux.HandleFunc("/c2.C2/FetchAsset", func(w http.ResponseWriter, r *http.Request) { + handleFetchAssetStreaming(w, r, upstream) + }) + mux.HandleFunc("/c2.C2/ReportFile", func(w http.ResponseWriter, r *http.Request) { + handleReportFileStreaming(w, r, upstream) + }) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + handleHTTPRequest(w, r, upstream) + }) + + srv := &http.Server{ + Addr: listenOn, + Handler: mux, + } + + slog.Info("HTTP/1.1 redirector started", "listen_on", listenOn) + return srv.ListenAndServe() +} diff --git a/tavern/internal/redirectors/redirector.go b/tavern/internal/redirectors/redirector.go new file mode 100644 index 000000000..b49a5e31d --- /dev/null +++ b/tavern/internal/redirectors/redirector.go @@ -0,0 +1,68 @@ +package redirectors + +import ( + "context" + "fmt" + "log/slog" + "maps" + "slices" + "sync" + + "google.golang.org/grpc" +) + +var ( + mu sync.RWMutex + redirectors = make(map[string]Redirector) +) + +// A Redirector for traffic to an upstream gRPC server +type Redirector interface { + Redirect(ctx context.Context, listenOn string, upstream *grpc.ClientConn) error +} + +// Register makes a Redirector available by the provided name. +// If Register is called twice with the same name or if driver is nil,it panics. +func Register(name string, redirector Redirector) { + mu.Lock() + defer mu.Unlock() + if redirector == nil { + panic("redirectors: Register redirector is nil") + } + if _, dup := redirectors[name]; dup { + panic("redirectors: Register called twice for redirector " + name) + } + redirectors[name] = redirector +} + +// List returns a sorted list of the names of the registered redirectors. +func List() []string { + mu.RLock() + defer mu.RUnlock() + return slices.Sorted(maps.Keys(redirectors)) +} + +// Run starts the redirector with the given name, connecting to the specified upstream address +func Run(ctx context.Context, name string, listenOn string, upstreamAddr string) error { + // Get the Redirector + mu.RLock() + redirector, exists := redirectors[name] + mu.RUnlock() + if !exists || redirector == nil { + return fmt.Errorf("redirector %q not found", name) + } + + // Connect to the upstream gRPC server + upstream, err := ConnectToUpstream(upstreamAddr) + if err != nil { + return fmt.Errorf("failed to connect to upstream: %v", err) + } + defer func() { + slog.DebugContext(ctx, "redirectors: closing connection to upstream grpc", "redirector_name", name) + upstream.Close() + }() + slog.DebugContext(ctx, "redirectors: connected to upstream grpc", "redirector_name", name, "upstream_addr", upstreamAddr) + + // Start the redirector + return redirector.Redirect(ctx, listenOn, upstream) +} diff --git a/tavern/main.go b/tavern/main.go index a74e2dbed..16b0dc6a4 100644 --- a/tavern/main.go +++ b/tavern/main.go @@ -14,11 +14,7 @@ import ( func main() { ctx := context.Background() - app := newApp(ctx, - ConfigureHTTPServerFromEnv(), - ConfigureMySQLFromEnv(), - ConfigureOAuthFromEnv("/oauth/authorize"), - ) + app := newApp(ctx) if err := app.Run(os.Args); err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatalf("fatal error: %v", err) }