diff --git a/docs/_docs/admin-guide/tavern.md b/docs/_docs/admin-guide/tavern.md index 69bbff201..66f4d99c9 100644 --- a/docs/_docs/admin-guide/tavern.md +++ b/docs/_docs/admin-guide/tavern.md @@ -119,7 +119,7 @@ By default, Tavern does not export metrics. You may use the below environment co | Env Var | Description | Default | Required | | ------- | ----------- | ------- | -------- | | ENABLE_METRICS | Set to any value to enable the "/metrics" endpoint. | Disabled | No | -| HTTP_METRICS_LISTEN_ADDR | Listen address for the metrics HTTP server, it must be different than the value of `HTTP_LISTEN_ADDR`. | `127.0.0.1:8080` | No | +| HTTP_METRICS_LISTEN_ADDR | Listen address for the metrics HTTP server, it must be different than the value of `HTTP_LISTEN_ADDR`. | `127.0.0.1:8000` | No | ### Secrets @@ -293,38 +293,6 @@ DISABLE_DEFAULT_TOMES=1 go run ./tavern Running Tavern with the `ENABLE_PPROF` environment variable set will enable performance profiling information to be collected and accessible. This should never be set for a production deployment as it will be unauthenticated and may provide access to sensitive information, it is intended for development purposes only. Read more on how to use `pprof` with tavern in the [Developer Guide](/dev-guide/tavern#performance-profiling). -### Logging - -The following environment variables configure Tavern's logging behavior. - -| Env Var | Description | Default | Required | -| ------- | ----------- | ------- | -------- | -| ENABLE_DEBUG_LOGGING | Emit verbose debug logs to help troubleshoot issues. | Disabled | No | -| ENABLE_JSON_LOGGING | Emit logs in JSON format for easier parsing by log aggregators. | Disabled | No | -| ENABLE_INSTANCE_ID_LOGGING | Include the tavern instance id in log messages. | Disabled | No | -| ENABLE_GRAPHQL_RAW_QUERY_LOGGING | Include the raw GraphQL query in graphql log messages. | Disabled | No | - -### Google Cloud & PubSub - -The following environment variables are available for configuring Google Cloud Platform integration and PubSub messaging for shell I/O. - -| Env Var | Description | Default | Required | -| ------- | ----------- | ------- | -------- | -| GCP_PROJECT_ID | The project id tavern is deployed in for Google Cloud Platform deployments. | N/A | No | -| GCP_PUBSUB_KEEP_ALIVE_INTERVAL_MS | Interval to publish no-op pubsub messages to help avoid gcppubsub coldstart latency. | 1000 | No | -| PUBSUB_TOPIC_SHELL_INPUT | The topic to publish shell input to. | `mem://shell_input` | No | -| PUBSUB_SUBSCRIPTION_SHELL_INPUT | The subscription to receive shell input from. | `mem://shell_input` | No | -| PUBSUB_TOPIC_SHELL_OUTPUT | The topic to publish shell output to. | `mem://shell_output` | No | -| PUBSUB_SUBSCRIPTION_SHELL_OUTPUT | The subscription to receive shell output from. | `mem://shell_output` | No | - -### Testing - -The following environment variables are used for testing purposes. - -| Env Var | Description | Default | Required | -| ------- | ----------- | ------- | -------- | -| ENABLE_TEST_RUN_AND_EXIT | Start the application, but exit immediately after. Useful for testing startup. | Disabled | No | - ## Build and publish tavern container If you want to deploy tavern without using the published version you'll have to build and publish your own container. diff --git a/docs/_docs/user-guide/eldritch.md b/docs/_docs/user-guide/eldritch.md index d52fc7e29..1ddc71fcf 100644 --- a/docs/_docs/user-guide/eldritch.md +++ b/docs/_docs/user-guide/eldritch.md @@ -9,6 +9,7 @@ permalink: user-guide/eldritch 🚨 **DEPRECATION WARNING:** Eldritch v1 will soon be deprecated and replaced with v2 🚨 + Eldritch is a Pythonic red team Domain Specific Language (DSL) based on [starlark](https://github.com/facebookexperimental/starlark-rust). It uses and supports most python syntax and basic functionality such as list comprehension, string operations (`lower()`, `join()`, `replace()`, etc.), and built-in methods (`any()`, `dir()`, `sorted()`, etc.). For more details on the supported functionality not listed here, please consult the [Starlark Spec Reference](https://github.com/bazelbuild/starlark/blob/master/spec.md), but for the most part you can treat this like basic Python with extra red team functionality. Eldritch is a small interpreter that can be embedded into a c2 agent as it is with Golem and Imix. @@ -233,18 +234,6 @@ The assets.read method returns a UTF-8 string representation of the asset ## Crypto -### crypto.aes_decrypt (V2-Only) - -`crypto.aes_decrypt(key: Bytes, iv: Bytes, data: Bytes) -> Bytes` - -The crypto.aes_decrypt method decrypts the given data using AES (CBC mode). The key must be 16, 24, or 32 bytes, and the IV must be 16 bytes. - -### crypto.aes_encrypt (V2-Only) - -`crypto.aes_encrypt(key: Bytes, iv: Bytes, data: Bytes) -> Bytes` - -The crypto.aes_encrypt method encrypts the given data using AES (CBC mode). The key must be 16, 24, or 32 bytes, and the IV must be 16 bytes. - ### crypto.aes_decrypt_file `crypto.aes_decrypt_file(src: str, dst: str, key: str) -> None` @@ -454,11 +443,11 @@ Here is an example of the Dict layout: The file.mkdir method will make a new directory at `path`. If the parent directory does not exist or the directory cannot be created, it will error; unless the `parent` parameter is passed as `True`. -### file.move +### file.moveto -`file.move(src: str, dst: str) -> None` +`file.moveto(src: str, dst: str) -> None` -The file.move method moves or renames a file or directory from `src` to `dst`. If the `dst` directory or file exists it will be deleted before being replaced to ensure consistency across systems. +The file.moveto method moves a file or directory from `src` to `dst`. If the `dst` directory or file exists it will be deleted before being replaced to ensure consistency across systems. ### file.parent_dir @@ -525,11 +514,11 @@ If the destination file doesn't exist it will be created (if the parent director The `args` dictionary currently supports values of: `int`, `str`, and `List`. `autoescape` when `True` will perform HTML character escapes according to the [OWASP XSS guidelines](https://cheatsheetseries.owasp.org/cheatsheets/Cross_Site_Scripting_Prevention_Cheat_Sheet.html) -### file.timestomp (V2-Only) +### file.timestomp -`file.timestomp(path: str, mtime: Option, atime: Option, ctime: Option, ref_file: Option) -> None` +`file.timestomp(src: str, dst: str) -> None` -The file.timestomp method modifies the timestamps of a file. It can update the modification time (`mtime`), access time (`atime`), and creation time (`ctime`) to specific values (epoch integer or string). Alternatively, if `ref_file` is provided, the timestamps from that file will be copied to `path`. +Unimplemented. ### file.write @@ -565,15 +554,15 @@ The http.download method downloads a file at the URI specified in `uri` t ### http.get -`http.get(uri: str, query_params: Option>, headers: Option>, allow_insecure: Option) -> Dict` +`http.get(uri: str, query_params: Option>, headers: Option>, allow_insecure: Option) -> str` -The http.get method sends an HTTP GET request to the URI specified in `uri` with the optional query paramters specified in `query_params` and headers specified in `headers`. It returns a dictionary containing the `status_code` (int), `body` (Bytes), and `headers` (Dict). Note: in order to conform with HTTP2+ all header names are transmuted to lowercase. +The http.get method sends an HTTP GET request to the URI specified in `uri` with the optional query paramters specified in `query_params` and headers specified in `headers`, then return the response body as a string. Note: in order to conform with HTTP2+ all header names are transmuted to lowercase. ### http.post -`http.post(uri: str, body: Option, form: Option>, headers: Option>, allow_insecure: Option) -> Dict` +`http.post(uri: str, body: Option, form: Option>, headers: Option>, allow_insecure: Option) -> str` -The http.post method sends an HTTP POST request to the URI specified in `uri` with the optional request body specified by `body`, form paramters specified in `form`, and headers specified in `headers`. It returns a dictionary containing the `status_code` (int), `body` (Bytes), and `headers` (Dict). Note: in order to conform with HTTP2+ all header names are transmuted to lowercase. Other Note: if a `body` and a `form` are supplied the value of `body` will be used. +The http.post method sends an HTTP POST request to the URI specified in `uri` with the optional request body specified by `body`, form paramters specified in `form`, and headers specified in `headers`, then return the response body as a string. Note: in order to conform with HTTP2+ all header names are transmuted to lowercase. Other Note: if a `body` and a `form` are supplied the value of `body` will be used. --- @@ -609,6 +598,12 @@ $> pivot.arp_scan(["192.168.1.1/32"]) [] ``` +### pivot.bind_proxy + +`pivot.bind_proxy(listen_address: str, listen_port: int, username: str, password: str ) -> None` + +The pivot.bind_proxy method is being proposed to provide users another option when trying to connect and pivot within an environment. This function will start a SOCKS5 proxy on the specified port and interface, with the specified username and password (if provided). + ### pivot.ncat `pivot.ncat(address: str, port: int, data: str, protocol: str ) -> str` @@ -617,9 +612,15 @@ The pivot.ncat method allows a user to send arbitrary data over TCP/UDP t `protocol` must be `tcp`, or `udp` anything else will return an error `Protocol not supported please use: udp or tcp.`. +### pivot.port_forward + +`pivot.port_forward(listen_address: str, listen_port: int, forward_address: str, forward_port: int, str: protocol ) -> None` + +The pivot.port_forward method is being proposed to provide socat like functionality by forwarding traffic from a port on a local machine to a port on a different machine allowing traffic to be relayed. + ### pivot.port_scan -`pivot.port_scan(target_cidrs: List, ports: List, protocol: str, timeout: int, fd_limit: Option) -> List` +`pivot.port_scan(target_cidrs: List, ports: List, protocol: str, timeout: int) -> List` The pivot.port_scan method allows users to scan TCP/UDP ports within the eldritch language. Inputs: @@ -628,7 +629,6 @@ Inputs: - `ports` can be a list of any number of integers between 1 and 65535. - `protocol` must be: `tcp` or `udp`. These are the only supported options. - `timeout` is the number of seconds a scan will wait without a response before it's marked as `timeout` -- `fd_limit` is the optional limit on concurrent file descriptors (defaults to 64). Results will be in the format: @@ -661,11 +661,11 @@ NOTE: Windows scans against `localhost`/`127.0.0.1` can behave unexpectedly or e The **pivot.reverse_shell_pty** method spawns the provided command in a cross-platform PTY and opens a reverse shell over the agent's current transport (e.g. gRPC). If no command is provided, Windows will use `cmd.exe`. On other platforms, `/bin/bash` is used as a default, but if it does not exist then `/bin/sh` is used. -### pivot.reverse_shell_repl (V2-Only) +### pivot.smb_exec -`pivot.reverse_shell_repl() -> None` +`pivot.smb_exec(target: str, port: int, username: str, password: str, hash: str, command: str) -> str` -The **pivot.reverse_shell_repl** method spawns a generic Eldritch REPL reverse shell over the agent's current transport. This is useful when a PTY cannot be spawned. +The pivot.smb_exec method is being proposed to allow users a way to move between hosts running smb. ### pivot.ssh_copy @@ -808,12 +808,6 @@ The random library is designed to enable generation of cryptogrphically secure r The random.bool method returns a randomly sourced boolean value. -### random.bytes (V2-Only) - -`random.bytes(len: int) -> List` - -The random.bytes method returns a list of random bytes of the specified length. - ### random.int `random.int(min: i32, max: i32) -> i32` @@ -825,12 +819,6 @@ The random.int method returns randomly generated integer value between a `random.string(length: uint, charset: Optional) -> str` The random.string method returns a randomly generated string of the specified length. If `charset` is not provided defaults to [Alphanumeric](https://docs.rs/rand_distr/latest/rand_distr/struct.Alphanumeric.html). Warning, the string is stored entirely in memory so exceptionally large files (multiple megabytes) can lead to performance issues. -### random.uuid (V2-Only) - -`random.uuid() -> str` - -The random.uuid method returns a randomly generated UUID (v4). - --- ## Regex @@ -914,13 +902,12 @@ If your dll_bytes array contains a value greater than u8::MAX it will cause the ### sys.exec -`sys.exec(path: str, args: List, disown: Optional, env_vars: Option>, input: Option) -> Dict` +`sys.exec(path: str, args: List, disown: Optional, env_vars: Option>) -> Dict` The sys.exec method executes a program specified with `path` and passes the `args` list. On *nix systems disown will run the process in the background disowned from the agent. This is done through double forking. On Windows systems disown will run the process with detached stdin and stdout such that it won't block the tomes execution. The `env_vars` will be a map of environment variables to be added to the process of the execution. -The `input` parameter (V2-Only) allows you to pass a string to the process's stdin. ```python sys.exec("/bin/bash",["-c", "whoami"]) @@ -1246,6 +1233,6 @@ The time.now method returns the time since UNIX EPOCH (Jan 01 1970). This ### time.sleep -`time.sleep(secs: int)` +`time.sleep(secs: float)` The time.sleep method sleeps the task for the given number of seconds. diff --git a/implants/lib/eldritchv2/eldritch-core/src/interpreter/eval/access.rs b/implants/lib/eldritchv2/eldritch-core/src/interpreter/eval/access.rs index bc8347401..84afa9202 100644 --- a/implants/lib/eldritchv2/eldritch-core/src/interpreter/eval/access.rs +++ b/implants/lib/eldritchv2/eldritch-core/src/interpreter/eval/access.rs @@ -107,28 +107,6 @@ pub(crate) fn evaluate_index( } Ok(Value::String(chars[true_idx as usize].to_string())) } - Value::Bytes(b) => { - let idx_int = match idx_val { - Value::Int(i) => i, - _ => { - return interp.error( - EldritchErrorKind::TypeError, - "bytes indices must be integers", - index.span, - ); - } - }; - let len = b.len() as i64; - let true_idx = if idx_int < 0 { len + idx_int } else { idx_int }; - if true_idx < 0 || true_idx as usize >= b.len() { - return interp.error( - EldritchErrorKind::IndexError, - "Bytes index out of range", - span, - ); - } - Ok(Value::Int(b[true_idx as usize] as i64)) - } _ => interp.error( EldritchErrorKind::TypeError, &format!("'{}' object is not subscriptable", get_type_name(&obj_val)), @@ -271,28 +249,6 @@ pub(crate) fn evaluate_slice( } Ok(Value::String(result_chars.into_iter().collect())) } - Value::Bytes(b) => { - let len = b.len() as i64; - let (i, j) = adjust_slice_indices(len, &start_val_opt, &stop_val_opt, step_val); - let mut result_bytes = Vec::new(); - let mut curr = i; - if step_val > 0 { - while curr < j { - if curr >= 0 && curr < len { - result_bytes.push(b[curr as usize]); - } - curr += step_val; - } - } else { - while curr > j { - if curr >= 0 && curr < len { - result_bytes.push(b[curr as usize]); - } - curr += step_val; - } - } - Ok(Value::Bytes(result_bytes)) - } _ => interp.error( EldritchErrorKind::TypeError, &format!("'{}' object is not subscriptable", get_type_name(&obj_val)), diff --git a/implants/lib/eldritchv2/eldritch-core/tests/bytes_subscript.rs b/implants/lib/eldritchv2/eldritch-core/tests/bytes_subscript.rs deleted file mode 100644 index 7770b85e2..000000000 --- a/implants/lib/eldritchv2/eldritch-core/tests/bytes_subscript.rs +++ /dev/null @@ -1,36 +0,0 @@ -use eldritch_core::{Interpreter, Value}; - -#[test] -fn test_bytes_subscript() { - let mut interp = Interpreter::new(); - let code = r#" -b = b"hello world" -a = b[0] -b_slice = b[0:5] -b_slice_step = b[::2] -"#; - - interp.interpret(code).unwrap(); - - // Check results by interpreting expressions that return the values - let a = interp.interpret("a").unwrap(); - if let Value::Int(i) = a { - assert_eq!(i, 104); // 'h' - } else { - panic!("b[0] should be Int, got {:?}", a); - } - - let b_slice = interp.interpret("b_slice").unwrap(); - if let Value::Bytes(b) = b_slice { - assert_eq!(b, b"hello".to_vec()); - } else { - panic!("b[0:5] should be Bytes, got {:?}", b_slice); - } - - let b_slice_step = interp.interpret("b_slice_step").unwrap(); - if let Value::Bytes(b) = b_slice_step { - assert_eq!(b, b"hlowrd".to_vec()); - } else { - panic!("b[::2] should be Bytes, got {:?}", b_slice_step); - } -} diff --git a/implants/lib/eldritchv2/eldritch-core/tests/slicing_coverage.rs b/implants/lib/eldritchv2/eldritch-core/tests/slicing_coverage.rs new file mode 100644 index 000000000..4ec650332 --- /dev/null +++ b/implants/lib/eldritchv2/eldritch-core/tests/slicing_coverage.rs @@ -0,0 +1,134 @@ +mod assert; + +#[test] +fn test_list_slicing_basic() { + assert::pass( + r#" + l = [0, 1, 2, 3, 4, 5] + assert_eq(l[0:6], l) + assert_eq(l[:], l) + assert_eq(l[0:3], [0, 1, 2]) + assert_eq(l[3:], [3, 4, 5]) + assert_eq(l[:3], [0, 1, 2]) + assert_eq(l[3:6], [3, 4, 5]) + "#, + ); +} + +#[test] +fn test_list_slicing_steps() { + assert::pass( + r#" + l = [0, 1, 2, 3, 4, 5] + assert_eq(l[::2], [0, 2, 4]) + assert_eq(l[1::2], [1, 3, 5]) + assert_eq(l[::3], [0, 3]) + assert_eq(l[::100], [0]) + "#, + ); +} + +#[test] +fn test_list_slicing_negative_indices() { + assert::pass( + r#" + l = [0, 1, 2, 3, 4, 5] + assert_eq(l[-1], 5) + assert_eq(l[-2], 4) + assert_eq(l[:-1], [0, 1, 2, 3, 4]) + assert_eq(l[-3:], [3, 4, 5]) + assert_eq(l[-3:-1], [3, 4]) + "#, + ); +} + +#[test] +fn test_list_slicing_negative_steps() { + assert::pass( + r#" + l = [0, 1, 2, 3, 4, 5] + assert_eq(l[::-1], [5, 4, 3, 2, 1, 0]) + assert_eq(l[::-2], [5, 3, 1]) + assert_eq(l[4:2:-1], [4, 3]) + assert_eq(l[2:4:-1], []) + "#, + ); +} + +#[test] +fn test_list_slicing_empty_result_edge_cases() { + assert::pass( + r#" + l = [0, 1, 2, 3, 4, 5] + # Start > Stop with positive step + assert_eq(l[4:2], []) + # Start < Stop with negative step + assert_eq(l[2:4:-1], []) + # Out of bounds start (positive) + assert_eq(l[100:], []) + # Out of bounds stop (negative) + assert_eq(l[:-100], []) + "#, + ); +} + +#[test] +fn test_list_slicing_out_of_bounds() { + assert::pass( + r#" + l = [0, 1, 2] + assert_eq(l[0:100], [0, 1, 2]) + assert_eq(l[-100:], [0, 1, 2]) + assert_eq(l[-100:-50], []) + "#, + ); +} + +#[test] +fn test_string_slicing_extended() { + assert::pass( + r#" + s = "012345" + assert_eq(s[::2], "024") + assert_eq(s[::-1], "543210") + assert_eq(s[-3:], "345") + assert_eq(s[100:], "") + assert_eq(s[-100:], "012345") + + # Empty string + assert_eq(""[:], "") + assert_eq(""[::-1], "") + "#, + ); +} + +#[test] +fn test_tuple_slicing_extended() { + assert::pass( + r#" + t = (0, 1, 2, 3, 4, 5) + assert_eq(t[::2], (0, 2, 4)) + assert_eq(t[::-1], (5, 4, 3, 2, 1, 0)) + assert_eq(t[100:], ()) + "#, + ); +} + +#[test] +fn test_bytes_slicing_not_supported() { + assert::fail( + r#" + b = b"012345" + b[::2] + "#, + "'bytes' object is not subscriptable", + ); +} + +#[test] +fn test_slicing_zero_step_error() { + assert::fail("l = [1]; l[::0]", "slice step cannot be zero"); + assert::fail("s = 'a'; s[::0]", "slice step cannot be zero"); + assert::fail("t = (1,); t[::0]", "slice step cannot be zero"); + assert::fail("b = b'a'; b[::0]", "slice step cannot be zero"); +} diff --git a/tavern/internal/auth/context_edge_cases_test.go b/tavern/internal/auth/context_edge_cases_test.go new file mode 100644 index 000000000..cb1b4c1ee --- /dev/null +++ b/tavern/internal/auth/context_edge_cases_test.go @@ -0,0 +1,53 @@ +package auth_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "realm.pub/tavern/internal/auth" + "realm.pub/tavern/internal/ent/enttest" +) + +func TestContextFromTokens_Invalid(t *testing.T) { + // Setup Dependencies + var ( + driverName = "sqlite3" + dataSourceName = "file:ent?mode=memory&cache=shared&_fk=1" + ) + graph := enttest.Open(t, driverName, dataSourceName, enttest.WithOptions()) + defer graph.Close() + + // Test ContextFromSessionToken with invalid token + t.Run("ContextFromSessionToken_NotFound", func(t *testing.T) { + // Pass a nil context first to check if it panics? No, Background is fine. + ctx, err := auth.ContextFromSessionToken(context.Background(), graph, "invalid-token") + require.Error(t, err) + // Usually ent returns "ent: user not found" + assert.Contains(t, err.Error(), "user not found") + // The returned context should be nil if error? + // Looking at context.go: if err != nil { return nil, err } + assert.Nil(t, ctx) + }) + + // Test ContextFromAccessToken with invalid token + t.Run("ContextFromAccessToken_NotFound", func(t *testing.T) { + ctx, err := auth.ContextFromAccessToken(context.Background(), graph, "invalid-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "user not found") + assert.Nil(t, ctx) + }) +} + +func TestContextHelpers_EdgeCases(t *testing.T) { + ctx := context.Background() + + t.Run("EmptyContext", func(t *testing.T) { + assert.Nil(t, auth.IdentityFromContext(ctx)) + assert.Nil(t, auth.UserFromContext(ctx)) + assert.False(t, auth.IsAuthenticatedContext(ctx)) + assert.False(t, auth.IsActivatedContext(ctx)) + assert.False(t, auth.IsAdminContext(ctx)) + }) +} diff --git a/tavern/internal/http/stream/gcp_coldstart_test.go b/tavern/internal/http/stream/gcp_coldstart_test.go index acfc3eb88..45722466e 100644 --- a/tavern/internal/http/stream/gcp_coldstart_test.go +++ b/tavern/internal/http/stream/gcp_coldstart_test.go @@ -2,7 +2,6 @@ package stream_test import ( "context" - "fmt" "testing" "time" @@ -17,19 +16,18 @@ func TestPreventPubSubColdStarts_ValidInterval(t *testing.T) { defer cancel() // Create a mock topic and subscription. - topicName := fmt.Sprintf("mem://valid-%d", time.Now().UnixNano()) - topic, err := pubsub.OpenTopic(ctx, topicName) + topic, err := pubsub.OpenTopic(ctx, "mem://valid") if err != nil { t.Fatalf("Failed to open topic: %v", err) } defer topic.Shutdown(ctx) - sub, err := pubsub.OpenSubscription(ctx, topicName) + sub, err := pubsub.OpenSubscription(ctx, "mem://valid") if err != nil { t.Fatalf("Failed to open subscription: %v", err) } defer sub.Shutdown(ctx) - go stream.PreventPubSubColdStarts(ctx, 50*time.Millisecond, topicName, topicName) + go stream.PreventPubSubColdStarts(ctx, 50*time.Millisecond, "mem://valid", "mem://valid") // Expect to receive a message msg, err := sub.Receive(ctx) @@ -45,19 +43,18 @@ func TestPreventPubSubColdStarts_ZeroInterval(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - topicName := fmt.Sprintf("mem://zero-%d", time.Now().UnixNano()) - topic, err := pubsub.OpenTopic(ctx, topicName) + topic, err := pubsub.OpenTopic(ctx, "mem://zero") if err != nil { t.Fatalf("Failed to open topic: %v", err) } defer topic.Shutdown(ctx) - sub, err := pubsub.OpenSubscription(ctx, topicName) + sub, err := pubsub.OpenSubscription(ctx, "mem://zero") if err != nil { t.Fatalf("Failed to open subscription: %v", err) } defer sub.Shutdown(ctx) - go stream.PreventPubSubColdStarts(ctx, 0, topicName, topicName) + go stream.PreventPubSubColdStarts(ctx, 0, "mem://zero", "mem://zero") // Expect to not receive a message and for the context to timeout _, err = sub.Receive(ctx) @@ -69,19 +66,18 @@ func TestPreventPubSubColdStarts_SubMillisecondInterval(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - topicName := fmt.Sprintf("mem://sub-%d", time.Now().UnixNano()) - topic, err := pubsub.OpenTopic(ctx, topicName) + topic, err := pubsub.OpenTopic(ctx, "mem://sub") if err != nil { t.Fatalf("Failed to open topic: %v", err) } defer topic.Shutdown(ctx) - sub, err := pubsub.OpenSubscription(ctx, topicName) + sub, err := pubsub.OpenSubscription(ctx, "mem://sub") if err != nil { t.Fatalf("Failed to open subscription: %v", err) } defer sub.Shutdown(ctx) - go stream.PreventPubSubColdStarts(ctx, 1*time.Microsecond, topicName, topicName) + go stream.PreventPubSubColdStarts(ctx, 1*time.Microsecond, "mem://sub", "mem://sub") // Expect to receive a message msg, err := sub.Receive(ctx) diff --git a/tavern/internal/http/stream/mux_test.go b/tavern/internal/http/stream/mux_test.go index de466ed4f..38be5b97a 100644 --- a/tavern/internal/http/stream/mux_test.go +++ b/tavern/internal/http/stream/mux_test.go @@ -2,7 +2,6 @@ package stream_test import ( "context" - "fmt" "testing" "time" @@ -18,11 +17,10 @@ func TestMux(t *testing.T) { defer cancel() // Setup Topic and Subscription - topicName := fmt.Sprintf("mem://mux-test-%d", time.Now().UnixNano()) - topic, err := pubsub.OpenTopic(ctx, topicName) + topic, err := pubsub.OpenTopic(ctx, "mem://mux-test") require.NoError(t, err) defer topic.Shutdown(ctx) - sub, err := pubsub.OpenSubscription(ctx, topicName) + sub, err := pubsub.OpenSubscription(ctx, "mem://mux-test") require.NoError(t, err) defer sub.Shutdown(ctx) @@ -39,6 +37,9 @@ func TestMux(t *testing.T) { mux.Register(stream2) defer mux.Unregister(stream2) + // Give the mux a moment to register the streams + time.Sleep(50 * time.Millisecond) + // Send a message for stream1 err = topic.Send(ctx, &pubsub.Message{ Body: []byte("hello stream 1"), diff --git a/tavern/internal/http/stream/stream_test.go b/tavern/internal/http/stream/stream_test.go index 3d6fb02cd..7a837e0ee 100644 --- a/tavern/internal/http/stream/stream_test.go +++ b/tavern/internal/http/stream/stream_test.go @@ -17,11 +17,10 @@ func TestStream_SendMessage(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - topicName := fmt.Sprintf("mem://stream-test-send-%d", time.Now().UnixNano()) - topic, err := pubsub.OpenTopic(ctx, topicName) + topic, err := pubsub.OpenTopic(ctx, "mem://stream-test-send") require.NoError(t, err) defer topic.Shutdown(ctx) - sub, err := pubsub.OpenSubscription(ctx, topicName) + sub, err := pubsub.OpenSubscription(ctx, "mem://stream-test-send") require.NoError(t, err) defer sub.Shutdown(ctx) diff --git a/tavern/internal/http/stream/websocket_test.go b/tavern/internal/http/stream/websocket_test.go index d3cc7770f..8e67d5421 100644 --- a/tavern/internal/http/stream/websocket_test.go +++ b/tavern/internal/http/stream/websocket_test.go @@ -2,7 +2,6 @@ package stream_test import ( "context" - "fmt" "net/http/httptest" "strconv" "strings" @@ -30,20 +29,18 @@ func TestNewShellHandler(t *testing.T) { defer cancel() // Topic for messages going TO the websocket (server -> shell) - outputTopicName := fmt.Sprintf("mem://websocket-output-%d", time.Now().UnixNano()) - outputTopic, err := pubsub.OpenTopic(ctx, outputTopicName) + outputTopic, err := pubsub.OpenTopic(ctx, "mem://websocket-output") require.NoError(t, err) defer outputTopic.Shutdown(ctx) - outputSub, err := pubsub.OpenSubscription(ctx, outputTopicName) + outputSub, err := pubsub.OpenSubscription(ctx, "mem://websocket-output") require.NoError(t, err) defer outputSub.Shutdown(ctx) // Topic for messages coming FROM the websocket (shell -> server) - inputTopicName := fmt.Sprintf("mem://websocket-input-%d", time.Now().UnixNano()) - inputTopic, err := pubsub.OpenTopic(ctx, inputTopicName) + inputTopic, err := pubsub.OpenTopic(ctx, "mem://websocket-input") require.NoError(t, err) defer inputTopic.Shutdown(ctx) - inputSub, err := pubsub.OpenSubscription(ctx, inputTopicName) + inputSub, err := pubsub.OpenSubscription(ctx, "mem://websocket-input") require.NoError(t, err) defer inputSub.Shutdown(ctx) diff --git a/tavern/internal/redirectors/grpc/grpc_test.go b/tavern/internal/redirectors/grpc/grpc_test.go index 6eb690d6c..5cc768189 100644 --- a/tavern/internal/redirectors/grpc/grpc_test.go +++ b/tavern/internal/redirectors/grpc/grpc_test.go @@ -155,8 +155,8 @@ func TestRedirector_ContextCancellation(t *testing.T) { serverErr <- redirector.Redirect(ctx, addr, upstreamConn) }() - // Wait for the server to start listening. - waitForServer(t, addr) + // Wait a moment for the server to start listening. + time.Sleep(100 * time.Millisecond) // Cancel the context, which should trigger GracefulStop. cancel() @@ -215,17 +215,3 @@ func TestRedirector_UpstreamFailure(t *testing.T) { require.True(t, ok, "error should be a gRPC status error") require.Equal(t, codes.Unavailable, s.Code(), "error code should be Unavailable") } - -func waitForServer(t *testing.T, addr string) { - t.Helper() - deadline := time.Now().Add(5 * time.Second) - for time.Now().Before(deadline) { - conn, err := net.DialTimeout("tcp", addr, 100*time.Millisecond) - if err == nil { - conn.Close() - return - } - time.Sleep(10 * time.Millisecond) - } - t.Fatalf("server did not start listening on %s", addr) -}